!2902 move weight data copy to warmup stage

Merge pull request !2902 from dinghao/master
This commit is contained in:
mindspore-ci-bot 2020-07-07 15:05:10 +08:00 committed by Gitee
commit 7304f02410
2 changed files with 28 additions and 12 deletions

View File

@ -32,7 +32,6 @@ using mindspore::tensor::TensorPy;
namespace mindspore {
namespace session {
namespace {
std::set<AnfNodePtr> weight_infos;
static TypeId GetDataType(const py::buffer_info &buf) {
if (buf.format.size() == 1) {
switch (buf.format.front()) {
@ -105,10 +104,33 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
MS_EXCEPTION_IF_NULL(pk_node);
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
if (AnfAlgo::IsParameterWeight(pk_node)) {
if (weight_infos.count(pk_node) != 0) {
if (!AnfAlgo::IsParameterWeight(pk_node)) {
tensor = inputs[no_weight_input++];
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
}
}
GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
auto graph_id = AscendSession::CompileGraph(func_graph);
auto kernel_graph = GetGraph(graph_id);
MS_EXCEPTION_IF_NULL(kernel_graph);
// load weight data to device
auto input_nodes = kernel_graph->inputs();
for (size_t i = 0; i < input_nodes.size(); ++i) {
if (!input_nodes[i]->isa<Parameter>()) {
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter";
continue;
}
auto pk_node = input_nodes[i]->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(pk_node);
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
MS_EXCEPTION_IF_NULL(device_address);
if (AnfAlgo::IsParameterWeight(pk_node)) {
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(pk_node->default_param());
MS_EXCEPTION_IF_NULL(param_value);
auto py_param = param_value->value();
@ -120,16 +142,9 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
LongToSize(buf.size * buf.itemsize), buf_type, buf.ptr)) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
weight_infos.insert(pk_node);
} else {
tensor = inputs[no_weight_input++];
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
LongToSize(tensor->data().nbytes()), tensor->data_type(),
tensor->data_c())) {
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
}
}
}
return graph_id;
}
} // namespace session
} // namespace mindspore

View File

@ -38,6 +38,7 @@ class AscendInferenceSession : public AscendSession {
~AscendInferenceSession() = default;
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
const std::vector<tensor::TensorPtr> &inputs_const) const;
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
};
MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession);
} // namespace session