forked from mindspore-Ecosystem/mindspore
!48287 transpose dynamic test case using set_inputs
Merge pull request !48287 from liubuyu/transpose
This commit is contained in:
commit
e78ab6f508
|
@ -22,75 +22,74 @@ import mindspore.ops as ops
|
|||
from mindspore import Tensor
|
||||
|
||||
|
||||
class TransposeDynNet(nn.Cell):
|
||||
def __init__(self, axis=0):
|
||||
super(TransposeDynNet, self).__init__()
|
||||
self.unique = ops.Unique()
|
||||
self.gather = ops.Gather()
|
||||
class Net(nn.Cell):
|
||||
def __init__(self, perm_in):
|
||||
super(Net, self).__init__()
|
||||
self.transpose = ops.Transpose()
|
||||
self.axis = axis
|
||||
self.perm = perm_in
|
||||
|
||||
def construct(self, x, perm, indices):
|
||||
unique_indices, _ = self.unique(indices)
|
||||
input_x = self.gather(x, unique_indices, self.axis)
|
||||
return self.transpose(input_x, perm)
|
||||
def construct(self, input_):
|
||||
x = self.transpose(input_, self.perm)
|
||||
return x
|
||||
|
||||
|
||||
def dyn_case():
|
||||
perm = (1, 0, 2)
|
||||
x = np.arange(2 * 2 * 4).reshape(2, 2, 4).astype(np.float32)
|
||||
indices = np.array([0, 1, 0], dtype=np.int32)
|
||||
expect = np.array([[[[0, 1, 2, 3],
|
||||
[8, 9, 10, 11]],
|
||||
[[4, 5, 6, 7],
|
||||
[12, 13, 14, 15]]]]).astype(np.float32)
|
||||
in_shape = (2, 4, 8)
|
||||
np_value = np.random.uniform(0, 20, size=in_shape).astype(np.float16)
|
||||
real_input = Tensor(np_value)
|
||||
|
||||
net = TransposeDynNet()
|
||||
output = net(Tensor(x), perm, Tensor(indices))
|
||||
assert (output.asnumpy() == expect).all()
|
||||
# dynamic transpose
|
||||
dyn_transpose = Net(perm)
|
||||
dyn_input = Tensor(shape=[None for _ in real_input.shape], dtype=real_input.dtype)
|
||||
dyn_transpose.set_inputs(dyn_input)
|
||||
dyn_out = dyn_transpose(real_input)
|
||||
|
||||
# static transpose
|
||||
static_transpose = Net(perm)
|
||||
static_out = static_transpose(real_input)
|
||||
|
||||
np.allclose(dyn_out.asnumpy(), static_out.asnumpy(), 1e-6, 1e-6)
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_cpu
|
||||
@pytest.mark.env_onecard
|
||||
def test_transpose_dyn_cpu():
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_transpose_dyn_cpu(mode):
|
||||
"""
|
||||
Feature: test Transpose dynamic shape on CPU.
|
||||
Description: inputs is dynamic shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="CPU")
|
||||
dyn_case()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="CPU")
|
||||
context.set_context(mode=mode, device_target="CPU")
|
||||
dyn_case()
|
||||
|
||||
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_x86_gpu_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_transpose_dyn_gpu():
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_transpose_dyn_gpu(mode):
|
||||
"""
|
||||
Feature: test Transpose dynamic shape on GPU.
|
||||
Description: inputs is dynamic shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
|
||||
dyn_case()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
|
||||
context.set_context(mode=mode, device_target="GPU")
|
||||
dyn_case()
|
||||
|
||||
|
||||
@pytest.mark.level1
|
||||
@pytest.mark.level0
|
||||
@pytest.mark.platform_arm_ascend_training
|
||||
@pytest.mark.platform_x86_ascend_training
|
||||
@pytest.mark.env_onecard
|
||||
def test_transpose_dyn_ascend():
|
||||
@pytest.mark.parametrize('mode', [context.GRAPH_MODE, context.PYNATIVE_MODE])
|
||||
def test_transpose_dyn_ascend(mode):
|
||||
"""
|
||||
Feature: test Transpose dynamic shape on Ascend.
|
||||
Description: inputs is dynamic shape.
|
||||
Expectation: the result match with numpy result
|
||||
"""
|
||||
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
|
||||
dyn_case()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
|
||||
context.set_context(mode=mode, device_target="Ascend")
|
||||
dyn_case()
|
||||
|
|
Loading…
Reference in New Issue