!26812 add dump flag for transdata which is inserted for multi-output node

Merge pull request !26812 from yuchaojie/ir_fusion
This commit is contained in:
i-robot 2021-11-26 07:03:36 +00:00 committed by Gitee
commit 900c9ef88b
1 changed files with 11 additions and 5 deletions

View File

@ -212,24 +212,30 @@ AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePt
std::string trans_opname = (spec_format == kOpFormat_FRACTAL_ZN_RNN || spec_format == kOpFormat_ND_RNN_BIAS)
? prim::kPrimTransDataRNN->name()
: prim::kPrimTransData->name();
auto orig_node = node;
if (!is_insert_input && node->isa<CNode>() && AnfAlgo::GetCNodeName(node) == prim::kTupleGetItem) {
auto cnode = node->cast<CNodePtr>();
orig_node = cnode->input(kRealInputNodeIndexInTupleGetItem);
}
if (!need_padding) {
// don't need padding insert transdata only
trans_data = NewTransOpNode(func_graph, input_node, node, kernel_select, need_padding, trans_opname);
trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname);
trans_node = trans_data;
} else if (is_insert_input) {
// if need padding & is input need insert a transdata
// reshape[padding shape] -> transdata[padding shape] -> node
auto padding_shape = trans::PaddingShape(input_node_out_shape, AnfAlgo::GetInputFormat(node, insert_index),
AnfAlgo::GetInputReshapeType(node, insert_index));
auto reshape_node = CreateReshapeNode(func_graph, input_node, node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, node, kernel_select, need_padding, trans_opname);
auto reshape_node = CreateReshapeNode(func_graph, input_node, orig_node, kernel_select, padding_shape);
trans_data = NewTransOpNode(func_graph, reshape_node, orig_node, kernel_select, need_padding, trans_opname);
trans_node = trans_data;
trans_data->set_abstract(input_node->abstract());
} else {
// if need padding & is output need insert a transdata
// node -> transdata[padding shape] -> reshape[ori_shape]
trans_data = NewTransOpNode(func_graph, input_node, node, kernel_select, need_padding, trans_opname);
auto reshape_node = CreateReshapeNode(func_graph, trans_data, node, kernel_select, input_node_out_shape);
trans_data = NewTransOpNode(func_graph, input_node, orig_node, kernel_select, need_padding, trans_opname);
auto reshape_node = CreateReshapeNode(func_graph, trans_data, orig_node, kernel_select, input_node_out_shape);
trans_node = reshape_node;
}
if (trans_opname == prim::kPrimTransDataRNN->name()) {