forked from mindspore-Ecosystem/mindspore
remove enum fusion type
This commit is contained in:
parent
ca92df14bd
commit
82f04825f3
|
@ -669,7 +669,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
|
|||
return build_info->kernel_type();
|
||||
}
|
||||
|
||||
void AnfRuntimeAlgorithm::SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type) {
|
||||
void AnfRuntimeAlgorithm::SetFusionType(const AnfNodePtr &node, const std::string &type) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto builder =
|
||||
std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>(AnfAlgo::GetSelectKernelBuildInfo(node));
|
||||
|
@ -718,13 +718,13 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
|
|||
return build_info->processor();
|
||||
}
|
||||
|
||||
kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
|
||||
std::string AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
|
||||
MS_EXCEPTION_IF_NULL(node);
|
||||
auto kernel_info = dynamic_cast<device::KernelInfo *>(node->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(kernel_info);
|
||||
auto build_info = kernel_info->select_kernel_build_info();
|
||||
if (build_info == nullptr) {
|
||||
return kernel::FusionType::UNKNOWN_FUSION_TYPE;
|
||||
return kernel::kPatternUnknown;
|
||||
}
|
||||
return build_info->fusion_type();
|
||||
}
|
||||
|
|
|
@ -129,8 +129,8 @@ class BACKEND_EXPORT AnfRuntimeAlgorithm {
|
|||
// get processor type:AICORE,AICPU...
|
||||
static kernel::Processor GetProcessor(const AnfNodePtr &node);
|
||||
// get fusion type:AICORE,AICPU...
|
||||
static kernel::FusionType GetFusionType(const AnfNodePtr &node);
|
||||
static void SetFusionType(const AnfNodePtr &node, const kernel::FusionType &type);
|
||||
static std::string GetFusionType(const AnfNodePtr &node);
|
||||
static void SetFusionType(const AnfNodePtr &node, const std::string &type);
|
||||
static void SetOutputDataDesc(const AnfNodePtr &node, const std::vector<nlohmann::json> &desc);
|
||||
static std::vector<nlohmann::json> GetOutputDataDesc(const AnfNodePtr &node);
|
||||
// core type
|
||||
|
|
|
@ -135,7 +135,7 @@ void CallbackImpl::SetGraphKernelNodeKernelInfo(const AnfNodePtr &node) {
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder graph_info_builder;
|
||||
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
graph_info_builder.SetInputsFormat(graph_input_format);
|
||||
graph_info_builder.SetInputsDeviceType(graph_input_type);
|
||||
graph_info_builder.SetOutputsFormat(graph_output_format);
|
||||
|
@ -189,7 +189,7 @@ void CallbackImpl::SetBasicNodeKernelInfo(const AnfNodePtr &node, const std::vec
|
|||
info_builder.SetOutputsDeviceType(output_types);
|
||||
info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
auto selected_info = info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_info, node.get());
|
||||
}
|
||||
|
|
|
@ -41,7 +41,7 @@ CNodePtr AddCastCNode(const FuncGraphPtr &func_graph, const AnfNodePtr &input, c
|
|||
builder.SetOutputsFormat({format});
|
||||
builder.SetInputsDeviceType({input_type});
|
||||
builder.SetOutputsDeviceType({output_type});
|
||||
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder.SetFusionType(kernel::kPatternOpaque);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
if (cast->kernel_info() == nullptr) {
|
||||
|
@ -193,7 +193,7 @@ bool DecreaseComputePrecision::Process(const FuncGraphPtr &func_graph) const {
|
|||
graph_info_builder.SetOutputsDeviceType(cnode_output_type);
|
||||
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
auto info_1 = graph_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode1.get());
|
||||
if (is_output) {
|
||||
|
|
|
@ -168,7 +168,7 @@ bool DecreaseTransferPrecision::ProcessFather(const FuncGraphPtr &, const AnfNod
|
|||
graph_info_builder.SetOutputsDeviceType(cnode_output_type);
|
||||
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
auto info_1 = graph_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
|
||||
return cnode;
|
||||
|
@ -281,7 +281,7 @@ bool DecreaseTransferPrecision::ProcessSon(const FuncGraphPtr &, const AnfNodePt
|
|||
node_info_builder.SetOutputsDeviceType(cnode_output_type);
|
||||
node_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
node_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
node_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
node_info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
auto info_1 = node_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_1, cnode.get());
|
||||
(void)mng->Replace(old_input, cnode);
|
||||
|
|
|
@ -113,7 +113,7 @@ kernel::KernelBuildInfoPtr BuildSelectKernelBuildInfo(const std::vector<std::str
|
|||
graph_info_builder.SetOutputsDeviceType(output_types);
|
||||
graph_info_builder.SetProcessor(processor);
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
return graph_info_builder.Build();
|
||||
}
|
||||
|
||||
|
@ -168,7 +168,7 @@ void SetNewKernelInfo(const AnfNodePtr &new_node, const FuncGraphPtr &fg, const
|
|||
graph_info_builder.SetOutputsDeviceType(graph_output_type);
|
||||
graph_info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
auto graph_selected_info = graph_info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, new_node.get());
|
||||
}
|
||||
|
@ -394,7 +394,7 @@ CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &
|
|||
info_builder.SetOutputsDeviceType(output_types);
|
||||
info_builder.SetProcessor(kernel::GetProcessorFromContext());
|
||||
info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
auto selected_info = info_builder.Build();
|
||||
AnfAlgo::SetSelectKernelBuildInfo(selected_info, cnode.get());
|
||||
|
||||
|
|
|
@ -56,7 +56,7 @@ AnfNodePtr RaiseReductionPrecision::CreateReduceSum(const AnfNodePtr &node, cons
|
|||
info_builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
info_builder.SetProcessor(AnfAlgo::GetProcessor(node));
|
||||
info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(info_builder.Build(), cnode.get());
|
||||
return node;
|
||||
}
|
||||
|
|
|
@ -247,7 +247,7 @@ class CNodeDecoder {
|
|||
builder->SetOutputsDeviceType(output_types_);
|
||||
builder->SetProcessor(processor);
|
||||
builder->SetKernelType(KernelType::AKG_KERNEL);
|
||||
builder->SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder->SetFusionType(kernel::kPatternOpaque);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), cnode_.get());
|
||||
}
|
||||
|
||||
|
|
|
@ -48,60 +48,6 @@ constexpr auto kStridedSliceMaxDims = 8;
|
|||
constexpr auto kQuad = 4;
|
||||
constexpr size_t kInputFirstIndex = 0;
|
||||
constexpr char kOperatorOriginFormat[] = "operator_origin_format";
|
||||
|
||||
// Define all patterns here for different schedule
|
||||
const std::unordered_map<FusionType, std::string> fusion_type_name_maps = {
|
||||
{FusionType::BN_UPDATE_GRAD, "bn_update_grad"},
|
||||
{FusionType::BN_GRAD_REDUCE, "bn_grad_reduce"},
|
||||
{FusionType::LAYER_NORM_GRAD, "layer_norm_grad"},
|
||||
{FusionType::L2LOSS_MUL_ADDN, "l2loss_mul_addn"},
|
||||
{FusionType::ELEMWISE, "ElemWise"},
|
||||
{FusionType::PURE_BROADCAST, "PureBroadcast"},
|
||||
{FusionType::COMMREDUCE, "CommReduce"},
|
||||
{FusionType::SEGMENT, "Segment"},
|
||||
{FusionType::INPLACE, "Inplace"},
|
||||
{FusionType::MATMUL, "Matmul"},
|
||||
{FusionType::MATMUL_V2, "Matmul_v2"},
|
||||
{FusionType::GEMM, "GEMM"},
|
||||
{FusionType::CONV, "Convolution"},
|
||||
{FusionType::CONV2D_BACKPROP_INPUT, "Conv2d_backprop_input"},
|
||||
{FusionType::CONV2D_BACKPROP_FILTER, "Conv2d_backprop_filter"},
|
||||
{FusionType::CONV3D_BACKPROP_INPUT, "Conv3d_backprop_input"},
|
||||
{FusionType::CONV3D_BACKPROP_FILTER, "Conv3d_backprop_filter"},
|
||||
{FusionType::CUBE_LAYER_NORM, "cube_layer_norm"},
|
||||
{FusionType::OPAQUE, "Opaque"},
|
||||
{FusionType::BN_REDUCE, "bn_reduce"},
|
||||
{FusionType::BN_UPDATE, "bn_update"},
|
||||
{FusionType::SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, "softmax_cross_entropy_with_logits"},
|
||||
{FusionType::L2_NORMALIZE, "l2_normalize"},
|
||||
{FusionType::SOFTMAX, "softmax_pattern"},
|
||||
{FusionType::L2_LOSS, "l2_loss"},
|
||||
{FusionType::ASCEND_QUANT, "quant"},
|
||||
{FusionType::ASCEND_DEQUANT, "dequant"},
|
||||
{FusionType::ASCEND_ANTI_QUANT, "anti_quant"},
|
||||
{FusionType::STRIDED_READ, "strided_read"},
|
||||
{FusionType::STRIDED_WRITE, "strided_write"},
|
||||
{FusionType::ASCEND_DEQUANT_S16, "dequant_s16"},
|
||||
{FusionType::ASCEND_REQUANT, "requant"},
|
||||
{FusionType::ASCEND_REQUANT_S16, "requant_s16"},
|
||||
{FusionType::MAX_POOL, "MaxPool"},
|
||||
{FusionType::DEPTHWISECONV, "DepthwiseConvolution"},
|
||||
{FusionType::CONV3D, "Conv3d"},
|
||||
{FusionType::POOL2D, "Pool2d"},
|
||||
{FusionType::POOL3D, "Pool3d"},
|
||||
{FusionType::READ_SELECT, "read_select"},
|
||||
{FusionType::WRITE_SELECT, "write_select"},
|
||||
{FusionType::COSINE_EMBEDDING_LOSS, "cosine_embedding_loss"},
|
||||
{FusionType::DILATION_PATTERN, "dilation"},
|
||||
{FusionType::BROAD_CAST, "Broadcast"},
|
||||
{FusionType::BATCH_MATMUL, "BatchMatmul"},
|
||||
{FusionType::CONFUSION_TRANSPOSE, "confusiontranspose"},
|
||||
{FusionType::DROPOUT_DOMASKV3D, "DropOutDoMaskV3D"},
|
||||
{FusionType::TRANSDATA, "Transdata"},
|
||||
{FusionType::NORM, "Norm"},
|
||||
{FusionType::TRANSPOSE, "Transpose"},
|
||||
{FusionType::UNKNOWN_FUSION_TYPE, ""}};
|
||||
|
||||
abstract::BaseShapePtr GetValidShapeFromAbstract(const abstract::AbstractBasePtr &abs) {
|
||||
// Other abstract class, such as AbstractCSRTensor and AbstractCOOTensor, is converted to AbstractTensor early time.
|
||||
abstract::BaseShapePtr res_shape;
|
||||
|
@ -239,29 +185,6 @@ int CalDiagOffset(int diag_index, int max_diag_len, int inner_rows, int inner_co
|
|||
return offset;
|
||||
}
|
||||
|
||||
std::string GetFusionNameByType(const kernel::FusionType &type) {
|
||||
auto iter = fusion_type_name_maps.find(type);
|
||||
if (iter == fusion_type_name_maps.end()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal fusion type: " << type;
|
||||
}
|
||||
return iter->second;
|
||||
}
|
||||
|
||||
FusionType GetFusionTypeByName(const std::string &name) {
|
||||
std::string fusion_name_upper = name;
|
||||
transform(fusion_name_upper.begin(), fusion_name_upper.end(), fusion_name_upper.begin(), ::toupper);
|
||||
auto iter =
|
||||
std::find_if(fusion_type_name_maps.begin(), fusion_type_name_maps.end(), [&fusion_name_upper](const auto &it) {
|
||||
std::string name_upper = it.second;
|
||||
transform(name_upper.begin(), name_upper.end(), name_upper.begin(), ::toupper);
|
||||
return fusion_name_upper == name_upper;
|
||||
});
|
||||
if (iter == fusion_type_name_maps.end()) {
|
||||
MS_LOG(EXCEPTION) << "Illegal fusion name: " << name;
|
||||
}
|
||||
return iter->first;
|
||||
}
|
||||
|
||||
std::string GetCompilerCachePath() { return Common::GetUserDefineCachePath(); }
|
||||
|
||||
void KernelMeta::Initialize() {
|
||||
|
|
|
@ -180,8 +180,6 @@ inline float Scaler(const size_t x, const float scale, bool half_pixel_centers)
|
|||
}
|
||||
}
|
||||
float ScaleGrid(const int x, const float scale);
|
||||
BACKEND_EXPORT FusionType GetFusionTypeByName(const std::string &name);
|
||||
BACKEND_EXPORT std::string GetFusionNameByType(const kernel::FusionType &type);
|
||||
BACKEND_EXPORT std::vector<bool> Dec2Bin(const int64_t &mask);
|
||||
BACKEND_EXPORT void FillEmptyDims(const BaseOperatorPtr &base_operator, std::vector<int64_t> *begin,
|
||||
std::vector<int64_t> *end, std::vector<int64_t> *stride, ShapeVector *input_shape);
|
||||
|
|
|
@ -56,59 +56,56 @@ enum KernelType : int {
|
|||
ACL_KERNEL,
|
||||
};
|
||||
namespace kernel {
|
||||
// Supported fusion type
|
||||
enum FusionType {
|
||||
CONV = 0,
|
||||
ELEMWISE,
|
||||
COMMREDUCE,
|
||||
SEGMENT,
|
||||
OPAQUE,
|
||||
BN_UPDATE_GRAD,
|
||||
BN_GRAD_REDUCE,
|
||||
LAYER_NORM_GRAD,
|
||||
L2LOSS_MUL_ADDN,
|
||||
PURE_BROADCAST,
|
||||
INPLACE,
|
||||
MATMUL,
|
||||
MATMUL_V2,
|
||||
GEMM,
|
||||
CONV2D_BACKPROP_INPUT,
|
||||
CONV2D_BACKPROP_FILTER,
|
||||
CONV3D_BACKPROP_INPUT,
|
||||
CONV3D_BACKPROP_FILTER,
|
||||
CUBE_LAYER_NORM,
|
||||
BN_REDUCE,
|
||||
BN_UPDATE,
|
||||
SOFTMAX_CROSS_ENTROPY_WITH_LOGITS,
|
||||
L2_NORMALIZE,
|
||||
SOFTMAX,
|
||||
L2_LOSS,
|
||||
ASCEND_QUANT,
|
||||
ASCEND_DEQUANT,
|
||||
ASCEND_ANTI_QUANT,
|
||||
STRIDED_READ,
|
||||
STRIDED_WRITE,
|
||||
ASCEND_DEQUANT_S16,
|
||||
ASCEND_REQUANT,
|
||||
ASCEND_REQUANT_S16,
|
||||
MAX_POOL,
|
||||
DEPTHWISECONV,
|
||||
CONV3D,
|
||||
POOL2D,
|
||||
POOL3D,
|
||||
READ_SELECT,
|
||||
WRITE_SELECT,
|
||||
COSINE_EMBEDDING_LOSS,
|
||||
DILATION_PATTERN,
|
||||
BROAD_CAST,
|
||||
BATCH_MATMUL,
|
||||
CONFUSION_TRANSPOSE,
|
||||
DROPOUT_DOMASKV3D,
|
||||
TRANSDATA,
|
||||
NORM,
|
||||
TRANSPOSE,
|
||||
UNKNOWN_FUSION_TYPE = -1,
|
||||
};
|
||||
constexpr auto kPatternAntiQuant = "anti_quant";
|
||||
constexpr auto kPatternDequant = "dequant";
|
||||
constexpr auto kPatternDequantS16 = "dequant_s16";
|
||||
constexpr auto kPatternQuant = "quant";
|
||||
constexpr auto kPatternRequant = "requant";
|
||||
constexpr auto kPatternRequant_s16 = "requant_s16";
|
||||
constexpr auto kPatternBatchMatmul = "BatchMatmul";
|
||||
constexpr auto kPatternBnGradReduce = "bn_grad_reduce";
|
||||
constexpr auto kPatternBnReduce = "bn_reduce";
|
||||
constexpr auto kPatternBnUpdate = "bn_update";
|
||||
constexpr auto kPatternBnUpdate_grad = "bn_update_grad";
|
||||
constexpr auto kPatternBroadcast = "Broadcast";
|
||||
constexpr auto kPatternCommReduce = "CommReduce";
|
||||
constexpr auto kPatternConfusiontranspose = "confusiontranspose";
|
||||
constexpr auto kPatternConvolution = "Convolution";
|
||||
constexpr auto kPatternConv2dBackpropFilter = "Conv2d_backprop_filter";
|
||||
constexpr auto kPatternConv2dBackpropInput = "Conv2d_backprop_input";
|
||||
constexpr auto kPatternConv3d = "Conv3d";
|
||||
constexpr auto kPatternConv3dBackpropFilter = "Conv3d_backprop_filter";
|
||||
constexpr auto kPatternConv3dBackpropInput = "Conv3d_backprop_input";
|
||||
constexpr auto kPatternCosineEmbeddingLoss = "cosine_embedding_loss";
|
||||
constexpr auto kPatternCubeLayerNorm = "cube_layer_norm";
|
||||
constexpr auto kPatternDepthwiseConvolution = "DepthwiseConvolution";
|
||||
constexpr auto kPatternDilation = "dilation";
|
||||
constexpr auto kPatternDropOutDoMaskV3D = "DropOutDoMaskV3D";
|
||||
constexpr auto kPatternElemWise = "ElemWise";
|
||||
constexpr auto kPatternGEMM = "GEMM";
|
||||
constexpr auto kPatternInplace = "Inplace";
|
||||
constexpr auto kPatternL2Loss = "l2_loss";
|
||||
constexpr auto kPatternL2Normalize = "l2_normalize";
|
||||
constexpr auto kPatternL2lossMulAddn = "l2loss_mul_addn";
|
||||
constexpr auto kPatternLayerNormGrad = "layer_norm_grad";
|
||||
constexpr auto kPatternMatmul = "Matmul";
|
||||
constexpr auto kPatternMatmulV2 = "Matmul_v2";
|
||||
constexpr auto kPatternMaxPool = "MaxPool";
|
||||
constexpr auto kPatternNorm = "Norm";
|
||||
constexpr auto kPatternOpaque = "Opaque";
|
||||
constexpr auto kPatternPool2d = "Pool2d";
|
||||
constexpr auto kPatternPool3d = "Pool3d";
|
||||
constexpr auto kPatternPureBroadcast = "PureBroadcast";
|
||||
constexpr auto kPatternread_select = "read_select";
|
||||
constexpr auto kPatternSegment = "Segment";
|
||||
constexpr auto kPatternSoftmaxPattern = "softmax_pattern";
|
||||
constexpr auto kPatternSoftmaxCrossEntropyWithLogits = "softmax_cross_entropy_with_logits";
|
||||
constexpr auto kPatternStridedRead = "strided_read";
|
||||
constexpr auto kPatternStridedWrite = "strided_write";
|
||||
constexpr auto kPatternTransdata = "Transdata";
|
||||
constexpr auto kPatternTranspose = "Transpose";
|
||||
constexpr auto kPatternUnknown = "";
|
||||
constexpr auto kPatternWriteSelect = "write_select";
|
||||
|
||||
// Backend processor
|
||||
enum Processor {
|
||||
|
|
|
@ -209,7 +209,7 @@ void KernelBuildInfo::KernelBuildInfoBuilder::SetOutputsDeviceType(const std::ve
|
|||
kernel_build_info_->outputs_device_type_ = outputs_device_type;
|
||||
}
|
||||
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(FusionType fusion_type) {
|
||||
void KernelBuildInfo::KernelBuildInfoBuilder::SetFusionType(const std::string &fusion_type) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info_);
|
||||
kernel_build_info_->fusion_type_ = fusion_type;
|
||||
}
|
||||
|
|
|
@ -102,7 +102,7 @@ class BACKEND_EXPORT KernelBuildInfo {
|
|||
|
||||
std::vector<nlohmann::json> output_data_desc() const { return output_data_desc_; }
|
||||
|
||||
FusionType fusion_type() const { return fusion_type_; }
|
||||
std::string fusion_type() const { return fusion_type_; }
|
||||
|
||||
Processor processor() const { return processor_; }
|
||||
|
||||
|
@ -136,7 +136,7 @@ class BACKEND_EXPORT KernelBuildInfo {
|
|||
std::vector<KernelObjectType> inputs_kernel_object_type_;
|
||||
std::vector<KernelObjectType> outputs_kernel_object_type_;
|
||||
std::vector<nlohmann::json> output_data_desc_;
|
||||
FusionType fusion_type_{kernel::FusionType::OPAQUE};
|
||||
std::string fusion_type_{kernel::kPatternOpaque};
|
||||
Processor processor_{AICORE};
|
||||
};
|
||||
using KernelBuildInfoPtr = std::shared_ptr<KernelBuildInfo>;
|
||||
|
@ -189,7 +189,7 @@ class BACKEND_EXPORT KernelBuildInfo::KernelBuildInfoBuilder {
|
|||
|
||||
void SetCoreType(const std::string &core_type);
|
||||
|
||||
void SetFusionType(FusionType fusion_type);
|
||||
void SetFusionType(const std::string &fusion_type);
|
||||
// save prebuild result
|
||||
void SetOutputDataDesc(const std::vector<nlohmann::json> &data_desc);
|
||||
|
||||
|
|
|
@ -532,7 +532,7 @@ kernel::KernelBuildInfo::KernelBuildInfoBuilder KernelAdjust::CreateMngKernelBui
|
|||
selected_kernel_builder.SetInputsFormat(formats);
|
||||
selected_kernel_builder.SetInputsDeviceType(type_ids);
|
||||
|
||||
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
|
||||
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
|
||||
return selected_kernel_builder;
|
||||
|
@ -626,7 +626,7 @@ CNodePtr KernelAdjust::CreateEndOfSequenceOP(const std::shared_ptr<session::Kern
|
|||
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
|
||||
selected_kernel_builder.SetInputsDeviceType({kNumberTypeUInt8});
|
||||
|
||||
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
selected_kernel_builder.SetProcessor(kernel::Processor::AICPU);
|
||||
selected_kernel_builder.SetKernelType(KernelType::AICPU_KERNEL);
|
||||
|
||||
|
@ -748,7 +748,7 @@ CNodePtr KernelAdjust::CreateNPUGetFloatStatus(const std::shared_ptr<session::Ke
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
|
||||
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
|
||||
selected_kernel_builder.SetInputsDeviceType({kNumberTypeFloat32});
|
||||
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
|
||||
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
|
@ -771,7 +771,7 @@ CNodePtr KernelAdjust::CreateNPUClearStatus(const std::shared_ptr<session::Kerne
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
|
||||
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT});
|
||||
selected_kernel_builder.SetInputsDeviceType({kNumberTypeFloat32});
|
||||
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
|
||||
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
|
@ -793,7 +793,7 @@ CNodePtr KernelAdjust::CreateNPUAllocStatus(const std::shared_ptr<session::Kerne
|
|||
common::AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32}, {npu_output_shape}, npu_alloc_cnode.get());
|
||||
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
|
||||
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
|
||||
selected_kernel_builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
selected_kernel_builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
|
|
|
@ -455,7 +455,7 @@ KernelSelectStatus SelectCustomKernelInfo(const CNodePtr &kernel_node, KernelTyp
|
|||
auto builder = std::make_shared<kernel::KernelBuildInfo::KernelBuildInfoBuilder>();
|
||||
builder->SetKernelType(*kernel_type);
|
||||
builder->SetProcessor(kernel::Processor::AICORE);
|
||||
builder->SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder->SetFusionType(kernel::kPatternOpaque);
|
||||
builder->SetOpPattern(kernel::OpPattern::kCommonPattern);
|
||||
// set inputs info
|
||||
std::vector<TypeId> inputs_device_type;
|
||||
|
|
|
@ -456,7 +456,7 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const std::vector<std::pair
|
|||
graph_info_builder.SetOutputsDeviceType(graph_output_type);
|
||||
graph_info_builder.SetProcessor(kernel::Processor::AICORE);
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
auto graph_selected_info = graph_info_builder.Build();
|
||||
MS_EXCEPTION_IF_NULL(graph_selected_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
|
||||
|
|
|
@ -246,7 +246,7 @@ NotNull<CNodePtr> ProfilingUtils::CreateProfilingCNode(const ProfilingContent &p
|
|||
kernel::KernelBuildInfo::KernelBuildInfoBuilder selected_kernel_builder;
|
||||
selected_kernel_builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
selected_kernel_builder.SetInputsDeviceType({TypeId::kNumberTypeInt32, TypeId::kNumberTypeInt32});
|
||||
selected_kernel_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
selected_kernel_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
selected_kernel_builder.SetProcessor(kernel::Processor::AICORE);
|
||||
selected_kernel_builder.SetKernelType(KernelType::RT_KERNEL);
|
||||
abstract::AbstractBasePtr type_none_abstract = std::make_shared<abstract::AbstractNone>();
|
||||
|
|
|
@ -81,7 +81,7 @@ void AicpuMetadataInfoForSpecialNodes(const CNodePtr &kernel_node,
|
|||
builder.SetOutputsDeviceType(outputs_type);
|
||||
builder.SetProcessor(AICPU);
|
||||
builder.SetKernelType(AICPU_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
builder.SetFusionType(kPatternOpaque);
|
||||
(void)kernel_info_list->emplace_back(builder.Build());
|
||||
return;
|
||||
}
|
||||
|
|
|
@ -82,7 +82,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> GetNextDynamicDesc::GetKer
|
|||
builder.SetOutputsDeviceType(output_type);
|
||||
builder.SetProcessor(AICORE);
|
||||
builder.SetKernelType(RT_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
builder.SetFusionType(kPatternOpaque);
|
||||
get_next_dynamic_build_info.emplace_back(builder.Build());
|
||||
return get_next_dynamic_build_info;
|
||||
}
|
||||
|
|
|
@ -84,7 +84,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> LabelSwitchDesc::GetKernel
|
|||
builder.SetInputsDeviceType({input_type[i]});
|
||||
builder.SetProcessor(AICORE);
|
||||
builder.SetKernelType(RT_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
builder.SetFusionType(kPatternOpaque);
|
||||
// LabelSwitch always return UMonad.
|
||||
builder.SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
builder.SetOutputsDeviceType({TypeId::kObjectTypeUMonad});
|
||||
|
|
|
@ -159,7 +159,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> MemCpyAsyncDesc::GetKernel
|
|||
builder.SetOutputsDeviceType(output_type);
|
||||
builder.SetProcessor(AICORE);
|
||||
builder.SetKernelType(RT_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
builder.SetFusionType(kPatternOpaque);
|
||||
memcpy_build_info.emplace_back(builder.Build());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -81,7 +81,7 @@ void GetRtKelInfo(const CNodePtr &kernel_node,
|
|||
kernel_build_info_builder->SetOutputsFormat({kOpFormat_DEFAULT});
|
||||
kernel_build_info_builder->SetOutputsDeviceType({TypeId::kObjectTypeUMonad});
|
||||
// set other info
|
||||
kernel_build_info_builder->SetFusionType(kernel::FusionType::OPAQUE);
|
||||
kernel_build_info_builder->SetFusionType(kernel::kPatternOpaque);
|
||||
kernel_build_info_builder->SetProcessor(kernel::Processor::AICORE);
|
||||
kernel_build_info_builder->SetKernelType(KernelType::RT_KERNEL);
|
||||
kernel_info_list->push_back(kernel_build_info_builder->Build());
|
||||
|
|
|
@ -185,7 +185,7 @@ std::vector<std::shared_ptr<kernel::KernelBuildInfo>> TensorCopySlicesDesc::GetK
|
|||
builder.SetOutputsDeviceType(output_type);
|
||||
builder.SetProcessor(AICORE);
|
||||
builder.SetKernelType(RT_KERNEL);
|
||||
builder.SetFusionType(OPAQUE);
|
||||
builder.SetFusionType(kPatternOpaque);
|
||||
tensor_copy_slices_build_info.emplace_back(builder.Build());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -446,7 +446,7 @@ void TbeJsonCreator::GenComputeCommonJson(const AnfNodePtr &anf_node, nlohmann::
|
|||
(*compute_json)[kJIsDynamicImpl] = is_dynamic_impl;
|
||||
(*compute_json)[kJInt64Mode] = false;
|
||||
(*compute_json)[kJName] = cnode->fullname_with_scope();
|
||||
(*compute_json)[kJPattern] = kernel::GetFusionNameByType(AnfAlgo::GetFusionType(cnode));
|
||||
(*compute_json)[kJPattern] = AnfAlgo::GetFusionType(cnode);
|
||||
(*compute_json)[kJModuleName] = kJModuleNamePrefix + func_name;
|
||||
}
|
||||
|
||||
|
|
|
@ -373,13 +373,12 @@ void TbeKernelCompileManager::SavePreBuildResult(const std::string &json_name, c
|
|||
return;
|
||||
}
|
||||
auto op_pattern = GetJsonValue<std::string>(result, "op_pattern");
|
||||
auto fusion_type = kernel::GetFusionTypeByName(op_pattern);
|
||||
auto output_data_desc = GetJsonValue<nlohmann::json>(result, "op_params");
|
||||
auto core_type = GetJsonValue<nlohmann::json>(result, "core_type");
|
||||
// save pre build result
|
||||
struct PreBuildResult pre_res;
|
||||
pre_res.json_name = json_name;
|
||||
pre_res.fusion_type = fusion_type;
|
||||
pre_res.fusion_type = op_pattern;
|
||||
pre_res.core_type = core_type;
|
||||
pre_res.output_data_desc = output_data_desc;
|
||||
prebuild_res_map_[json_name] = pre_res;
|
||||
|
@ -594,12 +593,11 @@ void TbeKernelCompileManager::UpdateFusionTypeAndOutputDataDesc(const std::vecto
|
|||
}
|
||||
auto pre_res = prebuild_res_map_[kernel_name];
|
||||
auto fusion_type = pre_res.fusion_type;
|
||||
auto fusion_name = GetFusionNameByType(fusion_type);
|
||||
auto output_data_desc = pre_res.output_data_desc;
|
||||
auto core_type = pre_res.core_type;
|
||||
AnfAlgo::SetCoreType(node, core_type);
|
||||
AnfAlgo::SetFusionType(node, fusion_type);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrTbeFusionType, MakeValue(fusion_name), node);
|
||||
common::AnfAlgo::SetNodeAttr(kAttrTbeFusionType, MakeValue(fusion_type), node);
|
||||
AnfAlgo::SetOutputDataDesc(node, {output_data_desc});
|
||||
}
|
||||
MS_LOG(INFO) << "End update fusion type after pre build";
|
||||
|
|
|
@ -51,7 +51,7 @@ struct TaskInfo {
|
|||
struct PreBuildResult {
|
||||
std::string core_type;
|
||||
std::string json_name;
|
||||
kernel::FusionType fusion_type;
|
||||
std::string fusion_type;
|
||||
nlohmann::json output_data_desc;
|
||||
};
|
||||
|
||||
|
|
|
@ -480,7 +480,7 @@ CNodePtr AddCastOpNodeToGraph(const FuncGraphPtr &func_graph, const AnfNodePtr &
|
|||
builder.SetOutputsReshapeType({reshape_type});
|
||||
builder.SetInputsDeviceType({input_type});
|
||||
builder.SetOutputsDeviceType({output_type});
|
||||
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder.SetFusionType(kernel::kPatternOpaque);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
if (kernel::OpLib::FindOp(prim::kPrimCast->name(), kernel::kImplyTBE) != nullptr) {
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
|
|
@ -28,7 +28,7 @@ namespace opt {
|
|||
namespace {
|
||||
constexpr auto kAttrNoFusion = "no_fusion";
|
||||
|
||||
CNodePtr FindInputNode(const CNodePtr &cnode, const string &node_type, const kernel::FusionType &fusion_type) {
|
||||
CNodePtr FindInputNode(const CNodePtr &cnode, const string &node_type, const std::string &fusion_type) {
|
||||
auto input_num = common::AnfAlgo::GetInputTensorNum(cnode);
|
||||
for (size_t i = 1; i <= input_num; ++i) {
|
||||
auto input = cnode->input(i);
|
||||
|
@ -71,7 +71,7 @@ bool BatchMatmulEltwiseFusionPass::MatchPattern2(const CNodePtr &eltwise,
|
|||
return false;
|
||||
}
|
||||
|
||||
CNodePtr bmm = FindInputNode(eltwise, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL);
|
||||
CNodePtr bmm = FindInputNode(eltwise, kBatchMatMulOpName, kernel::kPatternBatchMatmul);
|
||||
if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm) || common::AnfAlgo::GetBooleanAttr(bmm, kAttrNoFusion)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -88,17 +88,17 @@ bool BatchMatmulEltwiseFusionPass::MatchPattern3(const CNodePtr &eltwise,
|
|||
return false;
|
||||
}
|
||||
|
||||
CNodePtr eltwise2 = FindInputNode(eltwise, kSigmoidOpName, kernel::FusionType::ELEMWISE);
|
||||
CNodePtr eltwise2 = FindInputNode(eltwise, kSigmoidOpName, kernel::kPatternElemWise);
|
||||
if (eltwise2 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr eltwise1 = FindInputNode(eltwise2, kMulOpName, kernel::FusionType::ELEMWISE);
|
||||
CNodePtr eltwise1 = FindInputNode(eltwise2, kMulOpName, kernel::kPatternElemWise);
|
||||
if (eltwise1 == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
CNodePtr bmm = FindInputNode(eltwise1, kBatchMatMulOpName, kernel::FusionType::BATCH_MATMUL);
|
||||
CNodePtr bmm = FindInputNode(eltwise1, kBatchMatMulOpName, kernel::kPatternMatmul);
|
||||
if (bmm == nullptr || common::AnfAlgo::IsDynamicShape(bmm)) {
|
||||
return false;
|
||||
}
|
||||
|
@ -124,8 +124,8 @@ void BatchMatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::Kerne
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
(AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE ||
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::BROAD_CAST)) {
|
||||
(AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise ||
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternBroadcast)) {
|
||||
mindspore::HashSet<AnfNodePtr> record;
|
||||
if (MatchPattern1(cnode, &record) || MatchPattern2(cnode, &record) || MatchPattern3(cnode, &record)) {
|
||||
candidate_fusion->push_back(record);
|
||||
|
|
|
@ -34,7 +34,7 @@ void BatchMatmulReduceSumFusionPass::MatchBatchMatmulReduceSum(const CNodePtr &r
|
|||
auto batch_matmul = reduce_sum->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(batch_matmul);
|
||||
const PrimitiveSet batch_matmul_prims{prim::kPrimBatchMatMul, prim::kPrimBatchMatMulV2};
|
||||
if (!batch_matmul->isa<CNode>() || AnfAlgo::GetFusionType(batch_matmul) != kernel::FusionType::BATCH_MATMUL ||
|
||||
if (!batch_matmul->isa<CNode>() || AnfAlgo::GetFusionType(batch_matmul) != kernel::kPatternBatchMatmul ||
|
||||
!IsOneOfPrimitiveCNode(batch_matmul, batch_matmul_prims)) {
|
||||
return;
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ void BatchMatmulReduceSumFusionPass::MatchSingleFusionPattern(const session::Ker
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::COMMREDUCE &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternCommReduce &&
|
||||
common::AnfAlgo::GetCNodeName(cnode) == kReduceSumDOpName) {
|
||||
MatchBatchMatmulReduceSum(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
|
|
|
@ -65,7 +65,7 @@ void BnupdateEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const session::K
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise &&
|
||||
common::AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE &&
|
||||
common::AnfAlgo::GetInputTensorNum(cnode) == (ELTWISE_INPUT_SIZE - 1)) {
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
|
|
|
@ -54,7 +54,7 @@ void BnupdateEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGr
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise &&
|
||||
common::AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_DOUBLE_OUTPUT_SIZE) {
|
||||
auto eltwise_input = cnode->input(1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
|
|
|
@ -94,7 +94,7 @@ void Conv2DBackpropEltwiseEltwiseFusionPass::MatchSingleFusionPattern(const sess
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise &&
|
||||
common::AnfAlgo::GetCNodeName(cnode) == kReluGradV2OpName) {
|
||||
MatchConv2DBackpropInputEltwiseEltwise(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ void ConvDoubleInFusionPass::MatchConvDoubleInEltwise(const CNodePtr &cnode, con
|
|||
return;
|
||||
}
|
||||
if (AnfAlgo::GetKernelType(double_in_eltwise_input) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::FusionType::CONV) {
|
||||
AnfAlgo::GetFusionType(double_in_eltwise_input) == kernel::kPatternConvolution) {
|
||||
(void)record.insert(double_in_eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
@ -65,7 +65,7 @@ void ConvDoubleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE &&
|
||||
!common::AnfAlgo::CheckPrimitiveType(node, prim::kPrimReluV2)) {
|
||||
MatchConvDoubleInEltwise(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ void ConvSingleInFusionPass::MatchConvSingleInEltwise(const CNodePtr &cnode, con
|
|||
return;
|
||||
}
|
||||
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::CONV) {
|
||||
AnfAlgo::GetFusionType(eltwise_input) == kernel::kPatternConvolution) {
|
||||
(void)record.insert(eltwise_input);
|
||||
candidate_fusion->push_back(record);
|
||||
SetRecordFusionId(record);
|
||||
|
@ -67,7 +67,7 @@ void ConvSingleInFusionPass::MatchSingleFusionPattern(const session::KernelGraph
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
|
||||
MatchConvSingleInEltwise(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -63,7 +63,7 @@ void DepthwiseConvEltwiseFusionPass::MatchSingleFusionPattern(const session::Ker
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE) {
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise) {
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
if (eltwise_input->isa<CNode>() &&
|
||||
|
|
|
@ -64,8 +64,8 @@ void EltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGraph &ker
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
(AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE ||
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::BROAD_CAST) &&
|
||||
(AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise ||
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternBroadcast) &&
|
||||
cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
|
||||
MatchEltwise(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
|
@ -83,8 +83,8 @@ bool EltwiseFusionPass::CheckEltWiseOrBroadCastNode(const session::KernelGraph &
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
(AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE ||
|
||||
AnfAlgo::GetFusionType(node) == kernel::FusionType::BROAD_CAST) &&
|
||||
(AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise ||
|
||||
AnfAlgo::GetFusionType(node) == kernel::kPatternBroadcast) &&
|
||||
not_updatestate_nums == ELTWISE_USE && cnode->inputs().size() == ELTWISE_INPUT_SIZE;
|
||||
}
|
||||
|
||||
|
@ -100,8 +100,8 @@ bool EltwiseFusionPass::CheckDoubleInEltWiseOrBroadCastNode(const session::Kerne
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
(AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE ||
|
||||
AnfAlgo::GetFusionType(node) == kernel::FusionType::BROAD_CAST) &&
|
||||
(AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise ||
|
||||
AnfAlgo::GetFusionType(node) == kernel::kPatternBroadcast) &&
|
||||
not_updatestate_nums == ELTWISE_USE && cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE;
|
||||
}
|
||||
} // namespace opt
|
||||
|
|
|
@ -33,7 +33,7 @@ bool FusionBasePass::CheckEltWiseNode(const session::KernelGraph &kernel_graph,
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && not_updatestate_nums == ELTWISE_USE &&
|
||||
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_USE &&
|
||||
cnode->inputs().size() == ELTWISE_INPUT_SIZE;
|
||||
}
|
||||
|
||||
|
@ -48,7 +48,7 @@ bool FusionBasePass::CheckDoubleInEltWiseNode(const session::KernelGraph &kernel
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && not_updatestate_nums == ELTWISE_USE &&
|
||||
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_USE &&
|
||||
cnode->inputs().size() == ELTWISE_DOUBLE_IN_INPUT_SIZE;
|
||||
}
|
||||
|
||||
|
@ -63,7 +63,7 @@ bool FusionBasePass::CheckMultiOutputEltWiseNode(const session::KernelGraph &ker
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
size_t not_updatestate_nums = GetNotUpdateStateUserNums(kernel_graph, node);
|
||||
return AnfAlgo::GetKernelType(node) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(node) == kernel::FusionType::ELEMWISE && not_updatestate_nums == ELTWISE_MULTI_USE &&
|
||||
AnfAlgo::GetFusionType(node) == kernel::kPatternElemWise && not_updatestate_nums == ELTWISE_MULTI_USE &&
|
||||
cnode->inputs().size() == ELTWISE_INPUT_SIZE;
|
||||
}
|
||||
|
||||
|
|
|
@ -48,8 +48,8 @@ void MatmulEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
(AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE ||
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::BROAD_CAST) &&
|
||||
(AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise ||
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternBroadcast) &&
|
||||
common::AnfAlgo::GetOutputTensorNum(cnode) == ELTWISE_SINGLE_OUTPUT_SIZE) {
|
||||
auto eltwise_input = cnode->input(kIndex1);
|
||||
MS_EXCEPTION_IF_NULL(eltwise_input);
|
||||
|
|
|
@ -68,7 +68,7 @@ void MultiOutputFusionPass::MatchSingleFusionPattern(const session::KernelGraph
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
|
||||
MatchMultiOutputEltwise(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ void ReduceEltwiseFusionPass::MatchReduceEltwise(const CNodePtr &cnode, const se
|
|||
return;
|
||||
}
|
||||
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::COMMREDUCE &&
|
||||
AnfAlgo::GetFusionType(eltwise_input) == kernel::kPatternCommReduce &&
|
||||
GetNodeOutputTotalUsedNum(kernel_graph, eltwise_input) == 1) {
|
||||
(void)record.insert(eltwise_input);
|
||||
auto previous_input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
|
@ -82,7 +82,7 @@ void ReduceEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGrap
|
|||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
// Fusion squaresumv1 and sqrt will get worse performance in bert
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE &&
|
||||
common::AnfAlgo::GetCNodeName(cnode) != kCastOpName && common::AnfAlgo::GetCNodeName(cnode) != kSqrtOpName) {
|
||||
MatchReduceEltwise(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ void SegmentEltwiseFusionPass::MatchSegmentEltwise(const CNodePtr &cnode, const
|
|||
return;
|
||||
}
|
||||
if (AnfAlgo::GetKernelType(eltwise_input) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(eltwise_input) == kernel::FusionType::SEGMENT) {
|
||||
AnfAlgo::GetFusionType(eltwise_input) == kernel::kPatternSegment) {
|
||||
(void)record.insert(eltwise_input);
|
||||
auto previous_input_cnode = eltwise_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(previous_input_cnode);
|
||||
|
@ -79,7 +79,7 @@ void SegmentEltwiseFusionPass::MatchSingleFusionPattern(const session::KernelGra
|
|||
auto cnode = node->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(cnode);
|
||||
if (AnfAlgo::GetKernelType(cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::FusionType::ELEMWISE && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
|
||||
AnfAlgo::GetFusionType(cnode) == kernel::kPatternElemWise && cnode->inputs().size() == ELTWISE_INPUT_SIZE) {
|
||||
MatchSegmentEltwise(cnode, kernel_graph, candidate_fusion);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -44,7 +44,7 @@ void StridedReadConvStridedWriteFusionPass::MatchStridedReadConvStridedWrite(con
|
|||
auto conv_cnode = write_input->cast<CNodePtr>();
|
||||
MS_EXCEPTION_IF_NULL(conv_cnode);
|
||||
if (AnfAlgo::GetKernelType(conv_cnode) == KernelType::TBE_KERNEL &&
|
||||
AnfAlgo::GetFusionType(conv_cnode) == kernel::FusionType::CONV &&
|
||||
AnfAlgo::GetFusionType(conv_cnode) == kernel::kPatternConvolution &&
|
||||
conv_cnode->inputs().size() >= CONV_DOUBLE_IN_INPUT_SIZE &&
|
||||
conv_cnode->inputs().size() <= CONV_QUART_IN_INPUT_SIZE) {
|
||||
(void)record.insert(write_input);
|
||||
|
|
|
@ -131,7 +131,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, con
|
|||
(void)outputs_device_format.emplace_back(*cmp_format);
|
||||
(void)outputs_device_type.emplace_back(*cmp_dtype);
|
||||
|
||||
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder.SetFusionType(kernel::kPatternOpaque);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(TBE_KERNEL);
|
||||
builder.SetInputsFormat(inputs_device_format);
|
||||
|
|
|
@ -124,7 +124,7 @@ ValueNodePtr CreateAssistNode(const std::vector<int64_t> &input_shape, int32_t k
|
|||
kernel::KernelBuildInfoPtr CreateKernelBuildInfo() {
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetKernelType(TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::OPAQUE);
|
||||
builder.SetFusionType(kernel::kPatternOpaque);
|
||||
builder.SetProcessor(kernel::AICORE);
|
||||
builder.SetInputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
builder.SetOutputsFormat({kOpFormat_DEFAULT, kOpFormat_DEFAULT});
|
||||
|
|
|
@ -364,11 +364,7 @@ REG_ASCEND_VM_OP_ADAPTATION_INFO(kResizeNearestNeighborGradOpName)
|
|||
.set_need_tbe_check_supported(true)
|
||||
.set_input_attr_info(1, "listInt");
|
||||
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kResizeNearestNeighborOpName)
|
||||
.set_backend_op_name(kResizeNearestNeighborV2OpName)
|
||||
.set_target_op_name(kResizeNearestNeighborV2DOpName)
|
||||
.set_need_tbe_check_supported(true)
|
||||
.set_input_attr_info(1, "listInt");
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kResizeNearestNeighborOpName).set_backend_op_name(kResizeNearestNeighborV2DOpName);
|
||||
|
||||
REG_ASCEND_VM_OP_ADAPTATION_INFO(kReverseV2OpName)
|
||||
.set_target_op_name(kReverseV2DOpName)
|
||||
|
|
|
@ -454,7 +454,7 @@ void SetGraphKernelInfo(const CNodePtr &kernel_node, const FuncGraphPtr &func_gr
|
|||
graph_info_builder.SetOutputsDeviceType(graph_output_type);
|
||||
graph_info_builder.SetProcessor(kernel::Processor::CUDA);
|
||||
graph_info_builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
graph_info_builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
graph_info_builder.SetFusionType(kernel::kPatternOpaque);
|
||||
auto graph_selected_info = graph_info_builder.Build();
|
||||
MS_EXCEPTION_IF_NULL(graph_selected_info);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(graph_selected_info, kernel_node.get());
|
||||
|
|
|
@ -85,7 +85,7 @@ kernel::KernelBuildInfoPtr GenerateKernelBuildInfo(const AnfNodePtr &concat, con
|
|||
outputs_device_format.emplace_back(*cmp_format);
|
||||
outputs_device_type.emplace_back(*cmp_dtype);
|
||||
|
||||
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder.SetFusionType(kernel::kPatternOpaque);
|
||||
builder.SetInputsFormat(inputs_device_format);
|
||||
builder.SetOutputsFormat(outputs_device_format);
|
||||
builder.SetInputsDeviceType(inputs_device_type);
|
||||
|
|
|
@ -42,7 +42,7 @@ class TestHcclAdapter : public UT::Common {
|
|||
void SetOutputs(const CNodePtr &cnode, const std::vector<ShapeVector> &shape, const std::vector<TypeId> &data_type) {
|
||||
common::AnfAlgo::SetOutputInferTypeAndShape(data_type, shape, cnode.get());
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder.SetFusionType(kernel::kPatternOpaque);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(TBE_KERNEL);
|
||||
builder.SetInputsFormat(std::vector<std::string>(cnode->size() - 1, format_));
|
||||
|
@ -65,7 +65,7 @@ class TestHcclAdapter : public UT::Common {
|
|||
common::AnfAlgo::SetOutputInferTypeAndShape(std::vector<TypeId>{data_type[i]}, std::vector<ShapeVector>{shape[i]},
|
||||
node.get());
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder;
|
||||
builder.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder.SetFusionType(kernel::kPatternOpaque);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(TBE_KERNEL);
|
||||
builder.SetInputsFormat({format_});
|
||||
|
|
|
@ -73,7 +73,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) {
|
|||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -88,7 +88,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_1) {
|
|||
builder1.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder1.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder1.SetFusionType(kernel::kPatternOpaque);
|
||||
builder1.SetProcessor(kernel::Processor::AICORE);
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -144,7 +144,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) {
|
|||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -167,7 +167,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) {
|
|||
builder1.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder1.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder1.SetFusionType(kernel::kPatternOpaque);
|
||||
builder1.SetProcessor(kernel::Processor::AICORE);
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -180,7 +180,7 @@ TEST_F(TestHWBufferFusion, test_tbe_eltwise_fusion_2) {
|
|||
builder2.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
|
||||
builder2.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder2.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder2.SetFusionType(kernel::FusionType::COMMREDUCE);
|
||||
builder2.SetFusionType(kernel::kPatternCommReduce);
|
||||
builder2.SetProcessor(kernel::Processor::AICORE);
|
||||
builder2.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -234,7 +234,7 @@ TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) {
|
|||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -257,7 +257,7 @@ TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) {
|
|||
builder1.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder1.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder1.SetFusionType(kernel::kPatternOpaque);
|
||||
builder1.SetProcessor(kernel::Processor::AICORE);
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -270,7 +270,7 @@ TEST_F(TestHWBufferFusion, test_tbe_reduce_eltwise_fusion) {
|
|||
builder2.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder2.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder2.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder2.SetFusionType(kernel::FusionType::COMMREDUCE);
|
||||
builder2.SetFusionType(kernel::kPatternCommReduce);
|
||||
builder2.SetProcessor(kernel::Processor::AICORE);
|
||||
builder2.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -316,7 +316,7 @@ TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) {
|
|||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
relu->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
|
@ -328,7 +328,7 @@ TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) {
|
|||
builder2.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
|
||||
builder2.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder2.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder2.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder2.SetFusionType(kernel::kPatternOpaque);
|
||||
builder2.SetProcessor(kernel::Processor::AICORE);
|
||||
builder2.SetKernelType(KernelType::TBE_KERNEL);
|
||||
matmul->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
|
@ -340,7 +340,7 @@ TEST_F(TestHWBufferFusion, test_tbe_matmul_eltwise_fusion) {
|
|||
builder1.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder1.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder1.SetFusionType(kernel::kPatternOpaque);
|
||||
builder1.SetProcessor(kernel::Processor::AICORE);
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
cast->set_kernel_info(std::make_shared<device::KernelInfo>());
|
||||
|
|
|
@ -58,7 +58,7 @@ TEST_F(TestHWInsertCast, test_insert_cast_op_for_single_output) {
|
|||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat16->type_id(), kFloat16->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
kernel::KernelBuildInfo::KernelBuildInfoBuilder builder1;
|
||||
|
@ -66,7 +66,7 @@ TEST_F(TestHWInsertCast, test_insert_cast_op_for_single_output) {
|
|||
builder1.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetOutputsFormat({"NC1HWC0"});
|
||||
builder1.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder1.SetFusionType(kernel::kPatternElemWise);
|
||||
builder1.SetProcessor(kernel::Processor::AICORE);
|
||||
builder1.SetKernelType(KernelType::AKG_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
|
@ -120,7 +120,7 @@ TEST_F(TestHWInsertCast, test_insert_cast_op_for_multiple_output) {
|
|||
builder1.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetOutputsFormat({"DefaultFormat"});
|
||||
builder1.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder1.SetFusionType(kernel::kPatternElemWise);
|
||||
builder1.SetProcessor(kernel::Processor::AICORE);
|
||||
builder1.SetKernelType(KernelType::AKG_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
|
|
|
@ -124,7 +124,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_fraz_nchw) {
|
|||
builder.SetOutputsFormat({kOpFormat_C1HWNCoC0});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetInputsReshapeType({""});
|
||||
builder.SetOutputsReshapeType({""});
|
||||
|
@ -172,7 +172,7 @@ TEST_F(TestHWTransdataSplit, test_transdata_split_nchw_fraz) {
|
|||
builder.SetOutputsFormat({"NCHW"});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetInputsReshapeType({""});
|
||||
builder.SetOutputsReshapeType({""});
|
||||
|
|
|
@ -70,7 +70,7 @@ TEST_F(TestHWTransposeTransdataFusion, test_transpose_transdata_fusion) {
|
|||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
auto kernel_info = std::make_shared<device::KernelInfo>();
|
||||
kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
|
|
|
@ -97,7 +97,7 @@ static KernelGraphPtr CreateKernelGraph() {
|
|||
builder.SetOutputsFormat({kOpFormat_NCHW});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(mindspore::kernel::CONV);
|
||||
builder.SetFusionType(mindspore::kernel::kPatternConvolution);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), kernelptr_first.get());
|
||||
|
||||
|
@ -127,7 +127,7 @@ static KernelGraphPtr CreateKernelGraph() {
|
|||
relu_builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
relu_builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
relu_builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
relu_builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
relu_builder.SetFusionType(kernel::kPatternElemWise);
|
||||
relu_builder.SetProcessor(kernel::Processor::AICORE);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder.Build(), kernelptr_floor.get());
|
||||
next_cnode_ptr = kernelptr_floor;
|
||||
|
|
|
@ -55,7 +55,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_all) {
|
|||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
|
@ -96,7 +96,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_group) {
|
|||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
|
@ -137,7 +137,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_op) {
|
|||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
|
@ -194,7 +194,7 @@ TEST_F(TestHWAllReduceFusion, test_fusion_sorted) {
|
|||
builder.SetOutputsFormat({"NC1HWC0"});
|
||||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::AKG_KERNEL);
|
||||
auto node_list = TopoSort(func_graph->get_return());
|
||||
|
|
|
@ -676,9 +676,9 @@ TEST_F(AnfRuntimeAlgorithmTest, GetFusionType) {
|
|||
auto d_kernel_info = dynamic_cast<KernelInfo *>(add->kernel_info());
|
||||
MS_EXCEPTION_IF_NULL(d_kernel_info);
|
||||
KernelBuildInfoBuilder builder;
|
||||
builder.SetFusionType(kernel::CONV);
|
||||
builder.SetFusionType(kernel::kPatternConvolution);
|
||||
d_kernel_info->set_select_kernel_build_info(builder.Build());
|
||||
EXPECT_EQ(AnfAlgo::GetFusionType(add), kernel::CONV);
|
||||
EXPECT_EQ(AnfAlgo::GetFusionType(add), kernel::kPatternConvolution);
|
||||
EXPECT_THROW(AnfAlgo::GetFusionType(nullptr), std::runtime_error);
|
||||
}
|
||||
|
||||
|
@ -688,10 +688,10 @@ TEST_F(AnfRuntimeAlgorithmTest, SetSelectKernelBuildInfo) {
|
|||
inputs.push_back(NewValueNode(prim::kPrimAdd));
|
||||
auto add = kernel_graph->NewCNode(inputs);
|
||||
std::shared_ptr<KernelBuildInfoBuilder> builder = std::make_shared<KernelBuildInfoBuilder>();
|
||||
builder->SetFusionType(kernel::CONV);
|
||||
builder->SetFusionType(kernel::kPatternConvolution);
|
||||
AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), add.get());
|
||||
EXPECT_THROW(AnfAlgo::SetSelectKernelBuildInfo(builder->Build(), nullptr), std::runtime_error);
|
||||
EXPECT_EQ(AnfAlgo::GetFusionType(add), kernel::CONV);
|
||||
EXPECT_EQ(AnfAlgo::GetFusionType(add), kernel::kPatternConvolution);
|
||||
}
|
||||
|
||||
TEST_F(AnfRuntimeAlgorithmTest, GetKernelMod) {
|
||||
|
|
|
@ -65,7 +65,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_common) {
|
|||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -107,7 +107,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_conv2d_backprop_filter) {
|
|||
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -166,7 +166,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_dynamic_rnn) {
|
|||
builder.SetOutputsDeviceType({kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(),
|
||||
kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id(), kFloat16->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -219,7 +219,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_single_layer_norm) {
|
|||
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat32->type_id(), kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -263,7 +263,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
|
|||
builder.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -278,7 +278,7 @@ TEST_F(TestHWTBEJsonCreator, test_tbe_fusion_common) {
|
|||
builder1.SetInputsDeviceType({kFloat32->type_id()});
|
||||
builder1.SetOutputsDeviceType({kFloat16->type_id()});
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder1.SetFusionType(kernel::FusionType::OPAQUE);
|
||||
builder1.SetFusionType(kernel::kPatternOpaque);
|
||||
builder1.SetProcessor(kernel::Processor::AICORE);
|
||||
builder1.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
@ -337,7 +337,7 @@ TEST_F(TestHWTBEJsonCreator, test_fusion_add_conv2d) {
|
|||
builder.SetInputsDeviceType({kFloat32->type_id(), kFloat32->type_id()});
|
||||
builder.SetOutputsDeviceType({kFloat32->type_id()});
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
builder.SetFusionType(kernel::FusionType::ELEMWISE);
|
||||
builder.SetFusionType(kernel::kPatternElemWise);
|
||||
builder.SetProcessor(kernel::Processor::AICORE);
|
||||
builder.SetKernelType(KernelType::TBE_KERNEL);
|
||||
|
||||
|
|
Loading…
Reference in New Issue