forked from mindspore-Ecosystem/mindspore
[auto-monad] Fix multi-call output parameter be overwritten issue
This commit is contained in:
parent
0bd1e34a4d
commit
01eaaed85f
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2020 Huawei Technologies Co., Ltd
|
||||
* Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -166,6 +166,63 @@ struct CallInfo {
|
|||
AnfNodePtr label_param = nullptr;
|
||||
};
|
||||
|
||||
//
|
||||
// ParameterPool cache parameters by its abstract, so that we can reuse
|
||||
// parameter with same abstract to store return values.
|
||||
//
|
||||
class ParameterPool {
|
||||
public:
|
||||
explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {}
|
||||
~ParameterPool() = default;
|
||||
|
||||
// Create or get a parameter from pool with the given abstract.
|
||||
AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) {
|
||||
// Find parameter in pool by the given abstract.
|
||||
auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) {
|
||||
auto para_abs = para->abstract();
|
||||
// Reuse output parameter with compatible abstract.
|
||||
return IsCompatible(abs, para_abs);
|
||||
});
|
||||
// Return the parameter if found.
|
||||
if (iter != paras_.end()) {
|
||||
return *iter;
|
||||
}
|
||||
// If parameter not found with the given abstract, create a new one.
|
||||
auto para = top_graph_->NewParameter(abs);
|
||||
auto out_para = top_graph_->TransTupleToMakeTuple(para);
|
||||
// This is required, so that device memory can be allocated for it.
|
||||
top_graph_->AddChildGraphResult(out_para);
|
||||
// Save new para to pool.
|
||||
paras_.push_back(out_para);
|
||||
return out_para;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Check if one abstract is compatible with another abstract.
|
||||
static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) {
|
||||
if (a1 == nullptr || a2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (a1->isa<abstract::AbstractTensor>() && a2->isa<abstract::AbstractTensor>()) {
|
||||
// This make AbstractRef compatible with AbstractTensor.
|
||||
auto &t1 = static_cast<abstract::AbstractTensor &>(*a1);
|
||||
auto &t2 = static_cast<abstract::AbstractTensor &>(*a2);
|
||||
return t1 == t2;
|
||||
}
|
||||
return *a1 == *a2;
|
||||
}
|
||||
|
||||
private:
|
||||
// The top graph.
|
||||
const KernelGraphPtr &top_graph_;
|
||||
|
||||
// Cached parameters.
|
||||
std::vector<AnfNodePtr> paras_;
|
||||
};
|
||||
|
||||
//
|
||||
// Base class for context.
|
||||
//
|
||||
class BaseContext {
|
||||
public:
|
||||
void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); }
|
||||
|
@ -185,7 +242,7 @@ class BaseContext {
|
|||
//
|
||||
class AscendAutoMonadContext : public BaseContext {
|
||||
public:
|
||||
explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg) {}
|
||||
explicit AscendAutoMonadContext(const KernelGraphPtr &kg) : top_graph_(kg), param_pool_(kg) {}
|
||||
~AscendAutoMonadContext() = default;
|
||||
|
||||
// Label id start from 1, and increased by 1 for each new id.
|
||||
|
@ -204,6 +261,9 @@ class AscendAutoMonadContext : public BaseContext {
|
|||
return out_para;
|
||||
}
|
||||
|
||||
// Get or create a temporary parameter for the given abstract.
|
||||
AnfNodePtr GetTempParameter(const AbstractBasePtr &abs) { return param_pool_.GetParameter(abs); }
|
||||
|
||||
const KernelGraphPtr &TopGraph() const { return top_graph_; }
|
||||
|
||||
// Map kernel_graph to its call info.
|
||||
|
@ -213,8 +273,8 @@ class AscendAutoMonadContext : public BaseContext {
|
|||
// The top graph.
|
||||
const KernelGraphPtr &top_graph_;
|
||||
|
||||
// Map kernel_graph to its output parameter.
|
||||
std::unordered_map<KernelGraphPtr, AnfNodePtr> kg_out_param_;
|
||||
// The parameter pool that cache parameters for return value.
|
||||
ParameterPool param_pool_;
|
||||
|
||||
// Current label id.
|
||||
uint32_t label_id_ = 1;
|
||||
|
@ -521,9 +581,18 @@ class AscendAutoMonadConverter {
|
|||
auto label_node = LabelSet(call_site.return_label);
|
||||
AnfNodePtr output = call_site.out_param;
|
||||
MS_EXCEPTION_IF_NULL(output);
|
||||
// Let output depend on the label node, this ensures the
|
||||
// return label is set before output is used.
|
||||
output = MakeDepend(output, label_node);
|
||||
const bool is_single_call = call_site.label_indexes.empty();
|
||||
if (is_single_call) {
|
||||
// For single call, let output depend on the label node,
|
||||
// this ensures the return label is set before output is used.
|
||||
output = MakeDepend(output, label_node);
|
||||
} else {
|
||||
// For multi-return call, assign result from temp parameter to
|
||||
// output parameter, this prevent result be overwritten by next call.
|
||||
auto tmp_param = context_.GetTempParameter(output->abstract());
|
||||
output = AssignAll(output, tmp_param);
|
||||
monad_ = UpdateState(GetMonad(), output);
|
||||
}
|
||||
// Replace the the call/switch node with the output.
|
||||
ReplaceNode(cnode, output);
|
||||
return;
|
||||
|
@ -603,12 +672,12 @@ class AscendAutoMonadConverter {
|
|||
if (return_points.empty()) {
|
||||
return;
|
||||
}
|
||||
// Assign output according the return points.
|
||||
AssignOutput(return_points);
|
||||
// Single return point.
|
||||
if (return_points.size() == 1) {
|
||||
// Insert Assign for output parameter.
|
||||
auto &return_point = return_points.front();
|
||||
AssignOutput(return_point);
|
||||
// Insert label_goto for return.
|
||||
auto &return_point = return_points.front();
|
||||
auto return_goto = LabelGoto(return_point.call_site->return_label);
|
||||
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
|
||||
kernel_graph_->set_end_goto(return_goto);
|
||||
|
@ -617,12 +686,9 @@ class AscendAutoMonadConverter {
|
|||
// Multi return points.
|
||||
std::vector<uint32_t> return_labels;
|
||||
return_labels.reserve(return_points.size());
|
||||
for (auto &return_point : return_points) {
|
||||
// Assign output to out_params of each return point.
|
||||
AssignOutput(return_point);
|
||||
// Get return labels.
|
||||
return_labels.emplace_back(return_point.call_site->return_label);
|
||||
}
|
||||
// Get return labels from return points.
|
||||
std::transform(return_points.begin(), return_points.end(), std::back_inserter(return_labels),
|
||||
[](const ReturnPoint &return_point) { return return_point.call_site->return_label; });
|
||||
// Insert label_switch for multi return points.
|
||||
auto &label_param = call_info_.label_param;
|
||||
MS_EXCEPTION_IF_NULL(label_param);
|
||||
|
@ -631,11 +697,18 @@ class AscendAutoMonadConverter {
|
|||
kernel_graph_->set_end_goto(return_switch);
|
||||
}
|
||||
|
||||
// Assign graph output to the output parameter for a return point.
|
||||
void AssignOutput(const ReturnPoint &return_point) {
|
||||
auto call_site = return_point.call_site;
|
||||
// Assign graph output to the output parameter.
|
||||
void AssignOutput(const std::vector<ReturnPoint> &return_points) {
|
||||
// For single call: we directly assign output to the output parameter of the call site;
|
||||
// For multi call: we assign output to a temp parameter, and let caller assign the
|
||||
// temp parameter to a output parameter after returned.
|
||||
auto call_site = return_points.front().call_site;
|
||||
MS_EXCEPTION_IF_NULL(call_site);
|
||||
auto assign_output = AssignAll(call_site->out_param, kernel_graph_->output());
|
||||
const bool is_single_call = (return_points.size() == 1 && call_site->label_indexes.empty());
|
||||
AnfNodePtr out_param =
|
||||
(is_single_call ? call_site->out_param : context_.GetTempParameter(kernel_graph_->output()->abstract()));
|
||||
MS_EXCEPTION_IF_NULL(out_param);
|
||||
auto assign_output = AssignAll(out_param, kernel_graph_->output());
|
||||
monad_ = UpdateState(GetMonad(), assign_output);
|
||||
}
|
||||
|
||||
|
@ -699,7 +772,7 @@ class AscendAutoMonadConverter {
|
|||
// For some cnode, attributes may set to primitive instance, so we create a new prim instance for each cnode.
|
||||
AnfNodePtr NewPrimitive(const PrimitivePtr &prim) { return NewValueNode(std::make_shared<Primitive>(prim->name())); }
|
||||
|
||||
AnfNodePtr GetAssignMonad() {
|
||||
AnfNodePtr GetLinkMonad() {
|
||||
if (last_monad_ != nullptr) {
|
||||
return last_monad_;
|
||||
}
|
||||
|
@ -708,7 +781,7 @@ class AscendAutoMonadConverter {
|
|||
|
||||
// Make a assign cnode.
|
||||
CNodePtr Assign(const AnfNodePtr &target, const AnfNodePtr &source, bool is_link = false) {
|
||||
auto monad = GetAssignMonad();
|
||||
auto monad = (is_link ? GetLinkMonad() : GetMonad());
|
||||
auto assign_prim = std::make_shared<Primitive>(prim::kPrimAssign->name());
|
||||
if (is_link) {
|
||||
// Mark this assign is to link real argument to formal argument.
|
||||
|
|
Loading…
Reference in New Issue