forked from mindspore-Ecosystem/mindspore
!12793 [auto-monad] Refactor ascend_auto_monad
From: @hwhewei Reviewed-by: Signed-off-by:
This commit is contained in:
commit
750d7e6e2a
|
@ -118,62 +118,6 @@ void DumpExecuteOrder(NotNull<KernelGraphPtr> kg) {
|
|||
fout.close();
|
||||
}
|
||||
|
||||
//
|
||||
// ParameterPool cache parameters by its abstract, so that we can reuse
|
||||
// parameter with same abstract to store return values.
|
||||
//
|
||||
class ParameterPool {
|
||||
public:
|
||||
explicit ParameterPool(const KernelGraphPtr &top_graph) : top_graph_(top_graph) {}
|
||||
~ParameterPool() = default;
|
||||
|
||||
// Create or get a parameter from pool with the given abstract.
|
||||
AnfNodePtr GetParameter(const abstract::AbstractBasePtr &abs) {
|
||||
// Find parameter in pool by the given abstract.
|
||||
auto iter = std::find_if(paras_.begin(), paras_.end(), [&abs](auto ¶) {
|
||||
auto para_abs = para->abstract();
|
||||
// Reuse output parameter with compatible abstract.
|
||||
return IsCompatible(abs, para_abs);
|
||||
});
|
||||
// Return the parameter if found.
|
||||
if (iter != paras_.end()) {
|
||||
return *iter;
|
||||
}
|
||||
// If parameter not found with the given abstract, create a new one.
|
||||
auto para = top_graph_->NewParameter(abs);
|
||||
auto out_para = top_graph_->TransTupleToMakeTuple(para);
|
||||
// This is required, so that device memory can be allocated for it.
|
||||
top_graph_->AddChildGraphResult(out_para);
|
||||
// Save new para to pool.
|
||||
paras_.push_back(out_para);
|
||||
return out_para;
|
||||
}
|
||||
|
||||
protected:
|
||||
// Check if one abstract is compatible with another abstract.
|
||||
static bool IsCompatible(const abstract::AbstractBasePtr &a1, const abstract::AbstractBasePtr &a2) {
|
||||
if (a1 == nullptr || a2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
if (a1->isa<abstract::AbstractTensor>() && a2->isa<abstract::AbstractTensor>()) {
|
||||
// This make AbstractRef compatible with AbstractTensor.
|
||||
auto &t1 = static_cast<abstract::AbstractTensor &>(*a1);
|
||||
auto &t2 = static_cast<abstract::AbstractTensor &>(*a2);
|
||||
return t1 == t2;
|
||||
}
|
||||
return *a1 == *a2;
|
||||
}
|
||||
|
||||
private:
|
||||
// The top graph.
|
||||
const KernelGraphPtr &top_graph_;
|
||||
|
||||
// Cached parameters.
|
||||
std::vector<AnfNodePtr> paras_;
|
||||
};
|
||||
|
||||
using ParameterPoolPtr = std::shared_ptr<ParameterPool>;
|
||||
|
||||
class BaseContext {
|
||||
public:
|
||||
void MarkVisited(const KernelGraphPtr &kg) { visited_graphs_.insert(kg); }
|
||||
|
@ -200,13 +144,38 @@ class AscendAutoMonadContext : public BaseContext {
|
|||
// Current label id, also the number of label ids we currently used.
|
||||
uint32_t CurrentLabel() const { return label_id_; }
|
||||
|
||||
// Create a new parameter pool.
|
||||
ParameterPoolPtr NewParameterPool() { return std::make_shared<ParameterPool>(top_graph_); }
|
||||
// Create or get a parameter for output of the kernel graph.
|
||||
AnfNodePtr GetOutputParameter(const KernelGraphPtr &kg) {
|
||||
// Find output parameter by kernel graph.
|
||||
auto iter = kg_out_param_.find(kg);
|
||||
if (iter != kg_out_param_.end()) {
|
||||
// Return output parameter if found.
|
||||
return iter->second;
|
||||
}
|
||||
// Create a new one if not found.
|
||||
// Output parameters are all created on top graph.
|
||||
auto para = top_graph_->NewParameter(kg->output()->abstract());
|
||||
auto out_para = top_graph_->TransTupleToMakeTuple(para);
|
||||
// This is required, so that device memory can be allocated for it.
|
||||
top_graph_->AddChildGraphResult(out_para);
|
||||
// Save new para as the output parameter of the kg.
|
||||
kg_out_param_.emplace(kg, out_para);
|
||||
return out_para;
|
||||
}
|
||||
|
||||
// Set output parameter for a kernel graph.
|
||||
void SetOutputParameter(const KernelGraphPtr &kg, const AnfNodePtr &out_para) {
|
||||
// Save new para as the output parameter of the kg.
|
||||
kg_out_param_.emplace(kg, out_para);
|
||||
}
|
||||
|
||||
private:
|
||||
// The top graph.
|
||||
const KernelGraphPtr &top_graph_;
|
||||
|
||||
// Map kernel_graph to its output parameter.
|
||||
std::unordered_map<KernelGraphPtr, AnfNodePtr> kg_out_param_;
|
||||
|
||||
// Current label id.
|
||||
uint32_t label_id_ = 1;
|
||||
};
|
||||
|
@ -254,6 +223,7 @@ class AscendAutoMonadConverter {
|
|||
// Prepare information for control flow processing.
|
||||
//
|
||||
void Prepare() {
|
||||
recursive_ = kernel_graph_->has_flag(kFuncGraphFlagRecursive);
|
||||
AnfNodePtr last_monad = nullptr;
|
||||
auto nodes = TopoSort(kernel_graph_->output());
|
||||
for (auto &node : nodes) {
|
||||
|
@ -291,26 +261,25 @@ class AscendAutoMonadConverter {
|
|||
for (auto &cnode : call_switch_nodes_) {
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall)) {
|
||||
HandleCall(cnode);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
|
||||
HandleSwitch(cnode);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
|
||||
HandleSwitchLayer(cnode);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString();
|
||||
}
|
||||
}
|
||||
// If no tail call, assign output value to output parameter,
|
||||
// and then goto the return label if set.
|
||||
if (tail_call_node_ == nullptr) {
|
||||
if (tail_call_node_ == nullptr || recursive_) {
|
||||
if (output_parameter_) {
|
||||
auto assign_output = AssignAll(output_parameter_, kernel_graph_->output());
|
||||
monad_ = UpdateState(GetMonad(), assign_output);
|
||||
}
|
||||
if (return_label_ != kNoLabel) {
|
||||
(void)LabelGoto(return_label_);
|
||||
} else {
|
||||
// Clear end goto if return label not set.
|
||||
kernel_graph_->set_end_goto(nullptr);
|
||||
// Insert label_goto for return.
|
||||
auto return_goto = LabelGoto(return_label_);
|
||||
AnfAlgo::SetNodeAttr(kAttrReturn, prim::kValueOne, return_goto);
|
||||
kernel_graph_->set_end_goto(return_goto);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -348,33 +317,37 @@ class AscendAutoMonadConverter {
|
|||
// as 'select kernel' can handle sub graphs.
|
||||
SetChildGrapAttr(goto_node, {graph});
|
||||
|
||||
// Setup return label if this is not a tail call.
|
||||
// Setup return label if this is not a tail call or it is a recursive call.
|
||||
const bool is_tail_call = (cnode == tail_call_node_);
|
||||
const bool need_return = !is_tail_call;
|
||||
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return);
|
||||
const bool need_return = (!is_tail_call || recursive_);
|
||||
if (!need_return) {
|
||||
// Set as end_goto if no return required.
|
||||
kernel_graph_->set_end_goto(goto_node);
|
||||
}
|
||||
auto [output_para, return_label] = MakeReturn(cnode, {graph}, need_return);
|
||||
|
||||
// Handle sub-graph recursively.
|
||||
HandleSubGraph(graph, para_pool, output_para, return_label);
|
||||
HandleSubGraph(graph, output_para, return_label);
|
||||
}
|
||||
|
||||
//
|
||||
// Convert switch node:
|
||||
// Convert switch/switchlayer node:
|
||||
// branch1 = Partial(graph1, arg)
|
||||
// branch2 = Partial(graph2, arg)
|
||||
// out = Switch(cond, branch1, branch2)
|
||||
// out = Switch/SwitchLayer(cond/index, branch1, branch2)
|
||||
// to:
|
||||
// r = link_args(graph1, arg)
|
||||
// c = UpdateState(c, r)
|
||||
// r = link_args(graph2, arg)
|
||||
// c = UpdateState(c, r)
|
||||
// c = LabelSwitch(cond, c) : L1, L2
|
||||
// c = LabelSwitch(cond/index, c) : L1, L2
|
||||
// c = LabelSet(c) : <return label>
|
||||
//
|
||||
void HandleSwitch(const CNodePtr &cnode) {
|
||||
// Update last_monad_.
|
||||
last_monad_ = monad_map_[cnode];
|
||||
|
||||
// Get both branches of the switch, true branch first.
|
||||
// Get branches of the switch or switchlayer, true or 0 branch first.
|
||||
auto branches = GetSwitchBranches(cnode);
|
||||
|
||||
// Link arguments and generate labels for branches.
|
||||
|
@ -394,63 +367,12 @@ class AscendAutoMonadConverter {
|
|||
labels.push_back(GetOrCreateGraphLabel(graph));
|
||||
}
|
||||
|
||||
// Since true/false branches is reversed in kernel LabelSwitch,
|
||||
// We reverse graphes and labels to make false branch first.
|
||||
std::reverse(graphes.begin(), graphes.end());
|
||||
std::reverse(labels.begin(), labels.end());
|
||||
|
||||
// Add LabelSwith node.
|
||||
auto switch_node = LabelSwitch(cnode->input(1), labels);
|
||||
|
||||
// Set child graph attribute for switch node.
|
||||
SetChildGrapAttr(switch_node, graphes);
|
||||
|
||||
// Setup return label if required.
|
||||
const bool is_tail_call = (cnode == tail_call_node_);
|
||||
const bool need_return = (return_label_ == kNoLabel || !is_tail_call);
|
||||
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return);
|
||||
|
||||
// Handle sub-graphs recursively.
|
||||
for (auto &graph : graphes) {
|
||||
HandleSubGraph(graph, para_pool, output_para, return_label);
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Convert switch node:
|
||||
// branch1 = Partial(graph1, arg)
|
||||
// branch2 = Partial(graph2, arg)
|
||||
// out = SwitchLayer(index, branch1, branch2)
|
||||
// to:
|
||||
// r = link_args(graph1, arg)
|
||||
// c = UpdateState(c, r)
|
||||
// r = link_args(graph2, arg)
|
||||
// c = UpdateState(c, r)
|
||||
// c = LabelSwitch(index, c) : L1, L2
|
||||
// c = LabelSet(c) : <return label>
|
||||
//
|
||||
void HandleSwitchLayer(const CNodePtr &cnode) {
|
||||
// Update last_monad_.
|
||||
last_monad_ = monad_map_[cnode];
|
||||
|
||||
// Get both branches of the switch, true branch first.
|
||||
auto branches = GetSwitchBranches(cnode);
|
||||
|
||||
// Link arguments and generate labels for branches.
|
||||
std::vector<KernelGraphPtr> graphes;
|
||||
std::vector<uint32_t> labels;
|
||||
graphes.reserve(branches.size());
|
||||
labels.reserve(graphes.size());
|
||||
for (auto &[graph, args] : branches) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
|
||||
}
|
||||
auto linked_args = LinkArguments(args, graph);
|
||||
if (linked_args != nullptr) {
|
||||
monad_ = UpdateState(GetMonad(), linked_args);
|
||||
}
|
||||
graphes.push_back(graph);
|
||||
labels.push_back(GetOrCreateGraphLabel(graph));
|
||||
const bool is_switch = AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch);
|
||||
if (is_switch) {
|
||||
// For Switch, we reverse the graphes and labels, so that the false branch
|
||||
// is the first one, since for kernel LabelSwitch, false is the first branch.
|
||||
std::reverse(graphes.begin(), graphes.end());
|
||||
std::reverse(labels.begin(), labels.end());
|
||||
}
|
||||
|
||||
// Add LabelSwith node.
|
||||
|
@ -459,41 +381,42 @@ class AscendAutoMonadConverter {
|
|||
// Set child graph attribute for switch node.
|
||||
SetChildGrapAttr(switch_node, graphes);
|
||||
|
||||
if (!is_switch) {
|
||||
// Mark the switch node is for 'switch_layer'.
|
||||
AnfAlgo::SetNodeAttr(kAttrSwitchLayer, prim::kValueOne, switch_node);
|
||||
}
|
||||
|
||||
// Setup return label if required.
|
||||
const bool is_tail_call = (cnode == tail_call_node_);
|
||||
const bool need_return = (return_label_ == kNoLabel || !is_tail_call);
|
||||
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return);
|
||||
const bool need_return = (return_label_ == kNoLabel || !is_tail_call || recursive_);
|
||||
auto [output_para, return_label] = MakeReturn(cnode, graphes, need_return);
|
||||
|
||||
// Handle sub-graphs recursively.
|
||||
for (auto &graph : graphes) {
|
||||
HandleSubGraph(graph, para_pool, output_para, return_label);
|
||||
HandleSubGraph(graph, output_para, return_label);
|
||||
}
|
||||
}
|
||||
|
||||
ParameterPoolPtr GetParameterPool(bool is_last_call) {
|
||||
if (!is_last_call) {
|
||||
// There are multiple calls in this graph, use a new parameter pool
|
||||
// for each of them except the last one.
|
||||
return context_.NewParameterPool();
|
||||
AnfNodePtr GetOutputParameter(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches) {
|
||||
const bool is_tail_call = (cnode == tail_call_node_);
|
||||
if (is_tail_call && output_parameter_ != nullptr) {
|
||||
return output_parameter_;
|
||||
}
|
||||
// For last call, try reuse parameter pool from the caller.
|
||||
if (para_pool_ == nullptr) {
|
||||
para_pool_ = context_.NewParameterPool();
|
||||
}
|
||||
return para_pool_;
|
||||
return context_.GetOutputParameter(branches.front());
|
||||
}
|
||||
|
||||
// Make return part of a call for the LabelGoto/LabelSwitch node.
|
||||
std::tuple<ParameterPoolPtr, AnfNodePtr, uint32_t> MakeReturn(const CNodePtr &cnode, bool need_return) {
|
||||
// Find a parameter pool for output parameter.
|
||||
const bool is_last_call = (cnode == call_switch_nodes_.back());
|
||||
auto para_pool = GetParameterPool(is_last_call);
|
||||
|
||||
// Prepare return label and output parameter.
|
||||
std::tuple<AnfNodePtr, uint32_t> MakeReturn(const CNodePtr &cnode, const std::vector<KernelGraphPtr> &branches,
|
||||
bool need_return) {
|
||||
// Prepare return label.
|
||||
uint32_t return_label = return_label_;
|
||||
auto output_para = para_pool->GetParameter(cnode->abstract());
|
||||
// Prepare output parameter.
|
||||
auto output_para = GetOutputParameter(cnode, branches);
|
||||
// Use same output parameter for all branches.
|
||||
for (auto &branch : branches) {
|
||||
context_.SetOutputParameter(branch, output_para);
|
||||
}
|
||||
auto output = output_para;
|
||||
|
||||
// Setup return label if return is required.
|
||||
if (need_return) {
|
||||
// Set a new label at return point.
|
||||
|
@ -504,16 +427,14 @@ class AscendAutoMonadConverter {
|
|||
output = MakeDepend(output, label_node);
|
||||
}
|
||||
|
||||
// Replace the the switch node with the output.
|
||||
// Replace the the call/switch node with the output.
|
||||
kernel_graph_->ReplaceNode(NOT_NULL(cnode), NOT_NULL(output));
|
||||
return {para_pool, output_para, return_label};
|
||||
return {output_para, return_label};
|
||||
}
|
||||
|
||||
// Handle sub-graphs recursively.
|
||||
void HandleSubGraph(const KernelGraphPtr &graph, const ParameterPoolPtr ¶_pool, const AnfNodePtr &out_para,
|
||||
uint32_t return_label) {
|
||||
void HandleSubGraph(const KernelGraphPtr &graph, const AnfNodePtr &out_para, uint32_t return_label) {
|
||||
AscendAutoMonadConverter converter(&context_, graph);
|
||||
converter.para_pool_ = para_pool;
|
||||
converter.output_parameter_ = out_para;
|
||||
converter.return_label_ = return_label;
|
||||
converter.Run();
|
||||
|
@ -717,7 +638,6 @@ class AscendAutoMonadConverter {
|
|||
auto cnode = kernel_graph_->NewCNode({label_goto, monad});
|
||||
AnfAlgo::SetNodeAttr(kAttrLabelIndex, MakeValue(label_id), cnode);
|
||||
cnode->set_abstract(monad->abstract());
|
||||
kernel_graph_->set_end_goto(cnode); // make 'goto' the last one in execute order.
|
||||
monad_ = cnode;
|
||||
return cnode;
|
||||
}
|
||||
|
@ -794,11 +714,11 @@ class AscendAutoMonadConverter {
|
|||
// Parameter to store the return value.
|
||||
AnfNodePtr output_parameter_;
|
||||
|
||||
// Parameter pool for output parameter allocation.
|
||||
ParameterPoolPtr para_pool_;
|
||||
|
||||
// The return label id.
|
||||
uint32_t return_label_ = kNoLabel;
|
||||
|
||||
// Is this graph include recursive calls.
|
||||
bool recursive_ = false;
|
||||
};
|
||||
|
||||
constexpr size_t kAssignTargetIndex = 1;
|
||||
|
@ -851,20 +771,22 @@ class ExecuteOrderGenerator {
|
|||
|
||||
std::vector<CNodePtr> execution_order;
|
||||
const auto &cnodes = graph_->execution_order();
|
||||
for (auto cnode : cnodes) {
|
||||
for (auto &cnode : cnodes) {
|
||||
// Push current node to execution order list.
|
||||
execution_order.push_back(cnode);
|
||||
// For cnode with sub-graphs, such as LabelSwitch, LabelGoto,
|
||||
// Generate execute order for these sub-graphs,
|
||||
// and then append them to current execution order list.
|
||||
if (HasSubGraphs(cnode)) {
|
||||
// We use reversed order to generate sub-graph's execution order,
|
||||
// because the true branch of LabelSwitch is the second one, but
|
||||
// we want to make true branch ahead of false branch in the generated
|
||||
// execution order.
|
||||
auto sub_graphs = GetSubGraphs(cnode);
|
||||
for (auto iter = sub_graphs.rbegin(); iter != sub_graphs.rend(); iter++) {
|
||||
auto &sub_graph = *iter;
|
||||
if (!AnfAlgo::HasNodeAttr(kAttrSwitchLayer, cnode)) {
|
||||
// For Switch, we use reversed order to generate sub-graph's execution order,
|
||||
// because the true branch of LabelSwitch is the second one, but
|
||||
// we want to make true branch ahead of false branch in the generated
|
||||
// execution order.
|
||||
std::reverse(sub_graphs.begin(), sub_graphs.end());
|
||||
}
|
||||
for (auto &sub_graph : sub_graphs) {
|
||||
if (context_.IsVisited(sub_graph)) {
|
||||
// Skip visited sub-graphs.
|
||||
continue;
|
||||
|
|
|
@ -398,6 +398,8 @@ constexpr auto kAttrNeedCseAfterRecompute = "need_cse_after_recompute";
|
|||
constexpr auto kAttrParallelDimInfo = "parallel_dim_info";
|
||||
constexpr auto kAttrStitch = "stitch";
|
||||
constexpr auto kAttrTopoSortRhsFirst = "topo_sort_rhs_first";
|
||||
constexpr auto kAttrSwitchLayer = "switch_layer";
|
||||
constexpr auto kAttrReturn = "return";
|
||||
|
||||
// attr value
|
||||
constexpr auto kValueTargetSwitch = "target_switch";
|
||||
|
|
|
@ -86,6 +86,7 @@ const char FUNC_GRAPH_FLAG_SPECIALIZE_PARAMETER[] = "spec_param";
|
|||
const char kFuncGraphFlagUndetermined[] = "Undeterminate";
|
||||
const char kFuncGraphFlagBackPropEntry[] = "BackPropEntry";
|
||||
const char kFuncGraphFlagReAutoMonad[] = "ReAutoMonad";
|
||||
const char kFuncGraphFlagRecursive[] = "Recursive";
|
||||
|
||||
namespace abstract {
|
||||
class AbstractKeywordArg;
|
||||
|
|
|
@ -24,11 +24,12 @@ from mindspore.common import dtype as mstype
|
|||
class CaseNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(CaseNet, self).__init__()
|
||||
self.conv = nn.Conv2d(1, 3, 3)
|
||||
self.conv = nn.Conv2d(1, 1, 3)
|
||||
self.relu = nn.ReLU()
|
||||
self.relu1 = nn.ReLU()
|
||||
self.softmax = nn.Softmax()
|
||||
self.layers1 = (self.relu, self.softmax)
|
||||
self.layers2 = (self.conv, self.relu)
|
||||
self.layers2 = (self.conv, self.relu1)
|
||||
|
||||
def construct(self, x, index1, index2):
|
||||
x = self.layers1[index1](x)
|
||||
|
@ -50,7 +51,3 @@ def test_switch_layer():
|
|||
true_value = relu(data)
|
||||
ret = np.allclose(value.asnumpy(), true_value.asnumpy())
|
||||
assert ret
|
||||
|
||||
idx3 = Tensor(3, mstype.int32)
|
||||
with pytest.raises(IndexError):
|
||||
value = net(data, idx3, idx2)
|
||||
|
|
Loading…
Reference in New Issue