forked from mindspore-Ecosystem/mindspore
matmul support fp16
This commit is contained in:
parent
dd2062bf8d
commit
b56572bb89
|
@ -67,7 +67,7 @@ class MatMulGpuKernel : public GpuKernel {
|
|||
CHECK_CUBLAS_RET_WITH_EXCEPT(
|
||||
cublasGemmStridedBatchedEx(handle_, transpose_x2_, transpose_x1_, SizeToInt(n_), SizeToInt(m_), SizeToInt(k_),
|
||||
&alpha, input2_addr, dtype_b_, ldb, stride_b, input1_addr, dtype_a_, lda, stride_a,
|
||||
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, dtype_c_, algo_),
|
||||
&beta, output_addr, dtype_c_, ldc, stride_c, batch_, CUDA_R_32F, algo_),
|
||||
"cublasSgemm Call Fail");
|
||||
return true;
|
||||
}
|
||||
|
|
|
@ -60,7 +60,7 @@ def test_4D():
|
|||
def test_4D_transpose_a():
|
||||
input_x = Tensor(np.arange(2*4*3*1).reshape(2,4,3,1), mstype.float32)
|
||||
input_y = Tensor(np.arange(2*4*3*4).reshape(2,4,3,4), mstype.float32)
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet(transpose_a=True)
|
||||
output = net(input_x, input_y)
|
||||
|
@ -82,7 +82,7 @@ def test_4D_transpose_a():
|
|||
def test_4D_transpose_b():
|
||||
input_x = Tensor(np.arange(2*4*1*3).reshape(2,4,1,3), mstype.float32)
|
||||
input_y = Tensor(np.arange(2*4*4*3).reshape(2,4,4,3), mstype.float32)
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet(transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
|
@ -104,7 +104,7 @@ def test_4D_transpose_b():
|
|||
def test_4D_transpose_ab():
|
||||
input_x = Tensor(np.arange(2*4*3*1).reshape(2,4,3,1), mstype.float32)
|
||||
input_y = Tensor(np.arange(2*4*4*3).reshape(2,4,4,3), mstype.float32)
|
||||
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet(transpose_a=True, transpose_b=True)
|
||||
output = net(input_x, input_y)
|
||||
|
@ -118,3 +118,29 @@ def test_4D_transpose_ab():
|
|||
[[4163, 4334, 4505, 4676]],
|
||||
[[5612, 5810, 6008, 6206]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
||||
class BatchMatMulNet(nn.Cell):
|
||||
def __init__(self, transpose_a=False, transpose_b=False):
|
||||
super(BatchMatMulNet, self).__init__()
|
||||
self.batch_matmul = P.BatchMatMul(transpose_a, transpose_b)
|
||||
|
||||
def construct(self, x, y):
|
||||
return self.batch_matmul(x, y)
|
||||
|
||||
def test_4D_fp16():
|
||||
input_x = Tensor(np.arange(2 * 4 * 1 * 3).reshape(2, 4, 1, 3), mstype.float16)
|
||||
input_y = Tensor(np.arange(2 * 4 * 3 * 4).reshape(2, 4, 3, 4), mstype.float16)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
net = BatchMatMulNet()
|
||||
output = net(input_x, input_y)
|
||||
expect = [[[[ 20, 23, 26, 29]],
|
||||
[[ 200, 212, 224, 236]],
|
||||
[[ 596, 617, 638, 659]],
|
||||
[[1208, 1238, 1268, 1298]]],
|
||||
|
||||
[[[2036, 2075, 2114, 2153]],
|
||||
[[3080, 3128, 3176, 3224]],
|
||||
[[4340, 4397, 4454, 4511]],
|
||||
[[5816, 5882, 5948, 6014]]]]
|
||||
assert (output.asnumpy() == expect).all()
|
||||
|
|
Loading…
Reference in New Issue