!25748 [MS][LITE]support switch layer

Merge pull request !25748 from mengyuanli/support_switch_layer
This commit is contained in:
i-robot 2021-11-02 01:06:53 +00:00 committed by Gitee
commit 4fd9d5739a
1 changed files with 23 additions and 7 deletions

View File

@ -297,17 +297,33 @@ int MindIRControlFlowAdjust::MoveCallInputsToPartialFusionInputs(const std::set<
auto make_tuple_op = switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex]->cast<CNodePtr>();
auto make_tuple_op_intpus = make_tuple_op->inputs();
for (size_t i = 1; i < make_tuple_op_intpus.size(); i++) {
if (!IsPartialFusion(make_tuple_op_intpus[i])) {
MS_LOG(ERROR) << "switch layer op make tuple inputs not is partial fusion op, not support now.";
if (IsPartialFusion(make_tuple_op_intpus[i])) {
auto partial_node = make_tuple_op_intpus[i]->cast<CNodePtr>();
auto partial_node_inputs = partial_node->inputs();
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_node_inputs));
partial_node->set_inputs(partial_node_inputs);
continue;
}
if (!utils::isa<ValueNodePtr>(make_tuple_op_intpus[i])) {
MS_LOG(ERROR)
<< "switch layer op make tuple inputs not is partial fusion op or function graph, not support now.";
return RET_NOT_SUPPORT;
}
auto partial_node = make_tuple_op_intpus[i]->cast<CNodePtr>();
auto partial_node_inputs = partial_node->inputs();
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_node_inputs));
partial_node->set_inputs(partial_node_inputs);
auto make_tuple_op_value_input = make_tuple_op_intpus[i]->cast<ValueNodePtr>();
if (GetValueNode<FuncGraphPtr>(make_tuple_op_value_input) == nullptr) {
MS_LOG(ERROR)
<< "switch layer op make tuple inputs not is partial fusion op or function graph, not support now.";
return RET_NOT_SUPPORT;
}
std::vector<AnfNodePtr> partial_cnode_inputs = {lite::GetPartialFusionPrim(), make_tuple_op_value_input};
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs));
auto partial_cnode = graph->NewCNode(partial_cnode_inputs);
MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "Failed to create C node.");
partial_cnode->set_fullname_with_scope("partial_" + make_tuple_op->fullname_with_scope() + "_" +
std::to_string(i));
make_tuple_op->set_input(i, partial_cnode);
}
}
call_cnode->set_inputs({call_first_input_cnode});
}
}