From bd6ee1445f80ec7b1048e7b5aec4924582e6565a Mon Sep 17 00:00:00 2001 From: gaoyong10 Date: Fri, 28 Jan 2022 16:07:55 +0800 Subject: [PATCH] Add monad parameter. --- .../framework/actor/data_prepare_actor.cc | 2 + .../runtime/framework/control_node_parser.cc | 110 +++++++----------- .../runtime/framework/control_node_parser.h | 1 + .../framework/control_node_scheduler.cc | 3 - 4 files changed, 48 insertions(+), 68 deletions(-) diff --git a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc index 6d7b9677f7a..6c1ca5d0912 100644 --- a/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc +++ b/mindspore/ccsrc/runtime/framework/actor/data_prepare_actor.cc @@ -144,6 +144,8 @@ void PrepareDataForValue(const ValuePtr &value, const KernelWithIndex &node_with } else if (value->isa()) { type = kNumberTypeInt32; (reinterpret_cast(host_addr.get()))[0] = GetValue(value); + } else if (value->isa()) { + return; } else { std::string error_info = "Invalid value:" + value->ToString(); SET_OPCONTEXT_FAIL_RET_WITH_ERROR((*context), error_info); diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.cc b/mindspore/ccsrc/runtime/framework/control_node_parser.cc index 62aae58edaa..ee79365c334 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.cc @@ -28,26 +28,12 @@ namespace { // Check if node is a value node need to create a device tensor. bool IsFrontValueNode(const KernelWithIndex &node_with_index) { const auto &node = node_with_index.first; - size_t index = node_with_index.second; MS_EXCEPTION_IF_NULL(node); if (!node->isa() || IsValueNode(node) || IsValueNode(node)) { return false; } - if (!IsValueNode(node)) { - return !HasAbstractMonad(node); - } - - const auto &abstract = node->abstract(); - MS_EXCEPTION_IF_NULL(abstract); - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - const auto &sub_abstracts = tuple_abstract->elements(); - if (sub_abstracts.size() <= index) { - MS_LOG(EXCEPTION) << "Invalid index:" << index << " for tuple value node:" << node->DebugString(); - } - MS_EXCEPTION_IF_NULL(sub_abstracts[index]); - return !sub_abstracts[index]->isa(); + return true; } // Fetch real input node in maketuple. @@ -421,11 +407,22 @@ void FetchAllExecutionFunction(const FuncGraphPtr &func_graph, std::setisa() || node->isa() || AnfAlgo::IsCallNode(node); +} + // Fetch all inputs of node. std::vector FetchInputNodeByNode(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (HasAbstractMonad(node)) { - return {}; + const auto &real_node_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0); + const auto &real_node = real_node_with_index.first; + MS_EXCEPTION_IF_NULL(real_node); + if (isValidMonadNode(real_node)) { + return {real_node_with_index}; + } + MS_LOG(EXCEPTION) << "Invalid monad node:" << real_node->DebugString(); } // The node is divided into the following types: @@ -436,10 +433,11 @@ std::vector FetchInputNodeByNode(const AnfNodePtr &node) { size_t real_index = node_with_index.second; MS_EXCEPTION_IF_NULL(real_node); std::vector results; - // 2. MakeTuple. - if (AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeTuple) || - AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeCSRTensor) || - AnfAlgo::CheckPrimitiveType(real_node, prim::kPrimMakeCOOTensor)) { + + // 2. Tuple node. + const PrimitiveSet expand_prims{prim::kPrimMakeTuple, prim::kPrimMakeCSRTensor, prim::kPrimMakeCOOTensor}; + // The MakeTuple/MakeSparse node need expand and recurse. + if (IsOneOfPrimitiveCNode(real_node, expand_prims)) { const auto &cnode = real_node->cast(); const auto &inputs = cnode->inputs(); for (size_t i = kMakeTupleInputStartPos; i < inputs.size(); ++i) { @@ -518,17 +516,6 @@ std::vector FetchInputNodeByNode(const AnfNodePtr &node) { for (size_t i = 0; i < output_num; ++i) { (void)results.emplace_back(real_node, i); } - if (abstract->isa()) { - auto tuple_abstract = abstract->cast(); - MS_EXCEPTION_IF_NULL(tuple_abstract); - const auto &sub_abstracts = tuple_abstract->elements(); - for (const auto &sub_abstract : sub_abstracts) { - MS_EXCEPTION_IF_NULL(sub_abstract); - if (sub_abstract->isa()) { - (void)results.pop_back(); - } - } - } return results; } @@ -692,10 +679,6 @@ std::vector FetchInputNodeByCNode(const AnfNodePtr &node) { for (size_t i = input_start_pos; i < inputs.size(); ++i) { MS_EXCEPTION_IF_NULL(inputs[i]); - // skip monad node. - if (HasAbstractMonad(inputs[i])) { - continue; - } const auto &sub_results = FetchInputNodeByNode(inputs[i]); (void)results.insert(results.end(), sub_results.begin(), sub_results.end()); } @@ -1288,9 +1271,9 @@ NodeWithContext ControlNodeParser::FetchBackendParameterWithContextByFrontParame if (AnfAlgo::GetOutputTensorMemSize(node_with_context.first, 0) != 0) { return node_with_context; } - MS_LOG(WARNING) << "Backend node:" << node_with_context.first->DebugString() - << " for front node:" << front_parameter_with_index.first->DebugString() - << " index:" << front_parameter_with_index.second << " output size is 0."; + MS_LOG(DEBUG) << "Backend node:" << node_with_context.first->DebugString() + << " for front node:" << front_parameter_with_index.first->DebugString() + << " index:" << front_parameter_with_index.second << " output size is 0."; } return {}; } @@ -1366,9 +1349,6 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector for (int i = SizeToInt(inputs.size()) - 1, j = SizeToInt(parameters.size()) - 1; i >= 1 && j >= 0; --i, --j) { MS_EXCEPTION_IF_NULL(inputs[IntToSize(i)]); MS_EXCEPTION_IF_NULL(parameters[IntToSize(j)]); - if (HasAbstractMonad(inputs[IntToSize(i)])) { - continue; - } AddFormalToRealParameter(parameters[IntToSize(j)], inputs[IntToSize(i)], call_node_to_func_graphs_, &formal_to_real_parameters); } @@ -1400,9 +1380,6 @@ void ControlNodeParser::ParseFormalToRealParameter(const std::vector for (size_t i = kPartialInputStartPos; i < inputs.size(); ++i) { MS_EXCEPTION_IF_NULL(inputs[i]); MS_EXCEPTION_IF_NULL(parameters[i - kPartialInputStartPos]); - if (HasAbstractMonad(inputs[i])) { - continue; - } AddFormalToRealParameter(parameters[i - kPartialInputStartPos], inputs[i], call_node_to_func_graphs_, &formal_to_real_parameters); } @@ -1521,9 +1498,6 @@ void ControlNodeParser::ParseFrontToBackendParameter(const std::vectorinput_nodes()) { - if (HasAbstractMonad(parameter)) { - continue; - } const auto &front_node = graph->GetFrontAnfByBackendAnf(parameter); const auto &front_node_with_index = graph->GetFrontNodeByInternalParameter(parameter); const auto &front_tuple_parameter_with_index = graph->GetElementInTupleBackendFrontIndexMap(parameter); @@ -1798,28 +1772,33 @@ void ControlNodeParser::ParseUnRecursionCallNode() { } } +bool ControlNodeParser::IsCallNodeNeedStack(const AnfNodePtr &node) { + MS_EXCEPTION_IF_NULL(node); + + auto input_with_indexs = FetchInputNodeByCNode(node); + for (const auto &input_with_index : input_with_indexs) { + MS_EXCEPTION_IF_NULL(input_with_index.first); + // If the call node has call or recursion graph input, a stack created for the call node is required. + if (!AnfAlgo::IsCallNode(input_with_index.first)) { + if (!input_with_index.first->isa()) { + continue; + } + const auto &graph = FetchKernelGraphByFrontNode(input_with_index.first); + if (graph == nullptr || (!IsRecursionKernelGraph(graph))) { + continue; + } + } + return true; + } + return false; +} + void ControlNodeParser::ParseNeedStackControlNode(const std::vector &control_nodes) { for (const auto &control_node : control_nodes) { MS_EXCEPTION_IF_NULL(control_node); - if (!AnfAlgo::IsCallNode(control_node)) { - continue; - } - auto input_with_indexs = FetchInputNodeByCNode(control_node); - for (const auto &input_with_index : input_with_indexs) { - MS_EXCEPTION_IF_NULL(input_with_index.first); - // If the call node has call or recursion graph input, a stack created for the call node is required. - if (!AnfAlgo::IsCallNode(input_with_index.first)) { - if (!input_with_index.first->isa()) { - continue; - } - const auto &graph = FetchKernelGraphByFrontNode(input_with_index.first); - if (graph == nullptr || (!IsRecursionKernelGraph(graph))) { - continue; - } - } + if (AnfAlgo::IsCallNode(control_node) && IsCallNodeNeedStack(control_node)) { (void)need_stack_control_nodes_.emplace(control_node); MS_LOG(DEBUG) << "Add need stack control node:" << control_node->DebugString(); - break; } } @@ -1841,7 +1820,8 @@ void ControlNodeParser::ParseNeedStackControlNode(const std::vector MS_LOG(EXCEPTION) << "Invalid return node:" << control_node->DebugString(); } - if (call_input_num != 0 && (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend))) { + if ((!IsInputInSameLevel(control_node)) || + (call_input_num != 0 && (AnfAlgo::CheckPrimitiveType(inputs[kReturnInputPos], prim::kPrimDepend)))) { (void)need_stack_control_nodes_.emplace(control_node); } } else if (AnfAlgo::CheckPrimitiveType(control_node, prim::kPrimPartial) || diff --git a/mindspore/ccsrc/runtime/framework/control_node_parser.h b/mindspore/ccsrc/runtime/framework/control_node_parser.h index ee5eedca4e2..6ae42e89a98 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_parser.h +++ b/mindspore/ccsrc/runtime/framework/control_node_parser.h @@ -227,6 +227,7 @@ class ControlNodeParser { // Get the control nodes and kernel graphs which need to add a stack actor for them. // When a control node or kernel graph has input that is a call node, you need to add a stack actor for it. void ParseNeedStackControlNode(const std::vector &control_nodes); + bool IsCallNodeNeedStack(const AnfNodePtr &node); void ParseNeedStackKernelGraph(const KernelGraphToDeviceContext &kernel_graph_to_device_contexts); // Parse the level of inputs and outputs of graphs and all control nodes. void ParseNodeLevel(const std::vector &control_nodes); diff --git a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc index ab87f3c15d8..9d4d7be2286 100644 --- a/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc +++ b/mindspore/ccsrc/runtime/framework/control_node_scheduler.cc @@ -218,9 +218,6 @@ std::vector ControlNodeScheduler::BuildEntranceActor(const Gra // The entrance actor has two parts of node members : // 1. The formal parameters of the subgraph are used to connect the actor's output arrows. for (const auto ¶meter : func_graph->parameters()) { - if (HasAbstractMonad(parameter)) { - continue; - } const auto &abstract = parameter->abstract(); MS_EXCEPTION_IF_NULL(abstract); size_t output_num = AnfAlgo::GetOutputNumByAbstract(abstract);