!2902 move weight data copy to warmup stage
Merge pull request !2902 from dinghao/master
This commit is contained in:
commit
7304f02410
|
@ -32,7 +32,6 @@ using mindspore::tensor::TensorPy;
|
|||
namespace mindspore {
|
||||
namespace session {
|
||||
namespace {
|
||||
std::set<AnfNodePtr> weight_infos;
|
||||
static TypeId GetDataType(const py::buffer_info &buf) {
|
||||
if (buf.format.size() == 1) {
|
||||
switch (buf.format.front()) {
|
||||
|
@ -105,10 +104,33 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
|
|||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
if (weight_infos.count(pk_node) != 0) {
|
||||
continue;
|
||||
if (!AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
tensor = inputs[no_weight_input++];
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
GraphId AscendInferenceSession::CompileGraph(NotNull<FuncGraphPtr> func_graph) {
|
||||
auto graph_id = AscendSession::CompileGraph(func_graph);
|
||||
auto kernel_graph = GetGraph(graph_id);
|
||||
MS_EXCEPTION_IF_NULL(kernel_graph);
|
||||
// load weight data to device
|
||||
auto input_nodes = kernel_graph->inputs();
|
||||
for (size_t i = 0; i < input_nodes.size(); ++i) {
|
||||
if (!input_nodes[i]->isa<Parameter>()) {
|
||||
MS_LOG(ERROR) << "Kernel graph inputs have anfnode which is not Parameter";
|
||||
continue;
|
||||
}
|
||||
auto pk_node = input_nodes[i]->cast<ParameterPtr>();
|
||||
MS_EXCEPTION_IF_NULL(pk_node);
|
||||
auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0);
|
||||
MS_EXCEPTION_IF_NULL(device_address);
|
||||
if (AnfAlgo::IsParameterWeight(pk_node)) {
|
||||
auto param_value = std::dynamic_pointer_cast<ParamValuePy>(pk_node->default_param());
|
||||
MS_EXCEPTION_IF_NULL(param_value);
|
||||
auto py_param = param_value->value();
|
||||
|
@ -120,16 +142,9 @@ void AscendInferenceSession::LoadInputData(const std::shared_ptr<KernelGraph> &k
|
|||
LongToSize(buf.size * buf.itemsize), buf_type, buf.ptr)) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
weight_infos.insert(pk_node);
|
||||
} else {
|
||||
tensor = inputs[no_weight_input++];
|
||||
if (!device_address->SyncHostToDevice(trans::GetRuntimePaddingShape(pk_node, 0),
|
||||
LongToSize(tensor->data().nbytes()), tensor->data_type(),
|
||||
tensor->data_c())) {
|
||||
MS_LOG(EXCEPTION) << "SyncHostToDevice failed.";
|
||||
}
|
||||
}
|
||||
}
|
||||
return graph_id;
|
||||
}
|
||||
} // namespace session
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,6 +38,7 @@ class AscendInferenceSession : public AscendSession {
|
|||
~AscendInferenceSession() = default;
|
||||
void LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
|
||||
const std::vector<tensor::TensorPtr> &inputs_const) const;
|
||||
GraphId CompileGraph(NotNull<FuncGraphPtr> func_graph) override;
|
||||
};
|
||||
MS_REG_SESSION(kDavinciInferenceDevice, AscendInferenceSession);
|
||||
} // namespace session
|
||||
|
|
Loading…
Reference in New Issue