!11095 Modify dot opt to support pynative mode

From: @xutianming1985
Reviewed-by: @wuxuejian,@liangchenghui
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2021-01-11 09:21:42 +08:00 committed by Gitee
commit 256a1ef5c8
2 changed files with 25 additions and 2 deletions

View File

@ -273,6 +273,13 @@ def _check_invalid_input(x1_shape, x2_shape):
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).')
@constexpr
def _get_transpose_shape(x2_shape):
x2_shape_range = tuple(range(len(x2_shape)))
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
return x2_shape_transpose
def dot(x1, x2):
"""
Computation a dot product between samples in two tensors.
@ -304,8 +311,7 @@ def dot(x1, x2):
_check_invalid_input(x1_shape, x2_shape)
if len(x1_shape) > 2 or len(x2_shape) > 2:
x2_shape_range = range(len(x2_shape))
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
x2_shape_transpose = _get_transpose_shape(x2_shape)
x2_transpose = transpose_op(x2, x2_shape_transpose)
x1_reshape = reshape_op(x1, (-1, x1_shape[-1]))
x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1))

View File

@ -207,3 +207,20 @@ def test_dot_010():
[[3., 3.]]]).astype(np.float32)
assert (ms_result_np.asnumpy() == expect_result).all()
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_dot_011():
# for document
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
input_x1 = Tensor(np.array(np.ones(shape=[2, 3])).astype(np.float32))
input_x2 = Tensor(np.array(np.ones(shape=[1, 3, 2])).astype(np.float32))
network = NetDot()
ms_result_np = network(input_x1, input_x2)
expect_result = np.array([[[3., 3.]],
[[3., 3.]]]).astype(np.float32)
assert (ms_result_np.asnumpy() == expect_result).all()