split unsupported transdata

This commit is contained in:
WilliamLian 2020-07-29 10:35:40 +08:00
parent 4f1e586ee3
commit edba641ddb
12 changed files with 211 additions and 13 deletions

View File

@ -158,13 +158,13 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetProcessor(Processor processor)
std::shared_ptr<KernelBuildInfo> KernelBuildInfo::KernelBuildInfoBuilder::Build() { return kernel_build_info_; }
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputsReshapeType(
const std::vector<std::vector<Axis>> &input_reshape_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->input_reshape_type_ = input_reshape_type;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsReshapeType(
const std::vector<std::vector<Axis>> &output_reshape_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->output_reshape_type_ = output_reshape_type;
@ -189,5 +189,37 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputFormat(const std::string
}
kernel_build_info_->outputs_format_[index] = format;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputReshapeType(const std::vector<Axis> &input_reshape_type,
size_t index) {
if (index >= kernel_build_info_->input_reshape_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
std::copy(input_reshape_type.begin(), input_reshape_type.end(),
std::back_inserter(kernel_build_info_->input_reshape_type_[index]));
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputReshapeType(const std::vector<Axis> &output_reshape_type,
size_t index) {
if (index >= kernel_build_info_->output_reshape_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
std::copy(output_reshape_type.begin(), output_reshape_type.end(),
std::back_inserter(kernel_build_info_->output_reshape_type_[index]));
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputDeviceType(const TypeId &output_device_type, size_t index) {
if (index >= kernel_build_info_->outputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
kernel_build_info_->outputs_device_type_[index] = output_device_type;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetInputDeviceType(const TypeId &input_device_type, size_t index) {
if (index >= kernel_build_info_->inputs_device_type_.size()) {
MS_LOG(EXCEPTION) << "index outof range!";
}
kernel_build_info_->inputs_device_type_[index] = input_device_type;
}
} // namespace kernel
} // namespace mindspore

View File

@ -71,6 +71,10 @@ class KernelBuildInfo {
std::vector<TypeId> GetAllOutputDeviceTypes() const;
std::vector<std::vector<Axis>> GetAllOutputReshapeType() const;
std::vector<std::vector<Axis>> GetAllInputReshapeType() const;
OpPattern op_pattern() const { return op_pattern_; }
FusionType fusion_type() const { return fusion_type_; }
@ -108,8 +112,23 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
public:
KernelBuildInfoBuilder() { kernel_build_info_ = std::make_shared<KernelBuildInfo>(); }
explicit KernelBuildInfoBuilder(std::shared_ptr<KernelBuildInfo> kernel_build_info)
: kernel_build_info_(std::move(kernel_build_info)) {}
explicit KernelBuildInfoBuilder(const std::shared_ptr<KernelBuildInfo> &kernel_build_info)
: kernel_build_info_(std::make_shared<KernelBuildInfo>()) {
SetKernelType(kernel_build_info->kernel_type());
SetFusionType(kernel_build_info->fusion_type());
SetProcessor(kernel_build_info->processor());
OpPattern(kernel_build_info->op_pattern());
for (size_t index = 0; index < kernel_build_info->GetInputNum(); ++index) {
kernel_build_info_->inputs_device_type_.emplace_back(kernel_build_info->GetInputDeviceType(index));
kernel_build_info_->inputs_format_.emplace_back(kernel_build_info->GetInputFormat(index));
kernel_build_info_->input_reshape_type_.emplace_back(kernel_build_info->GetInputReshapeType(index));
}
for (size_t index = 0; index < kernel_build_info->GetOutputNum(); ++index) {
kernel_build_info_->outputs_device_type_.emplace_back(kernel_build_info->GetOutputDeviceType(index));
kernel_build_info_->outputs_format_.emplace_back(kernel_build_info->GetOutputFormat(index));
kernel_build_info_->output_reshape_type_.emplace_back(kernel_build_info->GetOutputReshapeType(index));
}
}
~KernelBuildInfoBuilder() = default;
@ -123,9 +142,9 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void SetOutputsDeviceType(const std::vector<TypeId> &outputs_device_type);
void SetInputReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);
void SetInputsReshapeType(const std::vector<std::vector<Axis>> &input_reshape_type);
void SetOutputReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);
void SetOutputsReshapeType(const std::vector<std::vector<Axis>> &output_reshape_type);
void SetFusionType(FusionType fusion_type);
@ -137,6 +156,14 @@ class KernelBuildInfo::KernelBuildInfoBuilder {
void SetOutputFormat(const std::string &format, size_t index);
void SetInputReshapeType(const std::vector<Axis> &input_reshape_type, size_t index);
void SetOutputReshapeType(const std::vector<Axis> &output_reshape_type, size_t index);
void SetInputDeviceType(const TypeId &input_device_type, size_t index);
void SetOutputDeviceType(const TypeId &output_device_type, size_t index);
std::shared_ptr<KernelBuildInfo> Build();
private:

View File

@ -118,7 +118,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
}
builder.SetInputsDeviceType(inputs_device_type);
builder.SetInputsFormat(inputs_format);
builder.SetInputReshapeType(inputs_reshape_type);
builder.SetInputsReshapeType(inputs_reshape_type);
// output
std::vector<std::string> outputs_format;
std::vector<TypeId> outputs_device_type;
@ -129,7 +129,7 @@ void TbeKernelSelect::GetCommonPatternKernelInfo(const OpInfo &op_info) {
}
builder.SetOutputsDeviceType(outputs_device_type);
builder.SetOutputsFormat(outputs_format);
builder.SetOutputReshapeType(outputs_reshape_type);
builder.SetOutputsReshapeType(outputs_reshape_type);
kernel_info_list_->emplace_back(builder.Build());
}
MS_LOG(INFO) << "end.";

View File

@ -47,6 +47,7 @@
#include "backend/optimizer/ascend/ir_fission/transdata_split.h"
#include "backend/optimizer/ascend/ir_fission/topk_split.h"
#include "backend/optimizer/ascend/ir_fusion/momentum_lossscale_fusion.h"
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include "backend/optimizer/ascend/ir_fusion/mul_add_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/mul_addn_fusion.h"
#include "backend/optimizer/ascend/ir_fusion/matmul_biasadd_fusion.h"
@ -228,6 +229,7 @@ void AscendMixPrecision(const std::shared_ptr<session::KernelGraph> &kernel_grap
mixed_precision_pm->AddPass(std::make_shared<MergeCastToOp>());
mixed_precision_pm->AddPass(std::make_shared<LayerNormBetaGammaBackpropFusion>());
mixed_precision_pm->AddPass(std::make_shared<EraseVisitAttr>());
mixed_precision_pm->AddPass(std::make_shared<SplitUnsupportedTransData>());
mixed_precision_pm->AddPass(std::make_shared<ConvertUnSupportNodeToAICPU>());
mixed_precision_pm->AddPass(std::make_shared<RemoveInternalOutputCast>());
optimizer->AddPassManager(mixed_precision_pm);

View File

@ -174,8 +174,8 @@ void RefreshKernelBuildInfo(const std::string &input_format, const std::string &
MS_EXCEPTION_IF_NULL(ori_build_info);
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(ori_build_info);
builder->SetInputsFormat({input_format});
builder->SetInputReshapeType({reshape_type});
builder->SetOutputReshapeType({reshape_type});
builder->SetInputsReshapeType({reshape_type});
builder->SetOutputsReshapeType({reshape_type});
builder->SetOutputsFormat({output_format});
if (type_id != kTypeUnknown) {
builder->SetOutputsDeviceType({type_id});

View File

@ -0,0 +1,65 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/optimizer/ascend/format_type/split_unsupported_transdata.h"
#include <vector>
#include <memory>
#include "backend/session/anf_runtime_algorithm.h"
namespace mindspore {
namespace opt {
const BaseRef SplitUnsupportedTransData::DefinePattern() const {
VarPtr X = std::make_shared<Var>();
return VectorRef({prim::KPrimTransData, X});
}
const AnfNodePtr SplitUnsupportedTransData::Process(const FuncGraphPtr &func_graph, const AnfNodePtr &node,
const EquivPtr &) const {
if (node == nullptr || !node->isa<CNode>() || !AnfAlgo::IsRealKernel(node)) {
return nullptr;
}
auto ori_trans_data = node->cast<CNodePtr>();
if (AnfAlgo::GetCNodeName(ori_trans_data) != prim::KPrimTransData->name()) {
return nullptr;
}
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(ori_trans_data);
MS_EXCEPTION_IF_NULL(kernel_info);
if (kernel_info->GetInputNum() != 1 || kernel_info->GetOutputNum() != 1) {
MS_LOG(EXCEPTION) << "Transdata node's kernel info's input and output format size is not 1"
<< ori_trans_data->DebugString();
}
return SplitTransData(func_graph, ori_trans_data);
}
AnfNodePtr SplitUnsupportedTransData::SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const {
auto kernel_info = AnfAlgo::GetSelectKernelBuildInfo(trans_node);
if (kHWSpecialFormatSet.find(kernel_info->GetInputFormat(0)) == kHWSpecialFormatSet.end() ||
kHWSpecialFormatSet.find(kernel_info->GetOutputFormat(0)) == kHWSpecialFormatSet.end()) {
return trans_node;
}
auto builder_info_to_default = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
auto builder_info_to_special_foramt = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(kernel_info);
builder_info_to_default->SetOutputsFormat({kOpFormat_DEFAULT});
builder_info_to_special_foramt->SetInputsFormat({kOpFormat_DEFAULT});
std::vector<AnfNodePtr> next_trans_node_inputs = {
NewValueNode(std::make_shared<Primitive>(prim::KPrimTransData->name())), trans_node};
auto next_trans_node = func_graph->NewCNode(next_trans_node_inputs);
next_trans_node->set_abstract(trans_node->abstract());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_default->Build(), trans_node.get());
AnfAlgo::SetSelectKernelBuildInfo(builder_info_to_special_foramt->Build(), next_trans_node.get());
return next_trans_node;
}
} // namespace opt
} // namespace mindspore

View File

@ -0,0 +1,37 @@
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
#define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H
#include "backend/optimizer/common/optimizer.h"
namespace mindspore {
namespace opt {
class SplitUnsupportedTransData : public PatternProcessPass {
public:
explicit SplitUnsupportedTransData(bool multigraph = true)
: PatternProcessPass("split_unsupported_transdata", multigraph) {}
~SplitUnsupportedTransData() override = default;
const BaseRef DefinePattern() const override;
const AnfNodePtr Process(const FuncGraphPtr &, const AnfNodePtr &, const EquivPtr &) const override;
private:
AnfNodePtr SplitTransData(const FuncGraphPtr &func_graph, const CNodePtr &trans_node) const;
};
} // namespace opt
} // namespace mindspore
#endif // MINDSPORE_CCSRC_BACKEND_OPTIMIZER_ASCEND_FORMAT_TYPE_REMOVE_TRANSDATA_SPILT_H

View File

@ -51,6 +51,8 @@ class TestHWInsertTransOp : public BackendCommon {
builder.SetInputsFormat({format, format});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsFormat({format});
builder.SetInputsReshapeType({{},{}});
builder.SetOutputsReshapeType({});
builder.SetOutputsDeviceType({kFloat16->type_id()});
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), add.get());
@ -70,6 +72,8 @@ class TestHWInsertTransOp : public BackendCommon {
EXPECT_NE(ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1), nullptr);
auto max_pool = ret->input(1)->cast<CNodePtr>()->input(1)->cast<CNodePtr>()->input(1);
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{},{}});
builder.SetOutputsReshapeType({});
builder.SetInputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({format, format});
@ -88,6 +92,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
~MockInsertTransOpKernelSelectTrans4Dto5D() override = default;
void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});

View File

@ -52,6 +52,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
kg->AddInternalOutput(add, add);
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder.SetOutputsFormat({kOpFormat_NC1HWC0});
builder.SetOutputsDeviceType({kFloat16->type_id()});
@ -78,6 +80,8 @@ class TestHWRemoveInternalOutput : public BackendCommon {
kg->AddInternalOutput(tuple_getitem1, max_pool);
kg->AddInternalOutput(tuple_getitem2, max_pool);
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}, {}});
builder.SetInputsFormat({kOpFormat_DEFAULT});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsFormat({kOpFormat_NC1HWC0, kOpFormat_NC1HWC0});
@ -95,6 +99,8 @@ class MockRemoveInternalOutputTransOpKernelSelect : public KernelSelect {
~MockRemoveInternalOutputTransOpKernelSelect() override = default;
void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsFormat({kOpFormat_NC1HWC0});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({kOpFormat_DEFAULT});

View File

@ -51,6 +51,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else {
KernelBuildInfoBuilder builder;
@ -58,6 +60,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
}
}
@ -74,10 +78,14 @@ class MockTransdataSplitKernelSelect : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NCHW"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else {
KernelBuildInfoBuilder builder;
builder.SetInputsFormat({"NCHW"});
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NCHW"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
@ -116,6 +124,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->set_select_kernel_build_info(builder.Build());
transpose->set_kernel_info(kernel_info);
@ -162,6 +172,8 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) {
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->set_select_kernel_build_info(builder.Build());
transpose->set_kernel_info(kernel_info);

View File

@ -58,6 +58,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
} else {
KernelBuildInfoBuilder builder;
@ -65,6 +67,8 @@ class MockInsertTransOpKernelSelectTrans4Dto5D : public KernelSelect {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), cnode.get());
}
}
@ -97,6 +101,8 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({});
builder.SetOutputsReshapeType({});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetProcessor(kernel::Processor::AICORE);

View File

@ -56,6 +56,8 @@ class MockEliminate5To4And4To5KernelSelect : public KernelSelect {
~MockEliminate5To4And4To5KernelSelect() override = default;
void SelectKernel(const CNodePtr &cnode) override {
KernelBuildInfoBuilder builder;
builder.SetInputsReshapeType({{}});
builder.SetOutputsReshapeType({{}});
builder.SetInputsFormat({"NCHW"});
builder.SetInputsDeviceType({kFloat16->type_id()});
builder.SetOutputsFormat({"NC1HWC0"});
@ -102,7 +104,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_5to4_4to5) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
@ -168,7 +171,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());
@ -244,7 +248,8 @@ TEST_F(TestHWEliminateRedundantOp, test_eliminate_cast_depend_cast) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetInputsReshapeType({{}, {}});
builder.SetOutputsReshapeType({{}});
sub->set_kernel_info(std::make_shared<device::KernelInfo>());
add->set_kernel_info(std::make_shared<device::KernelInfo>());
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), sub.get());