!10895 add timedistributed pynative mode

From: @dinglongwei
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @wuxuejian,@c_34
This commit is contained in:
mindspore-ci-bot 2021-01-04 14:12:41 +08:00 committed by Gitee
commit efd22c96ad
2 changed files with 26 additions and 3 deletions

View File

@ -17,6 +17,7 @@
from mindspore.ops.primitive import constexpr, Primitive
from mindspore.ops import Reshape, Transpose, Pack, Unpack
from mindspore.common.dtype import tensor
from mindspore.common import Tensor
from ..cell import Cell
__all__ = ['TimeDistributed']
@ -104,7 +105,9 @@ class TimeDistributed(Cell):
self.reshape = Reshape()
def construct(self, inputs):
_check_data(isinstance(inputs, tensor))
is_capital_tensor = isinstance(inputs, Tensor)
is_tensor = True if is_capital_tensor else isinstance(inputs, tensor)
_check_data(is_tensor)
_check_inputs_dim(inputs.shape)
time_axis = self.time_axis % len(inputs.shape)
if self.reshape_with_axis is not None:
@ -119,7 +122,9 @@ class TimeDistributed(Cell):
inputs_shape_new = inputs.shape
inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:])
outputs = self.layer(inputs)
_check_data(isinstance(outputs, tensor))
is_capital_tensor = isinstance(outputs, Tensor)
is_tensor = True if is_capital_tensor else isinstance(outputs, tensor)
_check_data(is_tensor)
_check_reshape_pos(reshape_pos, inputs.shape, outputs.shape)
outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2]
if reshape_pos + 1 < len(outputs.shape):
@ -131,7 +136,9 @@ class TimeDistributed(Cell):
y = ()
for item in inputs:
outputs = self.layer(item)
_check_data(isinstance(outputs, tensor))
is_capital_tensor = isinstance(outputs, Tensor)
is_tensor = True if is_capital_tensor else isinstance(outputs, tensor)
_check_data(is_tensor)
_check_expand_dims_axis(time_axis, outputs.ndim)
y += (outputs,)
y = Pack(time_axis)(y)

View File

@ -78,6 +78,22 @@ def test_time_distributed_dense():
print("Dense layer wrapped successful")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard
def test_time_distributed_dense_pynative():
context.set_context(mode=context.PYNATIVE_MODE, device_target='CPU')
inputs = np.random.randint(0, 10, [32, 10])
dense = nn.Dense(10, 6)
output_expect = dense(Tensor(inputs, mindspore.float32)).asnumpy()
inputs = inputs.reshape([32, 1, 10]).repeat(6, axis=1)
time_distributed = TestTimeDistributed(dense, time_axis=1, reshape_with_axis=0)
output = time_distributed(Tensor(inputs, mindspore.float32)).asnumpy()
for i in range(output.shape[1]):
assert np.all(output[:, i, :] == output_expect)
print("Dense layer with pynative mode wrapped successful")
@pytest.mark.level0
@pytest.mark.platform_x86_cpu
@pytest.mark.env_onecard