!13600 [auto-monad] Prepare to support recursive calls

From: @hwhewei
Reviewed-by: @ginfung,@zh_qh
Signed-off-by: @zh_qh
This commit is contained in:
mindspore-ci-bot 2021-03-22 15:06:27 +08:00 committed by Gitee
commit 9359983123
1 changed files with 119 additions and 18 deletions

View File

@ -44,6 +44,13 @@ constexpr uint32_t kNoLabel = 0;
// Primitive attribute for argument link assign.
const char LINK[] = "link";
// Attribute to indicate that the node should not be eliminated.
// Used to keep argument Assign nodes for recursive graphs.
const char KEEP[] = "keep";
// Attribute to indicate that this is an assign for output.
const char OUTPUT[] = "output";
bool IsSaveGraph() {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -152,6 +159,9 @@ struct CallSite {
// Label param to index map.
std::map<AnfNodePtr, uint32_t> label_indexes;
// True if this is a recursive call.
bool recursive = false;
// True if this is a tail call.
bool tail = false;
};
@ -161,9 +171,18 @@ struct ReturnPoint {
};
struct CallInfo {
// Call sites in current graph.
std::vector<CallSite> call_sites;
// Return points of current graph.
std::vector<ReturnPoint> return_points;
// Parameter to store label index, if there are
// multi return points, this should be set.
AnfNodePtr label_param = nullptr;
// True if current graph is involved with recursive calls.
bool recursive = false;
};
//
@ -296,6 +315,8 @@ class CallInfoFinder {
void Run() {
FindCallSites();
FindRecursiveCalls();
DisableTailCalls();
FindCallReturns();
}
@ -340,11 +361,30 @@ class CallInfoFinder {
}
}
// Find recursive non-tail calls.
void FindRecursiveCalls() {
for (auto &[caller, call_info] : context_.call_info_map) {
for (auto &call_site : call_info.call_sites) {
if (!call_site.tail) {
SearchRecursiveCall(caller, &call_site);
}
}
}
}
// Disable tail call optimization for recursive call graphs.
void DisableTailCalls() {
for (auto &entry : context_.call_info_map) {
auto &call_info = entry.second;
if (call_info.recursive && !call_info.call_sites.empty()) {
call_info.call_sites.back().tail = false;
}
}
}
// Find call-return pairs.
void FindCallReturns() {
for (auto &entry : context_.call_info_map) {
auto &caller = entry.first;
auto &call_info = entry.second;
for (auto &[caller, call_info] : context_.call_info_map) {
for (auto &call_site : call_info.call_sites) {
for (auto &callee : call_site.callees) {
MakeGraphLabel(callee.graph);
@ -396,6 +436,54 @@ class CallInfoFinder {
}
}
struct SearchRecursiveContext {
const KernelGraphPtr &start_caller;
CallSite *start_site;
std::set<KernelGraphPtr> visited;
std::vector<KernelGraphPtr> call_path;
};
// Search recursive call from a call-site.
void SearchRecursiveCall(const KernelGraphPtr &start_caller, CallSite *start_site) {
SearchRecursiveContext context{.start_caller = start_caller, .start_site = start_site};
DoSearchRecursiveCall(start_caller, start_site, &context);
}
void DoSearchRecursiveCall(const KernelGraphPtr &graph, CallSite *call_site, SearchRecursiveContext *ctx) {
// Record call path.
ctx->call_path.push_back(graph);
// Handle callee graphs.
for (auto &callee : call_site->callees) {
auto &sub_graph = callee.graph;
if (sub_graph == ctx->start_caller) {
// Find a recursive call path.
for (auto &g : ctx->call_path) {
// Mark recursive for all graphs in call path.
context_.call_info_map[g].recursive = true;
}
// Mark recursive for the start call-site.
ctx->start_site->recursive = true;
continue;
}
if (ctx->visited.find(sub_graph) != ctx->visited.end()) {
// Skip visited graphs.
continue;
}
// Mark visited.
ctx->visited.emplace(sub_graph);
// Check call sites in the sub-graph.
auto &call_info = context_.call_info_map[sub_graph];
auto &sites = call_info.call_sites;
for (auto &site : sites) {
if (!site.callees.empty()) {
DoSearchRecursiveCall(sub_graph, &site, ctx);
}
}
}
// Don't forget this.
ctx->call_path.pop_back();
}
// Handle a call-return relation.
void HandleCallReturn(const KernelGraphPtr &caller, CallSite *call_site, const KernelGraphPtr &callee) {
// Create a label for the return point.
@ -590,7 +678,7 @@ class AscendAutoMonadConverter {
// For multi-return call, assign result from temp parameter to
// output parameter, this prevent result be overwritten by next call.
auto tmp_param = context_.GetTempParameter(output->abstract());
output = AssignAll(output, tmp_param);
output = AssignAll(output, tmp_param, false, false, true);
monad_ = UpdateState(GetMonad(), output);
}
// Replace the the call/switch node with the output.
@ -611,7 +699,7 @@ class AscendAutoMonadConverter {
void AssignLabelIndexes(const CallSite &call_site) {
for (auto &[label_param, label_index] : call_site.label_indexes) {
auto index_value = GetIndexValueNode(label_index);
auto assign = Assign(label_param, index_value);
auto assign = Assign(label_param, index_value, false, false, false);
monad_ = UpdateState(GetMonad(), assign);
}
}
@ -708,7 +796,7 @@ class AscendAutoMonadConverter {
AnfNodePtr out_param =
(is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract()));
MS_EXCEPTION_IF_NULL(out_param);
auto assign_output = AssignAll(out_param, kernel_graph_->output());
auto assign_output = AssignAll(out_param, kernel_graph_->output(), false, false, true);
monad_ = UpdateState(GetMonad(), assign_output);
}
@ -739,6 +827,8 @@ class AscendAutoMonadConverter {
if (args.empty()) {
return nullptr;
}
// We do not eliminate argument Assign for recursive graphs.
const bool keep = IsRecursive(graph);
// Single argument.
if (args.size() == 1) {
auto &value = args.front();
@ -746,7 +836,7 @@ class AscendAutoMonadConverter {
// No assign for single monad argument, return it.
return value;
}
return AssignAll(paras.front(), value, true);
return AssignAll(paras.front(), value, true, keep, false);
}
// Multi arguments.
AnfNodePtrList tuple_inputs;
@ -764,11 +854,14 @@ class AscendAutoMonadConverter {
if (target == value) {
continue;
}
tuple_inputs.emplace_back(AssignAll(target, value, true));
tuple_inputs.emplace_back(AssignAll(target, value, true, keep, false));
}
return kernel_graph_->NewCNode(tuple_inputs);
}
// Return true if the graph is involved with recursive calls.
bool IsRecursive(const KernelGraphPtr &kg) { return context_.call_info_map[kg].recursive; }
// For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode.
AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); }
@ -780,13 +873,21 @@ class AscendAutoMonadConverter {
}
// Make a assign cnode.
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) {
auto monad = (is_link ? GetLinkMonad() : GetMonad());
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
auto monad = (link ? GetLinkMonad() : GetMonad());
auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name());
if (is_link) {
if (link) {
// Mark this assign is to link real argument to formal argument.
assign_prim->set_attr(LINK, prim::kValueOne);
}
if (keep) {
// Mark this assign should not be eliminated.
assign_prim->set_attr(KEEP, prim::kValueOne);
}
if (output) {
// Mark this assign is used for output parameter.
assign_prim->set_attr(OUTPUT, prim::kValueOne);
}
auto assign = NewValueNode(assign_prim);
auto cnode = kernel_graph_->NewCNode({assign, target, source, monad});
cnode->set_abstract(target->abstract());
@ -794,10 +895,10 @@ class AscendAutoMonadConverter {
}
// AissgnAll support tuple to tuple assign.
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) {
AnfNodePtr AssignAll(const AnfNodePtr &target, const AnfNodePtr &source, bool link, bool keep, bool output) {
if (!AnfAlgo::CheckPrimitiveType(target, prim::kPrimMakeTuple)) {
// Assign single value.
return Assign(target, source, is_link);
return Assign(target, source, link, keep, output);
}
// Assign tuple.
std::vector<AnfNodePtr> targets = AnfAlgo::GetAllOutput(target, {prim::kPrimTupleGetItem});
@ -809,7 +910,7 @@ class AscendAutoMonadConverter {
tuple_inputs.reserve(targets.size() + 1);
tuple_inputs.emplace_back(NewValueNode(prim::kPrimMakeTuple));
for (size_t i = 0; i < targets.size(); ++i) {
tuple_inputs.emplace_back(Assign(targets[i], sources[i], is_link));
tuple_inputs.emplace_back(Assign(targets[i], sources[i], link, keep, output));
}
return kernel_graph_->NewCNode(tuple_inputs);
}
@ -1079,7 +1180,7 @@ class ExecuteOrderGenerator {
auto &node = *iter;
// We only try to erase argument link assign nodes,
// other assign nodes are skipped.
if (IsLinkAssign(node)) {
if (IsOptimizableAssign(node)) {
auto &target = node->inputs().at(kAssignTargetIndex);
MS_EXCEPTION_IF_NULL(target);
auto para = param_write_times.find(target);
@ -1174,8 +1275,8 @@ class ExecuteOrderGenerator {
return param_write_times;
}
// Check if a node is an assign for argument link.
bool IsLinkAssign(const AnfNodePtr &node) {
// Check if a node is an assign for argument link and can be optimized.
bool IsOptimizableAssign(const AnfNodePtr &node) {
auto cnode = dyn_cast<CNode>(node);
if (cnode == nullptr) {
return false;
@ -1184,7 +1285,7 @@ class ExecuteOrderGenerator {
if (!IsPrimitiveEquals(prim, prim::kPrimAssign)) {
return false;
}
return prim->GetAttr(LINK) == prim::kValueOne;
return (prim->GetAttr(LINK) == prim::kValueOne) && (prim->GetAttr(KEEP) != prim::kValueOne);
}
// Erase LabelGoto and LabelSet