diff --git a/mindspore/nn/layer/math.py b/mindspore/nn/layer/math.py index 3f3c4fa2e3..ff59fe741b 100644 --- a/mindspore/nn/layer/math.py +++ b/mindspore/nn/layer/math.py @@ -25,7 +25,7 @@ from ..._checkparam import Validator as validator from ..._checkparam import Rel -__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma'] +__all__ = ['ReduceLogSumExp', 'Range', 'LinSpace', 'LGamma', 'MatMul'] class ReduceLogSumExp(Cell): @@ -302,3 +302,106 @@ class LGamma(Cell): result = self.select(need_to_reflect, reflection, log_y) return self.select(self.isfinite(input_x), result, infinity) + + +@constexpr +def get_broadcast_matmul_shape(x_shape, y_shape): + """get broadcast_matmul shape""" + if (len(x_shape) < 2) or (len(y_shape) < 2): + raise ValueError('For matmul, rank of x1 and x2 should be equal to or greater than 2, ' + + f'but got {x_shape} and {y_shape}.') + x_shape_batch = x_shape[:-2] + y_shape_batch = y_shape[:-2] + if x_shape_batch == y_shape_batch: + return x_shape, y_shape + x_len = len(x_shape) + y_len = len(y_shape) + length = x_len if x_len < y_len else y_len + broadcast_shape_back = [] + for i in range(-length, -2): + if x_shape[i] == 1: + broadcast_shape_back.append(y_shape[i]) + elif y_shape[i] == 1: + broadcast_shape_back.append(x_shape[i]) + elif x_shape[i] == y_shape[i]: + broadcast_shape_back.append(x_shape[i]) + else: + raise ValueError(f"For MatMul, the x1_shape {x_shape} and x2_shape {y_shape} can not broadcast.") + + broadcast_shape_front = y_shape[0: y_len - length] if length == x_len else x_shape[0: x_len - length] + x_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + x_shape[-2:] + y_broadcast_shape = broadcast_shape_front + tuple(broadcast_shape_back) + y_shape[-2:] + return x_broadcast_shape, y_broadcast_shape + + +@constexpr +def check_col_row_equal(x1_shape, x2_shape, transpose_x1, transpose_x2): + """check col and row equal""" + x1_last = x1_shape[-2:] + x2_last = x2_shape[-2:] + x1_col = x1_last[not transpose_x1] # x1_col = x1_last[1] if (not transpose_a) else x1_last[0] + x2_row = x2_last[transpose_x2] # x2_row = x2_last[0] if (not transpose_b) else x2_last[1] + if x1_col != x2_row: + raise ValueError('The column of matrix dimensions of x1 should be equal to ' + + f'the row of matrix dimensions of x2, but got {x1_col} and {x2_row}.') + + +class MatMul(Cell): + """ + Multiplies matrix `x1` by matrix `x2`. + + The rank of input tensors must be not less than `2`. The none-matrix dimensions(batch) of inputs + will be broadcasted and must be broadcastable. + + Args: + transpose_x1 (bool): If True, `a` is transposed before multiplication. Default: False. + transpose_x2 (bool): If True, `b` is transposed before multiplication. Default: False. + + Inputs: + - **input_x1** (Tensor) - The first tensor to be multiplied. The shape of the tensor is :math:`(*A, N, C)`, + where :math:`*A` represents the batch size of `x1` which can be multidimensional. + If `transpose_a` is True, its shape should be :math:`(*A, N, C)` after transposing. + - **input_x2** (Tensor) - The second tensor to be multiplied. The shape of the tensor is :math:`(*B, C, M)`, + where :math:`*B` represents the batch size of `x2` which can be multidimensional. + If `transpose_b` is True, its shape should be :math:`(*B, C, M)` after transposing. + + Outputs: + Tensor, the shape of the output tensor is :math:`(*L, N, M)`. :math:`*L` is the batch size after broadcasting. + + Examples: + >>> net = nn.MatMul() + >>> input_x1 = Tensor(np.ones(shape=[3, 2, 3]), mindspore.float32) + >>> input_x2 = Tensor(np.ones(shape=[3, 4]), mindspore.float32) + >>> output = net(input_x1, input_x2) + >>> print(output.shape) + (3, 2, 4) + """ + + def __init__(self, transpose_x1=False, transpose_x2=False): + super(MatMul, self).__init__() + + validator.check_value_type('transpose_x1', transpose_x1, [bool], self.cls_name) + validator.check_value_type('transpose_x2', transpose_x2, [bool], self.cls_name) + self.transpose_x1 = transpose_x1 + self.transpose_x2 = transpose_x2 + self.shape_op = P.Shape() + self.matmul_op = P.MatMul(self.transpose_x1, self.transpose_x2) + self.batch_matmul_op = P.BatchMatMul(self.transpose_x1, self.transpose_x2) + + def construct(self, x1, x2): + x1_shape = self.shape_op(x1) + x2_shape = self.shape_op(x2) + check_col_row_equal(x1_shape, x2_shape, self.transpose_x1, self.transpose_x2) + + x1_broadcast_shape, x2_broadcast_shape = get_broadcast_matmul_shape(x1_shape, x2_shape) + x1_broadcast_to = P.BroadcastTo(x1_broadcast_shape) + x2_broadcast_to = P.BroadcastTo(x2_broadcast_shape) + if x1_broadcast_shape != x1_shape: + x1 = x1_broadcast_to(x1) + if x2_broadcast_shape != x2_shape: + x2 = x2_broadcast_to(x2) + if len(x1_broadcast_shape) == 2: + matmul_broadcast = self.matmul_op(x1, x2) + else: + matmul_broadcast = self.batch_matmul_op(x1, x2) + return matmul_broadcast diff --git a/tests/st/ops/ascend/test_nn_matmul.py b/tests/st/ops/ascend/test_nn_matmul.py new file mode 100644 index 0000000000..280996d044 --- /dev/null +++ b/tests/st/ops/ascend/test_nn_matmul.py @@ -0,0 +1,56 @@ +import numpy as np + +import mindspore.context as context +import mindspore.nn as nn +from mindspore import Tensor + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend") + + +class Net(nn.Cell): + def __init__(self, transpose_x1, transpose_x2): + super(Net, self).__init__() + self.matmul = nn.MatMul(transpose_x1, transpose_x2) + + def construct(self, x1, x2): + return self.matmul(x1, x2) + + +def test_x1_2D_x2_3D(): + x1 = np.random.randn(16, 64).astype(np.float32) + x2 = np.random.randn(32, 64, 20).astype(np.float32) + transpose_x1 = False + transpose_x2 = False + net = Net(transpose_x1, transpose_x2) + output = net(Tensor(x1), Tensor(x2)) + assert output.shape == (32, 16, 20) + + +def test_x1_4D_x2_3D_transpose_x2_True(): + x1 = np.random.randn(3, 2, 3, 4).astype(np.float32) + x2 = np.random.randn(1, 5, 4).astype(np.float32) + transpose_x1 = False + transpose_x2 = True + net = Net(transpose_x1, transpose_x2) + output = net(Tensor(x1), Tensor(x2)) + assert output.shape == (3, 2, 3, 5) + + +def test_x1_3D_transpose_x1_True_x2_2D(): + x1 = np.random.randn(2, 3, 4).astype(np.float32) + x2 = np.random.randn(3, 4).astype(np.float32) + transpose_x1 = True + transpose_x2 = False + net = Net(transpose_x1, transpose_x2) + output = net(Tensor(x1), Tensor(x2)) + assert output.shape == (2, 4, 4) + + +def test_x1_3D_transpose_x1_True_x2_3D_transpose_x2_True(): + x1 = np.random.randn(2, 5, 6).astype(np.float32) + x2 = np.random.randn(2, 4, 5).astype(np.float32) + transpose_x1 = True + transpose_x2 = True + net = Net(transpose_x1, transpose_x2) + output = net(Tensor(x1), Tensor(x2)) + assert output.shape == (2, 6, 4)