forked from mindspore-Ecosystem/mindspore
add notice for LSTMCell and fix narrow description
This commit is contained in:
parent
54c9971ad6
commit
2ca23cc3bd
|
@ -846,13 +846,13 @@ class Tensor(Tensor_):
|
|||
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mindspore.int32)
|
||||
>>> output = x.narrow(0, 0, 2)
|
||||
>>> print(output)
|
||||
[[ 1, 2, 3],
|
||||
[ 4, 5, 6]]
|
||||
[[ 1 2 3],
|
||||
[ 4 5 6]]
|
||||
>>> output = x.narrow(1, 1, 2)
|
||||
>>> print(output)
|
||||
[[ 2, 3],
|
||||
[ 5, 6],
|
||||
[ 8, 9]]
|
||||
[[ 2 3],
|
||||
[ 5 6],
|
||||
[ 8 9]]
|
||||
"""
|
||||
self._init_check()
|
||||
return tensor_operator_registry.get('narrow')(self, axis, start, length)
|
||||
|
|
|
@ -17,6 +17,7 @@ import math
|
|||
import numpy as np
|
||||
import mindspore.ops as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import log as logger
|
||||
from mindspore.common.tensor import Tensor
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.common.initializer import initializer, Uniform
|
||||
|
@ -42,7 +43,7 @@ def _check_is_tensor(param_name, input_data, cls_name):
|
|||
@constexpr
|
||||
def _check_is_tuple(param_name, input_data, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
if input_data is not None and not isinstance({P.typeof(input_data)}, mstype.Tuple):
|
||||
if input_data is not None and not isinstance(P.typeof(input_data), mstype.Tuple):
|
||||
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.Tuple}', "
|
||||
f"but got '{P.typeof(input_data)}'")
|
||||
|
||||
|
@ -59,6 +60,17 @@ def _check_batch_size_equal(batch_size_x, batch_size_hx, cls_name):
|
|||
raise ValueError(f"For '{cls_name}' batch size of x and hx should be equal, but got {batch_size_x} of x "
|
||||
f"and {batch_size_hx} of hx.")
|
||||
|
||||
def _check_lstmcell_init(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
logger.warning(f"LSTMCell has been changed from 'single LSTM layer' to 'single LSTM cell', "
|
||||
f"if you still need use single LSTM layer, please use `nn.LSTM` instead.")
|
||||
if len(args) > 4 or 'batch_size' in kwargs or \
|
||||
'dropout' in kwargs or 'bidirectional' in kwargs:
|
||||
raise ValueError(f"The arguments of `nn.LSTMCell` from old MindSpore version(<1.6) are detected, "
|
||||
f"if you still need use single LSTM layer, please use `nn.LSTM` instead.")
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
|
||||
def _rnn_tanh_cell(inputs, hidden, w_ih, w_hh, b_ih, b_hh):
|
||||
'''RNN cell function with tanh activation'''
|
||||
if b_ih is None:
|
||||
|
@ -188,8 +200,8 @@ class RNNCell(RNNCellBase):
|
|||
"""
|
||||
_non_linearity = ['tanh', 'relu']
|
||||
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True, nonlinearity: str = "tanh"):
|
||||
super().__init__(input_size, hidden_size, bias, num_chunks=1)
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True, nonlinearity: str = "tanh"):
|
||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=1)
|
||||
if nonlinearity not in self._non_linearity:
|
||||
raise ValueError("Unknown nonlinearity: {}".format(nonlinearity))
|
||||
self.nonlinearity = nonlinearity
|
||||
|
@ -261,8 +273,9 @@ class LSTMCell(RNNCellBase):
|
|||
>>> print(output[0][0].shape)
|
||||
(3, 16)
|
||||
"""
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
|
||||
super().__init__(input_size, hidden_size, bias, num_chunks=4)
|
||||
@_check_lstmcell_init
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
|
||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=4)
|
||||
self.support_non_tensor_inputs = True
|
||||
|
||||
def construct(self, x, hx):
|
||||
|
@ -278,6 +291,15 @@ class LSTMCell(RNNCellBase):
|
|||
_check_batch_size_equal(x.shape[0], hx[1].shape[0], self.cls_name)
|
||||
return _lstm_cell(x, hx, self.weight_ih, self.weight_hh, self.bias_ih, self.bias_hh)
|
||||
|
||||
def _check_construct_args(self, *inputs, **kwargs):
|
||||
if len(inputs) == 4:
|
||||
raise ValueError(f"The number of input args of construct is {len(inputs)}, if you are using "
|
||||
f"the implementation of `nn.LSTMCell` from old MindSpore version(<1.6), "
|
||||
f"please notice that: LSTMCell has been changed from 'single LSTM layer' to "
|
||||
f"'single LSTM cell', if you still need use single LSTM layer, "
|
||||
f"please use `nn.LSTM` instead.")
|
||||
return super()._check_construct_args(*inputs, **kwargs)
|
||||
|
||||
class GRUCell(RNNCellBase):
|
||||
r"""
|
||||
A GRU(Gated Recurrent Unit) cell.
|
||||
|
@ -329,8 +351,8 @@ class GRUCell(RNNCellBase):
|
|||
>>> print(output[0].shape)
|
||||
(3, 16)
|
||||
"""
|
||||
def __init__(self, input_size: int, hidden_size: int, bias: bool = True):
|
||||
super().__init__(input_size, hidden_size, bias, num_chunks=3)
|
||||
def __init__(self, input_size: int, hidden_size: int, has_bias: bool = True):
|
||||
super().__init__(input_size, hidden_size, has_bias, num_chunks=3)
|
||||
|
||||
def construct(self, x, hx):
|
||||
_check_is_tensor('x', x, self.cls_name)
|
||||
|
|
|
@ -373,13 +373,13 @@ def narrow(inputs, axis, start, length):
|
|||
>>> x = Tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], mindspore.int32)
|
||||
>>> output = F.narrow(x, 0, 0, 2)
|
||||
>>> print(output)
|
||||
[[ 1, 2, 3],
|
||||
[ 4, 5, 6]]
|
||||
[[ 1 2 3],
|
||||
[ 4 5 6]]
|
||||
>>> output = F.narrow(x, 1, 1, 2)
|
||||
>>> print(output)
|
||||
[[ 2, 3],
|
||||
[ 5, 6],
|
||||
[ 8, 9]]
|
||||
[[ 2 3],
|
||||
[ 5 6],
|
||||
[ 8 9]]
|
||||
"""
|
||||
validator.check_axis_in_range(axis, inputs.ndim)
|
||||
validator.check_int_range(start, 0, inputs.shape[axis], Rel.INC_LEFT)
|
||||
|
|
Loading…
Reference in New Issue