diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc index 43784550ccc..2b98bcf8f1a 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include "kernel/common_utils.h" #include "plugin/device/ascend/kernel/tbe/tbe_convert_utils.h" #include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h" @@ -45,6 +46,7 @@ constexpr char kParamTypeDynamic[] = "dynamic"; constexpr char kParamTypeRequre[] = "required"; constexpr char kParamTypeOptional[] = "optional"; constexpr int64_t kDynamicInvalidNum = -1; +constexpr size_t kMatMulInputSize = 3; void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector> *kernel_info_list) { auto tbe_selector = TbeKernelSelect(kernel_node, kernel_info_list); @@ -340,6 +342,9 @@ void TbeKernelSelect::FilterInvalidKernelInfo() { if (!FilterInvalidShape(kernel_build_info, !dynamic_inputs.empty())) { continue; } + if (!FilterUnsupportedMatMul(kernel_build_info)) { + continue; + } if (!TbeCheckSupported(kernel_build_info)) { continue; } @@ -351,6 +356,21 @@ void TbeKernelSelect::FilterInvalidKernelInfo() { (*kernel_info_list_).swap(kernel_info_list); } +bool TbeKernelSelect::FilterUnsupportedMatMul(const KernelBuildInfoPtr &kernel_build_info) { + // A MatMul op is unsupported if it has a bias and bias is fp32 + // we need to filter it out or it will cause compile error. + if (common::AnfAlgo::GetCNodeName(cnode_ptr_) != prim::kPrimMatMul->name() || + !common::AnfAlgo::IsDynamicShape(cnode_ptr_)) { + return true; + } + const auto &input_dtypes = kernel_build_info->GetAllInputDeviceTypes(); + if (input_dtypes.size() < kMatMulInputSize) { + return true; + } + const auto bias_dtype = input_dtypes[kMatMulInputSize - 1]; + return !(bias_dtype == TypeId::kNumberTypeFloat32 || bias_dtype == TypeId::kNumberTypeFloat); +} + bool TbeKernelSelect::FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input) { MS_EXCEPTION_IF_NULL(kernel_build_info); const auto &kernel_build_info_inputs_format = kernel_build_info->GetAllInputFormats(); diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h index 07277bb5bc1..9b9ea292904 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/tbe/tbe_kernel_select/tbe_kernel_select.h @@ -47,6 +47,7 @@ class TbeKernelSelect { void GetReducePatternKernelInfo(const OpInfo &op_info); void FilterInvalidKernelInfo(); bool FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input); + bool FilterUnsupportedMatMul(const KernelBuildInfoPtr &kernel_build_info); bool IsShapeMatchFormat(const ShapeVector &shape, const std::string &format); bool IsShapeMatchFormatRNN(const ShapeVector &shape, const std::string &format); bool TbeCheckSupported(const KernelBuildInfoPtr &kernel_build_info);