!3543 Split unsupport transdata

Merge pull request !3543 from lianliguang/unify-primitive
This commit is contained in:
mindspore-ci-bot 2020-07-28 17:23:44 +08:00 committed by Gitee
commit 7cb567ebbe
12 changed files with 210 additions and 12 deletions

View File

@ -158,13 +158,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor)
std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; } std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType( void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(
const std::vector<std::vector<Axis>> &input_reshape_type) { const std::vector<std::vector<Axis>> &input_reshape_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_); MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->input_reshape_type_ = input_reshape_type; kernel_build_info_->input_reshape_type_ = input_reshape_type;
} }
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType( void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
const std::vector<std::vector<Axis>> &output_reshape_type) { const std::vector<std::vector<Axis>> &output_reshape_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_); MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->output_reshape_type_ = output_reshape_type; kernel_build_info_->output_reshape_type_ = output_reshape_type;
@ -189,5 +189,36 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string
} }
kernel_build_info_->outputs_format_[index] = format; kernel_build_info_->outputs_format_[index] = format;
} }
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector<Axis> &input_reshape_type,
size_t index) {
if (index >= kernel_build_info_->input_reshape_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
std::copy(input_reshape_type.begin(), input_reshape_type.end(),
std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector<Axis> &output_reshape_type,
size_t index) {
if (index >= kernel_build_info_->output_reshape_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
std::copy(output_reshape_type.begin(), output_reshape_type.end(),
std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
if (index >= kernel_build_info_->outputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
kernel_build_info_->outputs_device_type_[index] = output_device_type;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) {
if (index >= kernel_build_info_->inputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
kernel_build_info_->inputs_device_type_[index] = input_device_type;
}
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore

View File

@ -71,6 +71,10 @@ class KernelBuildInfo {
std::vector<TypeId> GetAllOutputDeviceTypes() const; std::vector<TypeId> GetAllOutputDeviceTypes() const;
std::vector<std::vector<Axis>> GetAllOutputReshapeType() const;
std::vector<std::vector<Axis>> GetAllInputReshapeType() const;
OpPattern op_pattern() const { return op_pattern_; } OpPattern op_pattern() const { return op_pattern_; }
FusionType fusion_type() const { return fusion_type_; } FusionType fusion_type() const { return fusion_type_; }
@ -109,7 +113,22 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); } KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info) explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)
: kernel_build_info_(std::move(kernel_build_info)) {} : kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
SetKernelType(kernel_build_info->kernel_type());
SetFusionType(kernel_build_info->fusion_type());
SetProcessor(kernel_build_info->processor());
OpPattern(kernel_build_info->op_pattern());
for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
}
for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
}
}
~KernelBuildInfoBuilder() = default; ~KernelBuildInfoBuilder() = default;
@ -123,9 +142,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type); void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
void SetInputReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type); void SetInputsReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);
void SetOutputReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type); void SetOutputsReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);
void SetFusionType(FusionType fusion_type); void SetFusionType(FusionType fusion_type);
@ -137,6 +156,14 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void SetOutputFormat(const std::string &format, size_t index); void SetOutputFormat(const std::string &format, size_t index);
void SetInputReshapeType(const std::vector<Axis> &input_reshape_type, size_t index);
void SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, size_t index);
void SetInputDeviceType(const TypeId &input_device_type, size_t index);
void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
std::shared_ptr<KernelBuildInfo> Build(); std::shared_ptr<KernelBuildInfo> Build();
private: private:

View File

@ -118,7 +118,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
} }
builder.SetInputsDeviceType(inputs_device_type); builder.SetInputsDeviceType(inputs_device_type);
builder.SetInputsFormat(inputs_format); builder.SetInputsFormat(inputs_format);
builder.SetInputReshapeType(inputs_reshape_type); builder.SetInputsReshapeType(inputs_reshape_type);
// output // output
std::vector<std::string> outputs_format; std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_device_type; std::vector<TypeId> outputs_device_type;
@ -129,7 +129,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
} }
builder.SetOutputsDeviceType(outputs_device_type); builder.SetOutputsDeviceType(outputs_device_type);
builder.SetOutputsFormat(outputs_format); builder.SetOutputsFormat(outputs_format);
builder.SetOutputReshapeType(outputs_reshape_type); builder.SetOutputsReshapeType(outputs_reshape_type);
kernel_info_list_->emplace_back(builder.Build()); kernel_info_list_->emplace_back(builder.Build());
} }
MS_LOG(INFO) << "end."; MS_LOG(INFO) << "end.";

View File

@ -59,6 +59,7 @@
#include "backend/optimizer/ascend/format_type/insert_trans_op.h" #include "backend/optimizer/ascend/format_type/insert_trans_op.h"
#include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h" #include "backend/optimizer/ascend/format_type/rectify_do_mask_kernel_info.h"
#include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h" #include "backend/optimizer/ascend/format_type/chang_axis_of_reduce_kernel.h"
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include "backend/optimizer/pass/getitem_tuple.h" #include "backend/optimizer/pass/getitem_tuple.h"
#include "backend/optimizer/pass/optimize_dependence.h" #include "backend/optimizer/pass/optimize_dependence.h"
#include "backend/optimizer/pass/erase_visit_attr.h" #include "backend/optimizer/pass/erase_visit_attr.h"
@ -228,6 +229,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>()); mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>()); mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>()); mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<SplitUnsupportedTransData>());
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>()); mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>()); mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>());
optimizer->AddPassManager(mixed_precision_pm); optimizer->AddPassManager(mixed_precision_pm);

View File

@ -174,8 +174,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
MS_EXCEPTION_IF_NULL(ori_build_info); MS_EXCEPTION_IF_NULL(ori_build_info);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info); auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
builder->SetInputsFormat({input_format}); builder->SetInputsFormat({input_format});
builder->SetInputReshapeType({reshape_type}); builder->SetInputsReshapeType({reshape_type});
builder->SetOutputReshapeType({reshape_type}); builder->SetOutputsReshapeType({reshape_type});
builder->SetOutputsFormat({output_format}); builder->SetOutputsFormat({output_format});
if (type_id != kTypeUnknown) { if (type_id != kTypeUnknown) {
builder->SetOutputsDeviceType({type_id}); builder->SetOutputsDeviceType({type_id});

View File

@ -0,0 +1,65 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include <vector>
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
const BaseRef SplitUnsupportedTransData::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
return VectorRef({prim::KPrimTransData, X});
}
const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
return nullptr;
}
auto ori_trans_data = node->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) {
return nullptr;
}
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data);
MS_EXCEPTION_IF_NULL(kernel_info);
if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) {
MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1"
<< ori_trans_data->DebugString();
}
return SplitTransData(func_graph, ori_trans_data);
}
AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const {
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node);
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
return trans_node;
}
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
std::vector<AnfNodePtr> next_trans_node_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node};
auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
next_trans_node->set_abstract(trans_node->abstract());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
return next_trans_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class SplitUnsupportedTransData : public PatternProcessPass {
public:
explicit SplitUnsupportedTransData(bool multigraph = true)
: PatternProcessPass("split_unsupported_transdata", multigraph) {}
~SplitUnsupportedTransData() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H

View File

@ -50,6 +50,8 @@ class TestHWInsertTransOp : public BackendCommon {
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsFormat({format, format}); builder.SetInputsFormat({format, format});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetInputsReshapeType({{},{}});
builder.SetOutputsReshapeType({});
builder.SetOutputsFormat({format}); builder.SetOutputsFormat({format});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
add->set_kernel_info(std::make_shared<device::KernelInfo>()); add->set_kernel_info(std::make_shared<device::KernelInfo>());
@ -70,6 +72,8 @@ class TestHWInsertTransOp : public BackendCommon {
EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr); EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr);
auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1); auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{},{}});
builder.SetInputsFormat({kOpFormat_DEFAULT}); builder.SetInputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({format, format}); builder.SetOutputsFormat({format, format});
@ -88,6 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
~MockInsertTransOpKernelSelectTrans4Dto5D() override = default; ~MockInsertTransOpKernelSelectTrans4Dto5D() override = default;
void SelectKernel(const CNodePtr &cnode) override { void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsFormat({"NCHW"}); builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});

View File

@ -53,6 +53,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT}); builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()}); builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
builder.SetOutputsFormat({kOpFormat_NC1HWC0}); builder.SetOutputsFormat({kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
add->set_kernel_info(std::make_shared<device::KernelInfo>()); add->set_kernel_info(std::make_shared<device::KernelInfo>());
@ -78,6 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
kg->AddInternalOutput(tuple_getitem1, max_pool); kg->AddInternalOutput(tuple_getitem1, max_pool);
kg->AddInternalOutput(tuple_getitem2, max_pool); kg->AddInternalOutput(tuple_getitem2, max_pool);
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}, {}});
builder.SetInputsFormat({kOpFormat_DEFAULT}); builder.SetInputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat32->type_id()}); builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0}); builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
@ -99,6 +103,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({kOpFormat_DEFAULT}); builder.SetOutputsFormat({kOpFormat_DEFAULT});
builder.SetOutputsDeviceType({kFloat32->type_id()}); builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} }
}; };

View File

@ -51,6 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else { } else {
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
@ -58,7 +60,10 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} }
} }
}; };
@ -74,6 +79,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NCHW"}); builder.SetOutputsFormat({"NCHW"});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else { } else {
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
@ -81,6 +88,8 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NCHW"}); builder.SetOutputsFormat({"NCHW"});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} }
} }
@ -116,6 +125,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
builder.SetKernelType(KernelType::TBE_KERNEL); builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE); builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE); builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
auto kernel_info = std::make_shared<device::KernelInfo>(); auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->set_select_kernel_build_info(builder.Build()); kernel_info->set_select_kernel_build_info(builder.Build());
transpose->set_kernel_info(kernel_info); transpose->set_kernel_info(kernel_info);
@ -162,6 +173,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) {
builder.SetKernelType(KernelType::TBE_KERNEL); builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE); builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE); builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
auto kernel_info = std::make_shared<device::KernelInfo>(); auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->set_select_kernel_build_info(builder.Build()); kernel_info->set_select_kernel_build_info(builder.Build());
transpose->set_kernel_info(kernel_info); transpose->set_kernel_info(kernel_info);

View File

@ -58,6 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else { } else {
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
@ -65,6 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} }
} }
@ -93,6 +97,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
EXPECT_NE(transpose, nullptr); EXPECT_NE(transpose, nullptr);
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
builder.SetInputsFormat({"NCHW"}); builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});

View File

@ -56,6 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect {
~MockEliminate5To4And4To5KernelSelect() override = default; ~MockEliminate5To4And4To5KernelSelect() override = default;
void SelectKernel(const CNodePtr &cnode) override { void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder; KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsFormat({"NCHW"}); builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});
@ -102,7 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
sub->set_kernel_info(std::make_shared<device::KernelInfo>()); sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>()); add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
@ -168,7 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) {
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}, {}});
sub->set_kernel_info(std::make_shared<device::KernelInfo>()); sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>()); add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
@ -244,7 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) {
builder.SetOutputsFormat({"NC1HWC0"}); builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()}); builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()}); builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
sub->set_kernel_info(std::make_shared<device::KernelInfo>()); sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>()); add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get()); AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());