forked from mindspore-Ecosystem/mindspore
filter out matmul if it is fp16->fp32 with a fp32 bias
This commit is contained in:
parent
1dd7553c02
commit
11399a873f
|
@ -22,6 +22,7 @@
|
|||
#include <utility>
|
||||
#include <vector>
|
||||
#include <iterator>
|
||||
#include <algorithm>
|
||||
#include "kernel/common_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_convert_utils.h"
|
||||
#include "plugin/device/ascend/kernel/tbe/tbe_dynamic_shape_util.h"
|
||||
|
@ -45,6 +46,7 @@ constexpr char kParamTypeDynamic[] = "dynamic";
|
|||
constexpr char kParamTypeRequre[] = "required";
|
||||
constexpr char kParamTypeOptional[] = "optional";
|
||||
constexpr int64_t kDynamicInvalidNum = -1;
|
||||
constexpr size_t kMatMulInputSize = 3;
|
||||
|
||||
void TbeMetadataInfo(const CNodePtr &kernel_node, std::vector<std::shared_ptr<KernelBuildInfo>> *kernel_info_list) {
|
||||
auto tbe_selector = TbeKernelSelect(kernel_node, kernel_info_list);
|
||||
|
@ -340,6 +342,9 @@ void TbeKernelSelect::FilterInvalidKernelInfo() {
|
|||
if (!FilterInvalidShape(kernel_build_info, !dynamic_inputs.empty())) {
|
||||
continue;
|
||||
}
|
||||
if (!FilterUnsupportedMatMul(kernel_build_info)) {
|
||||
continue;
|
||||
}
|
||||
if (!TbeCheckSupported(kernel_build_info)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -351,6 +356,21 @@ void TbeKernelSelect::FilterInvalidKernelInfo() {
|
|||
(*kernel_info_list_).swap(kernel_info_list);
|
||||
}
|
||||
|
||||
bool TbeKernelSelect::FilterUnsupportedMatMul(const KernelBuildInfoPtr &kernel_build_info) {
|
||||
// A MatMul op is unsupported if it has a bias and bias is fp32
|
||||
// we need to filter it out or it will cause compile error.
|
||||
if (common::AnfAlgo::GetCNodeName(cnode_ptr_) != prim::kPrimMatMul->name() ||
|
||||
!common::AnfAlgo::IsDynamicShape(cnode_ptr_)) {
|
||||
return true;
|
||||
}
|
||||
const auto &input_dtypes = kernel_build_info->GetAllInputDeviceTypes();
|
||||
if (input_dtypes.size() < kMatMulInputSize) {
|
||||
return true;
|
||||
}
|
||||
const auto bias_dtype = input_dtypes[kMatMulInputSize - 1];
|
||||
return !(bias_dtype == TypeId::kNumberTypeFloat32 || bias_dtype == TypeId::kNumberTypeFloat);
|
||||
}
|
||||
|
||||
bool TbeKernelSelect::FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_build_info);
|
||||
const auto &kernel_build_info_inputs_format = kernel_build_info->GetAllInputFormats();
|
||||
|
|
|
@ -47,6 +47,7 @@ class TbeKernelSelect {
|
|||
void GetReducePatternKernelInfo(const OpInfo &op_info);
|
||||
void FilterInvalidKernelInfo();
|
||||
bool FilterInvalidShape(const KernelBuildInfoPtr &kernel_build_info, bool is_dynamic_input);
|
||||
bool FilterUnsupportedMatMul(const KernelBuildInfoPtr &kernel_build_info);
|
||||
bool IsShapeMatchFormat(const ShapeVector &shape, const std::string &format);
|
||||
bool IsShapeMatchFormatRNN(const ShapeVector &shape, const std::string &format);
|
||||
bool TbeCheckSupported(const KernelBuildInfoPtr &kernel_build_info);
|
||||
|
|
Loading…
Reference in New Issue