remove enum fusion type

This commit is contained in:
jjfeing 2022-12-07 11:16:56 +08:00
parent ca92df14bd
commit 82f04825f3
56 changed files with 164 additions and 252 deletions

View File

@ -669,7 +669,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
return build_info->kernel_type();
}
void AnfRuntimeAlgorithm::SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type) {
void AnfRuntimeAlgorithm::SetFusionType(const AnfNodePtr &node, const std::string &type) {
MS_EXCEPTION_IF_NULL(node);
auto builder =
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
@ -718,13 +718,13 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
return build_info->processor();
}
kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
std::string AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
if (build_info == nullptr) {
return kernel::FusionType::UNKNOWN_FUSION_TYPE;
return kernel::kPatternUnknown;
}
return build_info->fusion_type();
}

View File

@ -129,8 +129,8 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
// get processor type:AICORE,AICPU...
static kernel::Processor GetProcessor(const AnfNodePtr &node);
// get fusion type:AICORE,AICPU...
static kernel::FusionType GetFusionType(const AnfNodePtr &node);
static void SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type);
static std::string GetFusionType(const AnfNodePtr &node);
static void SetFusionType(const AnfNodePtr &node, const std::string &type);
static void SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc);
static std::vector<nlohmann::json> GetOutputDataDesc(const AnfNodePtr &node);
// core type

View File

@ -135,7 +135,7 @@ void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
graph_info_builder.SetInputsFormat(graph_input_format);
graph_info_builder.SetInputsDeviceType(graph_input_type);
graph_info_builder.SetOutputsFormat(graph_output_format);
@ -189,7 +189,7 @@ void CallbackImpl::SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vec
info_builder.SetOutputsDeviceType(output_types);
info_builder.SetProcessor(kernel::GetProcessorFromContext());
info_builder.SetKernelType(KernelType::AKG_KERNEL);
info_builder.SetFusionType(kernel::FusionType::OPAQUE);
info_builder.SetFusionType(kernel::kPatternOpaque);
auto selected_info = info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(selected_info, node.get());
}

View File

@ -41,7 +41,7 @@ CNodePtr AddCastCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, c
builder.SetOutputsFormat({format});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsDeviceType({output_type});
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetFusionType(kernel::kPatternOpaque);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AKG_KERNEL);
if (cast->kernel_info() == nullptr) {
@ -193,7 +193,7 @@ bool DecreaseComputePrecision::Process(const FuncGraphPtr &func_graph) const {
graph_info_builder.SetOutputsDeviceType(cnode_output_type);
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
auto info_1 = graph_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode1.get());
if (is_output) {

View File

@ -168,7 +168,7 @@ bool DecreaseTransferPrecision::ProcessFather(const FuncGraphPtr &, const AnfNod
graph_info_builder.SetOutputsDeviceType(cnode_output_type);
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
auto info_1 = graph_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
return cnode;
@ -281,7 +281,7 @@ bool DecreaseTransferPrecision::ProcessSon(const FuncGraphPtr &, const AnfNodePt
node_info_builder.SetOutputsDeviceType(cnode_output_type);
node_info_builder.SetProcessor(kernel::GetProcessorFromContext());
node_info_builder.SetKernelType(KernelType::AKG_KERNEL);
node_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
node_info_builder.SetFusionType(kernel::kPatternOpaque);
auto info_1 = node_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
(void)mng->Replace(old_input, cnode);

View File

@ -113,7 +113,7 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::str
graph_info_builder.SetOutputsDeviceType(output_types);
graph_info_builder.SetProcessor(processor);
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
return graph_info_builder.Build();
}
@ -168,7 +168,7 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
auto graph_selected_info = graph_info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get());
}
@ -394,7 +394,7 @@ CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &
info_builder.SetOutputsDeviceType(output_types);
info_builder.SetProcessor(kernel::GetProcessorFromContext());
info_builder.SetKernelType(KernelType::AKG_KERNEL);
info_builder.SetFusionType(kernel::FusionType::OPAQUE);
info_builder.SetFusionType(kernel::kPatternOpaque);
auto selected_info = info_builder.Build();
AnfAlgo::SetSelectKernelBuildInfo(selected_info, cnode.get());

View File

@ -56,7 +56,7 @@ AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, cons
info_builder.SetOutputsDeviceType({kFloat32->type_id()});
info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
info_builder.SetKernelType(KernelType::AKG_KERNEL);
info_builder.SetFusionType(kernel::FusionType::OPAQUE);
info_builder.SetFusionType(kernel::kPatternOpaque);
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), cnode.get());
return node;
}

View File

@ -247,7 +247,7 @@ class CNodeDecoder {
builder->SetOutputsDeviceType(output_types_);
builder->SetProcessor(processor);
builder->SetKernelType(KernelType::AKG_KERNEL);
builder->SetFusionType(kernel::FusionType::OPAQUE);
builder->SetFusionType(kernel::kPatternOpaque);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode_.get());
}

View File

@ -48,60 +48,6 @@ constexpr auto kStridedSliceMaxDims = 8;
constexpr auto kQuad = 4;
constexpr size_t kInputFirstIndex = 0;
constexpr char kOperatorOriginFormat[] = "operator_origin_format";
// Define all patterns here for different schedule
const std::unordered_map<FusionType, std::string> fusion_type_name_maps = {
{FusionType::BN_UPDATE_GRAD, "bn_update_grad"},
{FusionType::BN_GRAD_REDUCE, "bn_grad_reduce"},
{FusionType::LAYER_NORM_GRAD, "layer_norm_grad"},
{FusionType::L2LOSS_MUL_ADDN, "l2loss_mul_addn"},
{FusionType::ELEMWISE, "ElemWise"},
{FusionType::PURE_BROADCAST, "PureBroadcast"},
{FusionType::COMMREDUCE, "CommReduce"},
{FusionType::SEGMENT, "Segment"},
{FusionType::INPLACE, "Inplace"},
{FusionType::MATMUL, "Matmul"},
{FusionType::MATMUL_V2, "Matmul_v2"},
{FusionType::GEMM, "GEMM"},
{FusionType::CONV, "Convolution"},
{FusionType::CONV2D_BACKPROP_INPUT, "Conv2d_backprop_input"},
{FusionType::CONV2D_BACKPROP_FILTER, "Conv2d_backprop_filter"},
{FusionType::CONV3D_BACKPROP_INPUT, "Conv3d_backprop_input"},
{FusionType::CONV3D_BACKPROP_FILTER, "Conv3d_backprop_filter"},
{FusionType::CUBE_LAYER_NORM, "cube_layer_norm"},
{FusionType::OPAQUE, "Opaque"},
{FusionType::BN_REDUCE, "bn_reduce"},
{FusionType::BN_UPDATE, "bn_update"},
{FusionType::SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, "softmax_cross_entropy_with_logits"},
{FusionType::L2_NORMALIZE, "l2_normalize"},
{FusionType::SOFTMAX, "softmax_pattern"},
{FusionType::L2_LOSS, "l2_loss"},
{FusionType::ASCEND_QUANT, "quant"},
{FusionType::ASCEND_DEQUANT, "dequant"},
{FusionType::ASCEND_ANTI_QUANT, "anti_quant"},
{FusionType::STRIDED_READ, "strided_read"},
{FusionType::STRIDED_WRITE, "strided_write"},
{FusionType::ASCEND_DEQUANT_S16, "dequant_s16"},
{FusionType::ASCEND_REQUANT, "requant"},
{FusionType::ASCEND_REQUANT_S16, "requant_s16"},
{FusionType::MAX_POOL, "MaxPool"},
{FusionType::DEPTHWISECONV, "DepthwiseConvolution"},
{FusionType::CONV3D, "Conv3d"},
{FusionType::POOL2D, "Pool2d"},
{FusionType::POOL3D, "Pool3d"},
{FusionType::READ_SELECT, "read_select"},
{FusionType::WRITE_SELECT, "write_select"},
{FusionType::COSINE_EMBEDDING_LOSS, "cosine_embedding_loss"},
{FusionType::DILATION_PATTERN, "dilation"},
{FusionType::BROAD_CAST, "Broadcast"},
{FusionType::BATCH_MATMUL, "BatchMatmul"},
{FusionType::CONFUSION_TRANSPOSE, "confusiontranspose"},
{FusionType::DROPOUT_DOMASKV3D, "DropOutDoMaskV3D"},
{FusionType::TRANSDATA, "Transdata"},
{FusionType::NORM, "Norm"},
{FusionType::TRANSPOSE, "Transpose"},
{FusionType::UNKNOWN_FUSION_TYPE, ""}};
abstract::BaseShapePtr GetValidShapeFromAbstract(const abstract::AbstractBasePtr &abs) {
// Other abstract class, such as AbstractCSRTensor and AbstractCOOTensor, is converted to AbstractTensor early time.
abstract::BaseShapePtr res_shape;
@ -239,29 +185,6 @@ int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_co
return offset;
}
std::string GetFusionNameByType(const kernel::FusionType &type) {
auto iter = fusion_type_name_maps.find(type);
if (iter == fusion_type_name_maps.end()) {
MS_LOG(EXCEPTION) << "Illegal fusion type: " << type;
}
return iter->second;
}
FusionType GetFusionTypeByName(const std::string &name) {
std::string fusion_name_upper = name;
transform(fusion_name_upper.begin(), fusion_name_upper.end(), fusion_name_upper.begin(), ::toupper);
auto iter =
std::find_if(fusion_type_name_maps.begin(), fusion_type_name_maps.end(), [&fusion_name_upper](const auto &it) {
std::string name_upper = it.second;
transform(name_upper.begin(), name_upper.end(), name_upper.begin(), ::toupper);
return fusion_name_upper == name_upper;
});
if (iter == fusion_type_name_maps.end()) {
MS_LOG(EXCEPTION) << "Illegal fusion name: " << name;
}
return iter->first;
}
std::string GetCompilerCachePath() { return Common::GetUserDefineCachePath(); }
void KernelMeta::Initialize() {

View File

@ -180,8 +180,6 @@ inline float Scaler(const size_t x, const float scale, bool half_pixel_centers)
}
}
float ScaleGrid(const int x, const float scale);
BACKEND_EXPORT FusionType GetFusionTypeByName(const std::string &name);
BACKEND_EXPORT std::string GetFusionNameByType(const kernel::FusionType &type);
BACKEND_EXPORT std::vector<bool> Dec2Bin(const int64_t &mask);
BACKEND_EXPORT void FillEmptyDims(const BaseOperatorPtr &base_operator, std::vector<int64_t> *begin,
std::vector<int64_t> *end, std::vector<int64_t> *stride, ShapeVector *input_shape);

View File

@ -56,59 +56,56 @@ enum KernelType : int {
ACL_KERNEL,
};
namespace kernel {
// Supported fusion type
enum FusionType {
CONV = 0,
ELEMWISE,
COMMREDUCE,
SEGMENT,
OPAQUE,
BN_UPDATE_GRAD,
BN_GRAD_REDUCE,
LAYER_NORM_GRAD,
L2LOSS_MUL_ADDN,
PURE_BROADCAST,
INPLACE,
MATMUL,
MATMUL_V2,
GEMM,
CONV2D_BACKPROP_INPUT,
CONV2D_BACKPROP_FILTER,
CONV3D_BACKPROP_INPUT,
CONV3D_BACKPROP_FILTER,
CUBE_LAYER_NORM,
BN_REDUCE,
BN_UPDATE,
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
L2_NORMALIZE,
SOFTMAX,
L2_LOSS,
ASCEND_QUANT,
ASCEND_DEQUANT,
ASCEND_ANTI_QUANT,
STRIDED_READ,
STRIDED_WRITE,
ASCEND_DEQUANT_S16,
ASCEND_REQUANT,
ASCEND_REQUANT_S16,
MAX_POOL,
DEPTHWISECONV,
CONV3D,
POOL2D,
POOL3D,
READ_SELECT,
WRITE_SELECT,
COSINE_EMBEDDING_LOSS,
DILATION_PATTERN,
BROAD_CAST,
BATCH_MATMUL,
CONFUSION_TRANSPOSE,
DROPOUT_DOMASKV3D,
TRANSDATA,
NORM,
TRANSPOSE,
UNKNOWN_FUSION_TYPE = -1,
};
constexpr auto kPatternAntiQuant = "anti_quant";
constexpr auto kPatternDequant = "dequant";
constexpr auto kPatternDequantS16 = "dequant_s16";
constexpr auto kPatternQuant = "quant";
constexpr auto kPatternRequant = "requant";
constexpr auto kPatternRequant_s16 = "requant_s16";
constexpr auto kPatternBatchMatmul = "BatchMatmul";
constexpr auto kPatternBnGradReduce = "bn_grad_reduce";
constexpr auto kPatternBnReduce = "bn_reduce";
constexpr auto kPatternBnUpdate = "bn_update";
constexpr auto kPatternBnUpdate_grad = "bn_update_grad";
constexpr auto kPatternBroadcast = "Broadcast";
constexpr auto kPatternCommReduce = "CommReduce";
constexpr auto kPatternConfusiontranspose = "confusiontranspose";
constexpr auto kPatternConvolution = "Convolution";
constexpr auto kPatternConv2dBackpropFilter = "Conv2d_backprop_filter";
constexpr auto kPatternConv2dBackpropInput = "Conv2d_backprop_input";
constexpr auto kPatternConv3d = "Conv3d";
constexpr auto kPatternConv3dBackpropFilter = "Conv3d_backprop_filter";
constexpr auto kPatternConv3dBackpropInput = "Conv3d_backprop_input";
constexpr auto kPatternCosineEmbeddingLoss = "cosine_embedding_loss";
constexpr auto kPatternCubeLayerNorm = "cube_layer_norm";
constexpr auto kPatternDepthwiseConvolution = "DepthwiseConvolution";
constexpr auto kPatternDilation = "dilation";
constexpr auto kPatternDropOutDoMaskV3D = "DropOutDoMaskV3D";
constexpr auto kPatternElemWise = "ElemWise";
constexpr auto kPatternGEMM = "GEMM";
constexpr auto kPatternInplace = "Inplace";
constexpr auto kPatternL2Loss = "l2_loss";
constexpr auto kPatternL2Normalize = "l2_normalize";
constexpr auto kPatternL2lossMulAddn = "l2loss_mul_addn";
constexpr auto kPatternLayerNormGrad = "layer_norm_grad";
constexpr auto kPatternMatmul = "Matmul";
constexpr auto kPatternMatmulV2 = "Matmul_v2";
constexpr auto kPatternMaxPool = "MaxPool";
constexpr auto kPatternNorm = "Norm";
constexpr auto kPatternOpaque = "Opaque";
constexpr auto kPatternPool2d = "Pool2d";
constexpr auto kPatternPool3d = "Pool3d";
constexpr auto kPatternPureBroadcast = "PureBroadcast";
constexpr auto kPatternread_select = "read_select";
constexpr auto kPatternSegment = "Segment";
constexpr auto kPatternSoftmaxPattern = "softmax_pattern";
constexpr auto kPatternSoftmaxCrossEntropyWithLogits = "softmax_cross_entropy_with_logits";
constexpr auto kPatternStridedRead = "strided_read";
constexpr auto kPatternStridedWrite = "strided_write";
constexpr auto kPatternTransdata = "Transdata";
constexpr auto kPatternTranspose = "Transpose";
constexpr auto kPatternUnknown = "";
constexpr auto kPatternWriteSelect = "write_select";
// Backend processor
enum Processor {

View File

@ -209,7 +209,7 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::ve
kernel_build_info_->outputs_device_type_ = outputs_device_type;
}
void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) {
void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(const std::string &fusion_type) {
MS_EXCEPTION_IF_NULL(kernel_build_info_);
kernel_build_info_->fusion_type_ = fusion_type;
}

View File

@ -102,7 +102,7 @@ class BACKEND_EXPORT KernelBuildInfo {
std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; }
FusionType fusion_type() const { return fusion_type_; }
std::string fusion_type() const { return fusion_type_; }
Processor processor() const { return processor_; }
@ -136,7 +136,7 @@ class BACKEND_EXPORT KernelBuildInfo {
std::vector<KernelObjectType> inputs_kernel_object_type_;
std::vector<KernelObjectType> outputs_kernel_object_type_;
std::vector<nlohmann::json> output_data_desc_;
FusionType fusion_type_{kernel::FusionType::OPAQUE};
std::string fusion_type_{kernel::kPatternOpaque};
Processor processor_{AICORE};
};
using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
@ -189,7 +189,7 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
void SetCoreType(const std::string &core_type);
void SetFusionType(FusionType fusion_type);
void SetFusionType(const std::string &fusion_type);
// save prebuild result
void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc);

View File

@ -532,7 +532,7 @@ kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBui
selected_kernel_builder.SetInputsFormat(formats);
selected_kernel_builder.SetInputsDeviceType(type_ids);
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
return selected_kernel_builder;
@ -626,7 +626,7 @@ CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr<session::Kern
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8});
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
selected_kernel_builder.SetProcessor(kernel::Processor::AICPU);
selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL);
@ -748,7 +748,7 @@ CNodePtr KernelAdjust::CreateNPUGetFloatStatus(const std::shared_ptr<session::Ke
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
selected_kernel_builder.SetInputsDeviceType({kNumberTypeFloat32});
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
@ -771,7 +771,7 @@ CNodePtr KernelAdjust::CreateNPUClearStatus(const std::shared_ptr<session::Kerne
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
selected_kernel_builder.SetInputsDeviceType({kNumberTypeFloat32});
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
@ -793,7 +793,7 @@ CNodePtr KernelAdjust::CreateNPUAllocStatus(const std::shared_ptr<session::Kerne
common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {npu_output_shape}, npu_alloc_cnode.get());
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});

View File

@ -455,7 +455,7 @@ KernelSelectStatus SelectCustomKernelInfo(const CNodePtr &kernel_node, KernelTyp
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
builder->SetKernelType(*kernel_type);
builder->SetProcessor(kernel::Processor::AICORE);
builder->SetFusionType(kernel::FusionType::OPAQUE);
builder->SetFusionType(kernel::kPatternOpaque);
builder->SetOpPattern(kernel::OpPattern::kCommonPattern);
// set inputs info
std::vector<TypeId> inputs_device_type;

View File

@ -456,7 +456,7 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetProcessor(kernel::Processor::AICORE);
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
auto graph_selected_info = graph_info_builder.Build();
MS_EXCEPTION_IF_NULL(graph_selected_info);
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());

View File

@ -246,7 +246,7 @@ NotNull<CNodePtr> ProfilingUtils::CreateProfilingCNode(const ProfilingContent &p
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
selected_kernel_builder.SetInputsDeviceType({TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
abstract::AbstractBasePtr type_none_abstract = std::make_shared<abstract::AbstractNone>();

View File

@ -81,7 +81,7 @@ void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node,
builder.SetOutputsDeviceType(outputs_type);
builder.SetProcessor(AICPU);
builder.SetKernelType(AICPU_KERNEL);
builder.SetFusionType(OPAQUE);
builder.SetFusionType(kPatternOpaque);
(void)kernel_info_list->emplace_back(builder.Build());
return;
}

View File

@ -82,7 +82,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetNextDynamicDesc::GetKer
builder.SetOutputsDeviceType(output_type);
builder.SetProcessor(AICORE);
builder.SetKernelType(RT_KERNEL);
builder.SetFusionType(OPAQUE);
builder.SetFusionType(kPatternOpaque);
get_next_dynamic_build_info.emplace_back(builder.Build());
return get_next_dynamic_build_info;
}

View File

@ -84,7 +84,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernel
builder.SetInputsDeviceType({input_type[i]});
builder.SetProcessor(AICORE);
builder.SetKernelType(RT_KERNEL);
builder.SetFusionType(OPAQUE);
builder.SetFusionType(kPatternOpaque);
// LabelSwitch always return UMonad.
builder.SetOutputsFormat({kOpFormat_DEFAULT});
builder.SetOutputsDeviceType({TypeId::kObjectTypeUMonad});

View File

@ -159,7 +159,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> MemCpyAsyncDesc::GetKernel
builder.SetOutputsDeviceType(output_type);
builder.SetProcessor(AICORE);
builder.SetKernelType(RT_KERNEL);
builder.SetFusionType(OPAQUE);
builder.SetFusionType(kPatternOpaque);
memcpy_build_info.emplace_back(builder.Build());
}
}

View File

@ -81,7 +81,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
kernel_build_info_builder->SetOutputsDeviceType({TypeId::kObjectTypeUMonad});
// set other info
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
kernel_build_info_builder->SetFusionType(kernel::kPatternOpaque);
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
kernel_info_list->push_back(kernel_build_info_builder->Build());

View File

@ -185,7 +185,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> TensorCopySlicesDesc::GetK
builder.SetOutputsDeviceType(output_type);
builder.SetProcessor(AICORE);
builder.SetKernelType(RT_KERNEL);
builder.SetFusionType(OPAQUE);
builder.SetFusionType(kPatternOpaque);
tensor_copy_slices_build_info.emplace_back(builder.Build());
}
}

View File

@ -446,7 +446,7 @@ void TbeJsonCreator::GenComputeCommonJson(const AnfNodePtr &anf_node, nlohmann::
(*compute_json)[kJIsDynamicImpl] = is_dynamic_impl;
(*compute_json)[kJInt64Mode] = false;
(*compute_json)[kJName] = cnode->fullname_with_scope();
(*compute_json)[kJPattern] = kernel::GetFusionNameByType(AnfAlgo::GetFusionType(cnode));
(*compute_json)[kJPattern] = AnfAlgo::GetFusionType(cnode);
(*compute_json)[kJModuleName] = kJModuleNamePrefix + func_name;
}

View File

@ -373,13 +373,12 @@ void TbeKernelCompileManager::SavePreBuildResult(const std::string &json_name, c
return;
}
auto op_pattern = GetJsonValue<std::string>(result, "op_pattern");
auto fusion_type = kernel::GetFusionTypeByName(op_pattern);
auto output_data_desc = GetJsonValue<nlohmann::json>(result, "op_params");
auto core_type = GetJsonValue<nlohmann::json>(result, "core_type");
// save pre build result
struct PreBuildResult pre_res;
pre_res.json_name = json_name;
pre_res.fusion_type = fusion_type;
pre_res.fusion_type = op_pattern;
pre_res.core_type = core_type;
pre_res.output_data_desc = output_data_desc;
prebuild_res_map_[json_name] = pre_res;
@ -594,12 +593,11 @@ void TbeKernelCompileManager::UpdateFusionTypeAndOutputDataDesc(const std::vecto
}
auto pre_res = prebuild_res_map_[kernel_name];
auto fusion_type = pre_res.fusion_type;
auto fusion_name = GetFusionNameByType(fusion_type);
auto output_data_desc = pre_res.output_data_desc;
auto core_type = pre_res.core_type;
AnfAlgo::SetCoreType(node, core_type);
AnfAlgo::SetFusionType(node, fusion_type);
common::AnfAlgo::SetNodeAttr(kAttrTbeFusionType, MakeValue(fusion_name), node);
common::AnfAlgo::SetNodeAttr(kAttrTbeFusionType, MakeValue(fusion_type), node);
AnfAlgo::SetOutputDataDesc(node, {output_data_desc});
}
MS_LOG(INFO) << "End update fusion type after pre build";

View File

@ -51,7 +51,7 @@ struct TaskInfo {
struct PreBuildResult {
std::string core_type;
std::string json_name;
kernel::FusionType fusion_type;
std::string fusion_type;
nlohmann::json output_data_desc;
};

View File

@ -480,7 +480,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
builder.SetOutputsReshapeType({reshape_type});
builder.SetInputsDeviceType({input_type});
builder.SetOutputsDeviceType({output_type});
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetFusionType(kernel::kPatternOpaque);
builder.SetProcessor(kernel::Processor::AICORE);
if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kImplyTBE) != nullptr) {
builder.SetKernelType(KernelType::TBE_KERNEL);

View File

@ -28,7 +28,7 @@ namespace opt {
namespace {
constexpr auto kAttrNoFusion = "no_fusion";
CNodePtr FindInputNode(const CNodePtr &cnode, const string &node_type, const kernel::FusionType &fusion_type) {
CNodePtr FindInputNode(const CNodePtr &cnode, const string &node_type, const std::string &fusion_type) {
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
for (size_t i = 1; i <= input_num; ++i) {
auto input = cnode->input(i);
@ -71,7 +71,7 @@ bool BatchMatmulEltwiseFusionPass::MatchPattern2(const CNodePtr &eltwise,
return false;
}
CNodePtr bmm = FindInputNode(eltwise, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL);
CNodePtr bmm = FindInputNode(eltwise, kBatchMatMulOpName, kernel::kPatternBatchMatmul);
if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm) || common::AnfAlgo::GetBooleanAttr(bmm, kAttrNoFusion)) {
return false;
}
@ -88,17 +88,17 @@ bool BatchMatmulEltwiseFusionPass::MatchPattern3(const CNodePtr &eltwise,
return false;
}
CNodePtr eltwise2 = FindInputNode(eltwise, kSigmoidOpName, kernel::FusionType::ELEMWISE);
CNodePtr eltwise2 = FindInputNode(eltwise, kSigmoidOpName, kernel::kPatternElemWise);
if (eltwise2 == nullptr) {
return false;
}
CNodePtr eltwise1 = FindInputNode(eltwise2, kMulOpName, kernel::FusionType::ELEMWISE);
CNodePtr eltwise1 = FindInputNode(eltwise2, kMulOpName, kernel::kPatternElemWise);
if (eltwise1 == nullptr) {
return false;
}
CNodePtr bmm = FindInputNode(eltwise1, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL);
CNodePtr bmm = FindInputNode(eltwise1, kBatchMatMulOpName, kernel::kPatternMatmul);
if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm)) {
return false;
}
@ -124,8 +124,8 @@ void BatchMatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::Kerne
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
(AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE ||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::BROAD_CAST)) {
(AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise ||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternBroadcast)) {
mindspore::HashSet<AnfNodePtr> record;
if (MatchPattern1(cnode, &record) || MatchPattern2(cnode, &record) || MatchPattern3(cnode, &record)) {
candidate_fusion->push_back(record);

View File

@ -34,7 +34,7 @@ void BatchMatmulReduceSumFusionPass::MatchBatchMatmulReduceSum(const CNodePtr &r
auto batch_matmul = reduce_sum->input(kIndex1);
MS_EXCEPTION_IF_NULL(batch_matmul);
const PrimitiveSet batch_matmul_prims{prim::kPrimBatchMatMul, prim::kPrimBatchMatMulV2};
if (!batch_matmul->isa<CNode>() || AnfAlgo::GetFusionType(batch_matmul) != kernel::FusionType::BATCH_MATMUL ||
if (!batch_matmul->isa<CNode>() || AnfAlgo::GetFusionType(batch_matmul) != kernel::kPatternBatchMatmul ||
!IsOneOfPrimitiveCNode(batch_matmul, batch_matmul_prims)) {
return;
}
@ -84,7 +84,7 @@ void BatchMatmulReduceSumFusionPass::MatchSingleFusionPattern(const session::Ker
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::COMMREDUCE &&
AnfAlgo::GetFusionType(cnode) == kernel::kPatternCommReduce &&
common::AnfAlgo::GetCNodeName(cnode) == kReduceSumDOpName) {
MatchBatchMatmulReduceSum(cnode, kernel_graph, candidate_fusion);
}

View File

@ -65,7 +65,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise &&
common::AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE &&
common::AnfAlgo::GetInputTensorNum(cnode) == (ELTWISE_INPUT_SIZE - 1)) {
auto eltwise_input = cnode->input(kIndex1);

View File

@ -54,7 +54,7 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise &&
common::AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) {
auto eltwise_input = cnode->input(1);
MS_EXCEPTION_IF_NULL(eltwise_input);

View File

@ -94,7 +94,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const sess
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise &&
common::AnfAlgo::GetCNodeName(cnode) == kReluGradV2OpName) {
MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion);
}

View File

@ -45,7 +45,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
return;
}
if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONV) {
AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::kPatternConvolution) {
(void)record.insert(double_in_eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);
@ -65,7 +65,7 @@ void ConvDoubleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE &&
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE &&
!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReluV2)) {
MatchConvDoubleInEltwise(cnode, kernel_graph, candidate_fusion);
}

View File

@ -47,7 +47,7 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
return;
}
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONV) {
AnfAlgo::GetFusionType(eltwise_input) == kernel::kPatternConvolution) {
(void)record.insert(eltwise_input);
candidate_fusion->push_back(record);
SetRecordFusionId(record);
@ -67,7 +67,7 @@ void ConvSingleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
MatchConvSingleInEltwise(cnode, kernel_graph, candidate_fusion);
}
}

View File

@ -63,7 +63,7 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise) {
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);
if (eltwise_input->isa<CNode>() &&

View File

@ -64,8 +64,8 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
(AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE ||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::BROAD_CAST) &&
(AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise ||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternBroadcast) &&
cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
MatchEltwise(cnode, kernel_graph, candidate_fusion);
}
@ -83,8 +83,8 @@ bool EltwiseFusionPass::CheckEltWiseOrBroadCastNode(const session::KernelGraph &
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
(AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE ||
AnfAlgo::GetFusionType(node) == kernel::FusionType::BROAD_CAST) &&
(AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise ||
AnfAlgo::GetFusionType(node) == kernel::kPatternBroadcast) &&
not_updatestate_nums == ELTWISE_USE && cnode->inputs().size() == ELTWISE_INPUT_SIZE;
}
@ -100,8 +100,8 @@ bool EltwiseFusionPass::CheckDoubleInEltWiseOrBroadCastNode(const session::Kerne
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
(AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE ||
AnfAlgo::GetFusionType(node) == kernel::FusionType::BROAD_CAST) &&
(AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise ||
AnfAlgo::GetFusionType(node) == kernel::kPatternBroadcast) &&
not_updatestate_nums == ELTWISE_USE && cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE;
}
} // namespace opt

View File

@ -33,7 +33,7 @@ bool FusionBasePass::CheckEltWiseNode(const session::KernelGraph &kernel_graph,
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && not_updatestate_nums == ELTWISE_USE &&
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_USE &&
cnode->inputs().size() == ELTWISE_INPUT_SIZE;
}
@ -48,7 +48,7 @@ bool FusionBasePass::CheckDoubleInEltWiseNode(const session::KernelGraph &kernel
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && not_updatestate_nums == ELTWISE_USE &&
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_USE &&
cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE;
}
@ -63,7 +63,7 @@ bool FusionBasePass::CheckMultiOutputEltWiseNode(const session::KernelGraph &ker
MS_EXCEPTION_IF_NULL(cnode);
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && not_updatestate_nums == ELTWISE_MULTI_USE &&
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_MULTI_USE &&
cnode->inputs().size() == ELTWISE_INPUT_SIZE;
}

View File

@ -48,8 +48,8 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
(AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE ||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::BROAD_CAST) &&
(AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise ||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternBroadcast) &&
common::AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_SINGLE_OUTPUT_SIZE) {
auto eltwise_input = cnode->input(kIndex1);
MS_EXCEPTION_IF_NULL(eltwise_input);

View File

@ -68,7 +68,7 @@ void MultiOutputFusionPass::MatchSingleFusionPattern(const session::KernelGraph
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
MatchMultiOutputEltwise(cnode, kernel_graph, candidate_fusion);
}
}

View File

@ -47,7 +47,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
return;
}
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::COMMREDUCE &&
AnfAlgo::GetFusionType(eltwise_input) == kernel::kPatternCommReduce &&
GetNodeOutputTotalUsedNum(kernel_graph, eltwise_input) == 1) {
(void)record.insert(eltwise_input);
auto previous_input_cnode = eltwise_input->cast<CNodePtr>();
@ -82,7 +82,7 @@ void ReduceEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
MS_EXCEPTION_IF_NULL(cnode);
// Fusion squaresumv1 and sqrt will get worse performance in bert
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE &&
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE &&
common::AnfAlgo::GetCNodeName(cnode) != kCastOpName && common::AnfAlgo::GetCNodeName(cnode) != kSqrtOpName) {
MatchReduceEltwise(cnode, kernel_graph, candidate_fusion);
}

View File

@ -45,7 +45,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
return;
}
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::SEGMENT) {
AnfAlgo::GetFusionType(eltwise_input) == kernel::kPatternSegment) {
(void)record.insert(eltwise_input);
auto previous_input_cnode = eltwise_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(previous_input_cnode);
@ -79,7 +79,7 @@ void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGra
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
MatchSegmentEltwise(cnode, kernel_graph, candidate_fusion);
}
}

View File

@ -44,7 +44,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
auto conv_cnode = write_input->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(conv_cnode);
if (AnfAlgo::GetKernelType(conv_cnode) == KernelType::TBE_KERNEL &&
AnfAlgo::GetFusionType(conv_cnode) == kernel::FusionType::CONV &&
AnfAlgo::GetFusionType(conv_cnode) == kernel::kPatternConvolution &&
conv_cnode->inputs().size() >= CONV_DOUBLE_IN_INPUT_SIZE &&
conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) {
(void)record.insert(write_input);

View File

@ -131,7 +131,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, con
(void)outputs_device_format.emplace_back(*cmp_format);
(void)outputs_device_type.emplace_back(*cmp_dtype);
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetFusionType(kernel::kPatternOpaque);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(TBE_KERNEL);
builder.SetInputsFormat(inputs_device_format);

View File

@ -124,7 +124,7 @@ ValueNodePtr CreateAssistNode(const std::vector<int64_t> &input_shape, int32_t k
kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetKernelType(TBE_KERNEL);
builder.SetFusionType(kernel::OPAQUE);
builder.SetFusionType(kernel::kPatternOpaque);
builder.SetProcessor(kernel::AICORE);
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});

View File

@ -364,11 +364,7 @@ REG_ASCEND_VM_OP_ADAPTATION_INFO(kResizeNearestNeighborGradOpName)
.set_need_tbe_check_supported(true)
.set_input_attr_info(1, "listInt");
REG_ASCEND_VM_OP_ADAPTATION_INFO(kResizeNearestNeighborOpName)
.set_backend_op_name(kResizeNearestNeighborV2OpName)
.set_target_op_name(kResizeNearestNeighborV2DOpName)
.set_need_tbe_check_supported(true)
.set_input_attr_info(1, "listInt");
REG_ASCEND_VM_OP_ADAPTATION_INFO(kResizeNearestNeighborOpName).set_backend_op_name(kResizeNearestNeighborV2DOpName);
REG_ASCEND_VM_OP_ADAPTATION_INFO(kReverseV2OpName)
.set_target_op_name(kReverseV2DOpName)

View File

@ -454,7 +454,7 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_gr
graph_info_builder.SetOutputsDeviceType(graph_output_type);
graph_info_builder.SetProcessor(kernel::Processor::CUDA);
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
auto graph_selected_info = graph_info_builder.Build();
MS_EXCEPTION_IF_NULL(graph_selected_info);
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());

View File

@ -85,7 +85,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, con
outputs_device_format.emplace_back(*cmp_format);
outputs_device_type.emplace_back(*cmp_dtype);
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetFusionType(kernel::kPatternOpaque);
builder.SetInputsFormat(inputs_device_format);
builder.SetOutputsFormat(outputs_device_format);
builder.SetInputsDeviceType(inputs_device_type);

View File

@ -42,7 +42,7 @@ class TestHcclAdapter : public UT::Common {
void SetOutputs(const CNodePtr &cnode, const std::vector<ShapeVector> &shape, const std::vector<TypeId> &data_type) {
common::AnfAlgo::SetOutputInferTypeAndShape(data_type, shape, cnode.get());
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetFusionType(kernel::kPatternOpaque);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(TBE_KERNEL);
builder.SetInputsFormat(std::vector<std::string>(cnode->size() - 1, format_));
@ -65,7 +65,7 @@ class TestHcclAdapter : public UT::Common {
common::AnfAlgo::SetOutputInferTypeAndShape(std::vector<TypeId>{data_type[i]}, std::vector<ShapeVector>{shape[i]},
node.get());
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
builder.SetFusionType(kernel::FusionType::OPAQUE);
builder.SetFusionType(kernel::kPatternOpaque);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(TBE_KERNEL);
builder.SetInputsFormat({format_});

View File

@ -73,7 +73,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) {
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
@ -88,7 +88,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) {
builder1.SetInputsDeviceType({kFloat32->type_id()});
builder1.SetOutputsDeviceType({kFloat16->type_id()});
builder1.SetKernelType(KernelType::TBE_KERNEL);
builder1.SetFusionType(kernel::FusionType::OPAQUE);
builder1.SetFusionType(kernel::kPatternOpaque);
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::TBE_KERNEL);
@ -144,7 +144,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) {
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
@ -167,7 +167,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) {
builder1.SetInputsDeviceType({kFloat32->type_id()});
builder1.SetOutputsDeviceType({kFloat16->type_id()});
builder1.SetKernelType(KernelType::TBE_KERNEL);
builder1.SetFusionType(kernel::FusionType::OPAQUE);
builder1.SetFusionType(kernel::kPatternOpaque);
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::TBE_KERNEL);
@ -180,7 +180,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) {
builder2.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder2.SetOutputsDeviceType({kFloat32->type_id()});
builder2.SetKernelType(KernelType::TBE_KERNEL);
builder2.SetFusionType(kernel::FusionType::COMMREDUCE);
builder2.SetFusionType(kernel::kPatternCommReduce);
builder2.SetProcessor(kernel::Processor::AICORE);
builder2.SetKernelType(KernelType::TBE_KERNEL);
@ -234,7 +234,7 @@ TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) {
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
@ -257,7 +257,7 @@ TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) {
builder1.SetInputsDeviceType({kFloat32->type_id()});
builder1.SetOutputsDeviceType({kFloat16->type_id()});
builder1.SetKernelType(KernelType::TBE_KERNEL);
builder1.SetFusionType(kernel::FusionType::OPAQUE);
builder1.SetFusionType(kernel::kPatternOpaque);
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::TBE_KERNEL);
@ -270,7 +270,7 @@ TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) {
builder2.SetInputsDeviceType({kFloat32->type_id()});
builder2.SetOutputsDeviceType({kFloat32->type_id()});
builder2.SetKernelType(KernelType::TBE_KERNEL);
builder2.SetFusionType(kernel::FusionType::COMMREDUCE);
builder2.SetFusionType(kernel::kPatternCommReduce);
builder2.SetProcessor(kernel::Processor::AICORE);
builder2.SetKernelType(KernelType::TBE_KERNEL);
@ -316,7 +316,7 @@ TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) {
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
relu->set_kernel_info(std::make_shared<device::KernelInfo>());
@ -328,7 +328,7 @@ TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) {
builder2.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder2.SetOutputsDeviceType({kFloat32->type_id()});
builder2.SetKernelType(KernelType::TBE_KERNEL);
builder2.SetFusionType(kernel::FusionType::OPAQUE);
builder2.SetFusionType(kernel::kPatternOpaque);
builder2.SetProcessor(kernel::Processor::AICORE);
builder2.SetKernelType(KernelType::TBE_KERNEL);
matmul->set_kernel_info(std::make_shared<device::KernelInfo>());
@ -340,7 +340,7 @@ TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) {
builder1.SetInputsDeviceType({kFloat32->type_id()});
builder1.SetOutputsDeviceType({kFloat16->type_id()});
builder1.SetKernelType(KernelType::TBE_KERNEL);
builder1.SetFusionType(kernel::FusionType::OPAQUE);
builder1.SetFusionType(kernel::kPatternOpaque);
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::TBE_KERNEL);
cast->set_kernel_info(std::make_shared<device::KernelInfo>());

View File

@ -58,7 +58,7 @@ TEST_F(TestHWInsertCast, test_insert_cast_op_for_single_output) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AKG_KERNEL);
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
@ -66,7 +66,7 @@ TEST_F(TestHWInsertCast, test_insert_cast_op_for_single_output) {
builder1.SetInputsDeviceType({kFloat32->type_id()});
builder1.SetOutputsFormat({"NC1HWC0"});
builder1.SetOutputsDeviceType({kFloat32->type_id()});
builder1.SetFusionType(kernel::FusionType::ELEMWISE);
builder1.SetFusionType(kernel::kPatternElemWise);
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::AKG_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
@ -120,7 +120,7 @@ TEST_F(TestHWInsertCast, test_insert_cast_op_for_multiple_output) {
builder1.SetInputsDeviceType({kFloat32->type_id()});
builder1.SetOutputsFormat({"DefaultFormat"});
builder1.SetOutputsDeviceType({kFloat32->type_id()});
builder1.SetFusionType(kernel::FusionType::ELEMWISE);
builder1.SetFusionType(kernel::kPatternElemWise);
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::AKG_KERNEL);
auto node_list = TopoSort(func_graph->get_return());

View File

@ -124,7 +124,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
builder.SetOutputsFormat({kOpFormat_C1HWNCoC0});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});
@ -172,7 +172,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) {
builder.SetOutputsFormat({"NCHW"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetInputsReshapeType({""});
builder.SetOutputsReshapeType({""});

View File

@ -70,7 +70,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetOutputsDeviceType({kFloat16->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
auto kernel_info = std::make_shared<device::KernelInfo>();
kernel_info->set_select_kernel_build_info(builder.Build());

View File

@ -97,7 +97,7 @@ static KernelGraphPtr CreateKernelGraph() {
builder.SetOutputsFormat({kOpFormat_NCHW});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(mindspore::kernel::CONV);
builder.SetFusionType(mindspore::kernel::kPatternConvolution);
builder.SetProcessor(kernel::Processor::AICORE);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), kernelptr_first.get());
@ -127,7 +127,7 @@ static KernelGraphPtr CreateKernelGraph() {
relu_builder.SetInputsDeviceType({kFloat32->type_id()});
relu_builder.SetOutputsDeviceType({kFloat32->type_id()});
relu_builder.SetKernelType(KernelType::TBE_KERNEL);
relu_builder.SetFusionType(kernel::FusionType::ELEMWISE);
relu_builder.SetFusionType(kernel::kPatternElemWise);
relu_builder.SetProcessor(kernel::Processor::AICORE);
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), kernelptr_floor.get());
next_cnode_ptr = kernelptr_floor;

View File

@ -55,7 +55,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_all) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AKG_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
@ -96,7 +96,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_group) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AKG_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
@ -137,7 +137,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AKG_KERNEL);
auto node_list = TopoSort(func_graph->get_return());
@ -194,7 +194,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_sorted) {
builder.SetOutputsFormat({"NC1HWC0"});
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::AKG_KERNEL);
auto node_list = TopoSort(func_graph->get_return());

View File

@ -676,9 +676,9 @@ TEST_F(AnfRuntimeAlgorithmTest, GetFusionType) {
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
MS_EXCEPTION_IF_NULL(d_kernel_info);
KernelBuildInfoBuilder builder;
builder.SetFusionType(kernel::CONV);
builder.SetFusionType(kernel::kPatternConvolution);
d_kernel_info->set_select_kernel_build_info(builder.Build());
EXPECT_EQ(AnfAlgo::GetFusionType(add), kernel::CONV);
EXPECT_EQ(AnfAlgo::GetFusionType(add), kernel::kPatternConvolution);
EXPECT_THROW(AnfAlgo::GetFusionType(nullptr), std::runtime_error);
}
@ -688,10 +688,10 @@ TEST_F(AnfRuntimeAlgorithmTest, SetSelectKernelBuildInfo) {
inputs.push_back(NewValueNode(prim::kPrimAdd));
auto add = kernel_graph->NewCNode(inputs);
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
builder->SetFusionType(kernel::CONV);
builder->SetFusionType(kernel::kPatternConvolution);
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), add.get());
EXPECT_THROW(AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), nullptr), std::runtime_error);
EXPECT_EQ(AnfAlgo::GetFusionType(add), kernel::CONV);
EXPECT_EQ(AnfAlgo::GetFusionType(add), kernel::kPatternConvolution);
}
TEST_F(AnfRuntimeAlgorithmTest, GetKernelMod) {

View File

@ -65,7 +65,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) {
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
@ -107,7 +107,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
@ -166,7 +166,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
builder.SetOutputsDeviceType({kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(),
kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
@ -219,7 +219,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
@ -263,7 +263,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
builder.SetInputsDeviceType({kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);
@ -278,7 +278,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
builder1.SetInputsDeviceType({kFloat32->type_id()});
builder1.SetOutputsDeviceType({kFloat16->type_id()});
builder1.SetKernelType(KernelType::TBE_KERNEL);
builder1.SetFusionType(kernel::FusionType::OPAQUE);
builder1.SetFusionType(kernel::kPatternOpaque);
builder1.SetProcessor(kernel::Processor::AICORE);
builder1.SetKernelType(KernelType::TBE_KERNEL);
@ -337,7 +337,7 @@ TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
builder.SetOutputsDeviceType({kFloat32->type_id()});
builder.SetKernelType(KernelType::TBE_KERNEL);
builder.SetFusionType(kernel::FusionType::ELEMWISE);
builder.SetFusionType(kernel::kPatternElemWise);
builder.SetProcessor(kernel::Processor::AICORE);
builder.SetKernelType(KernelType::TBE_KERNEL);