forked from mindspore-Ecosystem/mindspore
!25748 [MS][LITE]support switch layer
Merge pull request !25748 from mengyuanli/support_switch_layer
This commit is contained in:
commit
4fd9d5739a
|
@ -297,17 +297,33 @@ int MindIRControlFlowAdjust::MoveCallInputsToPartialFusionInputs(const std::set<
|
|||
auto make_tuple_op = switch_layer_cnode_inputs[kSwitchLayerMakeTupleIndex]->cast<CNodePtr>();
|
||||
auto make_tuple_op_intpus = make_tuple_op->inputs();
|
||||
for (size_t i = 1; i < make_tuple_op_intpus.size(); i++) {
|
||||
if (!IsPartialFusion(make_tuple_op_intpus[i])) {
|
||||
MS_LOG(ERROR) << "switch layer op make tuple inputs not is partial fusion op, not support now.";
|
||||
if (IsPartialFusion(make_tuple_op_intpus[i])) {
|
||||
auto partial_node = make_tuple_op_intpus[i]->cast<CNodePtr>();
|
||||
auto partial_node_inputs = partial_node->inputs();
|
||||
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_node_inputs));
|
||||
partial_node->set_inputs(partial_node_inputs);
|
||||
continue;
|
||||
}
|
||||
if (!utils::isa<ValueNodePtr>(make_tuple_op_intpus[i])) {
|
||||
MS_LOG(ERROR)
|
||||
<< "switch layer op make tuple inputs not is partial fusion op or function graph, not support now.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
auto partial_node = make_tuple_op_intpus[i]->cast<CNodePtr>();
|
||||
auto partial_node_inputs = partial_node->inputs();
|
||||
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_node_inputs));
|
||||
partial_node->set_inputs(partial_node_inputs);
|
||||
auto make_tuple_op_value_input = make_tuple_op_intpus[i]->cast<ValueNodePtr>();
|
||||
if (GetValueNode<FuncGraphPtr>(make_tuple_op_value_input) == nullptr) {
|
||||
MS_LOG(ERROR)
|
||||
<< "switch layer op make tuple inputs not is partial fusion op or function graph, not support now.";
|
||||
return RET_NOT_SUPPORT;
|
||||
}
|
||||
std::vector<AnfNodePtr> partial_cnode_inputs = {lite::GetPartialFusionPrim(), make_tuple_op_value_input};
|
||||
std::copy(call_cnode_inputs.begin() + 1, call_cnode_inputs.end(), std::back_inserter(partial_cnode_inputs));
|
||||
auto partial_cnode = graph->NewCNode(partial_cnode_inputs);
|
||||
MS_CHECK_TRUE_MSG(partial_cnode != nullptr, RET_NULL_PTR, "Failed to create C node.");
|
||||
partial_cnode->set_fullname_with_scope("partial_" + make_tuple_op->fullname_with_scope() + "_" +
|
||||
std::to_string(i));
|
||||
make_tuple_op->set_input(i, partial_cnode);
|
||||
}
|
||||
}
|
||||
|
||||
call_cnode->set_inputs({call_first_input_cnode});
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue