diff --git a/mindspore/ops/composite/math_ops.py b/mindspore/ops/composite/math_ops.py index 0d57c06a10e..64165d0a541 100644 --- a/mindspore/ops/composite/math_ops.py +++ b/mindspore/ops/composite/math_ops.py @@ -273,6 +273,13 @@ def _check_invalid_input(x1_shape, x2_shape): + f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).') +@constexpr +def _get_transpose_shape(x2_shape): + x2_shape_range = tuple(range(len(x2_shape))) + x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:] + return x2_shape_transpose + + def dot(x1, x2): """ Computation a dot product between samples in two tensors. @@ -304,8 +311,7 @@ def dot(x1, x2): _check_invalid_input(x1_shape, x2_shape) if len(x1_shape) > 2 or len(x2_shape) > 2: - x2_shape_range = range(len(x2_shape)) - x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:] + x2_shape_transpose = _get_transpose_shape(x2_shape) x2_transpose = transpose_op(x2, x2_shape_transpose) x1_reshape = reshape_op(x1, (-1, x1_shape[-1])) x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1)) diff --git a/tests/st/ops/cpu/test_dot_op.py b/tests/st/ops/cpu/test_dot_op.py index a3c68889a7d..1efb5a8200c 100644 --- a/tests/st/ops/cpu/test_dot_op.py +++ b/tests/st/ops/cpu/test_dot_op.py @@ -207,3 +207,20 @@ def test_dot_010(): [[3., 3.]]]).astype(np.float32) assert (ms_result_np.asnumpy() == expect_result).all() + + +@pytest.mark.level0 +@pytest.mark.platform_x86_cpu +@pytest.mark.env_onecard +def test_dot_011(): + # for document + context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU") + input_x1 = Tensor(np.array(np.ones(shape=[2, 3])).astype(np.float32)) + input_x2 = Tensor(np.array(np.ones(shape=[1, 3, 2])).astype(np.float32)) + + network = NetDot() + ms_result_np = network(input_x1, input_x2) + expect_result = np.array([[[3., 3.]], + [[3., 3.]]]).astype(np.float32) + + assert (ms_result_np.asnumpy() == expect_result).all()