From dc8ec9d87fc499a3b5ace00aaa48e6c56ea9587c Mon Sep 17 00:00:00 2001 From: He Wei Date: Fri, 19 Mar 2021 11:06:31 +0800 Subject: [PATCH] [auto-monad] Prepare to support recursive calls 1. Find and mark recursive calls and graphs; 2. Do not eliminate argument Assign for recursive graphs; 3. Disable tail call optimization for recursive graphs; 4. Add attribute to output parameter Assign to disable stack push. --- .../backend/session/ascend_auto_monad.cc | 137 +++++++++++++++--- 1 file changed, 119 insertions(+), 18 deletions(-) diff --git a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc index e09c3d7bda8..f46e2cbe077 100644 --- a/mindspore/ccsrc/backend/session/ascend_auto_monad.cc +++ b/mindspore/ccsrc/backend/session/ascend_auto_monad.cc @@ -44,6 +44,13 @@ constexpr uint32_t kNoLabel = 0; // Primitive attribute for argument link assign. const char LINK[] = "link"; +// Attribute to indicate that the node should not be eliminated. +// Used to keep argument Assign nodes for recursive graphs. +const char KEEP[] = "keep"; + +// Attribute to indicate that this is an assign for output. +const char OUTPUT[] = "output"; + bool IsSaveGraph() { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -152,6 +159,9 @@ struct CallSite { // Label param to index map. std::map label_indexes; + // True if this is a recursive call. + bool recursive = false; + // True if this is a tail call. bool tail = false; }; @@ -161,9 +171,18 @@ struct ReturnPoint { }; struct CallInfo { + // Call sites in current graph. std::vector call_sites; + + // Return points of current graph. std::vector return_points; + + // Parameter to store label index, if there are + // multi return points, this should be set. AnfNodePtr label_param = nullptr; + + // True if current graph is involved with recursive calls. + bool recursive = false; }; // @@ -296,6 +315,8 @@ class CallInfoFinder { void Run() { FindCallSites(); + FindRecursiveCalls(); + DisableTailCalls(); FindCallReturns(); } @@ -340,11 +361,30 @@ class CallInfoFinder { } } + // Find recursive non-tail calls. + void FindRecursiveCalls() { + for (auto &[caller, call_info] : context_.call_info_map) { + for (auto &call_site : call_info.call_sites) { + if (!call_site.tail) { + SearchRecursiveCall(caller, &call_site); + } + } + } + } + + // Disable tail call optimization for recursive call graphs. + void DisableTailCalls() { + for (auto &entry : context_.call_info_map) { + auto &call_info = entry.second; + if (call_info.recursive && !call_info.call_sites.empty()) { + call_info.call_sites.back().tail = false; + } + } + } + // Find call-return pairs. void FindCallReturns() { - for (auto &entry : context_.call_info_map) { - auto &caller = entry.first; - auto &call_info = entry.second; + for (auto &[caller, call_info] : context_.call_info_map) { for (auto &call_site : call_info.call_sites) { for (auto &callee : call_site.callees) { MakeGraphLabel(callee.graph); @@ -396,6 +436,54 @@ class CallInfoFinder { } } + struct SearchRecursiveContext { + const KernelGraphPtr &start_caller; + CallSite *start_site; + std::set visited; + std::vector call_path; + }; + + // Search recursive call from a call-site. + void SearchRecursiveCall(const KernelGraphPtr &start_caller, CallSite *start_site) { + SearchRecursiveContext context{.start_caller = start_caller, .start_site = start_site}; + DoSearchRecursiveCall(start_caller, start_site, &context); + } + + void DoSearchRecursiveCall(const KernelGraphPtr &graph, CallSite *call_site, SearchRecursiveContext *ctx) { + // Record call path. + ctx->call_path.push_back(graph); + // Handle callee graphs. + for (auto &callee : call_site->callees) { + auto &sub_graph = callee.graph; + if (sub_graph == ctx->start_caller) { + // Find a recursive call path. + for (auto &g : ctx->call_path) { + // Mark recursive for all graphs in call path. + context_.call_info_map[g].recursive = true; + } + // Mark recursive for the start call-site. + ctx->start_site->recursive = true; + continue; + } + if (ctx->visited.find(sub_graph) != ctx->visited.end()) { + // Skip visited graphs. + continue; + } + // Mark visited. + ctx->visited.emplace(sub_graph); + // Check call sites in the sub-graph. + auto &call_info = context_.call_info_map[sub_graph]; + auto &sites = call_info.call_sites; + for (auto &site : sites) { + if (!site.callees.empty()) { + DoSearchRecursiveCall(sub_graph, &site, ctx); + } + } + } + // Don't forget this. + ctx->call_path.pop_back(); + } + // Handle a call-return relation. void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) { // Create a label for the return point. @@ -590,7 +678,7 @@ class AscendAutoMonadConverter { // For multi-return call, assign result from temp parameter to // output parameter, this prevent result be overwritten by next call. auto tmp_param = context_.GetTempParameter(output->abstract()); - output = AssignAll(output, tmp_param); + output = AssignAll(output, tmp_param, false, false, true); monad_ = UpdateState(GetMonad(), output); } // Replace the the call/switch node with the output. @@ -611,7 +699,7 @@ class AscendAutoMonadConverter { void AssignLabelIndexes(const CallSite &call_site) { for (auto &[label_param, label_index] : call_site.label_indexes) { auto index_value = GetIndexValueNode(label_index); - auto assign = Assign(label_param, index_value); + auto assign = Assign(label_param, index_value, false, false, false); monad_ = UpdateState(GetMonad(), assign); } } @@ -708,7 +796,7 @@ class AscendAutoMonadConverter { AnfNodePtr out_param = (is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract())); MS_EXCEPTION_IF_NULL(out_param); - auto assign_output = AssignAll(out_param, kernel_graph_->output()); + auto assign_output = AssignAll(out_param, kernel_graph_->output(), false, false, true); monad_ = UpdateState(GetMonad(), assign_output); } @@ -739,6 +827,8 @@ class AscendAutoMonadConverter { if (args.empty()) { return nullptr; } + // We do not eliminate argument Assign for recursive graphs. + const bool keep = IsRecursive(graph); // Single argument. if (args.size() == 1) { auto &value = args.front(); @@ -746,7 +836,7 @@ class AscendAutoMonadConverter { // No assign for single monad argument, return it. return value; } - return AssignAll(paras.front(), value, true); + return AssignAll(paras.front(), value, true, keep, false); } // Multi arguments. AnfNodePtrList tuple_inputs; @@ -764,11 +854,14 @@ class AscendAutoMonadConverter { if (target == value) { continue; } - tuple_inputs.emplace_back(AssignAll(target, value, true)); + tuple_inputs.emplace_back(AssignAll(target, value, true, keep, false)); } return kernel_graph_->NewCNode(tuple_inputs); } + // Return true if the graph is involved with recursive calls. + bool IsRecursive(const KernelGraphPtr &kg) { return context_.call_info_map[kg].recursive; } + // For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode. AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared(prim->name())); } @@ -780,13 +873,21 @@ class AscendAutoMonadConverter { } // Make a assign cnode. - CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { - auto monad = (is_link ? GetLinkMonad() : GetMonad()); + CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) { + auto monad = (link ? GetLinkMonad() : GetMonad()); auto assign_prim = std::make_shared(prim::kPrimAssign->name()); - if (is_link) { + if (link) { // Mark this assign is to link real argument to formal argument. assign_prim->set_attr(LINK, prim::kValueOne); } + if (keep) { + // Mark this assign should not be eliminated. + assign_prim->set_attr(KEEP, prim::kValueOne); + } + if (output) { + // Mark this assign is used for output parameter. + assign_prim->set_attr(OUTPUT, prim::kValueOne); + } auto assign = NewValueNode(assign_prim); auto cnode = kernel_graph_->NewCNode({assign, target, source, monad}); cnode->set_abstract(target->abstract()); @@ -794,10 +895,10 @@ class AscendAutoMonadConverter { } // AissgnAll support tuple to tuple assign. - AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) { + AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) { if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) { // Assign single value. - return Assign(target, source, is_link); + return Assign(target, source, link, keep, output); } // Assign tuple. std::vector targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem}); @@ -809,7 +910,7 @@ class AscendAutoMonadConverter { tuple_inputs.reserve(targets.size() + 1); tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple)); for (size_t i = 0; i < targets.size(); ++i) { - tuple_inputs.emplace_back(Assign(targets[i], sources[i], is_link)); + tuple_inputs.emplace_back(Assign(targets[i], sources[i], link, keep, output)); } return kernel_graph_->NewCNode(tuple_inputs); } @@ -1079,7 +1180,7 @@ class ExecuteOrderGenerator { auto &node = *iter; // We only try to erase argument link assign nodes, // other assign nodes are skipped. - if (IsLinkAssign(node)) { + if (IsOptimizableAssign(node)) { auto &target = node->inputs().at(kAssignTargetIndex); MS_EXCEPTION_IF_NULL(target); auto para = param_write_times.find(target); @@ -1174,8 +1275,8 @@ class ExecuteOrderGenerator { return param_write_times; } - // Check if a node is an assign for argument link. - bool IsLinkAssign(const AnfNodePtr &node) { + // Check if a node is an assign for argument link and can be optimized. + bool IsOptimizableAssign(const AnfNodePtr &node) { auto cnode = dyn_cast(node); if (cnode == nullptr) { return false; @@ -1184,7 +1285,7 @@ class ExecuteOrderGenerator { if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) { return false; } - return prim->GetAttr(LINK) == prim::kValueOne; + return (prim->GetAttr(LINK) == prim::kValueOne) && (prim->GetAttr(KEEP) != prim::kValueOne); } // Erase LabelGoto and LabelSet