filter out matmul if it is fp16->fp32 with a fp32 bias

This commit is contained in:
xulei 2023-01-04 09:33:33 +08:00
parent 1dd7553c02
commit 11399a873f
2 changed files with 21 additions and 0 deletions

View File

@ -22,6 +22,7 @@
#include <utility>
#include <vector>
#include <iterator>
#include <algorithm>
#include "kernel/common_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_convert_utils.h"
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
@ -45,6 +46,7 @@ constexpr char kParamTypeDynamic[] = "dynamic";
constexpr char kParamTypeRequre[] = "required";
constexpr char kParamTypeOptional[] = "optional";
constexpr int64_t kDynamicInvalidNum = -1;
constexpr size_t kMatMulInputSize = 3;
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
auto tbe_selector = TbeKernelSelect(kernel_node, kernel_info_list);
@ -340,6 +342,9 @@ void TbeKernelSelect::FilterInvalidKernelInfo() {
if (!FilterInvalidShape(kernel_build_info, !dynamic_inputs.empty())) {
continue;
}
if (!FilterUnsupportedMatMul(kernel_build_info)) {
continue;
}
if (!TbeCheckSupported(kernel_build_info)) {
continue;
}
@ -351,6 +356,21 @@ void TbeKernelSelect::FilterInvalidKernelInfo() {
(*kernel_info_list_).swap(kernel_info_list);
}
bool TbeKernelSelect::FilterUnsupportedMatMul(const KernelBuildInfoPtr &kernel_build_info) {
// A MatMul op is unsupported if it has a bias and bias is fp32
// we need to filter it out or it will cause compile error.
if (common::AnfAlgo::GetCNodeName(cnode_ptr_) != prim::kPrimMatMul->name() ||
!common::AnfAlgo::IsDynamicShape(cnode_ptr_)) {
return true;
}
const auto &input_dtypes = kernel_build_info->GetAllInputDeviceTypes();
if (input_dtypes.size() < kMatMulInputSize) {
return true;
}
const auto bias_dtype = input_dtypes[kMatMulInputSize - 1];
return !(bias_dtype == TypeId::kNumberTypeFloat32 || bias_dtype == TypeId::kNumberTypeFloat);
}
bool TbeKernelSelect::FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input) {
MS_EXCEPTION_IF_NULL(kernel_build_info);
const auto &kernel_build_info_inputs_format = kernel_build_info->GetAllInputFormats();

View File

@ -47,6 +47,7 @@ class TbeKernelSelect {
void GetReducePatternKernelInfo(const OpInfo &op_info);
void FilterInvalidKernelInfo();
bool FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input);
bool FilterUnsupportedMatMul(const KernelBuildInfoPtr &kernel_build_info);
bool IsShapeMatchFormat(const ShapeVector &shape, const std::string &format);
bool IsShapeMatchFormatRNN(const ShapeVector &shape, const std::string &format);
bool TbeCheckSupported(const KernelBuildInfoPtr &kernel_build_info);