forked from mindspore-Ecosystem/mindspore
!13600 [auto-monad] Prepare to support recursive calls
From: @hwhewei Reviewed-by: @ginfung,@zh_qh Signed-off-by: @zh_qh
This commit is contained in:
commit
9359983123
|
@ -44,6 +44,13 @@ constexpr uint32_t kNoLabel = 0;
|
|||
// Primitive attribute for argument link assign.
|
||||
const char LINK[] = "link";
|
||||
|
||||
// Attribute to indicate that the node should not be eliminated.
|
||||
// Used to keep argument Assign nodes for recursive graphs.
|
||||
const char KEEP[] = "keep";
|
||||
|
||||
// Attribute to indicate that this is an assign for output.
|
||||
const char OUTPUT[] = "output";
|
||||
|
||||
bool IsSaveGraph() {
|
||||
auto context_ptr = MsContext::GetInstance();
|
||||
MS_EXCEPTION_IF_NULL(context_ptr);
|
||||
|
@ -152,6 +159,9 @@ struct CallSite {
|
|||
// Label param to index map.
|
||||
std::map<AnfNodePtr, uint32_t> label_indexes;
|
||||
|
||||
// True if this is a recursive call.
|
||||
bool recursive = false;
|
||||
|
||||
// True if this is a tail call.
|
||||
bool tail = false;
|
||||
};
|
||||
|
@ -161,9 +171,18 @@ struct ReturnPoint {
|
|||
};
|
||||
|
||||
struct CallInfo {
|
||||
// Call sites in current graph.
|
||||
std::vector<CallSite> call_sites;
|
||||
|
||||
// Return points of current graph.
|
||||
std::vector<ReturnPoint> return_points;
|
||||
|
||||
// Parameter to store label index, if there are
|
||||
// multi return points, this should be set.
|
||||
AnfNodePtr label_param = nullptr;
|
||||
|
||||
// True if current graph is involved with recursive calls.
|
||||
bool recursive = false;
|
||||
};
|
||||
|
||||
//
|
||||
|
@ -296,6 +315,8 @@ class CallInfoFinder {
|
|||
|
||||
void Run() {
|
||||
FindCallSites();
|
||||
FindRecursiveCalls();
|
||||
DisableTailCalls();
|
||||
FindCallReturns();
|
||||
}
|
||||
|
||||
|
@ -340,11 +361,30 @@ class CallInfoFinder {
|
|||
}
|
||||
}
|
||||
|
||||
// Find recursive non-tail calls.
|
||||
void FindRecursiveCalls() {
|
||||
for (auto &[caller, call_info] : context_.call_info_map) {
|
||||
for (auto &call_site : call_info.call_sites) {
|
||||
if (!call_site.tail) {
|
||||
SearchRecursiveCall(caller, &call_site);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Disable tail call optimization for recursive call graphs.
|
||||
void DisableTailCalls() {
|
||||
for (auto &entry : context_.call_info_map) {
|
||||
auto &call_info = entry.second;
|
||||
if (call_info.recursive && !call_info.call_sites.empty()) {
|
||||
call_info.call_sites.back().tail = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Find call-return pairs.
|
||||
void FindCallReturns() {
|
||||
for (auto &entry : context_.call_info_map) {
|
||||
auto &caller = entry.first;
|
||||
auto &call_info = entry.second;
|
||||
for (auto &[caller, call_info] : context_.call_info_map) {
|
||||
for (auto &call_site : call_info.call_sites) {
|
||||
for (auto &callee : call_site.callees) {
|
||||
MakeGraphLabel(callee.graph);
|
||||
|
@ -396,6 +436,54 @@ class CallInfoFinder {
|
|||
}
|
||||
}
|
||||
|
||||
struct SearchRecursiveContext {
|
||||
const KernelGraphPtr &start_caller;
|
||||
CallSite *start_site;
|
||||
std::set<KernelGraphPtr> visited;
|
||||
std::vector<KernelGraphPtr> call_path;
|
||||
};
|
||||
|
||||
// Search recursive call from a call-site.
|
||||
void SearchRecursiveCall(const KernelGraphPtr &start_caller, CallSite *start_site) {
|
||||
SearchRecursiveContext context{.start_caller = start_caller, .start_site = start_site};
|
||||
DoSearchRecursiveCall(start_caller, start_site, &context);
|
||||
}
|
||||
|
||||
void DoSearchRecursiveCall(const KernelGraphPtr &graph, CallSite *call_site, SearchRecursiveContext *ctx) {
|
||||
// Record call path.
|
||||
ctx->call_path.push_back(graph);
|
||||
// Handle callee graphs.
|
||||
for (auto &callee : call_site->callees) {
|
||||
auto &sub_graph = callee.graph;
|
||||
if (sub_graph == ctx->start_caller) {
|
||||
// Find a recursive call path.
|
||||
for (auto &g : ctx->call_path) {
|
||||
// Mark recursive for all graphs in call path.
|
||||
context_.call_info_map[g].recursive = true;
|
||||
}
|
||||
// Mark recursive for the start call-site.
|
||||
ctx->start_site->recursive = true;
|
||||
continue;
|
||||
}
|
||||
if (ctx->visited.find(sub_graph) != ctx->visited.end()) {
|
||||
// Skip visited graphs.
|
||||
continue;
|
||||
}
|
||||
// Mark visited.
|
||||
ctx->visited.emplace(sub_graph);
|
||||
// Check call sites in the sub-graph.
|
||||
auto &call_info = context_.call_info_map[sub_graph];
|
||||
auto &sites = call_info.call_sites;
|
||||
for (auto &site : sites) {
|
||||
if (!site.callees.empty()) {
|
||||
DoSearchRecursiveCall(sub_graph, &site, ctx);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Don't forget this.
|
||||
ctx->call_path.pop_back();
|
||||
}
|
||||
|
||||
// Handle a call-return relation.
|
||||
void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) {
|
||||
// Create a label for the return point.
|
||||
|
@ -590,7 +678,7 @@ class AscendAutoMonadConverter {
|
|||
// For multi-return call, assign result from temp parameter to
|
||||
// output parameter, this prevent result be overwritten by next call.
|
||||
auto tmp_param = context_.GetTempParameter(output->abstract());
|
||||
output = AssignAll(output, tmp_param);
|
||||
output = AssignAll(output, tmp_param, false, false, true);
|
||||
monad_ = UpdateState(GetMonad(), output);
|
||||
}
|
||||
// Replace the the call/switch node with the output.
|
||||
|
@ -611,7 +699,7 @@ class AscendAutoMonadConverter {
|
|||
void AssignLabelIndexes(const CallSite &call_site) {
|
||||
for (auto &[label_param, label_index] : call_site.label_indexes) {
|
||||
auto index_value = GetIndexValueNode(label_index);
|
||||
auto assign = Assign(label_param, index_value);
|
||||
auto assign = Assign(label_param, index_value, false, false, false);
|
||||
monad_ = UpdateState(GetMonad(), assign);
|
||||
}
|
||||
}
|
||||
|
@ -708,7 +796,7 @@ class AscendAutoMonadConverter {
|
|||
AnfNodePtr out_param =
|
||||
(is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract()));
|
||||
MS_EXCEPTION_IF_NULL(out_param);
|
||||
auto assign_output = AssignAll(out_param, kernel_graph_->output());
|
||||
auto assign_output = AssignAll(out_param, kernel_graph_->output(), false, false, true);
|
||||
monad_ = UpdateState(GetMonad(), assign_output);
|
||||
}
|
||||
|
||||
|
@ -739,6 +827,8 @@ class AscendAutoMonadConverter {
|
|||
if (args.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
// We do not eliminate argument Assign for recursive graphs.
|
||||
const bool keep = IsRecursive(graph);
|
||||
// Single argument.
|
||||
if (args.size() == 1) {
|
||||
auto &value = args.front();
|
||||
|
@ -746,7 +836,7 @@ class AscendAutoMonadConverter {
|
|||
// No assign for single monad argument, return it.
|
||||
return value;
|
||||
}
|
||||
return AssignAll(paras.front(), value, true);
|
||||
return AssignAll(paras.front(), value, true, keep, false);
|
||||
}
|
||||
// Multi arguments.
|
||||
AnfNodePtrList tuple_inputs;
|
||||
|
@ -764,11 +854,14 @@ class AscendAutoMonadConverter {
|
|||
if (target == value) {
|
||||
continue;
|
||||
}
|
||||
tuple_inputs.emplace_back(AssignAll(target, value, true));
|
||||
tuple_inputs.emplace_back(AssignAll(target, value, true, keep, false));
|
||||
}
|
||||
return kernel_graph_->NewCNode(tuple_inputs);
|
||||
}
|
||||
|
||||
// Return true if the graph is involved with recursive calls.
|
||||
bool IsRecursive(const KernelGraphPtr &kg) { return context_.call_info_map[kg].recursive; }
|
||||
|
||||
// For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode.
|
||||
AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); }
|
||||
|
||||
|
@ -780,13 +873,21 @@ class AscendAutoMonadConverter {
|
|||
}
|
||||
|
||||
// Make a assign cnode.
|
||||
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) {
|
||||
auto monad = (is_link ? GetLinkMonad() : GetMonad());
|
||||
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
|
||||
auto monad = (link ? GetLinkMonad() : GetMonad());
|
||||
auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name());
|
||||
if (is_link) {
|
||||
if (link) {
|
||||
// Mark this assign is to link real argument to formal argument.
|
||||
assign_prim->set_attr(LINK, prim::kValueOne);
|
||||
}
|
||||
if (keep) {
|
||||
// Mark this assign should not be eliminated.
|
||||
assign_prim->set_attr(KEEP, prim::kValueOne);
|
||||
}
|
||||
if (output) {
|
||||
// Mark this assign is used for output parameter.
|
||||
assign_prim->set_attr(OUTPUT, prim::kValueOne);
|
||||
}
|
||||
auto assign = NewValueNode(assign_prim);
|
||||
auto cnode = kernel_graph_->NewCNode({assign, target, source, monad});
|
||||
cnode->set_abstract(target->abstract());
|
||||
|
@ -794,10 +895,10 @@ class AscendAutoMonadConverter {
|
|||
}
|
||||
|
||||
// AissgnAll support tuple to tuple assign.
|
||||
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) {
|
||||
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
|
||||
if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
|
||||
// Assign single value.
|
||||
return Assign(target, source, is_link);
|
||||
return Assign(target, source, link, keep, output);
|
||||
}
|
||||
// Assign tuple.
|
||||
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem});
|
||||
|
@ -809,7 +910,7 @@ class AscendAutoMonadConverter {
|
|||
tuple_inputs.reserve(targets.size() + 1);
|
||||
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
|
||||
for (size_t i = 0; i < targets.size(); ++i) {
|
||||
tuple_inputs.emplace_back(Assign(targets[i], sources[i], is_link));
|
||||
tuple_inputs.emplace_back(Assign(targets[i], sources[i], link, keep, output));
|
||||
}
|
||||
return kernel_graph_->NewCNode(tuple_inputs);
|
||||
}
|
||||
|
@ -1079,7 +1180,7 @@ class ExecuteOrderGenerator {
|
|||
auto &node = *iter;
|
||||
// We only try to erase argument link assign nodes,
|
||||
// other assign nodes are skipped.
|
||||
if (IsLinkAssign(node)) {
|
||||
if (IsOptimizableAssign(node)) {
|
||||
auto &target = node->inputs().at(kAssignTargetIndex);
|
||||
MS_EXCEPTION_IF_NULL(target);
|
||||
auto para = param_write_times.find(target);
|
||||
|
@ -1174,8 +1275,8 @@ class ExecuteOrderGenerator {
|
|||
return param_write_times;
|
||||
}
|
||||
|
||||
// Check if a node is an assign for argument link.
|
||||
bool IsLinkAssign(const AnfNodePtr &node) {
|
||||
// Check if a node is an assign for argument link and can be optimized.
|
||||
bool IsOptimizableAssign(const AnfNodePtr &node) {
|
||||
auto cnode = dyn_cast<CNode>(node);
|
||||
if (cnode == nullptr) {
|
||||
return false;
|
||||
|
@ -1184,7 +1285,7 @@ class ExecuteOrderGenerator {
|
|||
if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) {
|
||||
return false;
|
||||
}
|
||||
return prim->GetAttr(LINK) == prim::kValueOne;
|
||||
return (prim->GetAttr(LINK) == prim::kValueOne) && (prim->GetAttr(KEEP) != prim::kValueOne);
|
||||
}
|
||||
|
||||
// Erase LabelGoto and LabelSet
|
||||
|
|
Loading…
Reference in New Issue