diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index a24650b6b80..8d650e7c066 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -327,6 +327,7 @@ void RunOpAscendBackendIRFusionOptimization(const std::shared_ptr("ir_fusion_pm"); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); + ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); ir_fusion_pm->AddPass(std::make_shared()); diff --git a/mindspore/ccsrc/backend/session/ascend_session.cc b/mindspore/ccsrc/backend/session/ascend_session.cc index 9b43823d89e..6a12c810fd7 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.cc +++ b/mindspore/ccsrc/backend/session/ascend_session.cc @@ -22,6 +22,7 @@ #include #include "base/core_ops.h" +#include "base/base_ref_utils.h" #include "ir/tensor.h" #include "ir/anf.h" #include "common/trans.h" @@ -123,6 +124,284 @@ void InsertMakeTupleForOutput(NotNull root_graph) { {NewValueNode(std::make_shared(prim::kPrimMakeTuple->name())), root_graph->output()}); root_graph->set_output(make_tuple); } + +BaseRef CreateNodeOutputPlaceholder(const session::KernelWithIndex &node_output_pair, const KernelGraphPtr &graph, + const std::vector &input_tensors, + const std::vector &indexes, + std::map> *output_indexes) { + auto &node = node_output_pair.first; + MS_EXCEPTION_IF_NULL(node); + MS_EXCEPTION_IF_NULL(graph); + MS_EXCEPTION_IF_NULL(output_indexes); + MS_LOG(INFO) << "Create placeholder for output[" << node->DebugString() << "] index[" << node_output_pair.second + << "]"; + // if node is a value node, no need sync addr from device to host + if (node->isa()) { + auto value_node = node->cast(); + MS_EXCEPTION_IF_NULL(value_node); + return value_node->value(); + } + if (node->isa()) { + for (size_t input_idx = 0; input_idx < graph->inputs().size(); input_idx++) { + if (input_idx >= input_tensors.size()) { + MS_LOG(EXCEPTION) << "Input idx:" << input_idx << "out of range:" << input_tensors.size(); + } + if (graph->inputs()[input_idx] == node) { + return input_tensors[input_idx]; + } + } + MS_LOG(EXCEPTION) << "Parameter: " << node->DebugString() << " has no output addr"; + } + (*output_indexes)[node_output_pair] = indexes; + BaseRef output_placeholder = std::make_shared(); + return output_placeholder; +} + +BaseRef CreateNodeOutputPlaceholder(const AnfNodePtr &anf, const KernelGraphPtr &graph, + const std::vector &input_tensors, + const std::vector &indexes, + std::map> *output_indexes) { + MS_EXCEPTION_IF_NULL(anf); + MS_EXCEPTION_IF_NULL(output_indexes); + MS_LOG(INFO) << "Create placeholder for output[" << anf->DebugString() << "]"; + auto item_with_index = AnfAlgo::VisitKernelWithReturnType(anf, 0); + MS_EXCEPTION_IF_NULL(item_with_index.first); + MS_LOG(INFO) << "Create placeholder for output after visit:" << item_with_index.first->DebugString(); + // special handle for maketuple + if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) { + auto cnode = item_with_index.first->cast(); + MS_EXCEPTION_IF_NULL(cnode); + VectorRef ret; + for (size_t i = 1; i < cnode->inputs().size(); ++i) { + std::vector cur_index = indexes; + cur_index.emplace_back(i - 1); + auto out = CreateNodeOutputPlaceholder(cnode->input(i), graph, input_tensors, cur_index, output_indexes); + ret.push_back(out); + } + return ret; + } + // if is graph return nothing ,the function should return a null anylist + size_t size = AnfAlgo::GetOutputTensorNum(item_with_index.first); + if (size == 0) { + return VectorRef(); + } + return CreateNodeOutputPlaceholder(item_with_index, graph, input_tensors, indexes, output_indexes); +} + +void CreateOutputPlaceholder(const KernelGraphPtr &kernel_graph, const std::vector &input_tensors, + VectorRef *outputs, std::map> *output_indexes) { + MS_EXCEPTION_IF_NULL(kernel_graph); + MS_EXCEPTION_IF_NULL(outputs); + MS_EXCEPTION_IF_NULL(output_indexes); + auto anf_outputs = kernel_graph->outputs(); + size_t index = 0; + for (auto &item : anf_outputs) { + MS_EXCEPTION_IF_NULL(item); + MS_LOG(INFO) << "Create node output placeholder[" << item->DebugString() << "]"; + std::vector indexes{index++}; + outputs->emplace_back(CreateNodeOutputPlaceholder(item, kernel_graph, input_tensors, indexes, output_indexes)); + } +} + +void GetRefCount(KernelGraph *graph, std::map *ref_count) { + MS_EXCEPTION_IF_NULL(graph); + for (const auto &kernel : graph->execution_order()) { + for (size_t i = 1; i < kernel->inputs().size(); i += 1) { + const auto &input = kernel->input(i); + auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); + const auto &node = kernel_with_index.first; + if (node->isa()) { + (*ref_count)[kernel_with_index] += 1; + } + } + } +} + +void GetParameterIndex(KernelGraph *graph, const std::vector &inputs, + std::map *parameter_index) { + size_t index = 0; + for (const auto &input_node : graph->inputs()) { + auto params = AnfAlgo::GetAllOutput(input_node); + for (const auto ¶m : params) { + if (index >= inputs.size()) { + MS_LOG(EXCEPTION) << "Parameter size out of range. Parameter index: " << index + << ", input size: " << inputs.size(); + } + const auto &input = inputs[index]; + // Check shape of input and parameter + const auto &input_shape = input->shape(); + const auto ¶m_shape = AnfAlgo::GetOutputInferShape(param, 0); + if (input_shape.size() != param_shape.size()) { + MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index + << ", parameter: " << param->fullname_with_scope(); + } + for (size_t i = 0; i < input_shape.size(); i += 1) { + if (input_shape[i] < 0 || static_cast(input_shape[i]) != param_shape[i]) { + MS_LOG(EXCEPTION) << "Shapes of input and parameter are different, input index: " << index + << ", parameter: " << param->fullname_with_scope(); + } + } + parameter_index->emplace(param, index++); + } + } +} + +void GetOpInputTensors(const CNodePtr &cnode, const std::map &op_output, + const std::map ¶meter_index, + const std::vector &graph_inputs, InputTensorInfo *input_tensor_info) { + MS_EXCEPTION_IF_NULL(cnode); + for (size_t i = 1; i < cnode->inputs().size(); i += 1) { + const auto &input = cnode->input(i); + auto kernel_with_index = AnfAlgo::VisitKernel(input, 0); + auto real_input = kernel_with_index.first; + MS_EXCEPTION_IF_NULL(real_input); + tensor::TensorPtr tensor = nullptr; + if (real_input->isa()) { + auto value_node = input->cast(); + MS_EXCEPTION_IF_NULL(value_node); + auto value = GetValueNode(value_node); + MS_EXCEPTION_IF_NULL(value_node); + if (value->isa()) { + auto value_tuple = value->cast(); + MS_EXCEPTION_IF_NULL(value_tuple); + if (kernel_with_index.second >= value_tuple->size()) { + MS_LOG(EXCEPTION) << "Index " << kernel_with_index.second << "is out of value tuple range"; + } + auto tensor_value = value_tuple->value()[kernel_with_index.second]; + if (tensor_value->isa()) { + tensor = tensor_value->cast(); + } + } else if (value->isa()) { + if (kernel_with_index.second != 0) { + MS_LOG(EXCEPTION) << "Index should be 0 for Tensor ValueNode, but is " << kernel_with_index.second; + } + tensor = GetValueNode(value_node); + } + } else if (real_input->isa()) { + const auto &iter = parameter_index.find(real_input); + if (iter == parameter_index.end()) { + MS_LOG(EXCEPTION) << "Can not find parameter input of cnode, node = " << cnode->DebugString(); + } + const size_t index = iter->second; + if (index >= graph_inputs.size()) { + MS_LOG(EXCEPTION) << "Parameter index is greater than size of graph's input tensor, parameter index = " + << cnode->DebugString() << "input tensor size = " << graph_inputs.size(); + } + tensor = graph_inputs[index]; + } else if (real_input->isa()) { + const auto &iter = op_output.find(kernel_with_index); + if (iter == op_output.end()) { + MS_LOG(EXCEPTION) << "Can not find output tensor of cnode, node = " << real_input->DebugString(); + } + tensor = iter->second; + input_tensor_info->input_kernel.insert(kernel_with_index); + } else { + MS_LOG(EXCEPTION) << "Invalid input node, node = " << real_input->DebugString(); + } + MS_EXCEPTION_IF_NULL(tensor); + MS_LOG(DEBUG) << "Get" << i << "th input tensor of " << cnode->fullname_with_scope() << " from " + << real_input->fullname_with_scope() << "-" << kernel_with_index.second; + input_tensor_info->input_tensors_mask.emplace_back(tensor->is_parameter() ? kParameterWeightTensorMask + : kParameterDataTensorMask); + input_tensor_info->input_tensors.emplace_back(tensor); + } +} + +void HandleOpInputs(const std::set &input_kernel, std::map *ref_count, + std::map *op_output_map) { + for (auto &kernel_with_index : input_kernel) { + if (!kernel_with_index.first->isa()) { + continue; + } + auto ref_iter = ref_count->find(kernel_with_index); + if (ref_iter == ref_count->end()) { + MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in cnode reference count map, input cnode = " + << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second; + } + ref_iter->second -= 1; + if (ref_iter->second != 0) { + continue; + } + ref_count->erase(ref_iter); + auto output_iter = op_output_map->find(kernel_with_index); + if (output_iter == op_output_map->end()) { + MS_LOG(EXCEPTION) << "Can not find input KernelWithIndex in op_output map, input cnode = " + << kernel_with_index.first->DebugString() << ", index = " << kernel_with_index.second; + } + op_output_map->erase(output_iter); + } +} + +void HandleOpOutputs(const AnfNodePtr &kernel, const VectorRef &op_outputs, + const std::map> &output_indexes, + const std::map &ref_count, + std::map *op_output_map, VectorRef *outputs) { + auto output_tensors = TransformVectorRefToMultiTensor(op_outputs); + if (output_tensors.size() != op_outputs.size()) { + MS_LOG(EXCEPTION) << "Op output contains tuple, node = " << kernel->DebugString(); + } + size_t out_index = 0; + for (const auto &output_tensor : output_tensors) { + auto kernel_with_index = make_pair(kernel, out_index++); + if (ref_count.find(kernel_with_index) != ref_count.end()) { + (*op_output_map)[kernel_with_index] = output_tensor; + } + const auto &iter = output_indexes.find(kernel_with_index); + if (iter == output_indexes.end()) { + continue; + } + const std::vector &ref_indexes = iter->second; + size_t n = 0; + const VectorRef *cur_vector_ref = outputs; + while (n != ref_indexes.size() - 1) { + size_t index = ref_indexes.at(n++); + const BaseRef &base_ref = (*cur_vector_ref)[index]; + if (!utils::isa(base_ref)) { + MS_LOG(EXCEPTION) << "Get none VectorRef by ref index, indexes: " << ref_indexes << "cur n: " << n - 1; + } + cur_vector_ref = &utils::cast(base_ref); + } + BaseRef &tensor_ref = (*const_cast(cur_vector_ref))[ref_indexes.at(n)]; + tensor_ref = output_tensor; + } +} + +void GetSingleOpRunInfo(const CNodePtr cnode, OpRunInfo *run_info) { + MS_EXCEPTION_IF_NULL(cnode); + MS_EXCEPTION_IF_NULL(run_info); + auto primitive = AnfAlgo::GetCNodePrimitive(cnode); + run_info->primitive = primitive; + run_info->op_name = primitive->name(); + if (cnode->abstract() == nullptr) { + MS_LOG(EXCEPTION) << "Abstract is nullptr, node = " << cnode->DebugString(); + } + run_info->abstract = cnode->abstract(); +} + +GraphInfo GetSingleOpGraphInfo(const PrimitivePtr &prim, const std::vector &input_tensors) { + MS_EXCEPTION_IF_NULL(prim); + GraphInfo graph_info; + // get input tensor info + for (const auto &tensor : input_tensors) { + MS_EXCEPTION_IF_NULL(tensor); + auto tensor_shape = tensor->shape(); + (void)std::for_each(tensor_shape.begin(), tensor_shape.end(), + [&](const auto &dim) { (void)graph_info.append(std::to_string(dim) + "_"); }); + (void)graph_info.append(std::to_string(tensor->data_type()) + "_"); + if (tensor->device_address() != nullptr) { + const auto type_id = std::dynamic_pointer_cast(tensor->device_address())->type_id(); + (void)graph_info.append(std::to_string(type_id) + "_"); + const auto format = std::dynamic_pointer_cast(tensor->device_address())->format(); + (void)graph_info.append(format + "_"); + } + } + // get attr info + const auto &attr_map = prim->evaluate_added_attrs(); + (void)std::for_each(attr_map.begin(), attr_map.end(), + [&](const auto &element) { (void)graph_info.append(element.second->ToString() + "_"); }); + graph_info.append(prim->id()); + return graph_info; +} } // namespace void AscendSession::Init(uint32_t device_id) { @@ -417,7 +696,7 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra MS_EXCEPTION_IF_NULL(graph); MS_LOG(INFO) << "Run op " << op_run_info.op_name << " start!"; // malloc mem - RunOpMemoryAlloc(op_run_info.value, input_tensors, graph.get()); + RunOpMemoryAlloc(input_tensors, graph.get()); // Build dynamic kernel if (op_run_info.is_dynamic_shape) { BuildDynamicKernel(graph); @@ -432,6 +711,39 @@ void AscendSession::RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &gra MS_LOG(INFO) << "Run op " << op_run_info.op_name << " finish!"; } +void AscendSession::RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, + VectorRef *outputs) { + MS_LOG(INFO) << "Start"; + auto kernel_graph = GetGraph(graph_id); + std::map parameter_index; + GetParameterIndex(kernel_graph.get(), inputs, ¶meter_index); + std::map> output_indexes; + CreateOutputPlaceholder(kernel_graph, inputs, outputs, &output_indexes); + std::map cnode_ref; + GetRefCount(kernel_graph.get(), &cnode_ref); + + std::map op_output_map; + for (const auto &kernel : kernel_graph->execution_order()) { + // Generate input tensors, tensor masks and input kernel with index + InputTensorInfo input_tensor_info; + GetOpInputTensors(kernel, op_output_map, parameter_index, inputs, &input_tensor_info); + + // Get OpRunInfo and GraphInfo + OpRunInfo run_info; + GetSingleOpRunInfo(kernel, &run_info); + GraphInfo graph_info = GetSingleOpGraphInfo(run_info.primitive, input_tensor_info.input_tensors); + + // Build and run current single op + BuildOpImpl(run_info, graph_info, input_tensor_info.input_tensors, input_tensor_info.input_tensors_mask); + VectorRef op_outputs; + RunOpImpl(run_info, graph_info, input_tensor_info.input_tensors, &op_outputs); + + // Handle inputs and outputs of current op + HandleOpInputs(input_tensor_info.input_kernel, &cnode_ref, &op_output_map); + HandleOpOutputs(kernel, op_outputs, output_indexes, cnode_ref, &op_output_map, outputs); + } +} + // compile graph steps void AscendSession::SelectKernel(const KernelGraph &kernel_graph) const { MS_LOG(INFO) << "Start!"; @@ -591,15 +903,14 @@ void AscendSession::MemoryAlloc(KernelGraph *kernel_graph) const { MS_LOG(INFO) << "Finish!"; } -void AscendSession::RunOpMemoryAlloc(const ValuePtr &pre_output_value, - const std::vector &input_tensors, +void AscendSession::RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const { MS_LOG(INFO) << "Start memory alloc!"; MS_EXCEPTION_IF_NULL(kernel_graph); opt::RemoveNopNode(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph); + runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); MS_LOG(INFO) << "Finish!"; } diff --git a/mindspore/ccsrc/backend/session/ascend_session.h b/mindspore/ccsrc/backend/session/ascend_session.h index 4886ade5904..e46a98a6c69 100644 --- a/mindspore/ccsrc/backend/session/ascend_session.h +++ b/mindspore/ccsrc/backend/session/ascend_session.h @@ -35,6 +35,11 @@ namespace mindspore { namespace session { enum GraphType : int { COMMON_GRAPH = 0, CONDITION_GRAPH = 1, BRANCH_START = 2, BRANCH_END = 3 }; +struct InputTensorInfo { + std::vector input_tensors; + std::vector input_tensors_mask; + std::set input_kernel; +}; class AscendSession : public SessionBasic { public: @@ -56,6 +61,8 @@ class AscendSession : public SessionBasic { const std::vector &tensors_mask) override; void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, VectorRef *outputs) override; + void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, + VectorRef *outputs) override; private: // compile child graph when session have multiple child graphs @@ -72,8 +79,7 @@ class AscendSession : public SessionBasic { void BuildKernel(const std::shared_ptr &kernel_graph) const; void BuildDynamicKernel(const std::shared_ptr &kernel_graph) const; void MemoryAlloc(KernelGraph *kernel_graph) const; - void RunOpMemoryAlloc(const ValuePtr &pre_output_value, const std::vector &input_tensors, - KernelGraph *kernel_graph) const; + void RunOpMemoryAlloc(const std::vector &input_tensors, KernelGraph *kernel_graph) const; void RunOpMemoryClear(const KernelGraph *kernel_graph) const; void Load(const std::shared_ptr &kernel_graph) const; void Execute(const std::shared_ptr &kernel_graph, bool is_task) const; diff --git a/mindspore/ccsrc/backend/session/executor.cc b/mindspore/ccsrc/backend/session/executor.cc index 570a9dced49..7477cee6f0c 100644 --- a/mindspore/ccsrc/backend/session/executor.cc +++ b/mindspore/ccsrc/backend/session/executor.cc @@ -16,8 +16,9 @@ #include "backend/session/executor.h" #include #include -#include "runtime/device/kernel_runtime_manager.h" + #include "backend/session/executor_manager.h" +#include "runtime/device/kernel_runtime_manager.h" #include "utils/comm_manager.h" #include "utils/scoped_long_running.h" @@ -134,6 +135,11 @@ void RunOpTask::Run() { session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_); } +void RunOpsInGraphTask::Run() { + MS_EXCEPTION_IF_NULL(session_); + session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_); +} + void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); } void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); } @@ -361,6 +367,18 @@ void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const Gr *outputs = task->outputs_; } +void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, + const std::vector &inputs, VectorRef *outputs) { + MS_EXCEPTION_IF_NULL(session); + MS_EXCEPTION_IF_NULL(outputs); + auto task = std::make_shared(); + task->session_ = session; + task->graph_id_ = graph_id; + task->input_tensors_ = inputs; + SyncRunTask(task); + *outputs = task->outputs_; +} + bool Executor::CreateCommGroup(const std::string &group_name, std::vector ranks) { auto task = std::make_shared(); task->group_name_ = group_name; diff --git a/mindspore/ccsrc/backend/session/executor.h b/mindspore/ccsrc/backend/session/executor.h index 48fc7553e80..cdd66f50c38 100644 --- a/mindspore/ccsrc/backend/session/executor.h +++ b/mindspore/ccsrc/backend/session/executor.h @@ -16,22 +16,23 @@ #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H #define MINDSPORE_CCSRC_BACKEND_SESSION_EXECUTOR_H -#include -#include -#include -#include -#include -#include -#include -#include -#include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include + #include "backend/session/session_basic.h" #include "ir/anf.h" #include "ir/tensor.h" #include "utils/any.h" -#include "utils/contract.h" #include "utils/comm_manager.h" +#include "utils/contract.h" namespace mindspore { namespace session { @@ -45,7 +46,8 @@ enum TaskType { kRunGraph, kRunOp, kCreateCommGroup, - kDestroyCommGroup + kDestroyCommGroup, + kRunOpsInGraph }; class Task { @@ -98,6 +100,16 @@ class RunGraphTask : public Task { std::map tensor_to_node_; }; +class RunOpsInGraphTask : public Task { + public: + RunOpsInGraphTask() { type_ = kRunOpsInGraph; } + ~RunOpsInGraphTask() override = default; + void Run() override; + std::vector input_tensors_; + VectorRef outputs_; + GraphId graph_id_{0}; +}; + class BuildOpTask : public Task { public: BuildOpTask() { type_ = kBuildOp; } @@ -162,6 +174,8 @@ class Executor { const std::vector &input_tensors, const std::vector &tensors_mask); void RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, VectorRef *outputs); + void RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id, const std::vector &inputs, + VectorRef *outputs); void OnRunGraphFinished(); bool CreateCommGroup(const std::string &group_name, std::vector ranks); bool DestroyCommGroup(const std::string &group_name); diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index f16957a5117..5313ba338d1 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -198,13 +198,12 @@ void GPUSession::AllocateMemory(KernelGraph *kernel_graph) const { runtime_instance->AssignMemory(kernel_graph); } -void GPUSession::RunOpAllocateMemory(const ValuePtr &pre_output_value, - const std::vector &input_tensors, +void GPUSession::RunOpAllocateMemory(const std::vector &input_tensors, KernelGraph *kernel_graph) const { MS_EXCEPTION_IF_NULL(kernel_graph); auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); MS_EXCEPTION_IF_NULL(runtime_instance); - runtime_instance->RunOpAssignMemory(pre_output_value, input_tensors, kernel_graph); + runtime_instance->RunOpAssignMemory(input_tensors, kernel_graph); } void GPUSession::RunOpClearMemory(KernelGraph *kernel_graph) const { @@ -351,6 +350,8 @@ void GPUSession::RunGraphImpl(const GraphId &graph_id, const std::vector &kernel_graph TensorLoader *tensor_loader = debug_services->tensor_loader(); tensor_loader->EmptyPrevTensor(); } + +void GPUSession::SyncValueNodeDeviceAddr(const std::shared_ptr &kernel_graph) const { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { + return; + } + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->SyncValueNodeDeviceAddr(kernel_graph.get()); +} + +void GPUSession::CleanValueNodeDeviceAddr(const std::shared_ptr &kernel_graph) const { + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + if (context_ptr->get_param(MS_CTX_EXECUTION_MODE) != kPynativeMode) { + return; + } + auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_); + MS_EXCEPTION_IF_NULL(runtime_instance); + runtime_instance->CleanValueNodeDeviceAddr(kernel_graph.get()); +} } // namespace gpu } // namespace session } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/gpu_session.h b/mindspore/ccsrc/backend/session/gpu_session.h index 544aa543693..3ac31ccbd01 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.h +++ b/mindspore/ccsrc/backend/session/gpu_session.h @@ -61,8 +61,7 @@ class GPUSession : public SessionBasic { void AllocateMemory(KernelGraph *kernel_graph) const; - void RunOpAllocateMemory(const ValuePtr &pre_output_value, const std::vector &input_tensors, - KernelGraph *kernel_graph) const; + void RunOpAllocateMemory(const std::vector &input_tensors, KernelGraph *kernel_graph) const; void RunOpClearMemory(KernelGraph *kernel_graph) const; @@ -82,6 +81,10 @@ class GPUSession : public SessionBasic { void PreLoadTensor(const std::shared_ptr &kernel_graph) const; void PostLoadTensor(const std::shared_ptr &kernel_graph) const; + + void SyncValueNodeDeviceAddr(const std::shared_ptr &kernel_graph) const; + + void CleanValueNodeDeviceAddr(const std::shared_ptr &kernel_graph) const; }; using GPUSessionPtr = std::shared_ptr; MS_REG_SESSION(kGPUDevice, GPUSession); diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 8385e45a35d..dc04191523d 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -14,9 +14,11 @@ * limitations under the License. */ #include "backend/session/session_basic.h" -#include + #include +#include #include +#include #include "c_ops/primitive_c.h" #include "ir/manager.h" @@ -1606,6 +1608,12 @@ void SessionBasic::RunOp(OpRunInfo *op_run_info, const GraphInfo &graph_info, executor_->RunOp(shared_from_this(), op_run_info, graph_info, input_tensors, outputs); } +void SessionBasic::RunOpsInGraph(const GraphId &graph_id, const std::vector &inputs, + VectorRef *outputs) { + MS_EXCEPTION_IF_NULL(executor_); + executor_->RunOpsInGraph(shared_from_this(), graph_id, inputs, outputs); +} + void SessionBasic::RunGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs) { MS_EXCEPTION_IF_NULL(executor_); executor_->RunGraph(shared_from_this(), graph_id, inputs, outputs); diff --git a/mindspore/ccsrc/backend/session/session_basic.h b/mindspore/ccsrc/backend/session/session_basic.h index 57f2c3ea530..2037e971bbc 100644 --- a/mindspore/ccsrc/backend/session/session_basic.h +++ b/mindspore/ccsrc/backend/session/session_basic.h @@ -22,6 +22,7 @@ #include #include #include +#include #include "backend/session/session_context.h" #include "backend/session/kernel_graph.h" #include "backend/session/anf_runtime_algorithm.h" @@ -49,7 +50,6 @@ struct OpRunInfo { std::string op_name; PrimitivePtr primitive; AbstractBasePtr abstract; - ValuePtr value = nullptr; bool is_dynamic_shape = false; bool is_auto_mixed_precision = false; std::string next_op_name = ""; @@ -79,6 +79,7 @@ class SessionBasic : public std::enable_shared_from_this { void BuildOp(OpRunInfo *, const GraphInfo &, const std::vector &input_tensors, const std::vector &tensors_mask); void RunOp(OpRunInfo *, const GraphInfo &, const std::vector &input_tensors, VectorRef *outputs); + void RunOpsInGraph(const GraphId &graph_id, const std::vector &inputs, VectorRef *outputs); virtual void RegisterSummaryCallBackFunc(const CallBackFunc &callback); @@ -138,6 +139,7 @@ class SessionBasic : public std::enable_shared_from_this { friend class RunGraphTask; friend class BuildOpTask; friend class RunOpTask; + friend class RunOpsInGraphTask; virtual bool IsSupportSummary() { return true; } virtual void CreateOutputTensors(const GraphId &graph_id, const std::vector &input_tensors, VectorRef *outputs, @@ -155,6 +157,8 @@ class SessionBasic : public std::enable_shared_from_this { const std::vector &tensors_mask) {} virtual void RunOpImpl(const OpRunInfo &op_run_info, const GraphInfo &graph_info, const std::vector &input_tensors, VectorRef *outputs) {} + virtual void RunOpsInGraphImpl(const GraphId &graph_id, const std::vector &inputs, + VectorRef *outputs) {} void RunInfer(NotNull func_graph, const std::vector &inputs); virtual void SetSummaryNodes(KernelGraph *graph); diff --git a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc index db0ee7a896d..43f92a71b65 100644 --- a/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc +++ b/mindspore/ccsrc/frontend/optimizer/ad/dfunctor.cc @@ -281,24 +281,6 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) { return node_adjoint; } -void TensorSetAddress(const ValuePtr &value, std::map *tuple_tensors) { - MS_LOG(DEBUG) << "Start set tensor address" << value->ToString() << value->isa(); - if (value->isa()) { - auto tnode = value->cast(); - if (tuple_tensors->find(tnode->id()) != tuple_tensors->end()) { - MS_LOG(DEBUG) << "Set tensor" << tnode->device_address(); - (*tuple_tensors)[tnode->id()]->set_device_address(tnode->device_address()); - } - } - if (value->isa()) { - auto tuple = value->cast(); - for (size_t i = 0; i < tuple->size(); i++) { - MS_LOG(DEBUG) << "Set tuple tensor" << (*tuple)[i]->ToString(); - TensorSetAddress((*tuple)[i], tuple_tensors); - } - } -} - ValuePtr GenNewTensorInner(const ValuePtr &value) { std::vector value_list; if (value->isa()) { @@ -328,7 +310,6 @@ ValuePtr GenNewTensor(const FuncGraphManagerPtr &mng, const AnfNodePtr &node, co void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_morph) { auto forward = cnode_morph->forward().first; - auto forward_id = cnode_morph->forward().second; if (forward == nullptr) { return; } @@ -337,6 +318,7 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor return; } auto fg = GetValueNode(input); + // {prim::maketuple, forward_output, bprop_graph} auto output = fg->output(); if (!output->isa()) { return; @@ -350,25 +332,22 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor if (!IsValueNode(input_fg)) { return; } - std::map tuple_tensors; + // replace forward output with value node auto equivdout = cnode_input->cast(); + MS_EXCEPTION_IF_NULL(equivdout); auto func_graph = GetValueNode(input_fg); + MS_EXCEPTION_IF_NULL(func_graph); auto manager = Manage({fg, func_graph}, false); - auto ref_size = manager->node_users()[equivdout].size(); - auto forward_value = forward; - if (!forward_id.empty() && ref_size > 1) { - auto inst = pynative::PynativeExecutor::GetInstance(); - inst->SaveOpForwardValue(forward_id, forward_value, &tuple_tensors); - } - forward_value = GenNewTensor(manager, equivdout, forward); + auto forward_value = GenNewTensor(manager, equivdout, forward); MS_LOG(DEBUG) << "Replace: " << equivdout->ToString() << " with " << forward; auto value_node = NewValueNode(forward_value); value_node->set_has_new_value(true); manager->Replace(equivdout, value_node); + // replace input object with value node auto paras = fg->parameters(); auto inputs_value = cnode_morph->inputs_value(); - if (inputs_value.size() == 0) { + if (inputs_value.empty()) { return; } if (inputs_value.size() != paras.size()) { @@ -379,10 +358,6 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor auto input_value = inputs_value[i]; if (para_ref_size > 0 && input_value.first != nullptr) { MS_LOG(DEBUG) << "Replace: " << paras[i]->ToString() << " with " << input_value.first; - auto inst = pynative::PynativeExecutor::GetInstance(); - if (!input_value.second.empty()) { - inst->SaveOpForwardValue(input_value.second, input_value.first, &tuple_tensors); - } auto input_value_node = NewValueNode(input_value.first); input_value_node->set_has_new_value(true); manager->Replace(paras[i], input_value_node); @@ -394,30 +369,19 @@ void DFunctor::ReplaceEquivdout(const CNodePtr &cnode, const CNodePtr &cnode_mor res->set_func_graph(fg); PynativeElimOpt(res); auto out = fg->output()->cast(); + MS_EXCEPTION_IF_NULL(out); auto c_input = out->input(1); + MS_EXCEPTION_IF_NULL(c_input); if (!c_input->isa()) { return; } - auto out_node = c_input->cast(); + MS_EXCEPTION_IF_NULL(out_node); out_node->set_value(GenNewTensor(manager, out_node, out_node->value())); - + // clear resource cnode_morph->clear_inputs_value(); - - if (tuple_tensors.size() != 0) { - MS_LOG(DEBUG) << "Start tuple out" << fg->output()->DebugString(4); - for (auto &g : manager->func_graphs()) { - for (auto &node : g->value_nodes()) { - MS_LOG(DEBUG) << "Set Tensor addr" << node.first->ToString(); - auto vnode = node.first->cast()->value(); - TensorSetAddress(vnode, &tuple_tensors); - } - } - } - fg->ClearAllManagerInfo(); func_graph->ClearAllManagerInfo(); - return; } bool DFunctor::IsFreeMorphism(const AnfNodePtr &node) { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index a7fbdd56e85..00fcef715b3 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -298,14 +298,29 @@ class PynativeEliminater : public OptimizerCaller { return out; } + void OnlySaveAbstractInfo(const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(value_node); + auto &value = value_node->value(); + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + auto tensor = value->cast(); + MS_EXCEPTION_IF_NULL(tensor); + auto new_tensor = std::make_shared(tensor->Dtype()->type_id(), tensor->shape()); + value_node->set_value(MakeValue(new_tensor)); + } + } + public: AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { MS_LOG(DEBUG) << "Start replace node " << node->DebugString(4); - PatternNode symbol_str_vnode, c_vnode, zeros_like_vnode, getitem_vnode, arg, arg1; + PatternNode symbol_str_vnode; + PatternNode c_vnode; + PatternNode zeros_like_vnode; + PatternNode arg; auto resolve = PPrimitive(prim::kPrimResolve, symbol_str_vnode, c_vnode); auto getattr = PPrimitive(prim::kPrimGetAttr, resolve, zeros_like_vnode); auto pattern = PCNode(getattr, arg); - + // {{prim:getattr, {prim::resolve, SymbolStr, C}, zeros_like}, Xy} ->Tensor(0, shape(Xy)) if ((pattern).TryCapture(node) && (CheckNameSpaceVNode(symbol_str_vnode.GetNode(node), "SymbolStr") && CheckSymbolVNode(c_vnode.GetNode(node), "C") && CheckStrVNode(zeros_like_vnode.GetNode(node), "zeros_like"))) { @@ -320,8 +335,8 @@ class PynativeEliminater : public OptimizerCaller { } } } - MS_LOG(DEBUG) << "End replace 1 " << node->DebugString(4); + // {prim:getattr, {prim::resolve, SymbolStr, zeros_like}, Xy} ->Tensor(0, shape(Xy)) auto resolve1 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, zeros_like_vnode); auto pattern1 = PCNode(resolve1, arg); @@ -338,7 +353,13 @@ class PynativeEliminater : public OptimizerCaller { } } } - + // {prim:getattr, {prim::resolve, SymbolStr, binop_grad_common}, x, y, out, dout} -> {shape(x), shape(y), out, dout} + PatternNode binop_grad_common; + PatternNode getitem_vnode; + PatternNode arg1; + PatternNode arg2; + PatternNode arg3; + PatternNode arg4; // resolve(CommonOPS, getitem)((tensors), 3) auto resolve2 = PPrimitive(prim::kPrimResolve, symbol_str_vnode, getitem_vnode); auto pattern2 = PCNode(resolve2, arg, arg1); diff --git a/mindspore/ccsrc/pipeline/pynative/base.h b/mindspore/ccsrc/pipeline/pynative/base.h index 81fd4b6aa07..130ace395b8 100644 --- a/mindspore/ccsrc/pipeline/pynative/base.h +++ b/mindspore/ccsrc/pipeline/pynative/base.h @@ -51,21 +51,19 @@ enum RunOpArgsEnum { PY_PRIM = 0, PY_NAME, PY_INPUTS, PY_ARGS_NUM }; struct OpExecInfo { std::string op_name; + std::string op_index; std::string prim_id; PrimitivePyPtr py_primitive; AbstractBasePtr abstract; - bool is_dynamic_shape = false; - ValuePtr value = nullptr; py::list op_inputs; - py::dict op_attrs; std::vector inputs_mask; + bool is_dynamic_shape = false; std::string next_op_name = ""; bool is_mixed_precision_cast = false; size_t next_input_index = 0; }; using OpExecInfoPtr = std::shared_ptr; -OpExecInfoPtr GenerateOpExecInfo(const py::args &args); const std::set ignore_infer_prim = {"make_ref", "mixed_precision_cast"}; const std::set force_infer_prim = {"TopK", "DropoutGenMask"}; diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc index 6c1e9dfe31c..12c7b654faf 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.cc @@ -149,12 +149,6 @@ static std::string GetId(const py::object &obj) { return py::cast(ret); } -static std::string GetOpId(const OpExecInfoPtr &op_exec_info) { - auto id = GetId(op_exec_info->py_primitive->GetPyObj()); - op_exec_info->prim_id = id; - return id; -} - std::map> GetTypeIndex(const std::vector &dtypes) { std::map> type_indexes; for (size_t i = 0; i < dtypes.size(); ++i) { @@ -260,24 +254,6 @@ void PynativeInfer(const PrimitivePyPtr &prim, const py::list &py_args, OpExecIn MS_LOG(DEBUG) << "Prim " << prim->name() << " infer result " << op_exec_info->abstract->ToString(); } -OpExecInfoPtr GenerateOpExecInfo(const py::args &args) { - if (args.size() != PY_ARGS_NUM) { - MS_LOG(ERROR) << "Three args are needed by RunOp"; - return nullptr; - } - auto op_exec_info = std::make_shared(); - MS_EXCEPTION_IF_NULL(op_exec_info); - op_exec_info->op_name = py::cast(args[PY_NAME]); - auto prim = py::cast(args[PY_PRIM]); - if (!prim->HasPyObj()) { - MS_LOG(EXCEPTION) << "Pyobj is empty"; - } - op_exec_info->py_primitive = prim; - op_exec_info->op_attrs = py::getattr(args[PY_PRIM], "attrs"); - op_exec_info->op_inputs = args[PY_INPUTS]; - return op_exec_info; -} - std::string GetSingleOpGraphInfo(const OpExecInfoPtr &op_exec_info, const std::vector &input_tensors) { MS_EXCEPTION_IF_NULL(op_exec_info); @@ -580,7 +556,7 @@ py::tuple RunOp(const py::args &args) { auto executor = PynativeExecutor::GetInstance(); MS_EXCEPTION_IF_NULL(executor); MS_LOG(DEBUG) << "RunOp start " << args.size(); - OpExecInfoPtr op_exec_info = GenerateOpExecInfo(args); + OpExecInfoPtr op_exec_info = executor->GenerateOpExecInfo(args); try { return executor->RunOpInner(op_exec_info); } catch (const py::error_already_set &ex) { @@ -608,16 +584,17 @@ py::tuple RunOp(const py::args &args) { } py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { - auto prim = op_exec_info->py_primitive; - auto name = op_exec_info->op_name; if (op_exec_info->op_name == prim::kPrimMixedPrecisionCast->name()) { return RunOpWithInitBackendPolicy(op_exec_info); } - + // make cnode for building grad graph if grad flag is set. abstract::AbstractBasePtrList args_spec_list; std::vector op_masks; - auto cnode = PynativeExecutor::GetInstance()->MakeCNode(op_exec_info, &op_masks, &args_spec_list); + auto cnode = MakeCNode(op_exec_info, &op_masks, &args_spec_list); + op_exec_info->inputs_mask = op_masks; + // get output abstract info bool is_find = false; + auto prim = op_exec_info->py_primitive; if (prim_abs_list_.find(prim->id()) != prim_abs_list_.end()) { auto abs_list = prim_abs_list_[prim->id()]; MS_LOG(DEBUG) << "Match prim input args " << op_exec_info->op_name << mindspore::ToString(args_spec_list); @@ -629,7 +606,6 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { is_find = true; } } - if (op_exec_info->abstract == nullptr || force_infer_prim.find(op_exec_info->op_name) != force_infer_prim.end()) { // use python infer method if (ignore_infer_prim.find(op_exec_info->op_name) == ignore_infer_prim.end()) { @@ -648,11 +624,10 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { if (cnode != nullptr) { cnode->set_abstract(op_exec_info->abstract); } - - op_exec_info->inputs_mask = op_masks; + // infer output value for const prim MS_EXCEPTION_IF_NULL(op_exec_info); if (op_exec_info->abstract != nullptr) { - MS_LOG(DEBUG) << "Run op infer " << name << " " << op_exec_info->abstract->ToString(); + MS_LOG(DEBUG) << "Run op infer " << op_exec_info->op_name << " " << op_exec_info->abstract->ToString(); py::dict output = abstract::ConvertAbstractToPython(op_exec_info->abstract); if (!output["value"].is_none()) { py::tuple value_ret(1); @@ -665,7 +640,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { return value_ret; } } - + // add output abstract info into cache if (!is_find) { // const_value need infer every step auto &out = prim_abs_list_[prim->id()]; @@ -674,13 +649,7 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { out[args_spec_list].attrs = prim->evaluate_added_attrs(); MS_LOG(DEBUG) << "Set prim " << op_exec_info->op_name << mindspore::ToString(args_spec_list); } - - if (PynativeExecutor::GetInstance()->grad_flag()) { - op_exec_info->value = PynativeExecutor::GetInstance()->GetForwardValue(op_exec_info); - } else { - (void)GetOpId(op_exec_info); - } - + // run op with selected backend auto result = RunOpWithInitBackendPolicy(op_exec_info); py::object out_real = result; if (result.size() == 1) { @@ -689,13 +658,38 @@ py::tuple PynativeExecutor::RunOpInner(const OpExecInfoPtr &op_exec_info) { } std::string obj_id = GetId(out_real); node_abs_map_[obj_id] = op_exec_info->abstract; - PynativeExecutor::GetInstance()->SaveOutputNodeMap(obj_id, out_real, cnode); - if (cnode != nullptr) { - PynativeExecutor::GetInstance()->SaveAllResult(op_exec_info, cnode->cast(), result); - } + SaveOutputNodeMap(obj_id, out_real, cnode); + SaveAllResult(op_exec_info, cnode, out_real); + // Update the abstract and device address of value node with tensor in grad graph + UpdateAbstractAndDeviceAddress(op_exec_info, out_real); return result; } +OpExecInfoPtr PynativeExecutor::GenerateOpExecInfo(const py::args &args) { + if (args.size() != PY_ARGS_NUM) { + MS_LOG(ERROR) << "Three args are needed by RunOp"; + return nullptr; + } + auto op_exec_info = std::make_shared(); + auto op_name = py::cast(args[PY_NAME]); + op_exec_info->op_name = op_name; + if (grad_flag_) { + MS_EXCEPTION_IF_NULL(resource_); + int64_t graph_id = resource_->results()[pipeline::kPynativeGraphId].cast(); + op_exec_info->op_index = std::to_string(graph_id) + op_name + std::to_string(op_index_map_[op_name]); + op_index_map_[op_name]++; + } + auto prim = py::cast(args[PY_PRIM]); + MS_EXCEPTION_IF_NULL(prim); + if (!prim->HasPyObj()) { + MS_LOG(EXCEPTION) << "Pyobj is empty"; + } + op_exec_info->prim_id = GetId(prim->GetPyObj()); + op_exec_info->py_primitive = prim; + op_exec_info->op_inputs = args[PY_INPUTS]; + return op_exec_info; +} + AnfNodePtr PynativeExecutor::MakeCNode(const OpExecInfoPtr &op_exec_info, std::vector *op_masks, abstract::AbstractBasePtrList *args_spec_list) { MS_EXCEPTION_IF_NULL(op_masks); @@ -997,6 +991,56 @@ AnfNodePtr PynativeExecutor::GetInput(const py::object &obj, bool op_mask) { return node; } +void PynativeExecutor::UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real) { + MS_EXCEPTION_IF_NULL(op_exec_info); + if (!grad_flag_) { + return; + } + auto op_index = op_exec_info->op_index; + auto output_value = PyAttrValue(out_real); + MS_EXCEPTION_IF_NULL(output_value); + std::vector output_tensors; + TensorValueToTensor(output_value, &output_tensors); + if (op_index_with_tensor_id_.find(op_index) == op_index_with_tensor_id_.end()) { + // first step + std::for_each(output_tensors.begin(), output_tensors.end(), [&](const tensor::TensorPtr &tensor) { + op_index_with_tensor_id_[op_index].emplace_back(tensor->id()); + }); + return; + } + const auto &tensor_id_list = op_index_with_tensor_id_[op_index]; + for (size_t i = 0; i < tensor_id_list.size(); ++i) { + auto tensor_id = tensor_id_list[i]; + if (tensor_id_with_tensor_.find(tensor_id) != tensor_id_with_tensor_.end()) { + auto &new_tensor = output_tensors[i]; + auto &tensors_in_value_node = tensor_id_with_tensor_[tensor_id]; + std::for_each(tensors_in_value_node.begin(), tensors_in_value_node.end(), [&](tensor::TensorPtr &tensor) { + tensor->set_shape(new_tensor->shape()); + tensor->set_data_type(new_tensor->data_type()); + tensor->set_device_address(new_tensor->device_address()); + }); + } + } +} + +void PynativeExecutor::SaveTensorsInValueNode(const ResourcePtr &resource) { + MS_EXCEPTION_IF_NULL(resource); + tensor_id_with_tensor_.clear(); + const auto &func_graph = resource->func_graph(); + const auto &value_node_list = func_graph->value_nodes(); + for (const auto &elem : value_node_list) { + auto value_node = elem.first->cast(); + MS_EXCEPTION_IF_NULL(value_node); + std::vector tensors; + TensorValueToTensor(value_node->value(), &tensors); + for (const auto &tensor : tensors) { + if (tensor->device_address() != nullptr) { + tensor_id_with_tensor_[tensor->id()].emplace_back(tensor); + } + } + } +} + AnfNodePtr PynativeExecutor::GetObjNode(const py::object &obj, const std::string &obj_id) { auto &out = graph_info_map_[curr_g_].node_map[obj_id]; if (out.second.size() == 1 && out.second[0] == -1) { @@ -1054,23 +1098,6 @@ AnfNodePtr PynativeExecutor::MakeValueNode(const py::object &obj, const std::str return node; } -ValuePtr PynativeExecutor::GetForwardValue(const OpExecInfoPtr &op_exec_info) { - auto id = GetOpId(op_exec_info); - int64_t graph_id = resource_->results()[pipeline::kPynativeGraphId].cast(); - auto op = std::to_string(graph_id) + id; - op.append(std::to_string(op_id_map_[id])); - auto iter = op_forward_map_.find(op); - if (iter != op_forward_map_.end()) { - ++op_id_map_[id]; - MS_LOG(DEBUG) << "Get: " << op_exec_info->op_name << "(" << op << "), " << iter->second; - return iter->second; - } - if (!first_grad_step_) { - ++op_id_map_[id]; - } - return nullptr; -} - void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::object &out_real, const AnfNodePtr &cnode) { if (!grad_flag_ || graph_info_map_.empty()) { @@ -1093,16 +1120,16 @@ void PynativeExecutor::SaveOutputNodeMap(const std::string &obj_id, const py::ob SetPyObjInGraphInfoMap(curr_g_, obj_id); } -void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out) { - if (!grad_flag_ || op_exec_info->value != nullptr || cnode == nullptr) { +void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, + const py::object &out_real) { + if (!grad_flag_ || node == nullptr) { return; } - py::object out_real = out; - if (out.size() == 1) { - out_real = out[0]; - } - auto value = PyAttrValue(out_real); + MS_EXCEPTION_IF_NULL(op_exec_info); + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + // save input object size_t size = op_exec_info->op_inputs.size(); for (size_t i = 0; i < size; i++) { auto obj = op_exec_info->op_inputs[i]; @@ -1113,59 +1140,19 @@ void PynativeExecutor::SaveAllResult(const OpExecInfoPtr &op_exec_info, const CN cnode->add_input_value(nullptr, ""); } } - std::string id = GetOpId(op_exec_info); - int64_t graph_id = resource_->results()[pipeline::kPynativeGraphId].cast(); - auto op_id = std::to_string(graph_id) + id; - op_id.append(std::to_string(op_id_map_[id])); - cnode->set_forward(value, op_id); - ++op_id_map_[id]; + // save output object + auto output_value = PyAttrValue(out_real); + MS_EXCEPTION_IF_NULL(output_value); + cnode->set_forward(output_value, op_exec_info->op_index); auto out_id = GetId(out_real); if (py::isinstance(out_real)) { auto tuple_item = py::cast(out_real); for (size_t i = 0; i < tuple_item.size(); i++) { auto tuple_item_id = GetId(tuple_item[i]); - obj_to_forward_id_[tuple_item_id] = op_id; + obj_to_forward_id_[tuple_item_id] = op_exec_info->op_index; } - SaveOpForwardValue(op_id, value, nullptr); } - obj_to_forward_id_[out_id] = op_id; -} - -void PynativeExecutor::SaveOpForwardValue(const std::string &id, const ValuePtr &value, - std::map *t_map) { - if (op_forward_map_.find(id) != op_forward_map_.end()) { - // for one op have multi outputs but save only one tensor - if (op_forward_map_[id]->isa() && value->isa()) { - auto tuple = op_forward_map_[id]->cast(); - auto value_t = value->cast(); - for (size_t i = 0; i < tuple->size(); i++) { - if ((*tuple)[i]->isa()) { - auto tuple_t = (*tuple)[i]->cast(); - if (value_t->id() == tuple_t->id()) { - tuple_t->set_device_address(value_t->device_address()); - MS_LOG(DEBUG) << "After Saveop " << tuple_t->ToString(); - break; - } - } - } - } - - if (value->isa() && t_map != nullptr) { - GenTupleMap(op_forward_map_[id]->cast(), t_map); - } - MS_LOG(DEBUG) << "Save op forward value: " - << "(" << id << "), " << op_forward_map_[id]->ToString(); - return; - } - - if (value->isa() && t_map == nullptr) { - // make cnode gen all tuple node and set device_address be null - op_forward_map_[id] = CleanTupleAddr(value->cast()); - } else { - op_forward_map_[id] = value; - } - MS_LOG(DEBUG) << "Save op forward value: " - << "(" << id << "), " << value->ToString(); + obj_to_forward_id_[out_id] = op_exec_info->op_index; } void PynativeExecutor::GenTupleMap(const ValueTuplePtr &tuple, std::map *t_map) { @@ -1307,10 +1294,13 @@ py::object PynativeExecutor::RunOpInMs(const OpExecInfoPtr &op_exec_info, Pynati ConstructInputTensor(op_exec_info, &tensors_mask, &input_tensors); // get graph info for checking it whether existing in the cache std::string graph_info = GetSingleOpGraphInfo(op_exec_info, input_tensors); - session::OpRunInfo op_run_info = {op_exec_info->op_name, op_exec_info->py_primitive, - op_exec_info->abstract, op_exec_info->value, - op_exec_info->is_dynamic_shape, op_exec_info->is_mixed_precision_cast, - op_exec_info->next_op_name, op_exec_info->next_input_index}; + session::OpRunInfo op_run_info = {op_exec_info->op_name, + op_exec_info->py_primitive, + op_exec_info->abstract, + op_exec_info->is_dynamic_shape, + op_exec_info->is_mixed_precision_cast, + op_exec_info->next_op_name, + op_exec_info->next_input_index}; session->BuildOp(&op_run_info, graph_info, input_tensors, tensors_mask); EraseValueNodeTensor(tensors_mask, &input_tensors); VectorRef outputs; @@ -1524,6 +1514,7 @@ void PynativeExecutor::NewGraphInner(const py::object &cell, const py::args &arg if (it != cell_resource_map_.end()) { resource_ = it->second; MS_EXCEPTION_IF_NULL(resource_); + op_index_map_.clear(); } MS_LOG(DEBUG) << "Graph already compiled"; return; @@ -1571,7 +1562,8 @@ void PynativeExecutor::MakeNewTopGraph(const string &cell_id, const py::args &ar resource_->results()[pipeline::kPynativeGraphId] = graph_id_++; cell_resource_map_[cell_id] = resource_; MS_LOG(DEBUG) << "New top graph for " << cell_id; - first_grad_step_ = true; + op_index_map_.clear(); + op_index_with_tensor_id_.clear(); top_graph_cells_.emplace(cell_id); } @@ -1770,6 +1762,7 @@ void PynativeExecutor::GradNetInner(const GradOperationPtr &grad, const py::obje MS_LOG(DEBUG) << "Start opt"; PynativeOptimizeAction(resource_); + SaveTensorsInValueNode(resource_); TaskEmitAction(resource_); ExecuteAction(resource_); cell_graph_map_[cell_id].second = true; @@ -2021,7 +2014,6 @@ void PynativeExecutor::Clear(const std::string &flag) { } ConfigManager::GetInstance().ResetIterNum(); if (top_graph_cells_.find(flag) != top_graph_cells_.end()) { - op_forward_map_.clear(); Clean(); } node_abs_map_.clear(); @@ -2033,9 +2025,7 @@ void PynativeExecutor::Clear(const std::string &flag) { top_g_ = nullptr; df_builder_ = nullptr; curr_g_ = nullptr; - first_grad_step_ = false; graph_info_map_.clear(); - op_id_map_.clear(); obj_to_forward_id_.clear(); node_abs_map_.clear(); std::stack().swap(graph_stack_); diff --git a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h index b37902888b8..bb9c7e5ffc7 100644 --- a/mindspore/ccsrc/pipeline/pynative/pynative_execute.h +++ b/mindspore/ccsrc/pipeline/pynative/pynative_execute.h @@ -83,13 +83,12 @@ class PynativeExecutor : public std::enable_shared_from_this { void set_grad_flag(bool flag) { grad_flag_ = flag; } py::tuple RunOpInner(const OpExecInfoPtr &op_exec_info); + OpExecInfoPtr GenerateOpExecInfo(const py::args &args); void NewGraph(const py::object &cell, const py::args &args); py::object Run(const py::tuple &args, const py::object &phase); py::object CheckGraph(const py::object &cell, const py::args &args); void EndGraph(const py::object &cell, const py::object &out, const py::args &args); void GradNet(const GradOperationPtr &grad, const py::object &cell, const py::object &weights, const py::args &args); - void SaveOpForwardValue(const std::string &id, const ValuePtr &value, - std::map *t_map); // Call by python void Clear(const std::string &flag = ""); @@ -134,9 +133,11 @@ class PynativeExecutor : public std::enable_shared_from_this { // replace for grad graph ValuePtr CleanTupleAddr(const ValueTuplePtr &tuple); - ValuePtr GetForwardValue(const OpExecInfoPtr &op_exec_info); void GenTupleMap(const ValueTuplePtr &tuple, std::map *t_map); - void SaveAllResult(const OpExecInfoPtr &op_exec_info, const CNodePtr &cnode, const py::tuple &out); + void SaveAllResult(const OpExecInfoPtr &op_exec_info, const AnfNodePtr &node, const py::object &out_real); + // Update the abstract and device address info of value node and tensors in bprop graph + void UpdateAbstractAndDeviceAddress(const OpExecInfoPtr &op_exec_info, const py::object &out_real); + void SaveTensorsInValueNode(const ResourcePtr &resource); // construct grad graph void PushCurrentGraphToStack(); @@ -175,7 +176,6 @@ class PynativeExecutor : public std::enable_shared_from_this { static int64_t graph_id_; bool grad_flag_{false}; bool dynamic_cell_{false}; - bool first_grad_step_{false}; bool grad_is_running{false}; // Used for construct grad graph @@ -199,9 +199,10 @@ class PynativeExecutor : public std::enable_shared_from_this { std::unordered_map> df_builder_map_; // used for runop and replace forward result of grad graph - std::unordered_map op_forward_map_; - std::unordered_map op_id_map_; + std::unordered_map op_index_map_; std::unordered_map obj_to_forward_id_; + std::unordered_map> op_index_with_tensor_id_; + std::unordered_map> tensor_id_with_tensor_; std::unordered_map node_abs_map_; std::unordered_map prim_abs_list_; const inline static std::string kOpsFunctionModelName = "mindspore.ops.functional"; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 4768dd0f5e5..d0cb8d6938b 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -81,15 +81,13 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) { UpdateRefNodeOutputMem(graph); } -void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value, - const std::vector &input_tensors, +void KernelRuntime::RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(mem_manager_); mem_manager_->ResetDynamicMemory(); RunOpAssignInputMemory(input_tensors, graph); AssignStaticMemoryValueNode(graph); - RunOpAssignOutputNodeMemory(pre_output_value, graph); for (const auto &cnode : graph->execution_order()) { RunOpAssignOutputMemory(cnode); RunOpAssignWorkSpaceMemory(cnode); @@ -680,6 +678,52 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) { MS_LOG(INFO) << "AssignStaticMemoryValueNode end"; } +void KernelRuntime::SyncValueNodeDeviceAddr(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "SyncValueNodeDeviceAddr start"; + for (auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + auto &node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + if (!node_value->isa() && !node_value->isa()) { + continue; + } + std::vector tensors; + TensorValueToTensor(node_value, &tensors); + for (size_t index = 0; index < tensors.size(); index += 1) { + const auto &tensor = tensors[index]; + if (tensor->device_address() != nullptr) { + AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast(tensor->device_address()), index, + value_node.get()); + } else { + MS_LOG(INFO) << "Tensor of ValueNode[" << value_node->fullname_with_scope() << "]'s device address is nullptr."; + } + } + } + MS_LOG(INFO) << "SyncValueNodeDeviceAddr end"; +} + +void KernelRuntime::CleanValueNodeDeviceAddr(session::KernelGraph *graph) { + MS_EXCEPTION_IF_NULL(graph); + MS_LOG(INFO) << "CleanValueNodeDeviceAddr start"; + for (auto &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + auto &node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + if (!node_value->isa() && !node_value->isa()) { + continue; + } + std::vector tensors; + TensorValueToTensor(node_value, &tensors); + for (size_t index = 0; index < tensors.size(); index += 1) { + if (tensors[index]->device_address() != nullptr) { + AnfAlgo::SetOutputAddr(nullptr, index, value_node.get()); + } + } + } + MS_LOG(INFO) << "CleanValueNodeDeviceAddr end"; +} + void KernelRuntime::AssignDynamicMemory(session::KernelGraph *graph) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(mem_manager_); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 212f0b12280..8ecca3cf266 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -51,8 +51,7 @@ class KernelRuntime { virtual ~KernelRuntime(); virtual bool Init() = 0; virtual void AssignMemory(session::KernelGraph *graph); - void RunOpAssignMemory(const ValuePtr &pre_output_value, const std::vector &input_tensors, - session::KernelGraph *graph); + void RunOpAssignMemory(const std::vector &input_tensors, session::KernelGraph *graph); void RunOpClearMemory(const session::KernelGraph *graph); static bool DumpDataEnabled(); static bool DumpDataEnabledIteration(); @@ -67,6 +66,8 @@ class KernelRuntime { const AddressPtrList &kernel_workspaces) const; virtual void AssignStaticMemoryInput(const session::KernelGraph *graph); virtual void AssignStaticMemoryValueNode(session::KernelGraph *graph); + virtual void SyncValueNodeDeviceAddr(session::KernelGraph *graph); + virtual void CleanValueNodeDeviceAddr(session::KernelGraph *graph); virtual void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector &inputs, const std::unordered_set &value_nodes, const std::vector &execution_order); diff --git a/mindspore/ccsrc/vm/backend.cc b/mindspore/ccsrc/vm/backend.cc index 8ec9e6d8791..03647727ce8 100644 --- a/mindspore/ccsrc/vm/backend.cc +++ b/mindspore/ccsrc/vm/backend.cc @@ -18,13 +18,13 @@ #include #include -#include "utils/log_adapter.h" +#include "backend/session/session_factory.h" #include "ir/anf.h" +#include "pybind_api/ir/base_ref_py.h" #include "utils/callbacks.h" #include "utils/convert_utils.h" -#include "backend/session/session_factory.h" +#include "utils/log_adapter.h" #include "utils/ms_utils.h" -#include "pybind_api/ir/base_ref_py.h" #ifdef ENABLE_GE #include "utils/callbacks_ge.h" #endif @@ -83,10 +83,14 @@ LinConvertResult MsBackend::MsConvert(const GraphSegmentPtr &segment, const std: MS_LOG(INFO) << "PrecompileOnly, stop run graph"; return result; } - if (target != target_device_ && !target.empty()) { - other_sess_->BuildGraph(graph_id); - } else if (!is_multi_graph_sink_) { - target_sess_->BuildGraph(graph_id); + auto ms_context = MsContext::GetInstance(); + const bool pynative_mode = (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode); + if (!pynative_mode || target != "Ascend") { + if (target != target_device_ && !target.empty()) { + other_sess_->BuildGraph(graph_id); + } else if (!is_multi_graph_sink_) { + target_sess_->BuildGraph(graph_id); + } } result.run = std::make_shared( [graph_id, target, this](const VectorRef &args) -> VectorRef { return MsRunGraph(graph_id, args, target); }); @@ -154,12 +158,19 @@ VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const s PushInputTensor(arg, &inputs); } + auto ms_context = MsContext::GetInstance(); + const bool pynative_mode = (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode); + VectorRef outputs; // call ms rungraph (graphId, input ,output) if (target != target_device_ && !target.empty()) { other_sess_->RunGraphAsync(g, inputs, &outputs); } else { - target_sess_->RunGraphAsync(g, inputs, &outputs); + if (pynative_mode && target == "Ascend") { + target_sess_->RunOpsInGraph(g, inputs, &outputs); + } else { + target_sess_->RunGraphAsync(g, inputs, &outputs); + } } MS_LOG(DEBUG) << "RunGraph finished:" << outputs.size(); diff --git a/tests/st/pynative/test_pynative_hook.py b/tests/st/pynative/test_pynative_hook.py index e5cc6240144..11ead25f903 100644 --- a/tests/st/pynative/test_pynative_hook.py +++ b/tests/st/pynative/test_pynative_hook.py @@ -134,7 +134,6 @@ class MulAdd(nn.Cell): assert dout.asnumpy() == 1.0 return dout, y - class Ms_Cell(nn.Cell): def __init__(self): super(Ms_Cell, self).__init__() @@ -143,6 +142,19 @@ class Ms_Cell(nn.Cell): def construct(self, x): return self.relu(x) + def bprop(self, x, out, dout): + dout = Tensor(np.float32(0.0)) + assert dout.shape == () + return dout + +class Ms_Cell_Change_Shape(nn.Cell): + def __init__(self): + super(Ms_Cell_Change_Shape, self).__init__() + self.relu = P.ReLU() + + def construct(self, x): + return self.relu(x) + def bprop(self, x, out, dout): dout = Tensor(np.ones([5, 5]).astype(np.float32)) assert dout.shape == (5, 5) @@ -186,6 +198,19 @@ def test_pynative_custom_bprop_and_Cell_MulAdd(): (Tensor(1.0, mstype.float32), Tensor(2.0, mstype.float32)) +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_pynative_custom_bprop_and_Cell_Ms_Cell_Change_Shape(): + custom_cell = test_custom_cell_base() + ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell_Change_Shape()) + ms_Cell.bprop_debug = True + with pytest.raises(RuntimeError) as ex: + grad_all(ms_Cell)(Tensor(1, mstype.float32)) + assert "Shapes of input and parameter are different, input index" in str(ex.value) + + @pytest.mark.level0 @pytest.mark.platform_arm_ascend_training @pytest.mark.platform_x86_ascend_training @@ -194,5 +219,5 @@ def test_pynative_custom_bprop_and_Cell_Ms_Cell(): custom_cell = test_custom_cell_base() ms_Cell = custom_cell.test_custom_cell_function(Ms_Cell()) ms_Cell.bprop_debug = True - assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(1.0, mstype.float32),) + assert grad_all(ms_Cell)(Tensor(1, mstype.float32)) == (Tensor(0.0, mstype.float32),) \ No newline at end of file diff --git a/tests/ut/cpp/pynative/pynative_execute_test.cc b/tests/ut/cpp/pynative/pynative_execute_test.cc index e14935ab988..106de71a3d0 100644 --- a/tests/ut/cpp/pynative/pynative_execute_test.cc +++ b/tests/ut/cpp/pynative/pynative_execute_test.cc @@ -65,7 +65,7 @@ OpExecInfoPtr ConstructOpExecInfo() { py::none py_none; py::args args = py::make_tuple(conv_obj, op_name, op_inputs); py::list args_input = args[PY_INPUTS]; - return GenerateOpExecInfo(args); + return PynativeExecutor::GetInstance()->GenerateOpExecInfo(args); } TEST_F(TestPynativeExecute, TestCreateContext) {