!2902 move weight data copy to warmup stage
Merge pull request !2902 from dinghao/master
This commit is contained in:
commit
7304f02410
|
@ -32,7 +32,6 @@ using mindspore::tensor::TensorPy;
|
||||||
namespace mindspore {
|
namespace mindspore {
|
||||||
namespace session {
|
namespace session {
|
||||||
namespace {
|
namespace {
|
||||||
std::set<AnfNodePtr> weight_infos;
|
|
||||||
static TypeId GetDataType(const py::buffer_info &buf) {
|
static TypeId GetDataType(const py::buffer_info &buf) {
|
||||||
if (buf.format.size() == 1) {
|
if (buf.format.size() == 1) {
|
||||||
switch (buf.format.front()) {
|
switch (buf.format.front()) {
|
||||||
|
@ -105,10 +104,33 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
|
||||||
MS_EXCEPTION_IF_NULL(pk_node);
|
MS_EXCEPTION_IF_NULL(pk_node);
|
||||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||||
MS_EXCEPTION_IF_NULL(device_address);
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
if (AnfAlgo::IsParameterWeight(pk_node)) {
|
if (!AnfAlgo::IsParameterWeight(pk_node)) {
|
||||||
if (weight_infos.count(pk_node) != 0) {
|
tensor = inputs[no_weight_input++];
|
||||||
continue;
|
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||||
|
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||||
|
tensor->data_c())) {
|
||||||
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||||
|
auto graph_id = AscendSession::CompileGraph(func_graph);
|
||||||
|
auto kernel_graph = GetGraph(graph_id);
|
||||||
|
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||||
|
// load weight data to device
|
||||||
|
auto input_nodes = kernel_graph->inputs();
|
||||||
|
for (size_t i = 0; i < input_nodes.size(); ++i) {
|
||||||
|
if (!input_nodes[i]->isa<Parameter>()) {
|
||||||
|
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter";
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
auto pk_node = input_nodes[i]->cast<ParameterPtr>();
|
||||||
|
MS_EXCEPTION_IF_NULL(pk_node);
|
||||||
|
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||||
|
MS_EXCEPTION_IF_NULL(device_address);
|
||||||
|
if (AnfAlgo::IsParameterWeight(pk_node)) {
|
||||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(pk_node->default_param());
|
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(pk_node->default_param());
|
||||||
MS_EXCEPTION_IF_NULL(param_value);
|
MS_EXCEPTION_IF_NULL(param_value);
|
||||||
auto py_param = param_value->value();
|
auto py_param = param_value->value();
|
||||||
|
@ -120,16 +142,9 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
|
||||||
LongToSize(buf.size * buf.itemsize), buf_type, buf.ptr)) {
|
LongToSize(buf.size * buf.itemsize), buf_type, buf.ptr)) {
|
||||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||||
}
|
}
|
||||||
weight_infos.insert(pk_node);
|
|
||||||
} else {
|
|
||||||
tensor = inputs[no_weight_input++];
|
|
||||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
|
||||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
|
||||||
tensor->data_c())) {
|
|
||||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return graph_id;
|
||||||
}
|
}
|
||||||
} // namespace session
|
} // namespace session
|
||||||
} // namespace mindspore
|
} // namespace mindspore
|
||||||
|
|
|
@ -38,6 +38,7 @@ class AscendInferenceSession : public AscendSession {
|
||||||
~AscendInferenceSession() = default;
|
~AscendInferenceSession() = default;
|
||||||
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||||
|
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
|
||||||
};
|
};
|
||||||
MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession);
|
MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession);
|
||||||
} // namespace session
|
} // namespace session
|
||||||
|
|
Loading…
Reference in New Issue