forked from mindspore-Ecosystem/mindspore
add fp16 support for RNNs
This commit is contained in:
parent
26cb1fbcee
commit
399aa75b45
|
@ -50,6 +50,12 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
|
|||
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_input_dtype_same_and_valid(args_name, args_value, valid_values, cls_name):
|
||||
args = {args_name[i]: args_value[i] for i in range(len(args_value))}
|
||||
validator.check_types_same_and_valid(args, valid_values, cls_name)
|
||||
|
||||
|
||||
@constexpr
|
||||
def _check_is_tensor(param_name, input_data, cls_name):
|
||||
"""Internal function, used to check whether the input data is Tensor."""
|
||||
|
@ -173,9 +179,12 @@ class _DynamicRNNBase(Cell):
|
|||
return outputs, state_t
|
||||
|
||||
def construct(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh):
|
||||
x_dtype = x.dtype
|
||||
if seq_length is None:
|
||||
return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh)
|
||||
return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh)
|
||||
return self.recurrent(x, h, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
|
||||
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
|
||||
return self.variable_recurrent(x, h, seq_length, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
|
||||
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
|
||||
|
||||
|
||||
class _DynamicRNNRelu(_DynamicRNNBase):
|
||||
|
@ -270,7 +279,7 @@ class _DynamicLSTMCPUGPU(Cell):
|
|||
x,
|
||||
P.ExpandDims()(h_0[0], 0),
|
||||
P.ExpandDims()(h_0[1], 0),
|
||||
weights
|
||||
weights.astype(x.dtype)
|
||||
)
|
||||
return output, (h_n, c_n)
|
||||
|
||||
|
@ -493,35 +502,36 @@ class _RNNBase(Cell):
|
|||
max_batch_size = x.shape[0] if self.batch_first else x.shape[1]
|
||||
num_directions = 2 if self.bidirectional else 1
|
||||
_check_is_tensor("x", x, self.cls_name)
|
||||
_check_input_dtype(x.dtype, "x", [mstype.float32], self.cls_name)
|
||||
x_dtype = x.dtype
|
||||
if hx is not None:
|
||||
if not self.is_lstm:
|
||||
_check_is_tensor("h", hx, self.cls_name)
|
||||
_check_input_dtype(hx.dtype, "hx", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype_same_and_valid(['x', 'hx'], [x_dtype, hx.dtype], \
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
else:
|
||||
_check_is_tuple('hx', hx, self.cls_name)
|
||||
_check_tuple_length('hx', hx, 2, self.cls_name)
|
||||
_check_is_tensor('hx[0]', hx[0], self.cls_name)
|
||||
_check_is_tensor('hx[1]', hx[1], self.cls_name)
|
||||
_check_input_dtype(hx[0].dtype, "hx[0]", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype(hx[1].dtype, "hx[1]", [mstype.float32], self.cls_name)
|
||||
_check_input_dtype_same_and_valid(['x', 'hx[0]', 'hx[1]'], [x_dtype, hx[0].dtype, hx[1].dtype], \
|
||||
[mstype.float32, mstype.float16], self.cls_name)
|
||||
else:
|
||||
hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), \
|
||||
x_dtype, self.is_lstm)
|
||||
if seq_length is not None:
|
||||
_check_input_dtype(seq_length.dtype, "seq_length", [mstype.int32, mstype.int64], self.cls_name)
|
||||
_check_seq_length_size(max_batch_size, seq_length.shape[0], self.cls_name)
|
||||
if hx is None:
|
||||
hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), \
|
||||
x.dtype, self.is_lstm)
|
||||
if self.batch_first:
|
||||
x = P.Transpose()(x, (1, 0, 2))
|
||||
if self.bidirectional:
|
||||
x, h = self._stacked_bi_dynamic_rnn(x, hx, seq_length)
|
||||
x_n, hx_n = self._stacked_bi_dynamic_rnn(x, hx, seq_length)
|
||||
else:
|
||||
x, h = self._stacked_dynamic_rnn(x, hx, seq_length)
|
||||
x_n, hx_n = self._stacked_dynamic_rnn(x, hx, seq_length)
|
||||
if self.batch_first:
|
||||
x = P.Transpose()(x, (1, 0, 2))
|
||||
x_n = P.Transpose()(x_n, (1, 0, 2))
|
||||
if not self.is_lstm:
|
||||
return x.astype(mstype.float32), h.astype(mstype.float32)
|
||||
return x.astype(mstype.float32), (h[0].astype(mstype.float32), h[1].astype(mstype.float32))
|
||||
return x_n.astype(x_dtype), hx_n.astype(x_dtype)
|
||||
return x_n.astype(x_dtype), (hx_n[0].astype(x_dtype), hx_n[1].astype(x_dtype))
|
||||
|
||||
|
||||
class RNN(_RNNBase):
|
||||
|
@ -553,9 +563,9 @@ class RNN(_RNNBase):
|
|||
num_directions=2 if bidirectional=True otherwise 1. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of data type mindspore.float32 and
|
||||
- **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
||||
shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
|
||||
- **hx** (Tensor) - Tensor of data type mindspore.float32 and
|
||||
- **hx** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
||||
shape (num_directions * `num_layers`, batch_size, `hidden_size`). The data type of `hx` must be the same as
|
||||
`x`.
|
||||
- **seq_length** (Tensor) - The length of each sequence in an input batch.
|
||||
|
@ -649,9 +659,9 @@ class GRU(_RNNBase):
|
|||
num_directions=2 if bidirectional=True otherwise 1. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Tensor of data type mindspore.float32 and
|
||||
- **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
||||
shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
|
||||
- **hx** (Tensor) - Tensor of data type mindspore.float32 and
|
||||
- **hx** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
||||
shape (num_directions * `num_layers`, batch_size, `hidden_size`). The data type of `hx` must be the same as
|
||||
`x`.
|
||||
- **seq_length** (Tensor) - The length of each sequence in an input batch.
|
||||
|
@ -738,10 +748,10 @@ class LSTM(_RNNBase):
|
|||
num_directions=2 if bidirectional=True otherwise 1. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - (Tensor) - Tensor of data type mindspore.float32 and
|
||||
- **x** (Tensor) - (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
|
||||
shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
|
||||
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
|
||||
and shape (num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
or mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
The data type of `hx` must be the same as `x`.
|
||||
- **seq_length** (Tensor) - The length of each sequence in an input batch.
|
||||
Tensor of shape :math:`(\text{batch_size})`. Default: None.
|
||||
|
@ -759,7 +769,7 @@ class LSTM(_RNNBase):
|
|||
Raises:
|
||||
TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
|
||||
TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool.
|
||||
TypeError: If `dropout` is neither a float.
|
||||
TypeError: If `dropout` is not a float.
|
||||
ValueError: If `dropout` is not in range [0.0, 1.0].
|
||||
|
||||
Supported Platforms:
|
||||
|
|
Loading…
Reference in New Issue