diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc index 68392d18716..38af0f87eba 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.cc @@ -158,13 +158,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor) std::shared_ptr KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } -void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType( +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType( const std::vector> &input_reshape_type) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->input_reshape_type_ = input_reshape_type; } -void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType( const std::vector> &output_reshape_type) { MS_EXCEPTION_IF_NULL(kernel_build_info_); kernel_build_info_->output_reshape_type_ = output_reshape_type; @@ -189,5 +189,36 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string } kernel_build_info_->outputs_format_[index] = format; } +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector &input_reshape_type, + size_t index) { + if (index >= kernel_build_info_->input_reshape_type_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + std::copy(input_reshape_type.begin(), input_reshape_type.end(), + std::back_inserter(kernel_build_info_->input_reshape_type_[index])); +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector &output_reshape_type, + size_t index) { + if (index >= kernel_build_info_->output_reshape_type_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + std::copy(output_reshape_type.begin(), output_reshape_type.end(), + std::back_inserter(kernel_build_info_->output_reshape_type_[index])); +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) { + if (index >= kernel_build_info_->outputs_device_type_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->outputs_device_type_[index] = output_device_type; +} + +void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) { + if (index >= kernel_build_info_->inputs_device_type_.size()) { + MS_LOG(EXCEPTION) << "index outof range!"; + } + kernel_build_info_->inputs_device_type_[index] = input_device_type; +} } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h index 32debbbfabe..6303c7cdfc5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h +++ b/mindspore/ccsrc/backend/kernel_compiler/kernel_build_info.h @@ -71,6 +71,10 @@ class KernelBuildInfo { std::vector GetAllOutputDeviceTypes() const; + std::vector> GetAllOutputReshapeType() const; + + std::vector> GetAllInputReshapeType() const; + OpPattern op_pattern() const { return op_pattern_; } FusionType fusion_type() const { return fusion_type_; } @@ -109,7 +113,22 @@ class KernelBuildInfo::KernelBuildInfoBuilder { KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared(); } explicit KernelBuildInfoBuilder(std::shared_ptr kernel_build_info) - : kernel_build_info_(std::move(kernel_build_info)) {} + : kernel_build_info_(std::make_shared()) { + SetKernelType(kernel_build_info->kernel_type()); + SetFusionType(kernel_build_info->fusion_type()); + SetProcessor(kernel_build_info->processor()); + OpPattern(kernel_build_info->op_pattern()); + for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) { + kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index)); + kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index)); + kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index)); + } + for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) { + kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index)); + kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index)); + kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index)); + } + } ~KernelBuildInfoBuilder() = default; @@ -123,9 +142,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder { void SetOutputsDeviceType(const std::vector &outputs_device_type); - void SetInputReshapeType(const std::vector> &input_reshape_type); + void SetInputsReshapeType(const std::vector> &input_reshape_type); - void SetOutputReshapeType(const std::vector> &output_reshape_type); + void SetOutputsReshapeType(const std::vector> &output_reshape_type); void SetFusionType(FusionType fusion_type); @@ -137,6 +156,14 @@ class KernelBuildInfo::KernelBuildInfoBuilder { void SetOutputFormat(const std::string &format, size_t index); + void SetInputReshapeType(const std::vector &input_reshape_type, size_t index); + + void SetOutputReshapeType(const std::vector &output_reshape_type, size_t index); + + void SetInputDeviceType(const TypeId &input_device_type, size_t index); + + void SetOutputDeviceType(const TypeId &output_device_type, size_t index); + std::shared_ptr Build(); private: diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc index e83502a7c6f..bc7cd6b42c3 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -118,7 +118,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { } builder.SetInputsDeviceType(inputs_device_type); builder.SetInputsFormat(inputs_format); - builder.SetInputReshapeType(inputs_reshape_type); + builder.SetInputsReshapeType(inputs_reshape_type); // output std::vector outputs_format; std::vector outputs_device_type; @@ -129,7 +129,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) { } builder.SetOutputsDeviceType(outputs_device_type); builder.SetOutputsFormat(outputs_format); - builder.SetOutputReshapeType(outputs_reshape_type); + builder.SetOutputsReshapeType(outputs_reshape_type); kernel_info_list_->emplace_back(builder.Build()); } MS_LOG(INFO) << "end."; diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc index 67ebf57f26c..0cb70c42c2b 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_backend_optimization.cc @@ -59,6 +59,7 @@ #include "backend/optimizer/ascend/format_type/insert_trans_op.h" #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" +#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" #include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/pass/optimize_dependence.h" #include "backend/optimizer/pass/erase_visit_attr.h" @@ -228,6 +229,7 @@ void AscendMixPrecision(const std::shared_ptr &kernel_grap mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); + mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); mixed_precision_pm->AddPass(std::make_shared()); optimizer->AddPassManager(mixed_precision_pm); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc index db986f03f74..74f4f249e55 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ascend_helper.cc @@ -174,8 +174,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string & MS_EXCEPTION_IF_NULL(ori_build_info); auto builder = std::make_shared(ori_build_info); builder->SetInputsFormat({input_format}); - builder->SetInputReshapeType({reshape_type}); - builder->SetOutputReshapeType({reshape_type}); + builder->SetInputsReshapeType({reshape_type}); + builder->SetOutputsReshapeType({reshape_type}); builder->SetOutputsFormat({output_format}); if (type_id != kTypeUnknown) { builder->SetOutputsDeviceType({type_id}); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc new file mode 100644 index 00000000000..92f6c5799b3 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.cc @@ -0,0 +1,65 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h" +#include +#include +#include "backend/session/anf_runtime_algorithm.h" + +namespace mindspore { +namespace opt { +const BaseRef SplitUnsupportedTransData::DefinePattern() const { + VarPtr X = std::make_shared(); + return VectorRef({prim::KPrimTransData, X}); +} + +const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node, + const EquivPtr &) const { + if (node == nullptr || !node->isa() || !AnfAlgo::IsRealKernel(node)) { + return nullptr; + } + auto ori_trans_data = node->cast(); + if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) { + return nullptr; + } + auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data); + MS_EXCEPTION_IF_NULL(kernel_info); + if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) { + MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1" + << ori_trans_data->DebugString(); + } + return SplitTransData(func_graph, ori_trans_data); +} +AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const { + auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node); + if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() || + kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) { + return trans_node; + } + auto builder_info_to_default = std::make_shared(kernel_info); + auto builder_info_to_special_foramt = std::make_shared(kernel_info); + builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT}); + builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT}); + std::vector next_trans_node_inputs = { + NewValueNode(std::make_shared(prim::KPrimTransData->name())), trans_node}; + auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs); + next_trans_node->set_abstract(trans_node->abstract()); + AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get()); + AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get()); + return next_trans_node; +} +} // namespace opt +} // namespace mindspore diff --git a/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h new file mode 100644 index 00000000000..d4df2b57a83 --- /dev/null +++ b/mindspore/ccsrc/backend/optimizer/ascend/format_type/split_unsupported_transdata.h @@ -0,0 +1,37 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H +#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H + +#include "backend/optimizer/common/optimizer.h" + +namespace mindspore { +namespace opt { +class SplitUnsupportedTransData : public PatternProcessPass { + public: + explicit SplitUnsupportedTransData(bool multigraph = true) + : PatternProcessPass("split_unsupported_transdata", multigraph) {} + ~SplitUnsupportedTransData() override = default; + const BaseRef DefinePattern() const override; + const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override; + + private: + AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const; +}; +} // namespace opt +} // namespace mindspore +#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc index 0a5cf3dd9e4..4e3a29bf30d 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/insert_trans_op_test.cc @@ -50,6 +50,8 @@ class TestHWInsertTransOp : public BackendCommon { KernelBuildInfoBuilder builder; builder.SetInputsFormat({format, format}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); + builder.SetInputsReshapeType({{},{}}); + builder.SetOutputsReshapeType({}); builder.SetOutputsFormat({format}); builder.SetOutputsDeviceType({kFloat16->type_id()}); add->set_kernel_info(std::make_shared()); @@ -70,6 +72,8 @@ class TestHWInsertTransOp : public BackendCommon { EXPECT_NE(ret->input(1)->cast()->input(1)->cast()->input(1), nullptr); auto max_pool = ret->input(1)->cast()->input(1)->cast()->input(1); KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{},{}}); builder.SetInputsFormat({kOpFormat_DEFAULT}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({format, format}); @@ -88,6 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { ~MockInsertTransOpKernelSelectTrans4Dto5D() override = default; void SelectKernel(const CNodePtr &cnode) override { KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); builder.SetInputsFormat({"NCHW"}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); diff --git a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc index 8c3ed29a0cb..d4a69e70a59 100644 --- a/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/format_type/remove_internal_output_test.cc @@ -53,6 +53,8 @@ class TestHWRemoveInternalOutput : public BackendCommon { KernelBuildInfoBuilder builder; builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); + builder.SetInputsReshapeType({{}, {}}); + builder.SetOutputsReshapeType({{}}); builder.SetOutputsFormat({kOpFormat_NC1HWC0}); builder.SetOutputsDeviceType({kFloat16->type_id()}); add->set_kernel_info(std::make_shared()); @@ -78,6 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon { kg->AddInternalOutput(tuple_getitem1, max_pool); kg->AddInternalOutput(tuple_getitem2, max_pool); KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}, {}}); builder.SetInputsFormat({kOpFormat_DEFAULT}); builder.SetInputsDeviceType({kFloat32->type_id()}); builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); @@ -99,6 +103,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({kOpFormat_DEFAULT}); builder.SetOutputsDeviceType({kFloat32->type_id()}); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } }; diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc index 220e45f10a0..18ec4c41551 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fission/transdata_split_test.cc @@ -51,6 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } else { KernelBuildInfoBuilder builder; @@ -58,7 +60,10 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); + } } }; @@ -74,6 +79,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NCHW"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } else { KernelBuildInfoBuilder builder; @@ -81,6 +88,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NCHW"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } } @@ -116,6 +125,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) { builder.SetKernelType(KernelType::TBE_KERNEL); builder.SetFusionType(kernel::FusionType::ELEMWISE); builder.SetProcessor(kernel::Processor::AICORE); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); auto kernel_info = std::make_shared(); kernel_info->set_select_kernel_build_info(builder.Build()); transpose->set_kernel_info(kernel_info); @@ -162,6 +173,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) { builder.SetKernelType(KernelType::TBE_KERNEL); builder.SetFusionType(kernel::FusionType::ELEMWISE); builder.SetProcessor(kernel::Processor::AICORE); + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); auto kernel_info = std::make_shared(); kernel_info->set_select_kernel_build_info(builder.Build()); transpose->set_kernel_info(kernel_info); diff --git a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc index d156959c4c5..23fa05f621a 100644 --- a/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc +++ b/tests/ut/cpp/pre_activate/ascend/ir_fusion/transpose_transdata_fusion_test.cc @@ -58,6 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } else { KernelBuildInfoBuilder builder; @@ -65,6 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect { builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsDeviceType({kFloat16->type_id()}); + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); } } @@ -93,6 +97,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) { EXPECT_NE(transpose, nullptr); KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({}); + builder.SetOutputsReshapeType({}); builder.SetInputsFormat({"NCHW"}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); diff --git a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc index 07bef7a0421..3e543d7a63b 100644 --- a/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc +++ b/tests/ut/cpp/pre_activate/pass/eliminate_redundant_op_test.cc @@ -56,6 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect { ~MockEliminate5To4And4To5KernelSelect() override = default; void SelectKernel(const CNodePtr &cnode) override { KernelBuildInfoBuilder builder; + builder.SetInputsReshapeType({{}}); + builder.SetOutputsReshapeType({{}}); builder.SetInputsFormat({"NCHW"}); builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetOutputsFormat({"NC1HWC0"}); @@ -102,7 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) { builder.SetOutputsFormat({"NC1HWC0"}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - + builder.SetInputsReshapeType({{}, {}}); + builder.SetOutputsReshapeType({{}}); sub->set_kernel_info(std::make_shared()); add->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); @@ -168,7 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) { builder.SetOutputsFormat({"NC1HWC0"}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - + builder.SetInputsReshapeType({{}, {}}); + builder.SetOutputsReshapeType({{}, {}}); sub->set_kernel_info(std::make_shared()); add->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); @@ -244,7 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) { builder.SetOutputsFormat({"NC1HWC0"}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()}); - + builder.SetInputsReshapeType({{}, {}}); + builder.SetOutputsReshapeType({{}}); sub->set_kernel_info(std::make_shared()); add->set_kernel_info(std::make_shared()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());