From c21ffc0317b0be7393d4aa8aab44541253d395f2 Mon Sep 17 00:00:00 2001 From: qujianwei Date: Mon, 17 Aug 2020 17:08:45 +0800 Subject: [PATCH] fix gpu matmul fp32 accuracy --- .../backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h index 7888d442c92..d6d547113c5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/matmul_gpu_kernel.h @@ -42,7 +42,7 @@ class MatMulGpuKernel : public GpuKernel { dtype_a_(CUDA_R_32F), dtype_b_(CUDA_R_32F), dtype_c_(CUDA_R_32F), - algo_(CUBLAS_GEMM_DEFAULT_TENSOR_OP) {} + algo_(CUBLAS_GEMM_DEFAULT) {} ~MatMulGpuKernel() = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } @@ -85,6 +85,10 @@ class MatMulGpuKernel : public GpuKernel { dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))); dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))); dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))); + if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) { + MS_LOG(WARNING) << "input and output type is float16, allow to use Tensor Core operations if possible"; + algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP; + } auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); is_null_input_ = CHECK_NULL_INPUT(output_shape); if (is_null_input_) {