[auto-monad] Fix multi-call output parameter be overwritten issue

This commit is contained in:
He Wei 2021-03-18 10:49:06 +08:00
parent 0bd1e34a4d
commit 01eaaed85f
1 changed files with 95 additions and 22 deletions

View File

@ -1,5 +1,5 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
* Copyright 2020-2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
@ -166,6 +166,63 @@ struct CallInfo {
AnfNodePtr label_param = nullptr;
};
//
// ParameterPool cache parameters by its abstract, so that we can reuse
// parameter with same abstract to store return values.
//
class ParameterPool {
public:
explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {}
~ParameterPool() = default;
// Create or get a parameter from pool with the given abstract.
AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) {
// Find parameter in pool by the given abstract.
auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto &para) {
auto para_abs = para->abstract();
// Reuse output parameter with compatible abstract.
return IsCompatible(abs, para_abs);
});
// Return the parameter if found.
if (iter != paras_.end()) {
return *iter;
}
// If parameter not found with the given abstract, create a new one.
auto para = top_graph_->NewParameter(abs);
auto out_para = top_graph_->TransTupleToMakeTuple(para);
// This is required, so that device memory can be allocated for it.
top_graph_->AddChildGraphResult(out_para);
// Save new para to pool.
paras_.push_back(out_para);
return out_para;
}
protected:
// Check if one abstract is compatible with another abstract.
static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) {
if (a1 == nullptr || a2 == nullptr) {
return false;
}
if (a1->isa<abstract::AbstractTensor>() && a2->isa<abstract::AbstractTensor>()) {
// This make AbstractRef compatible with AbstractTensor.
auto &t1 = static_cast<abstract::AbstractTensor &>(*a1);
auto &t2 = static_cast<abstract::AbstractTensor &>(*a2);
return t1 == t2;
}
return *a1 == *a2;
}
private:
// The top graph.
const KernelGraphPtr &top_graph_;
// Cached parameters.
std::vector<AnfNodePtr> paras_;
};
//
// Base class for context.
//
class BaseContext {
public:
void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); }
@ -185,7 +242,7 @@ class BaseContext {
//
class AscendAutoMonadContext : public BaseContext {
public:
explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg) {}
explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg), param_pool_(kg) {}
~AscendAutoMonadContext() = default;
// Label id start from 1, and increased by 1 for each new id.
@ -204,6 +261,9 @@ class AscendAutoMonadContext : public BaseContext {
return out_para;
}
// Get or create a temporary parameter for the given abstract.
AnfNodePtr GetTempParameter(const AbstractBasePtr &abs) { return param_pool_.GetParameter(abs); }
const KernelGraphPtr &TopGraph() const { return top_graph_; }
// Map kernel_graph to its call info.
@ -213,8 +273,8 @@ class AscendAutoMonadContext : public BaseContext {
// The top graph.
const KernelGraphPtr &top_graph_;
// Map kernel_graph to its output parameter.
std::unordered_map<KernelGraphPtr, AnfNodePtr> kg_out_param_;
// The parameter pool that cache parameters for return value.
ParameterPool param_pool_;
// Current label id.
uint32_t label_id_ = 1;
@ -521,9 +581,18 @@ class AscendAutoMonadConverter {
auto label_node = LabelSet(call_site.return_label);
AnfNodePtr output = call_site.out_param;
MS_EXCEPTION_IF_NULL(output);
// Let output depend on the label node, this ensures the
// return label is set before output is used.
output = MakeDepend(output, label_node);
const bool is_single_call = call_site.label_indexes.empty();
if (is_single_call) {
// For single call, let output depend on the label node,
// this ensures the return label is set before output is used.
output = MakeDepend(output, label_node);
} else {
// For multi-return call, assign result from temp parameter to
// output parameter, this prevent result be overwritten by next call.
auto tmp_param = context_.GetTempParameter(output->abstract());
output = AssignAll(output, tmp_param);
monad_ = UpdateState(GetMonad(), output);
}
// Replace the the call/switch node with the output.
ReplaceNode(cnode, output);
return;
@ -603,12 +672,12 @@ class AscendAutoMonadConverter {
if (return_points.empty()) {
return;
}
// Assign output according the return points.
AssignOutput(return_points);
// Single return point.
if (return_points.size() == 1) {
// Insert Assign for output parameter.
auto &return_point = return_points.front();
AssignOutput(return_point);
// Insert label_goto for return.
auto &return_point = return_points.front();
auto return_goto = LabelGoto(return_point.call_site->return_label);
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
kernel_graph_->set_end_goto(return_goto);
@ -617,12 +686,9 @@ class AscendAutoMonadConverter {
// Multi return points.
std::vector<uint32_t> return_labels;
return_labels.reserve(return_points.size());
for (auto &return_point : return_points) {
// Assign output to out_params of each return point.
AssignOutput(return_point);
// Get return labels.
return_labels.emplace_back(return_point.call_site->return_label);
}
// Get return labels from return points.
std::transform(return_points.begin(), return_points.end(), std::back_inserter(return_labels),
[](const ReturnPoint &return_point) { return return_point.call_site->return_label; });
// Insert label_switch for multi return points.
auto &label_param = call_info_.label_param;
MS_EXCEPTION_IF_NULL(label_param);
@ -631,11 +697,18 @@ class AscendAutoMonadConverter {
kernel_graph_->set_end_goto(return_switch);
}
// Assign graph output to the output parameter for a return point.
void AssignOutput(const ReturnPoint &return_point) {
auto call_site = return_point.call_site;
// Assign graph output to the output parameter.
void AssignOutput(const std::vector<ReturnPoint> &return_points) {
// For single call: we directly assign output to the output parameter of the call site;
// For multi call: we assign output to a temp parameter, and let caller assign the
// temp parameter to a output parameter after returned.
auto call_site = return_points.front().call_site;
MS_EXCEPTION_IF_NULL(call_site);
auto assign_output = AssignAll(call_site->out_param, kernel_graph_->output());
const bool is_single_call = (return_points.size() == 1 && call_site->label_indexes.empty());
AnfNodePtr out_param =
(is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract()));
MS_EXCEPTION_IF_NULL(out_param);
auto assign_output = AssignAll(out_param, kernel_graph_->output());
monad_ = UpdateState(GetMonad(), assign_output);
}
@ -699,7 +772,7 @@ class AscendAutoMonadConverter {
// For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode.
AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); }
AnfNodePtr GetAssignMonad() {
AnfNodePtr GetLinkMonad() {
if (last_monad_ != nullptr) {
return last_monad_;
}
@ -708,7 +781,7 @@ class AscendAutoMonadConverter {
// Make a assign cnode.
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) {
auto monad = GetAssignMonad();
auto monad = (is_link ? GetLinkMonad() : GetMonad());
auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name());
if (is_link) {
// Mark this assign is to link real argument to formal argument.