diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.cc b/mindspore/ccsrc/backend/graph_compiler/backend.cc index 47d9ca2bbca..3f24aa11172 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.cc +++ b/mindspore/ccsrc/backend/graph_compiler/backend.cc @@ -197,9 +197,66 @@ void PushInputTensor(const BaseRef &arg, std::vector *inputs, } } +// Move these function to anonymous namespace +void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) { + MS_EXCEPTION_IF_NULL(flatted_value); + for (auto value_element : value) { + MS_EXCEPTION_IF_NULL(value_element); + if (utils::isa(value_element)) { + (void)flatted_value->emplace_back(value_element); + } else if (utils::isa(value_element)) { + auto value_tuple_element = value_element->cast(); + MS_EXCEPTION_IF_NULL(value_tuple_element); + FlatValueTupleValue(value_tuple_element->value(), flatted_value); + } else { + MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple."; + } + } +} + +void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) { + MS_EXCEPTION_IF_NULL(flatted_value); + if (utils::isa(arg)) { + auto value_sequence = utils::cast(arg); + MS_EXCEPTION_IF_NULL(value_sequence); + auto sequence_value = value_sequence->value(); + for (auto &value : sequence_value) { + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + (void)flatted_value->emplace_back(value); + } else { + FlattenValue(value, flatted_value); + } + } + } else if (utils::isa(arg)) { + auto value_dict = utils::cast(arg); + MS_EXCEPTION_IF_NULL(value_dict); + auto dict_value = value_dict->value(); + for (auto &iter : dict_value) { + auto value = iter.second; + MS_EXCEPTION_IF_NULL(value); + if (value->isa()) { + (void)flatted_value->emplace_back(value); + } else { + FlattenValue(value, flatted_value); + } + } + } else if (utils::isa(arg)) { + auto csr_tensor = utils::cast(arg); + MS_EXCEPTION_IF_NULL(csr_tensor); + for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) { + (void)flatted_value->emplace_back(csr_tensor->GetTensorAt(i)); + } + } else { + MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is " + << arg.ToString(); + } +} + // Insert the front_node related tensor in the input_tensor. void PushTensor(const VectorRef &args, const std::vector ¶meters, const KernelWithIndex &front_node, std::vector *input_tensor) { + MS_EXCEPTION_IF_NULL(input_tensor); const auto &iter = std::find(parameters.begin(), parameters.end(), front_node.first); if (iter == parameters.end()) { (void)((*input_tensor).emplace_back(nullptr)); @@ -209,6 +266,31 @@ void PushTensor(const VectorRef &args, const std::vector ¶meters PushInputTensor(args[position], input_tensor, front_node.second); } +void PushTupleTensor(const VectorRef &args, const std::vector ¶meters, const AnfNodePtr &front_node, + size_t index, std::vector *input_tensor) { + MS_EXCEPTION_IF_NULL(input_tensor); + const auto &iter = std::find(parameters.begin(), parameters.end(), front_node); + const size_t position = iter - parameters.begin(); + // If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph, + // and there is no need to input a tensor. + if (position >= args.size()) { + MS_LOG(INFO) << "Position out of args range, position value is " << position << " and args size is " << args.size() + << "."; + (void)input_tensor->emplace_back(nullptr); + return; + } + ValuePtrList flatted_value_tuple_value; + FlattenValue(args[position], &flatted_value_tuple_value); + if (index >= flatted_value_tuple_value.size()) { + MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index + << " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << "."; + } + auto input = flatted_value_tuple_value[index]; + MS_EXCEPTION_IF_NULL(input); + auto tensor_input = input->cast(); + input_tensor->push_back(tensor_input); +} + void UpdateOutputAbstract(const KernelGraphPtr &kernel_graph, OpRunInfo *op_run_info) { MS_EXCEPTION_IF_NULL(kernel_graph); MS_EXCEPTION_IF_NULL(op_run_info); @@ -334,6 +416,38 @@ bool EnablePyNativeSyncRunning() { MS_EXCEPTION_IF_NULL(ms_context); return ms_context->get_param(MS_CTX_ENABLE_PYNATIVE_SYNCHRONIZE); } + +std::vector> GetRunGraphInputs(const GraphCompilerInfo &graph_compiler_info, + const VectorRef &args) { + const auto &origin_parameters = graph_compiler_info.origin_parameters_order_; + std::vector> input_tensors; + for (const auto &kernel_graph : graph_compiler_info.graphs_) { + std::vector input_tensor; + MS_EXCEPTION_IF_NULL(kernel_graph); + for (const auto &input_node : kernel_graph->input_nodes()) { + auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node); + if (element_pair.first) { + PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &input_tensor); + } else { + const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node); + PushTensor(args, origin_parameters, {front_node, 0}, &input_tensor); + } + } + (void)input_tensors.emplace_back(input_tensor); + } + + // Input tensors of the control node. + std::vector input_tensor; + MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_); + // Get inputs of control node which come from the host actor. + const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters(); + for (const auto ¶meter : control_node_parameters) { + PushTensor(args, origin_parameters, parameter, &input_tensor); + } + (void)input_tensors.emplace_back(input_tensor); + + return input_tensors; +} } // namespace VectorRef MsBackend::MsRunGraph(const GraphId &g, const VectorRef &args, const std::string &target) { @@ -470,7 +584,7 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) { PROF_START(compile_func_graph); auto root_graph = WrapPrimitives(func_graph); MS_EXCEPTION_IF_NULL(root_graph); - root_graph_ = root_graph.get(); + root_graph_ = root_graph; // Register a summary callback function, which is called in the final stages of summary. graph_compiler_->RegisterSummaryCallBackFunc(callbacks::SummarySaveCallback); @@ -504,8 +618,9 @@ const ActorInfo &MindRTBackend::CompileGraphs(const FuncGraphPtr &func_graph) { // Construct the graph compiler info. auto graph_compiler_info = ConstructGraphCompilerInfo(root_graph); MS_EXCEPTION_IF_NULL(graph_compiler_info); - if (real_execution_mode_ == kGraphMode && graph_compiler_info->graphs_.size() != 0) { + if (real_execution_mode_ == kGraphMode && !graph_compiler_info->graphs_.empty()) { // Transform graph to actor DAG, and schedule the actor DAG. + ParseControlNodes(*graph_compiler_info); const auto &actor_set = runtime::GraphScheduler::GetInstance().Transform(*graph_compiler_info); runtime::GraphScheduler::GetInstance().Schedule(actor_set); } @@ -815,84 +930,7 @@ bool IsGraphOutputValueNodeOrParameter(const AnfNodePtr &graph_output, const Vec } } // namespace -void FlatValueTupleValue(const ValuePtrList &value, ValuePtrList *flatted_value) { - for (size_t i = 0; i < value.size(); ++i) { - auto value_element = value[i]; - MS_EXCEPTION_IF_NULL(value_element); - if (utils::isa(value_element)) { - (void)flatted_value->emplace_back(value_element); - } else if (utils::isa(value_element)) { - auto value_tuple_element = value_element->cast(); - MS_EXCEPTION_IF_NULL(value_tuple_element); - FlatValueTupleValue(value_tuple_element->value(), flatted_value); - } else { - MS_LOG(EXCEPTION) << "The value input to FlatValueTupleValue should only contains Tensor and ValueTuple."; - } - } -} - -void FlattenValue(const BaseRef &arg, ValuePtrList *flatted_value) { - if (utils::isa(arg)) { - auto value_sequence = utils::cast(arg); - MS_EXCEPTION_IF_NULL(value_sequence); - auto sequence_value = value_sequence->value(); - for (auto &value : sequence_value) { - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - (void)flatted_value->emplace_back(value); - } else { - FlattenValue(value, flatted_value); - } - } - } else if (utils::isa(arg)) { - auto value_dict = utils::cast(arg); - MS_EXCEPTION_IF_NULL(value_dict); - auto dict_value = value_dict->value(); - for (auto &iter : dict_value) { - auto value = iter.second; - MS_EXCEPTION_IF_NULL(value); - if (value->isa()) { - (void)flatted_value->emplace_back(value); - } else { - FlattenValue(value, flatted_value); - } - } - } else if (utils::isa(arg)) { - auto csr_tensor = utils::cast(arg); - for (size_t i = 0; i < csr_tensor->GetTensorLength(); ++i) { - (void)flatted_value->emplace_back(csr_tensor->GetTensorAt(i)); - } - } else { - MS_LOG(EXCEPTION) << "The value input to flatten should only contains be sequence or dictionary, but it is " - << arg.ToString(); - } -} - -void PushTupleTensor(const VectorRef &args, const std::vector ¶meters, const AnfNodePtr &front_node, - size_t index, std::vector *input_tensor) { - const auto &iter = std::find(parameters.begin(), parameters.end(), front_node); - const size_t position = iter - parameters.begin(); - // If the parameter is not found in the parameters of the root graph, it means that it is the input of the subgraph, - // and there is no need to input a tensor. - if (position >= args.size()) { - MS_LOG(INFO) << "Position out of args range, position value is " << position << " and args size is " << args.size() - << "."; - (void)input_tensor->emplace_back(nullptr); - return; - } - ValuePtrList flatted_value_tuple_value; - FlattenValue(args[position], &flatted_value_tuple_value); - if (index >= flatted_value_tuple_value.size()) { - MS_LOG(EXCEPTION) << "Index out of flatted_value_tuple_value range, index value is " << index - << " and flatted_value_tuple_value size is " << flatted_value_tuple_value.size() << "."; - } - auto input = flatted_value_tuple_value[index]; - MS_EXCEPTION_IF_NULL(input); - auto tensor_input = input->cast(); - input_tensor->push_back(tensor_input); -} - -void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph) { +void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph) { bool need_contruct_output = !(distributed::recovery::RecoveryContext::GetInstance()->enable_recovery() && distributed::recovery::RecoveryContext::GetInstance()->need_reset()); if (need_contruct_output) { @@ -911,8 +949,9 @@ void MindRTBackend::ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *ou } void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, - const std::vector> &inputs, VectorRef *outputs) { + const VectorRef &args, VectorRef *outputs) { WaitTaskFinish(); + auto inputs = GetRunGraphInputs(graph_compiler_info, args); MS_EXCEPTION_IF_NULL(graph_compiler_); auto graphs = graph_compiler_info.graphs_; if (graphs.size() > inputs.size()) { @@ -927,15 +966,22 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom const auto &graph = graphs[i]; MS_EXCEPTION_IF_NULL(graph); graph->set_flag(kFlagPyNativeRunInGraph, true); - // Get real format from input tensor before compile graph - pynative::GraphAdapter::ReplaceBpropGraphParameter(graph, inputs[i]); - graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.front()); + + // The size of control_nodes is at least 1 since there is return node in the graph. + if (control_nodes_.size() == 1) { + MS_LOG(INFO) << "Replace parameter format"; + // Input tensor is null if there is control-flow in graph. + // Need to get tensor after ParseControlNodes. + pynative::GraphAdapter::ReplaceBpropGraphParameter(graph, inputs.at(i)); + } + graph_compiler_->CompileGraphImpl(graph, graph_compiler_info.device_contexts_.at(i)); pynative::GraphAdapter::RemoveUnusedValueNodes(graph); graph->CacheGraphOutputToFrontNodeWithIndex({graph->output()}, graph->front_outputs()); // Clear front outputs after the outputs is cached. graph->set_front_outputs({}); } + ParseControlNodes(graph_compiler_info); actor_set = runtime::GraphScheduler::GetInstance().Transform(graph_compiler_info); MS_EXCEPTION_IF_NULL(actor_set); // Multithreading can cause spikes in memory usage and performance fluctuations @@ -952,9 +998,11 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom pynative::GraphAdapter::UpdateForwardOutputInBpropGraph(graph); } + std::vector> input_tensors = GetRunGraphInputs(graph_compiler_info, args); + // Release GIL and run actor DAG. mindspore::ScopedLongRunning long_running; - runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, inputs); + runtime::GraphScheduler::GetInstance().Run(actor_set, graph_compiler_info.device_contexts_, input_tensors); MS_EXCEPTION_IF_NULL(graph_compiler_); graph_compiler_->Summary(graph_compiler_info.graphs_); @@ -977,12 +1025,15 @@ void MindRTBackend::RunGraphByActors(const ActorInfo &actor_info, const GraphCom MS_LOG(INFO) << "Status record: end run actor: " << actor_info; } -void MindRTBackend::RunGraphBySingleOp(const std::vector &graphs, - const std::vector> &inputs, VectorRef *outputs) { +void MindRTBackend::RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_info, const VectorRef &args, + VectorRef *outputs) { WaitTaskFinish(); - MS_EXCEPTION_IF_NULL(graph_compiler_); auto &op_executor = runtime::OpExecutor::GetInstance(); op_executor.Register([this]() { BatchBuildCallback(); }); + + MS_EXCEPTION_IF_NULL(graph_compiler_); + const auto &graphs = graph_compiler_info.graphs_; + auto inputs = GetRunGraphInputs(graph_compiler_info, args); for (size_t graph_index = 0; graph_index < graphs.size(); ++graph_index) { const auto &graph = graphs[graph_index]; MS_EXCEPTION_IF_NULL(graph); @@ -1043,13 +1094,12 @@ void MindRTBackend::RunGraphBySingleOp(const std::vector &graphs } void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, - const std::vector> &input_tensors, - VectorRef *outputs) { + const VectorRef &args, VectorRef *outputs) { bool contain_cut_graph = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(), [](const KernelGraphPtr &graph) { return graph->has_flag(kFlagsIsCutGraph); }); if (contain_cut_graph) { // Python API will be called in cut_graph, so we cannot release gil here. - RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs); + RunGraphBySingleOp(graph_compiler_info, args, outputs); } else { bool is_dynamic_shape = std::any_of(graph_compiler_info.graphs_.begin(), graph_compiler_info.graphs_.end(), [](const KernelGraphPtr &graph) { return graph->is_dynamic_shape(); }); @@ -1061,10 +1111,12 @@ void MindRTBackend::RunGraphByCondition(const ActorInfo &actor_info, const Graph // TODO(caifubi): Need to support condition: 1. Dynamic shape. 2. AutoParallel. MS_EXCEPTION_IF_NULL(root_graph_); - if (root_graph_->has_flag(kFlagIsDynamicStructure) || is_dynamic_shape || is_parallel) { - RunGraphBySingleOp(graph_compiler_info.graphs_, input_tensors, outputs); + if (root_graph_->has_flag(kFlagIsDynamicStructure) || + // `ms_function + dynamic_shape` is already supported, so there is no need to execute RunGraphBySingleOp. + (is_dynamic_shape && root_graph_->has_flag(kFlagIsPynativeBpropGraph)) || is_parallel) { + RunGraphBySingleOp(graph_compiler_info, args, outputs); } else { - RunGraphByActors(actor_info, graph_compiler_info, input_tensors, outputs); + RunGraphByActors(actor_info, graph_compiler_info, args, outputs); } } MS_LOG(INFO) << "Status record: end run actor: " << actor_info; @@ -1094,46 +1146,18 @@ void MindRTBackend::RunGraph(const ActorInfo &actor_info, const VectorRef &args, } MS_EXCEPTION_IF_NULL(graph_iter->second); const auto &graph_compiler_info = *(graph_iter->second); - const auto &origin_parameters = graph_compiler_info.origin_parameters_order_; - // For pynative and graph mix execution. WaitTaskFinish(); - // Transform args to input tensors. - // Input tensors of the graph. - std::vector> input_tensors; - for (const auto &kernel_graph : graph_compiler_info.graphs_) { - std::vector input_tensor; - MS_EXCEPTION_IF_NULL(kernel_graph); - for (const auto &input_node : kernel_graph->input_nodes()) { - auto element_pair = kernel_graph->GetElementInTupleBackendFrontIndexMap(input_node); - if (element_pair.first) { - PushTupleTensor(args, origin_parameters, element_pair.first, element_pair.second, &input_tensor); - } else { - const auto &front_node = kernel_graph->GetFrontAnfByBackendAnf(input_node); - PushTensor(args, origin_parameters, {front_node, 0}, &input_tensor); - } - } - (void)input_tensors.emplace_back(input_tensor); - } - - // Input tensors of the control node. - std::vector input_tensor; - MS_EXCEPTION_IF_NULL(graph_compiler_info.control_node_parser_); - // Get inputs of control node which come from the host actor. - const auto &control_node_parameters = graph_compiler_info.control_node_parser_->control_node_parameters(); - for (const auto ¶meter : control_node_parameters) { - PushTensor(args, origin_parameters, parameter, &input_tensor); - } - (void)input_tensors.emplace_back(input_tensor); // Run in the pynative mode. MS_EXCEPTION_IF_NULL(outputs); // There will be more than one kernel graph in heterogeneous scenario in a ms function of PyNative Mode. if (real_execution_mode_ == kPynativeMode) { - RunGraphByCondition(actor_info, graph_compiler_info, input_tensors, outputs); + RunGraphByCondition(actor_info, graph_compiler_info, args, outputs); return; } + auto input_tensors = GetRunGraphInputs(graph_compiler_info, args); // Release python gil. mindspore::ScopedLongRunning long_running; // Run actor DAG. @@ -1298,22 +1322,7 @@ std::unique_ptr MindRTBackend::ConstructGraphCompilerInfo(con ++graph_index; } - FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs; - for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) { - const auto &func_graph = func_graph_to_kernel_graph_ids.first; - for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) { - std::vector kernel_graphs; - for (const auto &graph_id : sub_kernel_graphs_ids) { - const auto &kernel_graph = graph_compiler_->Fetch(graph_id); - MS_EXCEPTION_IF_NULL(kernel_graph); - (void)kernel_graphs.emplace_back(kernel_graph); - } - (void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs); - } - } - auto parser = std::make_shared(); - parser->Parse(control_nodes_, graphs, device_contexts, root_graph, func_graph_to_kernel_graphs); runtime::KernelMapPosition outputs_order; const auto &root_output = @@ -1635,5 +1644,25 @@ void MindRTBackend::UpdateOutput(const std::vector &ou outputs->emplace_back(output_tensor); } } + +void MindRTBackend::ParseControlNodes(const GraphCompilerInfo &graph_compile_info) { + FuncGraphToKernelGraphGroup func_graph_to_kernel_graphs; + for (const auto &func_graph_to_kernel_graph_ids : func_graph_to_kernel_graph_ids_) { + const auto &func_graph = func_graph_to_kernel_graph_ids.first; + for (const auto &sub_kernel_graphs_ids : func_graph_to_kernel_graph_ids.second) { + std::vector kernel_graphs; + for (const auto &graph_id : sub_kernel_graphs_ids) { + const auto &kernel_graph = graph_compiler_->Fetch(graph_id); + MS_EXCEPTION_IF_NULL(kernel_graph); + (void)kernel_graphs.emplace_back(kernel_graph); + } + (void)func_graph_to_kernel_graphs[func_graph].emplace_back(kernel_graphs); + } + } + + graph_compile_info.control_node_parser_->Parse(control_nodes_, graph_compile_info.graphs_, + graph_compile_info.device_contexts_, root_graph_, + func_graph_to_kernel_graphs); +} } // namespace compile } // namespace mindspore diff --git a/mindspore/ccsrc/backend/graph_compiler/backend.h b/mindspore/ccsrc/backend/graph_compiler/backend.h index 9b61e2d8180..137b52ab2cf 100644 --- a/mindspore/ccsrc/backend/graph_compiler/backend.h +++ b/mindspore/ccsrc/backend/graph_compiler/backend.h @@ -143,7 +143,7 @@ class BACKEND_EXPORT MindRTBackend : public Backend { // Get saved OpBuildTask in OpExecutor and build all the kernels together in PyNative mode. void CompileSingleOpGraphs(const std::vector> &build_tasks); - void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, FuncGraph *root_graph); + void ConstructOutputs(runtime::ActorSet *actor_set, VectorRef *outputs, const FuncGraphPtr &root_graph); // Restore the outputs tuple by the origin funcGraph output node and output tensors. void ConstructOutputs(const AnfNodePtr &output_node, const std::vector &output_tensors, @@ -160,6 +160,8 @@ class BACKEND_EXPORT MindRTBackend : public Backend { const std::vector *input_tensors, bool need_erase); + void ParseControlNodes(const GraphCompilerInfo &graph_compile_info); + // In PyNative mode, the size of single op cache list will be increasing, which lead to memory cost increasing, // so the latest single op cache should be erased when cache list size exceeds threshold value. void EraseSingleOpCache(const ActorInfo &actor_info, const std::string &graph_info, const KernelGraphPtr &graph); @@ -176,14 +178,13 @@ class BACKEND_EXPORT MindRTBackend : public Backend { OpRunInfo *op_run_info); void RunGraphByCondition(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, - const std::vector> &input_tensors, VectorRef *outputs); + const VectorRef &args, VectorRef *outputs); // Split complete kernel graph to single op graph in PyNative back // propagation, then compile and run single op graph. - void RunGraphBySingleOp(const std::vector &graphs, - const std::vector> &inputs, VectorRef *outputs); + void RunGraphBySingleOp(const GraphCompilerInfo &graph_compiler_info, const VectorRef &args, VectorRef *outputs); void RunGraphByActors(const ActorInfo &actor_info, const GraphCompilerInfo &graph_compiler_info, - const std::vector> &inputs, VectorRef *outputs); + const VectorRef &args, VectorRef *outputs); void UpdateOutput(const std::vector &output_nodes, VectorRef *const outputs); @@ -209,7 +210,7 @@ class BACKEND_EXPORT MindRTBackend : public Backend { // Cache forward op output value node tensor ref count of kernels for back propagation graph in PyNative mode. std::map forward_op_output_tensor_id_; - FuncGraph *root_graph_; + FuncGraphPtr root_graph_; GraphPartitionPtr graph_partition_; std::shared_ptr graph_compiler_; std::string device_name_; diff --git a/mindspore/ccsrc/pipeline/jit/action.cc b/mindspore/ccsrc/pipeline/jit/action.cc index db0f68d4dfd..d9a5c628113 100644 --- a/mindspore/ccsrc/pipeline/jit/action.cc +++ b/mindspore/ccsrc/pipeline/jit/action.cc @@ -955,16 +955,11 @@ void OriginSetRunMode(const ResourcePtr &resource) { } } -void SetRunMode(const ResourcePtr &resource, bool pynative_switch_to_graph_mode) { +void SetRunMode(const ResourcePtr &resource) { MS_EXCEPTION_IF_NULL(resource); auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); if (context_ptr->get_param(MS_CTX_ENABLE_MINDRT) && common::GetEnv("DISABLE_ASCEND_MINDRT") != "1") { - // Run in GRAPH_MODE if the func_graph is ms_function or the func_graph contain multi-subgraph. - if (pynative_switch_to_graph_mode) { - context_ptr->set_param(MS_CTX_EXECUTION_MODE, kGraphMode); - MS_LOG(INFO) << "PyNative graph Compile and Run in GRAPH_MODE"; - } SetRunMode(resource->func_graph(), resource->GetResult(kBackend).cast().get()); } else { OriginSetRunMode(resource); @@ -995,22 +990,14 @@ bool TaskEmitAction(const ResourcePtr &resource) { MS_LOG(WARNING) << "Multi device target is detected, CPU data is dumped in rank_0 directory"; } DisableMindRT(resource); - auto parallel_mode = parallel::ParallelContext::GetInstance()->parallel_mode(); - auto is_parallel = (parallel_mode == parallel::kSemiAutoParallel || parallel_mode == parallel::kAutoParallel); - bool pynative_switch_to_graph_mode = - context_ptr->get_param(MS_CTX_EXECUTION_MODE) == kPynativeMode && - (!func_graph->has_flag(kFlagIsPynativeBpropGraph) || func_graph->manager()->func_graphs().size() > 1) && - !is_parallel; - SetRunMode(resource, pynative_switch_to_graph_mode); + + SetRunMode(resource); auto bc_ptr = resource->GetResult(kBackend).cast(); MS_EXCEPTION_IF_NULL(bc_ptr); std::string backend = context_ptr->backend_policy(); // The graph compiling of mindRT. if ((backend == kMsConvert) && context_ptr->get_param(MS_CTX_ENABLE_MINDRT)) { TaskEmitActionForMindRT(resource); - if (pynative_switch_to_graph_mode) { - context_ptr->set_param(MS_CTX_EXECUTION_MODE, kPynativeMode); - } return true; } diff --git a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc index 3d8f352dd65..d5e780816e1 100644 --- a/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc +++ b/mindspore/ccsrc/plugin/device/ascend/hal/hardware/ascend_device_context.cc @@ -455,7 +455,7 @@ void AscendDeviceContext::PreprocessBeforeRunGraph(const KernelGraphPtr &graph) AllocateGraphMemory(NOT_NULL(graph)); LoadModel(NOT_NULL(graph)); AssignOutputNopNodeDeviceAddress(graph); - } else if (graph->is_dynamic_shape() && IsGraphMode()) { + } else if (graph->is_dynamic_shape() && (IsGraphMode() || graph->has_flag(kFlagPyNativeRunInGraph))) { device::ascend::InsertAtomicCleanOps(graph->execution_order(), &node_atomics_); SetAtomicCleanToNodes(graph, node_atomics_); // graph mode may can do it too, instead of update execorder opt::DynamicShapeConvertPass(graph); diff --git a/mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_device_context.cc b/mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_device_context.cc index f0d904ce10c..9dfcb80ee25 100644 --- a/mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_device_context.cc +++ b/mindspore/ccsrc/plugin/device/cpu/hal/hardware/cpu_device_context.cc @@ -325,7 +325,8 @@ void CPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (kernel_graph->is_dynamic_shape() && ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { + if (kernel_graph->is_dynamic_shape() && (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode || + kernel_graph->has_flag(kFlagPyNativeRunInGraph))) { opt::DynamicShapeConvertPass(kernel_graph); } } diff --git a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc index 3838080677a..8477c62316f 100644 --- a/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc +++ b/mindspore/ccsrc/plugin/device/gpu/hal/hardware/gpu_device_context.cc @@ -225,7 +225,8 @@ void GPUDeviceContext::PreprocessBeforeRun(const FuncGraphPtr &graph) const { if (!kernel_graph->is_from_single_op()) { auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); - if (kernel_graph->is_dynamic_shape() && ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode) { + if (kernel_graph->is_dynamic_shape() && (ms_context->get_param(MS_CTX_EXECUTION_MODE) == kGraphMode || + kernel_graph->has_flag(kFlagPyNativeRunInGraph))) { opt::DynamicShapeConvertPass(kernel_graph); } } diff --git a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc index e0b5d78997c..996565c65c0 100644 --- a/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/graph_scheduler/graph_compiler.cc @@ -562,10 +562,14 @@ GraphId GraphCompiler::CompileGraph(const GraphSegmentPtr &segment, const AnfNod session_->DumpGraphs({graph}); - // Cache the backend graph output nodes to front nodes with output index. - auto backend_node = graph->output(); - MS_EXCEPTION_IF_NULL(backend_node); - graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs); + // The graph is not compiled yet in PyNative Mode. + // Need to cache output latter when the graph is compiled. + if (!run_in_pynative) { + // Cache the backend graph output nodes to front nodes with output index. + auto backend_node = graph->output(); + MS_EXCEPTION_IF_NULL(backend_node); + graph->CacheGraphOutputToFrontNodeWithIndex({backend_node}, outputs); + } auto ms_context = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(ms_context); std::string device_target = ms_context->get_param(MS_CTX_DEVICE_TARGET); diff --git a/tests/st/ops/ascend/test_tbe_ops/test_adaptive_max_pool2d.py b/tests/st/ops/ascend/test_tbe_ops/test_adaptive_max_pool2d.py index 9ce66c37b50..5d92447ebe8 100644 --- a/tests/st/ops/ascend/test_tbe_ops/test_adaptive_max_pool2d.py +++ b/tests/st/ops/ascend/test_tbe_ops/test_adaptive_max_pool2d.py @@ -44,13 +44,12 @@ def test_adaptive_max_pool2d(): x = np.random.randn(32, 64, 128, 128).astype(np.float32) context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") adaptive_max_pool2d = Net(64, True) - output1, argmax1 = adaptive_max_pool2d(Tensor(x)) + output1, _ = adaptive_max_pool2d(Tensor(x)) context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") adaptive_max_pool2d = Net(64, True) - output2, argmax2 = adaptive_max_pool2d(Tensor(x)) + output2, _ = adaptive_max_pool2d(Tensor(x)) assert (output1.asnumpy() == output2.asnumpy()).all() - assert (argmax1.asnumpy() == argmax2.asnumpy()).all() @pytest.mark.level0 diff --git a/tests/st/pynative/hook/test_pynative_forward_hook.py b/tests/st/pynative/hook/test_pynative_forward_hook.py index 9253adf3dbd..6e9ade5484d 100644 --- a/tests/st/pynative/hook/test_pynative_forward_hook.py +++ b/tests/st/pynative/hook/test_pynative_forward_hook.py @@ -20,6 +20,7 @@ from mindspore import context from mindspore.ops import GradOperation from mindspore.common import ParameterTuple from mindspore.common.api import ms_function +from mindspore import ops as P def forward_pre_hook_fn_bn(cell_id, inp): @@ -49,7 +50,7 @@ def forward_pre_hook_fn_multi_add(cell_id, inp): def forward_hook_fn_conv(cell_id, inp, outp): - out = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid")(outp) + out = P.Log()(outp) return out @@ -101,6 +102,10 @@ class SingleNet(nn.Cell): class SingleNetInConstruct(nn.Cell): def __init__(self): super(SingleNetInConstruct, self).__init__() + self.handle1 = None + self.handle2 = None + self.handle3 = None + self.handle4 = None self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid") self.relu = nn.ReLU() @@ -262,13 +267,14 @@ class CompareMultiNet1(nn.Cell): self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid") self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones") self.relu = nn.ReLU() + self.log = P.Log() def construct(self, x, y): x = self.mul(x + x, x * y) - x = self.conv(x) + x = self.log(x) x = self.bn(x) - x = self.relu(x) - x = self.mul(x, x) + y = self.relu(x) + x = self.mul(y, x) x = x + x return x @@ -280,10 +286,11 @@ class CompareMultiNet2(nn.Cell): self.conv = nn.Conv2d(2, 2, kernel_size=2, stride=1, padding=0, weight_init="ones", pad_mode="valid") self.bn = nn.BatchNorm2d(2, momentum=0.99, eps=0.00001, gamma_init="ones") self.relu = nn.ReLU() + self.log = P.Log() def construct(self, x, y): x = self.mul(x, y) - x = self.conv(x) + x = self.log(x) x = self.bn(x) x = self.mul(x, x) return x