forked from mindspore-Ecosystem/mindspore
!12732 fix switch layer
From: @youui Reviewed-by: @guoqi1024,@zhoufeng54 Signed-off-by: @guoqi1024
This commit is contained in:
commit
a5af03f8ca
|
@ -270,8 +270,9 @@ class AscendAutoMonadConverter {
|
|||
MS_LOG(EXCEPTION) << "Invalid CNode: " << cnode->DebugString() << std::endl;
|
||||
}
|
||||
if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimCall) ||
|
||||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
// Found call/switch node, set it as the tail call node.
|
||||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch) ||
|
||||
AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
|
||||
// Found call/switch/switchlayer node, set it as the tail call node.
|
||||
tail_call_node_ = cnode;
|
||||
call_switch_nodes_.emplace_back(cnode);
|
||||
monad_map_.emplace(cnode, last_monad);
|
||||
|
@ -292,8 +293,10 @@ class AscendAutoMonadConverter {
|
|||
HandleCall(cnode);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitch)) {
|
||||
HandleSwitch(cnode);
|
||||
} else if (AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimSwitchLayer)) {
|
||||
HandleSwitchLayer(cnode);
|
||||
} else {
|
||||
MS_LOG(EXCEPTION) << "Not a call/switch node: " << cnode->DebugString();
|
||||
MS_LOG(EXCEPTION) << "Not a call/switch/switchlayer node: " << cnode->DebugString();
|
||||
}
|
||||
}
|
||||
// If no tail call, assign output value to output parameter,
|
||||
|
@ -413,6 +416,60 @@ class AscendAutoMonadConverter {
|
|||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Convert switch node:
|
||||
// branch1 = Partial(graph1, arg)
|
||||
// branch2 = Partial(graph2, arg)
|
||||
// out = SwitchLayer(index, branch1, branch2)
|
||||
// to:
|
||||
// r = link_args(graph1, arg)
|
||||
// c = UpdateState(c, r)
|
||||
// r = link_args(graph2, arg)
|
||||
// c = UpdateState(c, r)
|
||||
// c = LabelSwitch(index, c) : L1, L2
|
||||
// c = LabelSet(c) : <return label>
|
||||
//
|
||||
void HandleSwitchLayer(const CNodePtr &cnode) {
|
||||
// Update last_monad_.
|
||||
last_monad_ = monad_map_[cnode];
|
||||
|
||||
// Get both branches of the switch, true branch first.
|
||||
auto branches = GetSwitchBranches(cnode);
|
||||
|
||||
// Link arguments and generate labels for branches.
|
||||
std::vector<KernelGraphPtr> graphes;
|
||||
std::vector<uint32_t> labels;
|
||||
graphes.reserve(branches.size());
|
||||
labels.reserve(graphes.size());
|
||||
for (auto &[graph, args] : branches) {
|
||||
if (graph == nullptr) {
|
||||
MS_LOG(EXCEPTION) << "Invalid switch: " << cnode->DebugString();
|
||||
}
|
||||
auto linked_args = LinkArguments(args, graph);
|
||||
if (linked_args != nullptr) {
|
||||
monad_ = UpdateState(GetMonad(), linked_args);
|
||||
}
|
||||
graphes.push_back(graph);
|
||||
labels.push_back(GetOrCreateGraphLabel(graph));
|
||||
}
|
||||
|
||||
// Add LabelSwith node.
|
||||
auto switch_node = LabelSwitch(cnode->input(1), labels);
|
||||
|
||||
// Set child graph attribute for switch node.
|
||||
SetChildGrapAttr(switch_node, graphes);
|
||||
|
||||
// Setup return label if required.
|
||||
const bool is_tail_call = (cnode == tail_call_node_);
|
||||
const bool need_return = (return_label_ == kNoLabel || !is_tail_call);
|
||||
auto [para_pool, output_para, return_label] = MakeReturn(cnode, need_return);
|
||||
|
||||
// Handle sub-graphs recursively.
|
||||
for (auto &graph : graphes) {
|
||||
HandleSubGraph(graph, para_pool, output_para, return_label);
|
||||
}
|
||||
}
|
||||
|
||||
ParameterPoolPtr GetParameterPool(bool is_last_call) {
|
||||
if (!is_last_call) {
|
||||
// There are multiple calls in this graph, use a new parameter pool
|
||||
|
@ -483,10 +540,13 @@ class AscendAutoMonadConverter {
|
|||
}
|
||||
|
||||
std::vector<GraphArgPair> GetSwitchBranches(const CNodePtr &cnode) {
|
||||
constexpr size_t true_index = 2;
|
||||
constexpr size_t false_index = 3;
|
||||
// True branch first, then false branch.
|
||||
return {GetSwitchBranch(cnode, true_index), GetSwitchBranch(cnode, false_index)};
|
||||
constexpr size_t cond_start_index = 2;
|
||||
// switch branches
|
||||
std::vector<GraphArgPair> switch_branches;
|
||||
for (size_t index = cond_start_index; index < cnode->inputs().size(); ++index) {
|
||||
switch_branches.emplace_back(GetSwitchBranch(cnode, index));
|
||||
}
|
||||
return switch_branches;
|
||||
}
|
||||
|
||||
//
|
||||
|
|
|
@ -928,9 +928,8 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchInputs(const CNodePtr &cno
|
|||
return cnode_inputs;
|
||||
}
|
||||
|
||||
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input) {
|
||||
void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector<AnfNodePtr> &real_inputs) {
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
MS_EXCEPTION_IF_NULL(real_input);
|
||||
if (!(AnfAlgo::CheckPrimitiveType(cnode, prim::kPrimPartial))) {
|
||||
MS_LOG(EXCEPTION) << "Node: " << cnode->DebugString() << "is not a partial node.";
|
||||
}
|
||||
|
@ -940,24 +939,37 @@ void SessionBasic::CreateCallNodeReturnFunction(const CNodePtr &cnode, const Anf
|
|||
auto ret = partial_kernel_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto return_input = ret->input(kFirstDataInputIndex);
|
||||
// if kernel graph return node is a function
|
||||
// return node is a function
|
||||
std::vector<AnfNodePtr> call_inputs = {
|
||||
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
AnfNodePtr real_kernel_graph;
|
||||
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial)) {
|
||||
std::vector<AnfNodePtr> call_inputs = {
|
||||
partial_kernel_graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimCall->name())))};
|
||||
auto return_input_cnode = return_input->cast<CNodePtr>();
|
||||
|
||||
auto partial_inputs = return_input_cnode->inputs();
|
||||
call_inputs.insert(call_inputs.end(), partial_inputs.begin() + kFirstDataInputIndex, partial_inputs.end());
|
||||
real_kernel_graph = partial_inputs[kFirstDataInputIndex];
|
||||
} else { // return node is kernel graph
|
||||
call_inputs.emplace_back(return_input);
|
||||
real_kernel_graph = return_input;
|
||||
}
|
||||
|
||||
// new call node inputs
|
||||
for (auto real_input : real_inputs) {
|
||||
auto parameter_for_input = CreateNewParameterFromCNode(real_input, partial_kernel_graph.get());
|
||||
call_inputs.emplace_back(parameter_for_input);
|
||||
auto call_node = partial_kernel_graph->NewCNode(call_inputs);
|
||||
// update abstract
|
||||
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_inputs[kFirstDataInputIndex]);
|
||||
}
|
||||
|
||||
auto call_node = partial_kernel_graph->NewCNode(call_inputs);
|
||||
// update abstract
|
||||
MS_EXCEPTION_IF_NULL(real_kernel_graph);
|
||||
if (real_kernel_graph->isa<ValueNode>() && IsValueNode<FuncGraph>(real_kernel_graph)) {
|
||||
KernelGraphPtr sub_partial_kernel_graph = GetValueNode<KernelGraphPtr>(real_kernel_graph);
|
||||
MS_EXCEPTION_IF_NULL(sub_partial_kernel_graph);
|
||||
auto ret_partial = sub_partial_kernel_graph->get_return();
|
||||
call_node->set_abstract(ret_partial->abstract());
|
||||
// update return input
|
||||
ret->set_input(kFirstDataInputIndex, call_node);
|
||||
}
|
||||
// update return input
|
||||
ret->set_input(kFirstDataInputIndex, call_node);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph) {
|
||||
|
@ -977,9 +989,11 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|||
auto node = make_tuple_node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto make_tuple_inputs = node->inputs();
|
||||
// there is real input in call, should put it to make_tuple in switch_layer
|
||||
auto real_input = cnode->input(kFirstDataInputIndex);
|
||||
auto real_input_back = graph->GetBackendAnfByFrontAnf(real_input);
|
||||
// there are real inputs in call, should put it to make_tuple in switch_layer
|
||||
std::vector<AnfNodePtr> real_inputs;
|
||||
for (size_t idx = kFirstDataInputIndex; idx < cnode->inputs().size(); ++idx) {
|
||||
real_inputs.emplace_back(graph->GetBackendAnfByFrontAnf(cnode->input(idx)));
|
||||
}
|
||||
std::vector<AnfNodePtr> new_make_tuple_inputs = {
|
||||
graph->NewValueNode(NewValueNode(std::make_shared<Primitive>(prim::kPrimMakeTuple->name())))};
|
||||
for (size_t idx = kFirstDataInputIndex; idx < make_tuple_inputs.size(); idx++) {
|
||||
|
@ -990,10 +1004,18 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|||
auto partial_node = partial_idx->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(partial_node);
|
||||
// update kernel graph when switch_layer node return function
|
||||
CreateCallNodeReturnFunction(partial_node, real_input_back);
|
||||
auto partial_input = partial_node->input(kFirstDataInputIndex);
|
||||
KernelGraphPtr partial_kernel_graph = GetValueNode<KernelGraphPtr>(partial_input);
|
||||
MS_EXCEPTION_IF_NULL(partial_kernel_graph);
|
||||
auto ret = partial_kernel_graph->get_return();
|
||||
MS_EXCEPTION_IF_NULL(ret);
|
||||
auto return_input = ret->input(kFirstDataInputIndex);
|
||||
if (AnfAlgo::CheckPrimitiveType(return_input, prim::kPrimPartial) || IsValueNode<KernelGraph>(return_input)) {
|
||||
CreateCallNodeReturnFunction(partial_node, real_inputs);
|
||||
}
|
||||
|
||||
std::vector<AnfNodePtr> new_partial_inputs = partial_node->inputs();
|
||||
new_partial_inputs.emplace_back(real_input_back);
|
||||
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
|
||||
auto new_partial = graph->NewCNode(new_partial_inputs);
|
||||
new_make_tuple_inputs.emplace_back(new_partial);
|
||||
}
|
||||
|
@ -1003,7 +1025,7 @@ std::vector<AnfNodePtr> SessionBasic::CreateCallSwitchLayerInputs(const CNodePtr
|
|||
std::vector<AnfNodePtr> new_partial_inputs;
|
||||
new_partial_inputs.emplace_back(NewValueNode(std::make_shared<Primitive>(prim::kPrimPartial->name())));
|
||||
new_partial_inputs.emplace_back(partial_idx);
|
||||
new_partial_inputs.emplace_back(real_input_back);
|
||||
new_partial_inputs.insert(new_partial_inputs.end(), real_inputs.begin(), real_inputs.end());
|
||||
auto new_partial = graph->NewCNode(new_partial_inputs);
|
||||
new_make_tuple_inputs.emplace_back(new_partial);
|
||||
}
|
||||
|
|
|
@ -147,7 +147,7 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
|
|||
void GetNewCNodeInputs(const CNodePtr &cnode, KernelGraph *graph, std::vector<AnfNodePtr> *cnode_inputs,
|
||||
std::unordered_map<AnfNodePtr, AnfNodePtr> *other_graph_cnode);
|
||||
std::vector<AnfNodePtr> CreateCallSwitchLayerInputs(const CNodePtr &cnode, KernelGraph *graph);
|
||||
void CreateCallNodeReturnFunction(const CNodePtr &cnode, const AnfNodePtr &real_input);
|
||||
void CreateCallNodeReturnFunction(const CNodePtr &cnode, const std::vector<AnfNodePtr> &real_inputs);
|
||||
|
||||
protected:
|
||||
friend class Executor;
|
||||
|
|
Loading…
Reference in New Issue