From d9d53a77100fdc98e0a1d5db26872d944660f242 Mon Sep 17 00:00:00 2001 From: dinglongwei Date: Thu, 7 Jan 2021 12:31:13 +0800 Subject: [PATCH] unified Tensor type of graph mode and pynative mode --- mindspore/nn/layer/timedistributed.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/mindspore/nn/layer/timedistributed.py b/mindspore/nn/layer/timedistributed.py index 0b1541ef859..2b40934bb19 100644 --- a/mindspore/nn/layer/timedistributed.py +++ b/mindspore/nn/layer/timedistributed.py @@ -16,7 +16,6 @@ from mindspore.ops.primitive import constexpr, Primitive from mindspore.ops import Reshape, Transpose, Pack, Unpack -from mindspore.common.dtype import tensor from mindspore.common import Tensor from ..cell import Cell @@ -105,9 +104,7 @@ class TimeDistributed(Cell): self.reshape = Reshape() def construct(self, inputs): - is_capital_tensor = isinstance(inputs, Tensor) - is_tensor = True if is_capital_tensor else isinstance(inputs, tensor) - _check_data(is_tensor) + _check_data(isinstance(inputs, Tensor)) _check_inputs_dim(inputs.shape) time_axis = self.time_axis % len(inputs.shape) if self.reshape_with_axis is not None: @@ -122,9 +119,7 @@ class TimeDistributed(Cell): inputs_shape_new = inputs.shape inputs = self.reshape(inputs, inputs_shape_new[: reshape_pos] + (-1,) + inputs_shape_new[reshape_pos + 2:]) outputs = self.layer(inputs) - is_capital_tensor = isinstance(outputs, Tensor) - is_tensor = True if is_capital_tensor else isinstance(outputs, tensor) - _check_data(is_tensor) + _check_data(isinstance(outputs, Tensor)) _check_reshape_pos(reshape_pos, inputs.shape, outputs.shape) outputs_shape_new = outputs.shape[:reshape_pos] + inputs_shape_new[reshape_pos: reshape_pos + 2] if reshape_pos + 1 < len(outputs.shape): @@ -136,9 +131,7 @@ class TimeDistributed(Cell): y = () for item in inputs: outputs = self.layer(item) - is_capital_tensor = isinstance(outputs, Tensor) - is_tensor = True if is_capital_tensor else isinstance(outputs, tensor) - _check_data(is_tensor) + _check_data(isinstance(outputs, Tensor)) _check_expand_dims_axis(time_axis, outputs.ndim) y += (outputs,) y = Pack(time_axis)(y)