forked from mindspore-Ecosystem/mindspore
!26812 add dump flag for transdata which is inserted for multi-output node
Merge pull request !26812 from yuchaojie/ir_fusion
This commit is contained in:
commit
900c9ef88b
|
@ -212,24 +212,30 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
|
|||
std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
|
||||
? prim::kPrimTransDataRNN->name()
|
||||
: prim::kPrimTransData->name();
|
||||
auto orig_node = node;
|
||||
if (!is_insert_input && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == prim::kTupleGetItem) {
|
||||
auto cnode = node->cast<CNodePtr>();
|
||||
orig_node = cnode->input(kRealInputNodeIndexInTupleGetItem);
|
||||
}
|
||||
|
||||
if (!need_padding) {
|
||||
// don't need padding insert transdata only
|
||||
trans_data = NewTransOpNode(func_graph, input_node, node, kernel_select, need_padding, trans_opname);
|
||||
trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname);
|
||||
trans_node = trans_data;
|
||||
} else if (is_insert_input) {
|
||||
// if need padding & is input need insert a transdata
|
||||
// reshape[padding shape] -> transdata[padding shape] -> node
|
||||
auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index),
|
||||
AnfAlgo::GetInputReshapeType(node, insert_index));
|
||||
auto reshape_node = CreateReshapeNode(func_graph, input_node, node, kernel_select, padding_shape);
|
||||
trans_data = NewTransOpNode(func_graph, reshape_node, node, kernel_select, need_padding, trans_opname);
|
||||
auto reshape_node = CreateReshapeNode(func_graph, input_node, orig_node, kernel_select, padding_shape);
|
||||
trans_data = NewTransOpNode(func_graph, reshape_node, orig_node, kernel_select, need_padding, trans_opname);
|
||||
trans_node = trans_data;
|
||||
trans_data->set_abstract(input_node->abstract());
|
||||
} else {
|
||||
// if need padding & is output need insert a transdata
|
||||
// node -> transdata[padding shape] -> reshape[ori_shape]
|
||||
trans_data = NewTransOpNode(func_graph, input_node, node, kernel_select, need_padding, trans_opname);
|
||||
auto reshape_node = CreateReshapeNode(func_graph, trans_data, node, kernel_select, input_node_out_shape);
|
||||
trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname);
|
||||
auto reshape_node = CreateReshapeNode(func_graph, trans_data, orig_node, kernel_select, input_node_out_shape);
|
||||
trans_node = reshape_node;
|
||||
}
|
||||
if (trans_opname == prim::kPrimTransDataRNN->name()) {
|
||||
|
|
Loading…
Reference in New Issue