forked from mindspore-Ecosystem/mindspore
!11095 Modify dot opt to support pynative mode
From: @xutianming1985 Reviewed-by: @wuxuejian,@liangchenghui Signed-off-by: @liangchenghui
This commit is contained in:
commit
256a1ef5c8
|
@ -273,6 +273,13 @@ def _check_invalid_input(x1_shape, x2_shape):
|
|||
+ f'while x1 is ({len(x1_shape)}) and x2 is ({len(x2_shape)}).')
|
||||
|
||||
|
||||
@constexpr
|
||||
def _get_transpose_shape(x2_shape):
|
||||
x2_shape_range = tuple(range(len(x2_shape)))
|
||||
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
|
||||
return x2_shape_transpose
|
||||
|
||||
|
||||
def dot(x1, x2):
|
||||
"""
|
||||
Computation a dot product between samples in two tensors.
|
||||
|
@ -304,8 +311,7 @@ def dot(x1, x2):
|
|||
_check_invalid_input(x1_shape, x2_shape)
|
||||
|
||||
if len(x1_shape) > 2 or len(x2_shape) > 2:
|
||||
x2_shape_range = range(len(x2_shape))
|
||||
x2_shape_transpose = x2_shape_range[-2:-1] + x2_shape_range[:-2] + x2_shape_range[-1:]
|
||||
x2_shape_transpose = _get_transpose_shape(x2_shape)
|
||||
x2_transpose = transpose_op(x2, x2_shape_transpose)
|
||||
x1_reshape = reshape_op(x1, (-1, x1_shape[-1]))
|
||||
x2_reshape = reshape_op(x2_transpose, (x2_shape[-2], -1))
|
||||
|
|
|
@ -207,3 +207,20 @@ def test_dot_010():
|
|||
[[3., 3.]]]).astype(np.float32)
|
||||
|
||||
assert (ms_result_np.asnumpy() == expect_result).all()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_dot_011():
|
||||
# for document
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
input_x1 = Tensor(np.array(np.ones(shape=[2, 3])).astype(np.float32))
|
||||
input_x2 = Tensor(np.array(np.ones(shape=[1, 3, 2])).astype(np.float32))
|
||||
|
||||
network = NetDot()
|
||||
ms_result_np = network(input_x1, input_x2)
|
||||
expect_result = np.array([[[3., 3.]],
|
||||
[[3., 3.]]]).astype(np.float32)
|
||||
|
||||
assert (ms_result_np.asnumpy() == expect_result).all()
|
||||
|
|
Loading…
Reference in New Issue