!48287 transpose dynamic test case using set_inputs

Merge pull request !48287 from liubuyu/transpose
This commit is contained in:
i-robot 2023-02-02 12:59:36 +00:00 committed by Gitee
commit e78ab6f508
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 31 additions and 32 deletions

View File

@ -22,75 +22,74 @@ import mindspore.ops as ops
from mindspore import Tensor
class TransposeDynNet(nn.Cell):
def __init__(self, axis=0):
super(TransposeDynNet, self).__init__()
self.unique = ops.Unique()
self.gather = ops.Gather()
class Net(nn.Cell):
def __init__(self, perm_in):
super(Net, self).__init__()
self.transpose = ops.Transpose()
self.axis = axis
self.perm = perm_in
def construct(self, x, perm, indices):
unique_indices, _ = self.unique(indices)
input_x = self.gather(x, unique_indices, self.axis)
return self.transpose(input_x, perm)
def construct(self, input_):
x = self.transpose(input_, self.perm)
return x
def dyn_case():
perm = (1, 0, 2)
x = np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.float32)
indices = np.array([0, 1, 0], dtype=np.int32)
expect = np.array([[[[0, 1, 2, 3],
[8, 9, 10, 11]],
[[4, 5, 6, 7],
[12, 13, 14, 15]]]]).astype(np.float32)
in_shape = (2, 4, 8)
np_value = np.random.uniform(0, 20, size=in_shape).astype(np.float16)
real_input = Tensor(np_value)
net = TransposeDynNet()
output = net(Tensor(x), perm, Tensor(indices))
assert (output.asnumpy() == expect).all()
# dynamic transpose
dyn_transpose = Net(perm)
dyn_input = Tensor(shape=[None for _ in real_input.shape], dtype=real_input.dtype)
dyn_transpose.set_inputs(dyn_input)
dyn_out = dyn_transpose(real_input)
# static transpose
static_transpose = Net(perm)
static_out = static_transpose(real_input)
np.allclose(dyn_out.asnumpy(), static_out.asnumpy(), 1e-6, 1e-6)
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_transpose_dyn_cpu():
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_transpose_dyn_cpu(mode):
"""
Feature: test Transpose dynamic shape on CPU.
Description: inputs is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
dyn_case()
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
context.set_context(mode=mode, device_target="CPU")
dyn_case()
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_transpose_dyn_gpu():
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_transpose_dyn_gpu(mode):
"""
Feature: test Transpose dynamic shape on GPU.
Description: inputs is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
dyn_case()
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
context.set_context(mode=mode, device_target="GPU")
dyn_case()
@pytest.mark.level1
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_transpose_dyn_ascend():
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
def test_transpose_dyn_ascend(mode):
"""
Feature: test Transpose dynamic shape on Ascend.
Description: inputs is dynamic shape.
Expectation: the result match with numpy result
"""
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
dyn_case()
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
context.set_context(mode=mode, device_target="Ascend")
dyn_case()