diff --git a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc b/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc index a96cede94e7..808d5998533 100644 --- a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc +++ b/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.cc @@ -26,5 +26,13 @@ MS_REG_GPU_KERNEL_ONE( MatMul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), MatMulGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( + BatchMatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), + MatMulGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( + BatchMatMul, + KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), + MatMulGpuKernel, half) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h b/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h index 36f4272c683..e2c0a965102 100644 --- a/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h +++ b/mindspore/ccsrc/kernel/gpu/math/matmul_gpu_kernel.h @@ -38,7 +38,10 @@ class MatMulGpuKernel : public GpuKernel { transpose_x1_(CUBLAS_OP_N), transpose_x2_(CUBLAS_OP_N), handle_(nullptr), - cudaDataType_(CUDA_R_32F) {} + dtype_a_(CUDA_R_32F), + dtype_b_(CUDA_R_32F), + dtype_c_(CUDA_R_32F), + algo_(CUBLAS_GEMM_DEFAULT_TENSOR_OP) {} ~MatMulGpuKernel() = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } const std::vector &GetOutputSizeList() const override { return output_size_list_; } @@ -54,24 +57,25 @@ class MatMulGpuKernel : public GpuKernel { const float alpha = 1; const float beta = 0; - const int lda = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_); - const int ldb = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_); + const int lda = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_); + const int ldb = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_); + const int ldc = n_; - for (size_t i = 0; i < batch_; i++) { - auto input1_slice = input1_addr + i * m_ * k_; - auto input2_slice = input2_addr + i * k_ * n_; - auto output_slice = output_addr + i * m_ * n_; - - CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSgemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), - SizeToInt(k_), &alpha, input2_slice, cudaDataType_, lda, input1_slice, - cudaDataType_, ldb, &beta, output_slice, cudaDataType_, SizeToInt(n_)), - "cublasSgemm Call Fail"); - } + auto stride_a = SizeToInt(m_ * k_); + auto stride_b = SizeToInt(k_ * n_); + auto stride_c = SizeToInt(m_ * n_); + CHECK_CUBLAS_RET_WITH_EXCEPT( + cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_), + &alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a, + &beta, output_addr, dtype_c_, ldc, stride_c, batch_, dtype_c_, algo_), + "cublasSgemm Call Fail"); return true; } bool Init(const CNodePtr &kernel_node) override { handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle(); - cudaDataType_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; + dtype_a_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))]; + dtype_b_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))]; + dtype_c_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))]; auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0); auto dims = output_shape.size(); if (dims < 2) { @@ -119,9 +123,12 @@ class MatMulGpuKernel : public GpuKernel { cublasOperation_t transpose_x1_; cublasOperation_t transpose_x2_; - cublasHandle_t handle_; - cudaDataType_t cudaDataType_; + cudaDataType_t dtype_a_; + cudaDataType_t dtype_b_; + cudaDataType_t dtype_c_; + cublasGemmAlgo_t algo_; + std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; diff --git a/tests/st/ops/gpu/test_batch_matmul.py b/tests/st/ops/gpu/test_batch_matmul.py new file mode 100644 index 00000000000..4e357095c57 --- /dev/null +++ b/tests/st/ops/gpu/test_batch_matmul.py @@ -0,0 +1,120 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.common.api import ms_function +from mindspore.common.initializer import initializer +from mindspore.common.parameter import Parameter +import mindspore.nn as nn +import mindspore.context as context +from mindspore.common import dtype as mstype + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +class BatchMatMulNet(nn.Cell): + def __init__(self, transpose_a=False, transpose_b=False): + super(BatchMatMulNet, self).__init__() + self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b) + + def construct(self, x, y): + return self.batch_matmul(x, y) + +def test_4D(): + input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32) + input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = BatchMatMulNet() + output = net(input_x, input_y) + expect = [[[[ 20, 23, 26, 29]], + [[ 200, 212, 224, 236]], + [[ 596, 617, 638, 659]], + [[1208, 1238, 1268, 1298]]], + + [[[2036, 2075, 2114, 2153]], + [[3080, 3128, 3176, 3224]], + [[4340, 4397, 4454, 4511]], + [[5816, 5882, 5948, 6014]]]] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_4D_transpose_a(): + input_x = Tensor(np.arange(2*4*3*1).reshape(2,4,3,1), mstype.float32) + input_y = Tensor(np.arange(2*4*3*4).reshape(2,4,3,4), mstype.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = BatchMatMulNet(transpose_a=True) + output = net(input_x, input_y) + expect = [[[[ 20, 23, 26, 29]], + [[ 200, 212, 224, 236]], + [[ 596, 617, 638, 659]], + [[1208, 1238, 1268, 1298]]], + + [[[2036, 2075, 2114, 2153]], + [[3080, 3128, 3176, 3224]], + [[4340, 4397, 4454, 4511]], + [[5816, 5882, 5948, 6014]]]] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_4D_transpose_b(): + input_x = Tensor(np.arange(2*4*1*3).reshape(2,4,1,3), mstype.float32) + input_y = Tensor(np.arange(2*4*4*3).reshape(2,4,4,3), mstype.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = BatchMatMulNet(transpose_b=True) + output = net(input_x, input_y) + expect = [[[[ 5, 14, 23, 32]], + [[ 158, 194, 230, 266]], + [[ 527, 590, 653, 716]], + [[1112, 1202, 1292, 1382]]], + + [[[1913, 2030, 2147, 2264]], + [[2930, 3074, 3218, 3362]], + [[4163, 4334, 4505, 4676]], + [[5612, 5810, 6008, 6206]]]] + assert (output.asnumpy() == expect).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_4D_transpose_ab(): + input_x = Tensor(np.arange(2*4*3*1).reshape(2,4,3,1), mstype.float32) + input_y = Tensor(np.arange(2*4*4*3).reshape(2,4,4,3), mstype.float32) + + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + net = BatchMatMulNet(transpose_a=True, transpose_b=True) + output = net(input_x, input_y) + expect = [[[[ 5, 14, 23, 32]], + [[ 158, 194, 230, 266]], + [[ 527, 590, 653, 716]], + [[1112, 1202, 1292, 1382]]], + + [[[1913, 2030, 2147, 2264]], + [[2930, 3074, 3218, 3362]], + [[4163, 4334, 4505, 4676]], + [[5612, 5810, 6008, 6206]]]] + assert (output.asnumpy() == expect).all()