matmul support fp16

This commit is contained in:
wilfChen 2020-05-11 16:09:32 +08:00
parent dd2062bf8d
commit b56572bb89
2 changed files with 30 additions and 4 deletions

View File

@ -67,7 +67,7 @@ class MatMulGpuKernel : public GpuKernel {
CHECK_CUBLAS_RET_WITH_EXCEPT(
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, dtype_c_, algo_),
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
"cublasSgemm Call Fail");
return true;
}

View File

@ -60,7 +60,7 @@ def test_4D():
def test_4D_transpose_a():
input_x = Tensor(np.arange(2*4*3*1).reshape(2,4,3,1), mstype.float32)
input_y = Tensor(np.arange(2*4*3*4).reshape(2,4,3,4), mstype.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = BatchMatMulNet(transpose_a=True)
output = net(input_x, input_y)
@ -82,7 +82,7 @@ def test_4D_transpose_a():
def test_4D_transpose_b():
input_x = Tensor(np.arange(2*4*1*3).reshape(2,4,1,3), mstype.float32)
input_y = Tensor(np.arange(2*4*4*3).reshape(2,4,4,3), mstype.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = BatchMatMulNet(transpose_b=True)
output = net(input_x, input_y)
@ -104,7 +104,7 @@ def test_4D_transpose_b():
def test_4D_transpose_ab():
input_x = Tensor(np.arange(2*4*3*1).reshape(2,4,3,1), mstype.float32)
input_y = Tensor(np.arange(2*4*4*3).reshape(2,4,4,3), mstype.float32)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = BatchMatMulNet(transpose_a=True, transpose_b=True)
output = net(input_x, input_y)
@ -118,3 +118,29 @@ def test_4D_transpose_ab():
[[4163, 4334, 4505, 4676]],
[[5612, 5810, 6008, 6206]]]]
assert (output.asnumpy() == expect).all()
class BatchMatMulNet(nn.Cell):
def __init__(self, transpose_a=False, transpose_b=False):
super(BatchMatMulNet, self).__init__()
self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
def construct(self, x, y):
return self.batch_matmul(x, y)
def test_4D_fp16():
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float16)
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float16)
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
net = BatchMatMulNet()
output = net(input_x, input_y)
expect = [[[[ 20, 23, 26, 29]],
[[ 200, 212, 224, 236]],
[[ 596, 617, 638, 659]],
[[1208, 1238, 1268, 1298]]],
[[[2036, 2075, 2114, 2153]],
[[3080, 3128, 3176, 3224]],
[[4340, 4397, 4454, 4511]],
[[5816, 5882, 5948, 6014]]]]
assert (output.asnumpy() == expect).all()