!12793 [auto-monad] Refactor ascend_auto_monad

From: @hwhewei
This commit is contained in:
mindspore-ci-bot 2021-03-09 09:20:21 +08:00 committed by Gitee
commit 750d7e6e2a
4 changed files with 96 additions and 174 deletions

View File

@ -118,62 +118,6 @@ void DumpExecuteOrder(NotNull<KernelGraphPtr> kg) {
// ParameterPool cache parameters by its abstract, so that we can reuse
// parameter with same abstract to store return values.
class ParameterPool {
explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {}
~ParameterPool() = default;
// Create or get a parameter from pool with the given abstract.
AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) {
// Find parameter in pool by the given abstract.
auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto &para) {
auto para_abs = para->abstract();
// Reuse output parameter with compatible abstract.
return IsCompatible(abs, para_abs);
// Return the parameter if found.
if (iter != paras_.end()) {
return *iter;
// If parameter not found with the given abstract, create a new one.
auto para = top_graph_->NewParameter(abs);
auto out_para = top_graph_->TransTupleToMakeTuple(para);
// This is required, so that device memory can be allocated for it.
// Save new para to pool.
return out_para;
// Check if one abstract is compatible with another abstract.
static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) {
if (a1 == nullptr || a2 == nullptr) {
return false;
if (a1->isa<abstract::AbstractTensor>() && a2->isa<abstract::AbstractTensor>()) {
// This make AbstractRef compatible with AbstractTensor.
auto &t1 = static_cast<abstract::AbstractTensor &>(*a1);
auto &t2 = static_cast<abstract::AbstractTensor &>(*a2);
return t1 == t2;
return *a1 == *a2;
// The top graph.
const KernelGraphPtr &top_graph_;
// Cached parameters.
std::vector<AnfNodePtr> paras_;
using ParameterPoolPtr = std::shared_ptr<ParameterPool>;
class BaseContext {
void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); }
@ -200,13 +144,38 @@ class AscendAutoMonadContext : public BaseContext {
// Current label id, also the number of label ids we currently used.
uint32_t CurrentLabel() const { return label_id_; }
// Create a new parameter pool.
ParameterPoolPtr NewParameterPool() { return std::make_shared<ParameterPool>(top_graph_); }
// Create or get a parameter for output of the kernel graph.
AnfNodePtr GetOutputParameter(const KernelGraphPtr &kg) {
// Find output parameter by kernel graph.
auto iter = kg_out_param_.find(kg);
if (iter != kg_out_param_.end()) {
// Return output parameter if found.
return iter->second;
// Create a new one if not found.
// Output parameters are all created on top graph.
auto para = top_graph_->NewParameter(kg->output()->abstract());
auto out_para = top_graph_->TransTupleToMakeTuple(para);
// This is required, so that device memory can be allocated for it.
// Save new para as the output parameter of the kg.
kg_out_param_.emplace(kg, out_para);
return out_para;
// Set output parameter for a kernel graph.
void SetOutputParameter(const KernelGraphPtr &kg, const AnfNodePtr &out_para) {
// Save new para as the output parameter of the kg.
kg_out_param_.emplace(kg, out_para);
// The top graph.
const KernelGraphPtr &top_graph_;
// Map kernel_graph to its output parameter.
std::unordered_map<KernelGraphPtr, AnfNodePtr> kg_out_param_;
// Current label id.
uint32_t label_id_ = 1;
@ -254,6 +223,7 @@ class AscendAutoMonadConverter {
// Prepare information for control flow processing.
void Prepare() {
recursive_ = kernel_graph_->has_flag(kFuncGraphFlagRecursive);
AnfNodePtr last_monad = nullptr;
auto nodes = TopoSort(kernel_graph_->output());
for (auto &node : nodes) {
@ -291,26 +261,25 @@ class AscendAutoMonadConverter {
for (auto &cnode : call_switch_nodes_) {
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
} else {
MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString();
// If no tail call, assign output value to output parameter,
// and then goto the return label if set.
if (tail_call_node_ == nullptr) {
if (tail_call_node_ == nullptr || recursive_) {
if (output_parameter_) {
auto assign_output = AssignAll(output_parameter_, kernel_graph_->output());
monad_ = UpdateState(GetMonad(), assign_output);
if (return_label_ != kNoLabel) {
} else {
// Clear end goto if return label not set.
// Insert label_goto for return.
auto return_goto = LabelGoto(return_label_);
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
@ -348,33 +317,37 @@ class AscendAutoMonadConverter {
// as 'select kernel' can handle sub graphs.
SetChildGrapAttr(goto_node, {graph});
// Setup return label if this is not a tail call.
// Setup return label if this is not a tail call or it is a recursive call.
const bool is_tail_call = (cnode == tail_call_node_);
const bool need_return = !is_tail_call;
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return);
const bool need_return = (!is_tail_call || recursive_);
if (!need_return) {
// Set as end_goto if no return required.
auto [output_para, return_label] = MakeReturn(cnode, {graph}, need_return);
// Handle sub-graph recursively.
HandleSubGraph(graph, para_pool, output_para, return_label);
HandleSubGraph(graph, output_para, return_label);
// Convert switch node:
// Convert switch/switchlayer node:
// branch1 = Partial(graph1, arg)
// branch2 = Partial(graph2, arg)
// out = Switch(cond, branch1, branch2)
// out = Switch/SwitchLayer(cond/index, branch1, branch2)
// to:
// r = link_args(graph1, arg)
// c = UpdateState(c, r)
// r = link_args(graph2, arg)
// c = UpdateState(c, r)
// c = LabelSwitch(cond, c) : L1, L2
// c = LabelSwitch(cond/index, c) : L1, L2
// c = LabelSet(c) : <return label>
void HandleSwitch(const CNodePtr &cnode) {
// Update last_monad_.
last_monad_ = monad_map_[cnode];
// Get both branches of the switch, true branch first.
// Get branches of the switch or switchlayer, true or 0 branch first.
auto branches = GetSwitchBranches(cnode);
// Link arguments and generate labels for branches.
@ -394,63 +367,12 @@ class AscendAutoMonadConverter {
// Since true/false branches is reversed in kernel LabelSwitch,
// We reverse graphes and labels to make false branch first.
std::reverse(graphes.begin(), graphes.end());
std::reverse(labels.begin(), labels.end());
// Add LabelSwith node.
auto switch_node = LabelSwitch(cnode->input(1), labels);
// Set child graph attribute for switch node.
SetChildGrapAttr(switch_node, graphes);
// Setup return label if required.
const bool is_tail_call = (cnode == tail_call_node_);
const bool need_return = (return_label_ == kNoLabel || !is_tail_call);
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return);
// Handle sub-graphs recursively.
for (auto &graph : graphes) {
HandleSubGraph(graph, para_pool, output_para, return_label);
// Convert switch node:
// branch1 = Partial(graph1, arg)
// branch2 = Partial(graph2, arg)
// out = SwitchLayer(index, branch1, branch2)
// to:
// r = link_args(graph1, arg)
// c = UpdateState(c, r)
// r = link_args(graph2, arg)
// c = UpdateState(c, r)
// c = LabelSwitch(index, c) : L1, L2
// c = LabelSet(c) : <return label>
void HandleSwitchLayer(const CNodePtr &cnode) {
// Update last_monad_.
last_monad_ = monad_map_[cnode];
// Get both branches of the switch, true branch first.
auto branches = GetSwitchBranches(cnode);
// Link arguments and generate labels for branches.
std::vector<KernelGraphPtr> graphes;
std::vector<uint32_t> labels;
for (auto &[graph, args] : branches) {
if (graph == nullptr) {
MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
auto linked_args = LinkArguments(args, graph);
if (linked_args != nullptr) {
monad_ = UpdateState(GetMonad(), linked_args);
const bool is_switch = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch);
if (is_switch) {
// For Switch, we reverse the graphes and labels, so that the false branch
// is the first one, since for kernel LabelSwitch, false is the first branch.
std::reverse(graphes.begin(), graphes.end());
std::reverse(labels.begin(), labels.end());
// Add LabelSwith node.
@ -459,41 +381,42 @@ class AscendAutoMonadConverter {
// Set child graph attribute for switch node.
SetChildGrapAttr(switch_node, graphes);
if (!is_switch) {
// Mark the switch node is for 'switch_layer'.
AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, switch_node);
// Setup return label if required.
const bool is_tail_call = (cnode == tail_call_node_);
const bool need_return = (return_label_ == kNoLabel || !is_tail_call);
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return);
const bool need_return = (return_label_ == kNoLabel || !is_tail_call || recursive_);
auto [output_para, return_label] = MakeReturn(cnode, graphes, need_return);
// Handle sub-graphs recursively.
for (auto &graph : graphes) {
HandleSubGraph(graph, para_pool, output_para, return_label);
HandleSubGraph(graph, output_para, return_label);
ParameterPoolPtr GetParameterPool(bool is_last_call) {
if (!is_last_call) {
// There are multiple calls in this graph, use a new parameter pool
// for each of them except the last one.
return context_.NewParameterPool();
AnfNodePtr GetOutputParameter(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches) {
const bool is_tail_call = (cnode == tail_call_node_);
if (is_tail_call && output_parameter_ != nullptr) {
return output_parameter_;
// For last call, try reuse parameter pool from the caller.
if (para_pool_ == nullptr) {
para_pool_ = context_.NewParameterPool();
return para_pool_;
return context_.GetOutputParameter(branches.front());
// Make return part of a call for the LabelGoto/LabelSwitch node.
std::tuple<ParameterPoolPtr, AnfNodePtr, uint32_t> MakeReturn(const CNodePtr &cnode, bool need_return) {
// Find a parameter pool for output parameter.
const bool is_last_call = (cnode == call_switch_nodes_.back());
auto para_pool = GetParameterPool(is_last_call);
// Prepare return label and output parameter.
std::tuple<AnfNodePtr, uint32_t> MakeReturn(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches,
bool need_return) {
// Prepare return label.
uint32_t return_label = return_label_;
auto output_para = para_pool->GetParameter(cnode->abstract());
// Prepare output parameter.
auto output_para = GetOutputParameter(cnode, branches);
// Use same output parameter for all branches.
for (auto &branch : branches) {
context_.SetOutputParameter(branch, output_para);
auto output = output_para;
// Setup return label if return is required.
if (need_return) {
// Set a new label at return point.
@ -504,16 +427,14 @@ class AscendAutoMonadConverter {
output = MakeDepend(output, label_node);
// Replace the the switch node with the output.
// Replace the the call/switch node with the output.
kernel_graph_->ReplaceNode(NOT_NULL(cnode), NOT_NULL(output));
return {para_pool, output_para, return_label};
return {output_para, return_label};
// Handle sub-graphs recursively.
void HandleSubGraph(const KernelGraphPtr &graph, const ParameterPoolPtr &para_pool, const AnfNodePtr &out_para,
uint32_t return_label) {
void HandleSubGraph(const KernelGraphPtr &graph, const AnfNodePtr &out_para, uint32_t return_label) {
AscendAutoMonadConverter converter(&context_, graph);
converter.para_pool_ = para_pool;
converter.output_parameter_ = out_para;
converter.return_label_ = return_label;
@ -717,7 +638,6 @@ class AscendAutoMonadConverter {
auto cnode = kernel_graph_->NewCNode({label_goto, monad});
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
kernel_graph_->set_end_goto(cnode); // make 'goto' the last one in execute order.
monad_ = cnode;
return cnode;
@ -794,11 +714,11 @@ class AscendAutoMonadConverter {
// Parameter to store the return value.
AnfNodePtr output_parameter_;
// Parameter pool for output parameter allocation.
ParameterPoolPtr para_pool_;
// The return label id.
uint32_t return_label_ = kNoLabel;
// Is this graph include recursive calls.
bool recursive_ = false;
constexpr size_t kAssignTargetIndex = 1;
@ -851,20 +771,22 @@ class ExecuteOrderGenerator {
std::vector<CNodePtr> execution_order;
const auto &cnodes = graph_->execution_order();
for (auto cnode : cnodes) {
for (auto &cnode : cnodes) {
// Push current node to execution order list.
// For cnode with sub-graphs, such as LabelSwitch, LabelGoto,
// Generate execute order for these sub-graphs,
// and then append them to current execution order list.
if (HasSubGraphs(cnode)) {
// We use reversed order to generate sub-graph's execution order,
// because the true branch of LabelSwitch is the second one, but
// we want to make true branch ahead of false branch in the generated
// execution order.
auto sub_graphs = GetSubGraphs(cnode);
for (auto iter = sub_graphs.rbegin(); iter != sub_graphs.rend(); iter++) {
auto &sub_graph = *iter;
if (!AnfAlgo::HasNodeAttr(kAttrSwitchLayer, cnode)) {
// For Switch, we use reversed order to generate sub-graph's execution order,
// because the true branch of LabelSwitch is the second one, but
// we want to make true branch ahead of false branch in the generated
// execution order.
std::reverse(sub_graphs.begin(), sub_graphs.end());
for (auto &sub_graph : sub_graphs) {
if (context_.IsVisited(sub_graph)) {
// Skip visited sub-graphs.

View File

@ -398,6 +398,8 @@ constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
constexpr auto kAttrStitch = "stitch";
constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first";
constexpr auto kAttrSwitchLayer = "switch_layer";
constexpr auto kAttrReturn = "return";
// attr value
constexpr auto kValueTargetSwitch = "target_switch";

View File

@ -86,6 +86,7 @@ const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
const char kFuncGraphFlagReAutoMonad[] = "ReAutoMonad";
const char kFuncGraphFlagRecursive[] = "Recursive";
namespace abstract {
class AbstractKeywordArg;

View File

@ -24,11 +24,12 @@ from mindspore.common import dtype as mstype
class CaseNet(nn.Cell):
def __init__(self):
super(CaseNet, self).__init__()
self.conv = nn.Conv2d(1, 3, 3)
self.conv = nn.Conv2d(1, 1, 3)
self.relu = nn.ReLU()
self.relu1 = nn.ReLU()
self.softmax = nn.Softmax()
self.layers1 = (self.relu, self.softmax)
self.layers2 = (self.conv, self.relu)
self.layers2 = (self.conv, self.relu1)
def construct(self, x, index1, index2):
x = self.layers1[index1](x)
@ -50,7 +51,3 @@ def test_switch_layer():
true_value = relu(data)
ret = np.allclose(value.asnumpy(), true_value.asnumpy())
assert ret
idx3 = Tensor(3, mstype.int32)
with pytest.raises(IndexError):
value = net(data, idx3, idx2)