add notice for LSTMCell and fix narrow description

This commit is contained in:
lvyufeng 2021-12-30 15:59:01 +08:00
parent 54c9971ad6
commit 2ca23cc3bd
3 changed files with 39 additions and 17 deletions

View File

@ -846,13 +846,13 @@ class Tensor(Tensor_):
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mindspore.int32)
>>> output = x.narrow(0, 0, 2)
>>> print(output)
[[ 1, 2, 3],
[ 4, 5, 6]]
[[ 1 2 3],
[ 4 5 6]]
>>> output = x.narrow(1, 1, 2)
>>> print(output)
[[ 2, 3],
[ 5, 6],
[ 8, 9]]
[[ 2 3],
[ 5 6],
[ 8 9]]
"""
self._init_check()
return tensor_operator_registry.get('narrow')(self, axis, start, length)

View File

@ -17,6 +17,7 @@ import math
import numpy as np
import mindspore.ops as P
import mindspore.common.dtype as mstype
from mindspore import log as logger
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.common.initializer import initializer, Uniform
@ -42,7 +43,7 @@ def _check_is_tensor(param_name, input_data, cls_name):
@constexpr
def _check_is_tuple(param_name, input_data, cls_name):
"""Internal function, used to check whether the input data is Tensor."""
if input_data is not None and not isinstance({P.typeof(input_data)}, mstype.Tuple):
if input_data is not None and not isinstance(P.typeof(input_data), mstype.Tuple):
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.Tuple}', "
f"but got '{P.typeof(input_data)}'")
@ -59,6 +60,17 @@ def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x "
f"and {batch_size_hx} of hx.")
def _check_lstmcell_init(func):
def wrapper(*args, **kwargs):
logger.warning(f"LSTMCell has been changed from 'single LSTM layer' to 'single LSTM cell', "
f"if you still need use single LSTM layer, please use `nn.LSTM` instead.")
if len(args) > 4 or 'batch_size' in kwargs or \
'dropout' in kwargs or 'bidirectional' in kwargs:
raise ValueError(f"The arguments of `nn.LSTMCell` from old MindSpore version(<1.6) are detected, "
f"if you still need use single LSTM layer, please use `nn.LSTM` instead.")
return func(*args, **kwargs)
return wrapper
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
'''RNN cell function with tanh activation'''
if b_ih is None:
@ -188,8 +200,8 @@ class RNNCell(RNNCellBase):
"""
_non_linearity = ['tanh', 'relu']
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh"):
super().__init__(input_size, hidden_size, bias, num_chunks=1)
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
if nonlinearity not in self._non_linearity:
raise ValueError("Unknown nonlinearity: {}".format(nonlinearity))
self.nonlinearity = nonlinearity
@ -261,8 +273,9 @@ class LSTMCell(RNNCellBase):
>>> print(output[0][0].shape)
(3, 16)
"""
def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
super().__init__(input_size, hidden_size, bias, num_chunks=4)
@_check_lstmcell_init
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
super().__init__(input_size, hidden_size, has_bias, num_chunks=4)
self.support_non_tensor_inputs = True
def construct(self, x, hx):
@ -278,6 +291,15 @@ class LSTMCell(RNNCellBase):
_check_batch_size_equal(x.shape[0], hx[1].shape[0], self.cls_name)
return _lstm_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
def _check_construct_args(self, *inputs, **kwargs):
if len(inputs) == 4:
raise ValueError(f"The number of input args of construct is {len(inputs)}, if you are using "
f"the implementation of `nn.LSTMCell` from old MindSpore version(<1.6), "
f"please notice that: LSTMCell has been changed from 'single LSTM layer' to "
f"'single LSTM cell', if you still need use single LSTM layer, "
f"please use `nn.LSTM` instead.")
return super()._check_construct_args(*inputs, **kwargs)
class GRUCell(RNNCellBase):
r"""
A GRU(Gated Recurrent Unit) cell.
@ -329,8 +351,8 @@ class GRUCell(RNNCellBase):
>>> print(output[0].shape)
(3, 16)
"""
def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
super().__init__(input_size, hidden_size, bias, num_chunks=3)
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
def construct(self, x, hx):
_check_is_tensor('x', x, self.cls_name)

View File

@ -373,13 +373,13 @@ def narrow(inputs, axis, start, length):
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mindspore.int32)
>>> output = F.narrow(x, 0, 0, 2)
>>> print(output)
[[ 1, 2, 3],
[ 4, 5, 6]]
[[ 1 2 3],
[ 4 5 6]]
>>> output = F.narrow(x, 1, 1, 2)
>>> print(output)
[[ 2, 3],
[ 5, 6],
[ 8, 9]]
[[ 2 3],
[ 5 6],
[ 8 9]]
"""
validator.check_axis_in_range(axis, inputs.ndim)
validator.check_int_range(start, 0, inputs.shape[axis], Rel.INC_LEFT)