diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index 1d2630f28b0..a96aa30aae1 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -118,62 +118,6 @@ void DumpExecuteOrder(NotNull kg) { fout.close(); } -// -// ParameterPool cache parameters by its abstract, so that we can reuse -// parameter with same abstract to store return values. -// -class ParameterPool { - public: - explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {} - ~ParameterPool() = default; - - // Create or get a parameter from pool with the given abstract. - AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) { - // Find parameter in pool by the given abstract. - auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) { - auto para_abs = para->abstract(); - // Reuse output parameter with compatible abstract. - return IsCompatible(abs, para_abs); - }); - // Return the parameter if found. - if (iter != paras_.end()) { - return *iter; - } - // If parameter not found with the given abstract, create a new one. - auto para = top_graph_->NewParameter(abs); - auto out_para = top_graph_->TransTupleToMakeTuple(para); - // This is required, so that device memory can be allocated for it. - top_graph_->AddChildGraphResult(out_para); - // Save new para to pool. - paras_.push_back(out_para); - return out_para; - } - - protected: - // Check if one abstract is compatible with another abstract. - static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) { - if (a1 == nullptr || a2 == nullptr) { - return false; - } - if (a1->isa() && a2->isa()) { - // This make AbstractRef compatible with AbstractTensor. - auto &t1 = static_cast(*a1); - auto &t2 = static_cast(*a2); - return t1 == t2; - } - return *a1 == *a2; - } - - private: - // The top graph. - const KernelGraphPtr &top_graph_; - - // Cached parameters. - std::vector paras_; -}; - -using ParameterPoolPtr = std::shared_ptr; - class BaseContext { public: void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); } @@ -200,13 +144,38 @@ class AscendAutoMonadContext : public BaseContext { // Current label id, also the number of label ids we currently used. uint32_t CurrentLabel() const { return label_id_; } - // Create a new parameter pool. - ParameterPoolPtr NewParameterPool() { return std::make_shared(top_graph_); } + // Create or get a parameter for output of the kernel graph. + AnfNodePtr GetOutputParameter(const KernelGraphPtr &kg) { + // Find output parameter by kernel graph. + auto iter = kg_out_param_.find(kg); + if (iter != kg_out_param_.end()) { + // Return output parameter if found. + return iter->second; + } + // Create a new one if not found. + // Output parameters are all created on top graph. + auto para = top_graph_->NewParameter(kg->output()->abstract()); + auto out_para = top_graph_->TransTupleToMakeTuple(para); + // This is required, so that device memory can be allocated for it. + top_graph_->AddChildGraphResult(out_para); + // Save new para as the output parameter of the kg. + kg_out_param_.emplace(kg, out_para); + return out_para; + } + + // Set output parameter for a kernel graph. + void SetOutputParameter(const KernelGraphPtr &kg, const AnfNodePtr &out_para) { + // Save new para as the output parameter of the kg. + kg_out_param_.emplace(kg, out_para); + } private: // The top graph. const KernelGraphPtr &top_graph_; + // Map kernel_graph to its output parameter. + std::unordered_map kg_out_param_; + // Current label id. uint32_t label_id_ = 1; }; @@ -254,6 +223,7 @@ class AscendAutoMonadConverter { // Prepare information for control flow processing. // void Prepare() { + recursive_ = kernel_graph_->has_flag(kFuncGraphFlagRecursive); AnfNodePtr last_monad = nullptr; auto nodes = TopoSort(kernel_graph_->output()); for (auto &node : nodes) { @@ -291,26 +261,25 @@ class AscendAutoMonadConverter { for (auto &cnode : call_switch_nodes_) { if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) { HandleCall(cnode); - } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) { + } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) || + AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { HandleSwitch(cnode); - } else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) { - HandleSwitchLayer(cnode); } else { MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString(); } } // If no tail call, assign output value to output parameter, // and then goto the return label if set. - if (tail_call_node_ == nullptr) { + if (tail_call_node_ == nullptr || recursive_) { if (output_parameter_) { auto assign_output = AssignAll(output_parameter_, kernel_graph_->output()); monad_ = UpdateState(GetMonad(), assign_output); } if (return_label_ != kNoLabel) { - (void)LabelGoto(return_label_); - } else { - // Clear end goto if return label not set. - kernel_graph_->set_end_goto(nullptr); + // Insert label_goto for return. + auto return_goto = LabelGoto(return_label_); + AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto); + kernel_graph_->set_end_goto(return_goto); } } } @@ -348,33 +317,37 @@ class AscendAutoMonadConverter { // as 'select kernel' can handle sub graphs. SetChildGrapAttr(goto_node, {graph}); - // Setup return label if this is not a tail call. + // Setup return label if this is not a tail call or it is a recursive call. const bool is_tail_call = (cnode == tail_call_node_); - const bool need_return = !is_tail_call; - auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return); + const bool need_return = (!is_tail_call || recursive_); + if (!need_return) { + // Set as end_goto if no return required. + kernel_graph_->set_end_goto(goto_node); + } + auto [output_para, return_label] = MakeReturn(cnode, {graph}, need_return); // Handle sub-graph recursively. - HandleSubGraph(graph, para_pool, output_para, return_label); + HandleSubGraph(graph, output_para, return_label); } // - // Convert switch node: + // Convert switch/switchlayer node: // branch1 = Partial(graph1, arg) // branch2 = Partial(graph2, arg) - // out = Switch(cond, branch1, branch2) + // out = Switch/SwitchLayer(cond/index, branch1, branch2) // to: // r = link_args(graph1, arg) // c = UpdateState(c, r) // r = link_args(graph2, arg) // c = UpdateState(c, r) - // c = LabelSwitch(cond, c) : L1, L2 + // c = LabelSwitch(cond/index, c) : L1, L2 // c = LabelSet(c) : // void HandleSwitch(const CNodePtr &cnode) { // Update last_monad_. last_monad_ = monad_map_[cnode]; - // Get both branches of the switch, true branch first. + // Get branches of the switch or switchlayer, true or 0 branch first. auto branches = GetSwitchBranches(cnode); // Link arguments and generate labels for branches. @@ -394,63 +367,12 @@ class AscendAutoMonadConverter { labels.push_back(GetOrCreateGraphLabel(graph)); } - // Since true/false branches is reversed in kernel LabelSwitch, - // We reverse graphes and labels to make false branch first. - std::reverse(graphes.begin(), graphes.end()); - std::reverse(labels.begin(), labels.end()); - - // Add LabelSwith node. - auto switch_node = LabelSwitch(cnode->input(1), labels); - - // Set child graph attribute for switch node. - SetChildGrapAttr(switch_node, graphes); - - // Setup return label if required. - const bool is_tail_call = (cnode == tail_call_node_); - const bool need_return = (return_label_ == kNoLabel || !is_tail_call); - auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return); - - // Handle sub-graphs recursively. - for (auto &graph : graphes) { - HandleSubGraph(graph, para_pool, output_para, return_label); - } - } - - // - // Convert switch node: - // branch1 = Partial(graph1, arg) - // branch2 = Partial(graph2, arg) - // out = SwitchLayer(index, branch1, branch2) - // to: - // r = link_args(graph1, arg) - // c = UpdateState(c, r) - // r = link_args(graph2, arg) - // c = UpdateState(c, r) - // c = LabelSwitch(index, c) : L1, L2 - // c = LabelSet(c) : - // - void HandleSwitchLayer(const CNodePtr &cnode) { - // Update last_monad_. - last_monad_ = monad_map_[cnode]; - - // Get both branches of the switch, true branch first. - auto branches = GetSwitchBranches(cnode); - - // Link arguments and generate labels for branches. - std::vector graphes; - std::vector labels; - graphes.reserve(branches.size()); - labels.reserve(graphes.size()); - for (auto &[graph, args] : branches) { - if (graph == nullptr) { - MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString(); - } - auto linked_args = LinkArguments(args, graph); - if (linked_args != nullptr) { - monad_ = UpdateState(GetMonad(), linked_args); - } - graphes.push_back(graph); - labels.push_back(GetOrCreateGraphLabel(graph)); + const bool is_switch = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch); + if (is_switch) { + // For Switch, we reverse the graphes and labels, so that the false branch + // is the first one, since for kernel LabelSwitch, false is the first branch. + std::reverse(graphes.begin(), graphes.end()); + std::reverse(labels.begin(), labels.end()); } // Add LabelSwith node. @@ -459,41 +381,42 @@ class AscendAutoMonadConverter { // Set child graph attribute for switch node. SetChildGrapAttr(switch_node, graphes); + if (!is_switch) { + // Mark the switch node is for 'switch_layer'. + AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, switch_node); + } + // Setup return label if required. const bool is_tail_call = (cnode == tail_call_node_); - const bool need_return = (return_label_ == kNoLabel || !is_tail_call); - auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return); + const bool need_return = (return_label_ == kNoLabel || !is_tail_call || recursive_); + auto [output_para, return_label] = MakeReturn(cnode, graphes, need_return); // Handle sub-graphs recursively. for (auto &graph : graphes) { - HandleSubGraph(graph, para_pool, output_para, return_label); + HandleSubGraph(graph, output_para, return_label); } } - ParameterPoolPtr GetParameterPool(bool is_last_call) { - if (!is_last_call) { - // There are multiple calls in this graph, use a new parameter pool - // for each of them except the last one. - return context_.NewParameterPool(); + AnfNodePtr GetOutputParameter(const CNodePtr &cnode, const std::vector &branches) { + const bool is_tail_call = (cnode == tail_call_node_); + if (is_tail_call && output_parameter_ != nullptr) { + return output_parameter_; } - // For last call, try reuse parameter pool from the caller. - if (para_pool_ == nullptr) { - para_pool_ = context_.NewParameterPool(); - } - return para_pool_; + return context_.GetOutputParameter(branches.front()); } // Make return part of a call for the LabelGoto/LabelSwitch node. - std::tuple MakeReturn(const CNodePtr &cnode, bool need_return) { - // Find a parameter pool for output parameter. - const bool is_last_call = (cnode == call_switch_nodes_.back()); - auto para_pool = GetParameterPool(is_last_call); - - // Prepare return label and output parameter. + std::tuple MakeReturn(const CNodePtr &cnode, const std::vector &branches, + bool need_return) { + // Prepare return label. uint32_t return_label = return_label_; - auto output_para = para_pool->GetParameter(cnode->abstract()); + // Prepare output parameter. + auto output_para = GetOutputParameter(cnode, branches); + // Use same output parameter for all branches. + for (auto &branch : branches) { + context_.SetOutputParameter(branch, output_para); + } auto output = output_para; - // Setup return label if return is required. if (need_return) { // Set a new label at return point. @@ -504,16 +427,14 @@ class AscendAutoMonadConverter { output = MakeDepend(output, label_node); } - // Replace the the switch node with the output. + // Replace the the call/switch node with the output. kernel_graph_->ReplaceNode(NOT_NULL(cnode), NOT_NULL(output)); - return {para_pool, output_para, return_label}; + return {output_para, return_label}; } // Handle sub-graphs recursively. - void HandleSubGraph(const KernelGraphPtr &graph, const ParameterPoolPtr ¶_pool, const AnfNodePtr &out_para, - uint32_t return_label) { + void HandleSubGraph(const KernelGraphPtr &graph, const AnfNodePtr &out_para, uint32_t return_label) { AscendAutoMonadConverter converter(&context_, graph); - converter.para_pool_ = para_pool; converter.output_parameter_ = out_para; converter.return_label_ = return_label; converter.Run(); @@ -717,7 +638,6 @@ class AscendAutoMonadConverter { auto cnode = kernel_graph_->NewCNode({label_goto, monad}); AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode); cnode->set_abstract(monad->abstract()); - kernel_graph_->set_end_goto(cnode); // make 'goto' the last one in execute order. monad_ = cnode; return cnode; } @@ -794,11 +714,11 @@ class AscendAutoMonadConverter { // Parameter to store the return value. AnfNodePtr output_parameter_; - // Parameter pool for output parameter allocation. - ParameterPoolPtr para_pool_; - // The return label id. uint32_t return_label_ = kNoLabel; + + // Is this graph include recursive calls. + bool recursive_ = false; }; constexpr size_t kAssignTargetIndex = 1; @@ -851,20 +771,22 @@ class ExecuteOrderGenerator { std::vector execution_order; const auto &cnodes = graph_->execution_order(); - for (auto cnode : cnodes) { + for (auto &cnode : cnodes) { // Push current node to execution order list. execution_order.push_back(cnode); // For cnode with sub-graphs, such as LabelSwitch, LabelGoto, // Generate execute order for these sub-graphs, // and then append them to current execution order list. if (HasSubGraphs(cnode)) { - // We use reversed order to generate sub-graph's execution order, - // because the true branch of LabelSwitch is the second one, but - // we want to make true branch ahead of false branch in the generated - // execution order. auto sub_graphs = GetSubGraphs(cnode); - for (auto iter = sub_graphs.rbegin(); iter != sub_graphs.rend(); iter++) { - auto &sub_graph = *iter; + if (!AnfAlgo::HasNodeAttr(kAttrSwitchLayer, cnode)) { + // For Switch, we use reversed order to generate sub-graph's execution order, + // because the true branch of LabelSwitch is the second one, but + // we want to make true branch ahead of false branch in the generated + // execution order. + std::reverse(sub_graphs.begin(), sub_graphs.end()); + } + for (auto &sub_graph : sub_graphs) { if (context_.IsVisited(sub_graph)) { // Skip visited sub-graphs. continue; diff --git a/mindspore/ccsrc/utils/utils.h b/mindspore/ccsrc/utils/utils.h index 37804d3848f..c95e24e8584 100644 --- a/mindspore/ccsrc/utils/utils.h +++ b/mindspore/ccsrc/utils/utils.h @@ -398,6 +398,8 @@ constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute"; constexpr auto kAttrParallelDimInfo = "parallel_dim_info"; constexpr auto kAttrStitch = "stitch"; constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first"; +constexpr auto kAttrSwitchLayer = "switch_layer"; +constexpr auto kAttrReturn = "return"; // attr value constexpr auto kValueTargetSwitch = "target_switch"; diff --git a/mindspore/core/ir/func_graph.h b/mindspore/core/ir/func_graph.h index f82618dd297..2bf50017301 100644 --- a/mindspore/core/ir/func_graph.h +++ b/mindspore/core/ir/func_graph.h @@ -86,6 +86,7 @@ const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param"; const char kFuncGraphFlagUndetermined[] = "Undeterminate"; const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry"; const char kFuncGraphFlagReAutoMonad[] = "ReAutoMonad"; +const char kFuncGraphFlagRecursive[] = "Recursive"; namespace abstract { class AbstractKeywordArg; diff --git a/tests/st/control/test_switch_layer.py b/tests/st/control/test_switch_layer.py index c6af14343b7..e62c0584d44 100644 --- a/tests/st/control/test_switch_layer.py +++ b/tests/st/control/test_switch_layer.py @@ -24,11 +24,12 @@ from mindspore.common import dtype as mstype class CaseNet(nn.Cell): def __init__(self): super(CaseNet, self).__init__() - self.conv = nn.Conv2d(1, 3, 3) + self.conv = nn.Conv2d(1, 1, 3) self.relu = nn.ReLU() + self.relu1 = nn.ReLU() self.softmax = nn.Softmax() self.layers1 = (self.relu, self.softmax) - self.layers2 = (self.conv, self.relu) + self.layers2 = (self.conv, self.relu1) def construct(self, x, index1, index2): x = self.layers1[index1](x) @@ -50,7 +51,3 @@ def test_switch_layer(): true_value = relu(data) ret = np.allclose(value.asnumpy(), true_value.asnumpy()) assert ret - - idx3 = Tensor(3, mstype.int32) - with pytest.raises(IndexError): - value = net(data, idx3, idx2)