forked from mindspore-Ecosystem/mindspore
!372 Gpu support BatchMatMul kernel
Merge pull request !372 from chenweifeng/batchmatmul
This commit is contained in:
commit
378a7122a5
|
@ -26,5 +26,13 @@ MS_REG_GPU_KERNEL_ONE(
|
|||
MatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
MatMulGpuKernel, half)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
BatchMatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
|
||||
MatMulGpuKernel, float)
|
||||
MS_REG_GPU_KERNEL_ONE(
|
||||
BatchMatMul,
|
||||
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
|
||||
MatMulGpuKernel, half)
|
||||
} // namespace kernel
|
||||
} // namespace mindspore
|
||||
|
|
|
@ -38,7 +38,10 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
transpose_x1_(CUBLAS_OP_N),
|
||||
transpose_x2_(CUBLAS_OP_N),
|
||||
handle_(nullptr),
|
||||
cudaDataType_(CUDA_R_32F) {}
|
||||
dtype_a_(CUDA_R_32F),
|
||||
dtype_b_(CUDA_R_32F),
|
||||
dtype_c_(CUDA_R_32F),
|
||||
algo_(CUBLAS_GEMM_DEFAULT_TENSOR_OP) {}
|
||||
~MatMulGpuKernel() = default;
|
||||
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
|
||||
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
|
||||
|
@ -54,24 +57,25 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
|
||||
const float alpha = 1;
|
||||
const float beta = 0;
|
||||
const int lda = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_);
|
||||
const int ldb = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_);
|
||||
const int lda = (transpose_x1_ == CUBLAS_OP_T) ? SizeToInt(m_) : SizeToInt(k_);
|
||||
const int ldb = (transpose_x2_ == CUBLAS_OP_T) ? SizeToInt(k_) : SizeToInt(n_);
|
||||
const int ldc = n_;
|
||||
|
||||
for (size_t i = 0; i < batch_; i++) {
|
||||
auto input1_slice = input1_addr + i * m_ * k_;
|
||||
auto input2_slice = input2_addr + i * k_ * n_;
|
||||
auto output_slice = output_addr + i * m_ * n_;
|
||||
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(cublasSgemmEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_),
|
||||
SizeToInt(k_), &alpha, input2_slice, cudaDataType_, lda, input1_slice,
|
||||
cudaDataType_, ldb, &beta, output_slice, cudaDataType_, SizeToInt(n_)),
|
||||
"cublasSgemm Call Fail");
|
||||
}
|
||||
auto stride_a = SizeToInt(m_ * k_);
|
||||
auto stride_b = SizeToInt(k_ * n_);
|
||||
auto stride_c = SizeToInt(m_ * n_);
|
||||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
|
||||
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
|
||||
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, dtype_c_, algo_),
|
||||
"cublasSgemm Call Fail");
|
||||
return true;
|
||||
}
|
||||
bool Init(const CNodePtr &kernel_node) override {
|
||||
handle_ = device::gpu::GPUDeviceManager::GetInstance().GetCublasHandle();
|
||||
cudaDataType_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
|
||||
dtype_a_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 0))];
|
||||
dtype_b_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetInputDeviceDataType(kernel_node, 1))];
|
||||
dtype_c_ = kCudaDtypeMap[TypeIdLabel(AnfAlgo::GetOutputDeviceDataType(kernel_node, 0))];
|
||||
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
|
||||
auto dims = output_shape.size();
|
||||
if (dims < 2) {
|
||||
|
@ -119,9 +123,12 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
|
||||
cublasOperation_t transpose_x1_;
|
||||
cublasOperation_t transpose_x2_;
|
||||
|
||||
cublasHandle_t handle_;
|
||||
cudaDataType_t cudaDataType_;
|
||||
cudaDataType_t dtype_a_;
|
||||
cudaDataType_t dtype_b_;
|
||||
cudaDataType_t dtype_c_;
|
||||
cublasGemmAlgo_t algo_;
|
||||
|
||||
std::vector<size_t> input_size_list_;
|
||||
std::vector<size_t> output_size_list_;
|
||||
std::vector<size_t> workspace_size_list_;
|
||||
|
|
|
@ -0,0 +1,120 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
from mindspore import Tensor
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.api import ms_function
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
import mindspore.nn as nn
|
||||
import mindspore.context as context
|
||||
from mindspore.common import dtype as mstype
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
class BatchMatMulNet(nn.Cell):
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
super(BatchMatMulNet, self).__init__()
|
||||
self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.batch_matmul(x, y)
|
||||
|
||||
def test_4D():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float32)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet()
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[ 20, 23, 26, 29]],
|
||||
[[ 200, 212, 224, 236]],
|
||||
[[ 596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_4D_transpose_a():
|
||||
input_x = Tensor(np.arange(2*4*3*1).reshape(2,4,3,1), mstype.float32)
|
||||
input_y = Tensor(np.arange(2*4*3*4).reshape(2,4,3,4), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet(transpose_a=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[ 20, 23, 26, 29]],
|
||||
[[ 200, 212, 224, 236]],
|
||||
[[ 596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_4D_transpose_b():
|
||||
input_x = Tensor(np.arange(2*4*1*3).reshape(2,4,1,3), mstype.float32)
|
||||
input_y = Tensor(np.arange(2*4*4*3).reshape(2,4,4,3), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet(transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[ 5, 14, 23, 32]],
|
||||
[[ 158, 194, 230, 266]],
|
||||
[[ 527, 590, 653, 716]],
|
||||
[[1112, 1202, 1292, 1382]]],
|
||||
|
||||
[[[1913, 2030, 2147, 2264]],
|
||||
[[2930, 3074, 3218, 3362]],
|
||||
[[4163, 4334, 4505, 4676]],
|
||||
[[5612, 5810, 6008, 6206]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_4D_transpose_ab():
|
||||
input_x = Tensor(np.arange(2*4*3*1).reshape(2,4,3,1), mstype.float32)
|
||||
input_y = Tensor(np.arange(2*4*4*3).reshape(2,4,4,3), mstype.float32)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet(transpose_a=True, transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[ 5, 14, 23, 32]],
|
||||
[[ 158, 194, 230, 266]],
|
||||
[[ 527, 590, 653, 716]],
|
||||
[[1112, 1202, 1292, 1382]]],
|
||||
|
||||
[[[1913, 2030, 2147, 2264]],
|
||||
[[2930, 3074, 3218, 3362]],
|
||||
[[4163, 4334, 4505, 4676]],
|
||||
[[5612, 5810, 6008, 6206]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
Loading…
Reference in New Issue