forked from mindspore-Ecosystem/mindspore
add float64 support to matmul ops
This commit is contained in:
parent
ced5575387
commit
4c18e0894e
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -49,8 +49,8 @@ static std::map<std::string, cudnnDataType_t> kCudnnDtypeMap = {
|
|||
{"kNumberTypeBool", CUDNN_DATA_INT8}, {"kNumberTypeInt8", CUDNN_DATA_INT8},
|
||||
{"kNumberTypeUInt8", CUDNN_DATA_UINT8}};
|
||||
// Used by mixprecision, cuda dtype select
|
||||
static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {{"kNumberTypeFloat32", CUDA_R_32F},
|
||||
{"kNumberTypeFloat16", CUDA_R_16F}};
|
||||
static std::map<std::string, cudaDataType_t> kCudaDtypeMap = {
|
||||
{"kNumberTypeFloat64", CUDA_R_64F}, {"kNumberTypeFloat32", CUDA_R_32F}, {"kNumberTypeFloat16", CUDA_R_16F}};
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
/**
|
||||
* Copyright 2019 Huawei Technologies Co., Ltd
|
||||
* Copyright 2019-2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
|
@ -18,6 +18,10 @@
|
|||
|
||||
namespace mindspore {
|
||||
namespace kernel {
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
MatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
MatMulGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
MatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
@ -26,6 +30,10 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
MatMulGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
BatchMatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat64).AddInputAttr(kNumberTypeFloat64).AddOutputAttr(kNumberTypeFloat64),
|
||||
MatMulGpuKernel, double)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
BatchMatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
|
||||
#ifndef MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H_
|
||||
#define MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H_
|
||||
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_runtime_api.h>
|
||||
|
@ -47,8 +47,10 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
auto input2_addr = GetDeviceAddress<T>(inputs, 1);
|
||||
auto output_addr = GetDeviceAddress<T>(outputs, 0);
|
||||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
T alpha = static_cast<T>(1.0f);
|
||||
T beta = static_cast<T>(0.0f);
|
||||
cudaDataType_t compute_type = (dtype_a_ == CUDA_R_64F) ? CUDA_R_64F : CUDA_R_32F;
|
||||
|
||||
const int lda = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_);
|
||||
const int ldb = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_);
|
||||
const int ldc = n_;
|
||||
|
@ -58,20 +60,44 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
auto stride_c = SizeToInt(m_ * n_);
|
||||
|
||||
try {
|
||||
// Use cublasGemmEx to get high performance when batch_ is 1
|
||||
if (batch_ == 1) {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(kernel_node_,
|
||||
cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_),
|
||||
SizeToInt(k_), &alpha, input2_addr, dtype_b_, ldb, input1_addr,
|
||||
dtype_a_, lda, &beta, output_addr, dtype_c_, ldc, CUDA_R_32F, algo_),
|
||||
"cublasSgemm Call Fail");
|
||||
if (dtype_a_ == CUDA_R_16F) {
|
||||
const float alphaf = 1.0f;
|
||||
const float betaf = 0.0f;
|
||||
// Use cublasGemmEx to get high performance when batch_ is 1
|
||||
if (batch_ == 1) {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), &alphaf,
|
||||
input2_addr, dtype_b_, ldb, input1_addr, dtype_a_, lda, &betaf, output_addr, dtype_c_, ldc,
|
||||
compute_type, algo_),
|
||||
"cublasGemmEx failed");
|
||||
} else {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_),
|
||||
SizeToInt(k_), &alphaf, input2_addr, dtype_b_, ldb, stride_b, input1_addr,
|
||||
dtype_a_, lda, stride_a, &betaf, output_addr, dtype_c_, ldc, stride_c, batch_,
|
||||
compute_type, algo_),
|
||||
"cublasGemmStridedBatchedEx failed");
|
||||
}
|
||||
} else {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
|
||||
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
|
||||
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
|
||||
"cublasGemmStridedBatchedEx Call Fail");
|
||||
// Use cublasGemmEx to get high performance when batch_ is 1
|
||||
if (batch_ == 1) {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cublasGemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), &alpha,
|
||||
input2_addr, dtype_b_, ldb, input1_addr, dtype_a_, lda, &beta, output_addr, dtype_c_, ldc,
|
||||
compute_type, algo_),
|
||||
"cublasGemmEx failed");
|
||||
} else {
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
kernel_node_,
|
||||
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_),
|
||||
SizeToInt(k_), &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr,
|
||||
dtype_a_, lda, stride_a, &beta, output_addr, dtype_c_, ldc, stride_c, batch_,
|
||||
compute_type, algo_),
|
||||
"cublasGemmStridedBatchedEx failed");
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
MS_LOG(EXCEPTION) << "Encountered an exception: " << e.what() << " when invoke cublas "
|
||||
|
@ -85,6 +111,10 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
dtype_a_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0)));
|
||||
dtype_b_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1)));
|
||||
dtype_c_ = GetCudaDataType(TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0)));
|
||||
auto node_name = AnfAlgo::GetCNodeName(kernel_node);
|
||||
if (dtype_a_ != dtype_b_ || dtype_a_ != dtype_c_) {
|
||||
MS_LOG(EXCEPTION) << "input and output types are not the same in " << node_name;
|
||||
}
|
||||
if (dtype_a_ == CUDA_R_16F && dtype_b_ == CUDA_R_16F && dtype_c_ == CUDA_R_16F) {
|
||||
MS_LOG(INFO) << "input and output type is float16, allow to use Tensor Core operations if possible";
|
||||
algo_ = CUBLAS_GEMM_DEFAULT_TENSOR_OP;
|
||||
|
@ -174,4 +204,4 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H
|
||||
#endif // MINDSPORE_CCSRC_BACKEND_KERNEL_COMPILER_GPU_MATH_MATMUL_GPU_KERNEL_H_
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
@ -54,6 +54,28 @@ def test_4d():
|
|||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_4d_float64():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float64)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float64)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet()
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[20, 23, 26, 29]],
|
||||
[[200, 212, 224, 236]],
|
||||
[[596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
|
|
@ -22,6 +22,15 @@ from mindspore import Tensor
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops.operations import _inner_ops as inner
|
||||
|
||||
class MatMulNet(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MatMulNet, self).__init__()
|
||||
self.matmul = P.MatMul()
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.matmul(x, y)
|
||||
|
||||
|
||||
class MatMul_d(nn.Cell):
|
||||
def __init__(self):
|
||||
super(MatMul_d, self).__init__()
|
||||
|
@ -33,6 +42,7 @@ class MatMul_d(nn.Cell):
|
|||
y = self.test_dynamic(y)
|
||||
return self.matmul(x, y)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
|
@ -52,3 +62,18 @@ def test_MatMul_dynamic():
|
|||
output2 = net(Tensor(x2), Tensor(y2))
|
||||
expect2 = np.matmul(x2, y2)
|
||||
np.testing.assert_array_almost_equal(output2.asnumpy(), expect2)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_matmul_float64():
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = MatMulNet()
|
||||
|
||||
x = np.arange(102).reshape(34, 3).astype(np.float64)
|
||||
y = np.arange(18).reshape(3, 6).astype(np.float64)
|
||||
output = net(Tensor(x), Tensor(y))
|
||||
expect = np.matmul(x, y)
|
||||
np.testing.assert_array_almost_equal(output.asnumpy(), expect)
|
||||
|
|
Loading…
Reference in New Issue