fix gpu matmul fp32 accuracy

This commit is contained in:
qujianwei 2020-08-17 17:08:45 +08:00
parent f41ca6b5c6
commit c21ffc0317
1 changed files with 5 additions and 1 deletions

View File

@ -42,7 +42,7 @@ class MatMulGpuKernel : public GpuKernel {
dtype_a_(CUDA_R_32F),
dtype_b_(CUDA_R_32F),
dtype_c_(CUDA_R_32F),
algo_(CUBLAS_GEMM_DEFAULT_TENSOR_OP) {}
algo_(CUBLAS_GEMM_DEFAULT) {}
~MatMulGpuKernel() = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
@ -85,6 +85,10 @@ class MatMulGpuKernel : public GpuKernel {
dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)));
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) {
MS_LOG(WARNING) << "input and output type is float16, allow to use Tensor Core operations if possible";
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
}
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
is_null_input_ = CHECK_NULL_INPUT(output_shape);
if (is_null_input_) {