diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc index f95882994e7..05d36fd4f22 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.cc @@ -31,156 +31,18 @@ namespace mindspore { namespace opt { using KernelBuildInfoBuilder = kernel::KernelBuildInfo::KernelBuildInfoBuilder; namespace { -AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, - const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { - std::vector trans_inputs; - auto prim = std::make_shared(prim::kPrimReshape->name()); - trans_inputs.emplace_back(NewValueNode(prim)); - trans_inputs.emplace_back(input_node); - auto reshape = func_graph->NewCNode(trans_inputs); - AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get()); - AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); - AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape); - reshape->set_scope(input_node->scope()); - kernel_select->SelectKernel(reshape); - return reshape; -} - -AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select, size_t insert_index, bool is_insert_input) { - AnfNodePtr trans_node = nullptr; - AnfNodePtr input_node = node; - CNodePtr trans_data = nullptr; - std::string input_format = is_insert_input ? kOpFormat_DEFAULT : AnfAlgo::GetOutputFormat(node, 0); - std::string dst_format = is_insert_input ? AnfAlgo::GetInputFormat(node, 0) : kOpFormat_DEFAULT; - TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); - std::vector padding_axis = AnfAlgo::GetOutputReshapeType(node, 0); - MS_EXCEPTION_IF_NULL(node); - // if insert transdata for input we need to change the input - if (is_insert_input) { - if (!node->isa()) { - MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; - } - auto cnode = node->cast(); - dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); - dst_format = AnfAlgo::GetInputFormat(cnode, insert_index); - input_node = AnfAlgo::GetInputNode(cnode, insert_index); - padding_axis = AnfAlgo::GetInputReshapeType(node, 0); - } - bool need_padding = false; - if (is_insert_input) { - need_padding = (trans::IsNeedPadding(dst_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); - } else { - need_padding = (trans::IsNeedPadding(input_format, AnfAlgo::GetOutputInferShape(input_node, 0).size())); - } - if (!need_padding) { - // don't need padding insert transdata only - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); - trans_node = trans_data; - } else if (is_insert_input) { - // if need padding & is input need insert a transdata - // reshape[padding shape] -> transdata[padding shape] -> node - auto padding_shape = - trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); - auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); - trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, prim::KPrimTransData->name()); - trans_node = trans_data; - } else { - // if need padding & is output need insert a transdata - // node -> transdata[padding shape] -> reshape[ori_shape] - trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, prim::KPrimTransData->name()); - auto reshape_node = - CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); - trans_node = reshape_node; - } - // refresh the transdata's format to ori format & dst format - RefreshKernelBuildInfo(input_format, dst_format, dtype, trans_data, padding_axis); - return trans_node; -} - -AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(node); - auto input_node = AnfAlgo::GetInputNode(node, index); - auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); - MS_EXCEPTION_IF_NULL(node_with_index.first); - auto real_input = node_with_index.first; - if (real_input->isa() || real_input->isa()) { - input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); - MS_EXCEPTION_IF_NULL(input_node); - AnfAlgo::SetNodeInput(node, input_node, index); - } - if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) { - MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index) - << "when inserting the transdata node " << node->DebugString(); - } - std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); - std::string dest_format = AnfAlgo::GetInputFormat(node, index); - if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { - MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) - << " To DefaultFormat , index: " << index; - return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, true); - } - return input_node; -} - -AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(node); - std::string output_format = AnfAlgo::GetOutputFormat(node, 0); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, 0); - if (output_format == kOpFormat_NC1KHKWHWC0) { - MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " - << node->DebugString(); - } - if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { - MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; - return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, false); - } - return node; -} - -AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, - const KernelSelectPtr &kernel_select) { - MS_EXCEPTION_IF_NULL(func_graph); - MS_EXCEPTION_IF_NULL(node); - std::vector make_tuple_inputs; - make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); - for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { - std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); - if (output_format == kOpFormat_NC1KHKWHWC0) { - MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " - << node->DebugString(); - } - auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); - std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); - if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { - make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, false)); - } else { - // No need insert trans op. - make_tuple_inputs.push_back(tuple_getitem); - } - } - AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); - return make_tuple; -} -} // namespace -void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, - const AnfNodePtr &trans_data, const std::vector &reshape_type) { - MS_EXCEPTION_IF_NULL(trans_data); - MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); - auto ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); +kernel::KernelBuildInfoPtr RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, + const AnfNodePtr &node, const TypeId device_type, + const kernel::KernelBuildInfo &ori_build_info) { KernelBuildInfoBuilder builder; builder.SetInputsFormat({input_format}); - builder.SetInputReshapeType({reshape_type}); - builder.SetInputReshapeType({reshape_type}); builder.SetOutputsFormat({output_format}); builder.SetInputsDeviceType({device_type}); builder.SetOutputsDeviceType({device_type}); - builder.SetKernelType(ori_build_info->kernel_type()); - builder.SetFusionType(ori_build_info->fusion_type()); - builder.SetProcessor(ori_build_info->processor()); - AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), trans_data.get()); + builder.SetKernelType(ori_build_info.kernel_type()); + builder.SetFusionType(ori_build_info.fusion_type()); + builder.SetProcessor(ori_build_info.processor()); + return builder.Build(); } CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, @@ -193,7 +55,8 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, trans_inputs.push_back(input); CNodePtr trans_node = func_graph->NewCNode(trans_inputs); MS_EXCEPTION_IF_NULL(trans_node); - auto padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); + std::vector padding_axis; + padding_axis = AnfAlgo::GetOutputReshapeType(input, 0); if (need_padding) { // if need padding we should set the transdata node's shape to the padding shape AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input, 0)}, @@ -216,6 +79,154 @@ CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, return trans_node; } +AnfNodePtr CreateReshapeNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input_node, + const KernelSelectPtr &kernel_select, const std::vector &dst_shape) { + std::vector trans_inputs; + auto prim = std::make_shared(prim::kPrimReshape->name()); + trans_inputs.emplace_back(NewValueNode(prim)); + trans_inputs.emplace_back(input_node); + auto reshape = func_graph->NewCNode(trans_inputs); + AnfAlgo::SetOutputInferTypeAndShape({AnfAlgo::GetOutputInferDataType(input_node, 0)}, {dst_shape}, reshape.get()); + AnfAlgo::SetNodeAttr(kAttrVisited, MakeValue(true), reshape); + AnfAlgo::SetNodeAttr(kAttrShape, MakeValue(dst_shape), reshape); + reshape->set_scope(input_node->scope()); + kernel_select->SelectKernel(reshape); + return reshape; +} + +AnfNodePtr GetTransInputNodePtr(const FuncGraphPtr &func_graph, const CNodePtr &node, size_t index, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(node); + auto input_node = AnfAlgo::GetInputNode(node, index); + auto node_with_index = AnfAlgo::VisitKernel(input_node, 0); + MS_EXCEPTION_IF_NULL(node_with_index.first); + auto real_input = node_with_index.first; + if (real_input->isa() || real_input->isa()) { + input_node = InsertTransOpForOutput(func_graph, input_node, kernel_select); + MS_EXCEPTION_IF_NULL(input_node); + AnfAlgo::SetNodeInput(node, input_node, index); + } + if (AnfAlgo::GetInputFormat(node, index) == kOpFormat_NC1KHKWHWC0) { + MS_LOG(EXCEPTION) << "got the format " << AnfAlgo::GetInputFormat(node, index) + << "when inserting the transdata node " << node->DebugString(); + } + std::vector origin_shape = AnfAlgo::GetPrevNodeOutputInferShape(node, index); + std::string origin_format = kOpFormat_DEFAULT; + std::string dest_format = AnfAlgo::GetInputFormat(node, index); + if (kNeedTransFormatSet.find(dest_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { + MS_LOG(DEBUG) << node->DebugString() << "Insert transdata " << AnfAlgo::GetInputFormat(node, index) + << " To DefaultFormat , index: " << index; + return AddTransOpNodeToGraph(func_graph, node, kernel_select, index, origin_format, dest_format, kTransDataOpName, + true); + } + return input_node; +} + +AnfNodePtr InsertTransOpForSingleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(node); + std::string output_format = AnfAlgo::GetOutputFormat(node, 0); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, 0); + if (output_format == kOpFormat_NC1KHKWHWC0) { + MS_LOG(EXCEPTION) << "got the hw format " << output_format << "when insert the transdata node " + << node->DebugString(); + } + std::string origin_format = output_format; + std::string dest_format = kOpFormat_DEFAULT; + if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { + MS_LOG(DEBUG) << "Inserted Transdata " << output_format << " To default , index :0"; + return AddTransOpNodeToGraph(func_graph, node, kernel_select, 0, origin_format, dest_format, kTransDataOpName, + false); + } + return node; +} + +AnfNodePtr InsertTransOpForMultipleOutput(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select) { + MS_EXCEPTION_IF_NULL(func_graph); + MS_EXCEPTION_IF_NULL(node); + std::vector make_tuple_inputs; + make_tuple_inputs.push_back(NewValueNode(prim::kPrimMakeTuple)); + for (size_t output_idx = 0; output_idx < AnfAlgo::GetOutputTensorNum(node); ++output_idx) { + std::string output_format = AnfAlgo::GetOutputFormat(node, output_idx); + if (output_format == kOpFormat_NC1KHKWHWC0) { + MS_LOG(EXCEPTION) << "Got the special format" << output_format << " when insert the transdata node " + << node->DebugString(); + } + auto tuple_getitem = CreatTupleGetItemNode(func_graph, node, output_idx); + std::vector origin_shape = AnfAlgo::GetOutputInferShape(node, output_idx); + std::string dest_format = kOpFormat_DEFAULT; + if (kNeedTransFormatSet.find(output_format) != kNeedTransFormatSet.end() && origin_shape.size() > 1) { + make_tuple_inputs.emplace_back(AddTransOpNodeToGraph(func_graph, tuple_getitem, kernel_select, 0, output_format, + dest_format, kTransDataOpName, false)); + } else { + // No need insert trans op. + make_tuple_inputs.push_back(tuple_getitem); + } + } + AnfNodePtr make_tuple = func_graph->NewCNode(make_tuple_inputs); + return make_tuple; +} +} // namespace +AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select, size_t insert_index, + const std::string &origin_format, const std::string &dest_format, + const std::string &op_name, bool is_insert_input) { + AnfNodePtr trans_node = nullptr; + AnfNodePtr input_node = node; + AnfNodePtr trans_data = nullptr; + TypeId dtype = AnfAlgo::GetOutputDeviceDataType(node, 0); + MS_EXCEPTION_IF_NULL(node); + if (origin_format.empty() || dest_format.empty()) { + MS_LOG(EXCEPTION) << "trans op format is error, origin = " << origin_format << ", dest " << origin_format; + } + // if insert transdata for input we need to change the input + if (is_insert_input) { + if (!node->isa()) { + MS_LOG(EXCEPTION) << "cannot insert a transdata node to a node's input which the node is not a cnode"; + } + auto cnode = node->cast(); + dtype = AnfAlgo::GetInputDeviceDataType(cnode, insert_index); + MS_EXCEPTION_IF_NULL(cnode); + input_node = AnfAlgo::GetInputNode(cnode, insert_index); + } + bool need_padding = false; + if (is_insert_input) { + need_padding = (trans::IsNeedPadding(dest_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && + op_name == kTransDataOpName); + } else { + need_padding = (trans::IsNeedPadding(origin_format, AnfAlgo::GetOutputInferShape(input_node, 0).size()) && + op_name == kTransDataOpName); + } + if (!need_padding) { + // don't need padding insert transdata only + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); + trans_node = trans_data; + } else if (is_insert_input) { + // if need padding & is input need insert a transdata + // reshape[padding shape] -> transdata[padding shape] -> node + auto padding_shape = + trans::PaddingShapeTo4d(AnfAlgo::GetOutputInferShape(input_node, 0), AnfAlgo::GetInputReshapeType(node, 0)); + auto reshape_node = CreateReshapeNode(func_graph, input_node, kernel_select, padding_shape); + trans_data = NewTransOpNode(func_graph, reshape_node, kernel_select, need_padding, op_name); + trans_node = trans_data; + } else { + // if need padding & is output need insert a transdata + // node -> transdata[padding shape] -> reshape[ori_shape] + trans_data = NewTransOpNode(func_graph, input_node, kernel_select, need_padding, op_name); + auto reshape_node = + CreateReshapeNode(func_graph, trans_data, kernel_select, AnfAlgo::GetOutputInferShape(input_node, 0)); + trans_node = reshape_node; + } + // refresh the transdata's format to ori format & dst format + MS_EXCEPTION_IF_NULL(trans_data); + MS_EXCEPTION_IF_NULL(trans_data->kernel_info()); + auto trans_ori_build_info = trans_data->kernel_info()->select_kernel_build_info(); + auto kernel_build_info = RefreshKernelBuildInfo(origin_format, dest_format, input_node, dtype, *trans_ori_build_info); + AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info, trans_data.get()); + return trans_node; +} + AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, const TypeId &input_type, const TypeId &output_type, const std::vector &origin_shape, const TypeId &origin_type) { diff --git a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h index 66e3f2ad330..a5463131b47 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h +++ b/mindspore/ccsrc/pre_activate/ascend/ascend_helper.h @@ -58,11 +58,11 @@ class KernelQuery { } }; using KernelQueryPtr = std::shared_ptr; -void RefreshKernelBuildInfo(const std::string &input_format, const std::string &output_format, const TypeId device_type, - const AnfNodePtr &trans_data, const std::vector &reshape_type = {}); -CNodePtr NewTransOpNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const KernelSelectPtr &kernel_select, - const bool need_padding, const std::string &op_name); +AnfNodePtr AddTransOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const KernelSelectPtr &kernel_select, size_t insert_index, + const std::string &origin_format, const std::string &dest_format, + const std::string &op_name, bool is_insert_input); AnfNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &input, const std::string &format, const TypeId &input_type, const TypeId &output_type, diff --git a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc index 43857dddfd8..a9196c5c428 100644 --- a/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc +++ b/mindspore/ccsrc/pre_activate/ascend/format_type/deal_ref_trans_and_cast.cc @@ -105,8 +105,8 @@ AnfNodePtr AddAdditionalToRefOutput(const FuncGraphPtr &func_graph, const CNodeP // insert trans if (origin_format != cur_format && cur_shape.size() > 1) { auto kernel_select = std::make_shared(); - final_node = NewTransOpNode(func_graph, final_node, kernel_select, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(cur_format, origin_format, origin_type, final_node); + final_node = AddTransOpNodeToGraph(func_graph, final_node, kernel_select, 0, cur_format, origin_format, + kTransDataOpName, false); final_index = 0; MS_EXCEPTION_IF_NULL(final_node); MS_LOG(INFO) << "DealRefTransAndCast add trans op, op debug info is " << final_node->DebugString(); diff --git a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc index 0305104f5b5..2c77794b145 100644 --- a/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc +++ b/mindspore/ccsrc/pre_activate/ascend/ir_fission/transdata_split.cc @@ -67,30 +67,22 @@ bool TransDataSplit::DoSplit(const FuncGraphPtr &func_graph, const AnfNodePtr &n // if output_format=default transdata need split transdata->transpose else transpose->transdata if (output_format == kOpFormat_DEFAULT || output_format == kOpFormat_NCHW) { // trans input_format to hwcn - new_transdata_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, - false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), - new_transdata_node); + new_transdata_node = + AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, kTransDataOpName, true); // trans hwcn to default_format - new_transpose_node = - NewTransOpNode(func_graph, new_transdata_node, kernel_select_, false, prim::kPrimTranspose->name()); - RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), - new_transpose_node); + new_transpose_node = AddTransOpNodeToGraph(func_graph, new_transdata_node, kernel_select_, 0, kOpFormat_HWCN, + output_format, prim::kPrimTranspose->name(), false); AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{3, 2, 0, 1}), new_transpose_node); new_replace_node = new_transpose_node; } else { // trans default to hwcn - new_transpose_node = NewTransOpNode(func_graph, AnfAlgo::GetInputNode(node->cast(), 0), kernel_select_, - false, prim::kPrimTranspose->name()); + new_transpose_node = AddTransOpNodeToGraph(func_graph, node, kernel_select_, 0, input_format, kOpFormat_HWCN, + prim::kPrimTranspose->name(), true); AnfAlgo::SetNodeAttr(kAttrPerm, MakeValue(std::vector{2, 3, 1, 0}), new_transpose_node); - RefreshKernelBuildInfo(input_format, kOpFormat_HWCN, AnfAlgo::GetOutputDeviceDataType(new_transpose_node, 0), - new_transpose_node); // trans hwcn to output_format - new_transdata_node = - NewTransOpNode(func_graph, new_transpose_node, kernel_select_, false, prim::KPrimTransData->name()); - RefreshKernelBuildInfo(kOpFormat_HWCN, output_format, AnfAlgo::GetOutputDeviceDataType(new_transdata_node, 0), - new_transdata_node); + new_transdata_node = AddTransOpNodeToGraph(func_graph, new_transpose_node, kernel_select_, 0, kOpFormat_HWCN, + output_format, kTransDataOpName, false); new_replace_node = new_transdata_node; } FuncGraphManagerPtr manager = func_graph->manager();