diff --git a/mindspore/ccsrc/common/trans.cc b/mindspore/ccsrc/common/trans.cc index 731c71c7811..e53b8bdf283 100644 --- a/mindspore/ccsrc/common/trans.cc +++ b/mindspore/ccsrc/common/trans.cc @@ -575,7 +575,7 @@ bool NchwToFracZ(const FormatArgs &args, void *result) { auto src_ni = hfi * kCubeSize + col; auto src_idx = src_row_offset + chw * col; auto dst_idx = gfi * fractal_ele_cnt + col * c0 + row; - auto pad_zero = src_ni >= n || src_idx >= nchw || src_ci >= c; + auto pad_zero = (src_ni >= n || src_idx >= nchw || src_ci >= c) ? 1 : 0; SetDataBysize(size, pad_zero); } } @@ -770,9 +770,11 @@ bool NchwToFracNz(const FormatArgs &args, void *result) { auto h1h0_head = times_head + h1h0_idx * w0; auto src_h_head = src_times_head + h1h0_idx * w; for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) { - size_t dst_idx = h1h0_head + w1_idx * h1h0w0; - size_t src_idx = src_h_head + w1_idx * w0; - SetDataBysize(size, 0); + for (size_t i = 0; i < w0; ++i) { + size_t src_idx = src_h_head + w1_idx * w0 + i; + size_t dst_idx = h1h0_head + w1_idx * h1h0w0 + i; + SetDataBysize(size, 0); + } } auto w1_head = num_w1 * w0; for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) { @@ -830,9 +832,11 @@ bool FracNzToNchw(const FormatArgs &args, void *result) { auto h1h0_head = times_head + h1h0_idx * w0; auto src_h_head = src_times_head + h1h0_idx * w; for (size_t w1_idx = 0; w1_idx < num_w1; w1_idx++) { - size_t src_idx = h1h0_head + w1_idx * h1h0w0; - size_t dst_idx = src_h_head + w1_idx * w0; - SetDataBysize(size, 0); + for (size_t i = 0; i < w0; ++i) { + size_t src_idx = h1h0_head + w1_idx * h1h0w0 + i; + size_t dst_idx = src_h_head + w1_idx * w0 + i; + SetDataBysize(size, 0); + } } auto w1_head = num_w1 * w0; for (size_t w0_idx = 0; w1_head + w0_idx < w; w0_idx++) { diff --git a/mindspore/ccsrc/device/kernel_runtime.cc b/mindspore/ccsrc/device/kernel_runtime.cc index 00c557f2d5f..8daf309cba8 100644 --- a/mindspore/ccsrc/device/kernel_runtime.cc +++ b/mindspore/ccsrc/device/kernel_runtime.cc @@ -159,6 +159,11 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector const session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(mem_manager_); + if (input_tensors.size() != graph->inputs().size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors.size() + << " should be equal to graph input parameter size " << graph->inputs().size(); + } + for (size_t input_index = 0; input_index < graph->inputs().size(); ++input_index) { auto item = graph->inputs()[input_index]; MS_EXCEPTION_IF_NULL(item); @@ -416,6 +421,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const MS_EXCEPTION_IF_NULL(value_node); MS_EXCEPTION_IF_NULL(node_value); MS_EXCEPTION_IF_NULL(mem_manager_); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); auto tensor = node_value->cast(); if (tensor == nullptr) { MS_LOG(WARNING) << "Tensor is null"; @@ -423,13 +430,23 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const } size_t tensor_size = tensor->data().nbytes(); auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); - auto ptr = mem_manager_->MallocMem(kStaticMem, node_size); TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); if (output_type_id == kTypeUnknown) { output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); } - auto address = CreateDeviceAddress(ptr, node_size, AnfAlgo::GetOutputFormat(value_node, output_idx), output_type_id); - MS_EXCEPTION_IF_NULL(address); + auto output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); + DeviceAddressPtr address = nullptr; + if (ms_context->enable_pynative_infer()) { + address = CreateDeviceAddress(nullptr, node_size, output_format, output_type_id); + MS_EXCEPTION_IF_NULL(address); + if (!mem_manager_->MallocMemFromMemPool(address, node_size)) { + MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; + } + } else { + auto ptr = mem_manager_->MallocMem(kStaticMem, node_size); + address = CreateDeviceAddress(ptr, node_size, output_format, output_type_id); + MS_EXCEPTION_IF_NULL(address); + } AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(), tensor->data_c(false))) { @@ -442,6 +459,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(mem_manager_); + auto ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); for (auto &value_node : graph->graph_value_nodes()) { MS_EXCEPTION_IF_NULL(value_node); if (AnfAlgo::OutputAddrExist(value_node, 0)) { @@ -455,9 +474,18 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { } else if (node_value->isa()) { auto value = GetValue(node_value); size_t tensor_size = value.size(); - auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); - auto address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); - MS_EXCEPTION_IF_NULL(address); + DeviceAddressPtr address = nullptr; + if (ms_context->enable_pynative_infer()) { + address = CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); + MS_EXCEPTION_IF_NULL(address); + if (!mem_manager_->MallocMemFromMemPool(address, tensor_size)) { + MS_LOG(EXCEPTION) << "Malloc value node device memory failed !"; + } + } else { + auto ptr = mem_manager_->MallocMem(kStaticMem, tensor_size); + address = CreateDeviceAddress(ptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); + MS_EXCEPTION_IF_NULL(address); + } AnfAlgo::SetOutputAddr(address, 0, value_node.get()); std::vector shape = {1, SizeToInt(tensor_size)}; if (!address->SyncHostToDevice(shape, tensor_size, kNumberTypeUInt8, value.data())) { diff --git a/mindspore/ccsrc/pynative/pynative_execute.cc b/mindspore/ccsrc/pynative/pynative_execute.cc index 32ffdde34eb..54865be0614 100644 --- a/mindspore/ccsrc/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pynative/pynative_execute.cc @@ -213,22 +213,22 @@ bool RunOpConvertConstInputToAttr(const py::object &input_object, size_t input_i } void PlantTensorTupleToVector(const py::tuple &tuple_inputs, const PrimitivePtr &op_prim, - std::vector *input_tensor) { + std::vector *input_tensors) { MS_EXCEPTION_IF_NULL(op_prim); - MS_EXCEPTION_IF_NULL(input_tensor); + MS_EXCEPTION_IF_NULL(input_tensors); for (const auto &input_object : tuple_inputs) { if (!py::isinstance(input_object)) { MS_LOG(EXCEPTION) << "The input object is not a tensor!"; } auto tensor = py::cast(input_object); MS_EXCEPTION_IF_NULL(tensor); - input_tensor->push_back(tensor); + input_tensors->push_back(tensor); } op_prim->set_attr(kAttrDynInputSizes, MakeValue(std::vector{SizeToInt(tuple_inputs.size())})); } -void ConvertValueTupleToTensor(const py::object &input_object, std::vector *input_tensor) { - MS_EXCEPTION_IF_NULL(input_tensor); +void ConvertValueTupleToTensor(const py::object &input_object, std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(input_tensors); ValuePtr input_value = parse::data_converter::PyDataToValue(input_object); MS_EXCEPTION_IF_NULL(input_value); if (!input_value->isa()) { @@ -238,20 +238,23 @@ void ConvertValueTupleToTensor(const py::object &input_object, std::vectorpush_back(tensor_ptr); + input_tensors->push_back(tensor_ptr); } void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr &op_prim, - std::vector *input_tensor) { + std::vector *input_tensors, int *tensor_mask) { MS_EXCEPTION_IF_NULL(op_prim); - MS_EXCEPTION_IF_NULL(input_tensor); + MS_EXCEPTION_IF_NULL(input_tensors); + MS_EXCEPTION_IF_NULL(tensor_mask); tensor::TensorPtr tensor_ptr = nullptr; if (py::isinstance(input_object)) { tensor_ptr = py::cast(input_object); } else if (py::isinstance(input_object)) { tensor_ptr = std::make_shared(py::cast(input_object), kFloat32); + *tensor_mask = kValueNodeTensorMask; } else if (py::isinstance(input_object)) { - tensor_ptr = std::make_shared(py::cast(input_object), nullptr); + tensor_ptr = std::make_shared(py::cast(input_object), kInt32); + *tensor_mask = kValueNodeTensorMask; } else if (py::isinstance(input_object)) { tensor_ptr = std::make_shared(py::cast(input_object), nullptr); } else if (py::isinstance(input_object)) { @@ -261,20 +264,22 @@ void ConvertPyObjectToTensor(const py::object &input_object, const PrimitivePtr } else if (py::isinstance(input_object)) { auto tuple_inputs = py::cast(input_object); if (py::isinstance(tuple_inputs[0])) { - PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensor); + PlantTensorTupleToVector(tuple_inputs, op_prim, input_tensors); } else { - ConvertValueTupleToTensor(input_object, input_tensor); + ConvertValueTupleToTensor(input_object, input_tensors); + *tensor_mask = kValueNodeTensorMask; } return; } else { MS_LOG(EXCEPTION) << "Run op inputs type is invalid!"; } MS_EXCEPTION_IF_NULL(tensor_ptr); - input_tensor->push_back(tensor_ptr); + input_tensors->push_back(tensor_ptr); } -void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *tensors_mask, +void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *tensors_mask, std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(op_run_info); MS_EXCEPTION_IF_NULL(tensors_mask); MS_EXCEPTION_IF_NULL(input_tensors); PrimitivePtr op_prim = op_run_info->py_primitive; @@ -295,14 +300,29 @@ void ConstructInputTensor(const OpExecInfoPtr &op_run_info, std::vector *t continue; } // convert const and tuple input to tensor - ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors); - // make tensors, weight : 1, data : 0 - std::vector new_mask(input_tensors->size() - tensors_mask->size(), - py::cast(op_run_info->inputs_mask[index])); + int tensor_mask = py::cast(op_run_info->inputs_mask[index]); + ConvertPyObjectToTensor(op_run_info->op_inputs[index], op_prim, input_tensors, &tensor_mask); + // mark tensors, data : 0, weight : 1, valuenode: 2 + std::vector new_mask(input_tensors->size() - tensors_mask->size(), tensor_mask); tensors_mask->insert(tensors_mask->end(), new_mask.begin(), new_mask.end()); } } +void EraseValueNodeTensor(const std::vector &tensors_mask, std::vector *input_tensors) { + MS_EXCEPTION_IF_NULL(input_tensors); + if (input_tensors->size() != tensors_mask.size()) { + MS_LOG(EXCEPTION) << "Input tensors size " << input_tensors->size() << " should be equal to tensors mask size " + << tensors_mask.size(); + } + std::vector new_input_tensors; + for (size_t index = 0; index < tensors_mask.size(); ++index) { + if (tensors_mask[index] != kValueNodeTensorMask) { + new_input_tensors.push_back(input_tensors->at(index)); + } + } + *input_tensors = new_input_tensors; +} + py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *status) { MS_EXCEPTION_IF_NULL(op_exec_info); MS_LOG(INFO) << "Start run op[" << op_exec_info->op_name << "] with backend policy ms"; @@ -323,9 +343,10 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat std::string graph_info = GetSingleOpGraphInfo(op_exec_info); std::vector input_tensors; - std::vector tensors_mask; + std::vector tensors_mask; ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); session->BuildOp(*op_exec_info, graph_info, input_tensors, tensors_mask); + EraseValueNodeTensor(tensors_mask, &input_tensors); py::tuple result = session->RunOp(*op_exec_info, graph_info, input_tensors); ms_context->set_enable_pynative_infer(false); *status = PYNATIVE_SUCCESS; diff --git a/mindspore/ccsrc/session/ascend_session.cc b/mindspore/ccsrc/session/ascend_session.cc index 5a6dc792c8d..fce65b28a09 100644 --- a/mindspore/ccsrc/session/ascend_session.cc +++ b/mindspore/ccsrc/session/ascend_session.cc @@ -127,6 +127,34 @@ std::vector GetRealArgs(const KernelGraphPtr graph, const VectorRef &ar } return real_args; } + +void ClearRunOpMemoryResource(const KernelGraphPtr &kernel_graph) { + MS_EXCEPTION_IF_NULL(kernel_graph); + // clear input parameter memory resource + for (const auto &input_node : kernel_graph->inputs()) { + MS_EXCEPTION_IF_NULL(input_node); + AnfAlgo::SetOutputAddr(nullptr, 0, input_node.get()); + } + // clear input value node memory resource + for (const auto &value_node : kernel_graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + AnfAlgo::SetOutputAddr(nullptr, 0, value_node.get()); + } + for (const auto &cnode : kernel_graph->execution_order()) { + MS_EXCEPTION_IF_NULL(cnode); + // clear output memory resource + for (size_t index = 0; index < AnfAlgo::GetOutputTensorNum(cnode); ++index) { + AnfAlgo::SetOutputAddr(nullptr, index, cnode.get()); + } + // clear workspace memory resource + auto kernel_mod = AnfAlgo::GetKernelMod(cnode); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_lists = kernel_mod->GetWorkspaceSizeList(); + for (size_t index = 0; index < workspace_lists.size(); ++index) { + AnfAlgo::SetWorkspaceAddr(nullptr, index, cnode.get()); + } + } +} } // namespace GraphId AscendSession::CompileGraph(const AnfNodePtrList &lst, const AnfNodePtrList &outputs) { @@ -308,8 +336,7 @@ bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const { } void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, - const std::vector &tensors_mask) { + const std::vector &input_tensors, const std::vector &tensors_mask) { MS_LOG(INFO) << "Build op " << op_run_info.op_name << " start !"; if (GraphCacheExist(graph_info)) { MS_LOG(INFO) << "Build op " << op_run_info.op_name << " finish !"; @@ -355,6 +382,7 @@ py::tuple AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &gr } py::object tuple_obj = utils::cast(output_tensors).object_; py::tuple tuple_tensors = py::cast(tuple_obj); + ClearRunOpMemoryResource(graph); MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; return tuple_tensors; } diff --git a/mindspore/ccsrc/session/ascend_session.h b/mindspore/ccsrc/session/ascend_session.h index 45662250d39..7752e2bbdc8 100755 --- a/mindspore/ccsrc/session/ascend_session.h +++ b/mindspore/ccsrc/session/ascend_session.h @@ -46,7 +46,7 @@ class AscendSession : public SessionBasic { void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildGraph(GraphId) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) override; + const std::vector &input_tensors, const std::vector &tensors_mask) override; py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors) override; diff --git a/mindspore/ccsrc/session/gpu_session.cc b/mindspore/ccsrc/session/gpu_session.cc index 3a80382e9b8..0e32231f78f 100644 --- a/mindspore/ccsrc/session/gpu_session.cc +++ b/mindspore/ccsrc/session/gpu_session.cc @@ -133,7 +133,7 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector &input_tensors, const std::vector &tensors_mask) { + const std::vector &input_tensors, const std::vector &tensors_mask) { // Prepare the graph auto kernel_graph = ConstructSingleOpGraph(op_run_info, input_tensors, tensors_mask); MS_EXCEPTION_IF_NULL(kernel_graph); diff --git a/mindspore/ccsrc/session/gpu_session.h b/mindspore/ccsrc/session/gpu_session.h index 2a3cc04b09c..db320a3c884 100644 --- a/mindspore/ccsrc/session/gpu_session.h +++ b/mindspore/ccsrc/session/gpu_session.h @@ -40,7 +40,7 @@ class GPUSession : public SessionBasic { void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) override; void BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, - const std::vector &input_tensors, const std::vector &tensors_mask) override; + const std::vector &input_tensors, const std::vector &tensors_mask) override; py::tuple RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors) override; diff --git a/mindspore/ccsrc/session/session_basic.cc b/mindspore/ccsrc/session/session_basic.cc index b3267db6c7d..66df0fc251c 100755 --- a/mindspore/ccsrc/session/session_basic.cc +++ b/mindspore/ccsrc/session/session_basic.cc @@ -312,11 +312,22 @@ size_t LoadCtrlInputTensor(const std::shared_ptr &context, std::vector< return inputs_params->size(); } +ValueNodePtr ConstructRunOpValueNode(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor) { + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(input_tensor); + auto abstract = std::make_shared(input_tensor); + auto value_node = std::make_shared(input_tensor); + value_node->set_abstract(abstract); + auto input_value_node = graph->NewValueNode(value_node); + graph->AddValueNodeToGraph(input_value_node); + return input_value_node; +} + ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, const tensor::TensorPtr &input_tensor, - bool is_weight) { + int tensor_mask) { auto param = graph->NewParameter(); MS_EXCEPTION_IF_NULL(param); - if (is_weight) { + if (tensor_mask == 1) { py::object obj; param->set_default_param(obj); } @@ -675,7 +686,7 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInfo &op_run_info, const std::vector &input_tensors, - const std::vector &tensors_mask) { + const std::vector &tensors_mask) { auto graph = std::make_shared(); std::vector inputs; // set input[0] @@ -689,7 +700,12 @@ std::shared_ptr SessionBasic::ConstructSingleOpGraph(const OpRunInf << tensors_mask.size(); } for (size_t i = 0; i < input_tensors.size(); ++i) { - auto parameter = ConstructRunOpParameter(graph, input_tensors.at(i), tensors_mask[i]); + if (tensors_mask[i] == kValueNodeTensorMask) { + auto value_node = ConstructRunOpValueNode(graph, input_tensors[i]); + inputs.push_back(value_node); + continue; + } + auto parameter = ConstructRunOpParameter(graph, input_tensors[i], tensors_mask[i]); inputs.push_back(parameter); graph->MutableInputs()->push_back(parameter); } diff --git a/mindspore/ccsrc/session/session_basic.h b/mindspore/ccsrc/session/session_basic.h index efd27743f70..12711a46f13 100755 --- a/mindspore/ccsrc/session/session_basic.h +++ b/mindspore/ccsrc/session/session_basic.h @@ -65,7 +65,7 @@ class SessionBasic { virtual void RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) = 0; virtual void BuildOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors, - const std::vector &tensors_mask) {} + const std::vector &tensors_mask) {} virtual py::tuple RunOp(const OpRunInfo &, const GraphInfo &, const std::vector &input_tensors) { return py::tuple(); @@ -105,7 +105,7 @@ class SessionBasic { // create a single run op graph std::shared_ptr ConstructSingleOpGraph(const OpRunInfo &op_run_info, const std::vector &input_tensors, - const std::vector &tensors_mask); + const std::vector &tensors_mask); // trans BaseRef list to py::tuple BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref); // create a new kernel graph and update the graph sum diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index a0eb2740d43..d537c1dec9e 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -193,6 +193,7 @@ const size_t kShape4dDims = 4; const size_t kShape5dDims = 5; const size_t kCubeSize = 16; const size_t kMemAlignSize = 512; +const int kValueNodeTensorMask = 2; // define special index in special node constexpr auto kAnfPrimitiveIndex = 0; diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index 2fb9f883dad..ee4551549e4 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -172,7 +172,8 @@ class Parameter: def __mul__(self, other): res = deepcopy(self) - res.default_input = res.default_input * other + default_input = res.default_input * other + res.default_input = Tensor(default_input.asnumpy().copy()) return res def __truediv__(self, other):