add fp16 support for RNNs

This commit is contained in:
lvyufeng 2022-01-05 11:35:46 +08:00
parent 26cb1fbcee
commit 399aa75b45
1 changed files with 32 additions and 22 deletions

View File

@ -50,6 +50,12 @@ def _check_input_dtype(input_dtype, param_name, allow_dtypes, cls_name):
validator.check_type_name(param_name, input_dtype, allow_dtypes, cls_name)
@constexpr
def _check_input_dtype_same_and_valid(args_name, args_value, valid_values, cls_name):
args = {args_name[i]: args_value[i] for i in range(len(args_value))}
validator.check_types_same_and_valid(args, valid_values, cls_name)
@constexpr
def _check_is_tensor(param_name, input_data, cls_name):
"""Internal function, used to check whether the input data is Tensor."""
@ -173,9 +179,12 @@ class _DynamicRNNBase(Cell):
return outputs, state_t
def construct(self, x, h, seq_length, w_ih, w_hh, b_ih, b_hh):
x_dtype = x.dtype
if seq_length is None:
return self.recurrent(x, h, w_ih, w_hh, b_ih, b_hh)
return self.variable_recurrent(x, h, seq_length, w_ih, w_hh, b_ih, b_hh)
return self.recurrent(x, h, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
return self.variable_recurrent(x, h, seq_length, w_ih.astype(x_dtype), w_hh.astype(x_dtype), \
b_ih.astype(x_dtype), b_hh.astype(x_dtype))
class _DynamicRNNRelu(_DynamicRNNBase):
@ -270,7 +279,7 @@ class _DynamicLSTMCPUGPU(Cell):
x,
P.ExpandDims()(h_0[0], 0),
P.ExpandDims()(h_0[1], 0),
weights
weights.astype(x.dtype)
)
return output, (h_n, c_n)
@ -493,35 +502,36 @@ class _RNNBase(Cell):
max_batch_size = x.shape[0] if self.batch_first else x.shape[1]
num_directions = 2 if self.bidirectional else 1
_check_is_tensor("x", x, self.cls_name)
_check_input_dtype(x.dtype, "x", [mstype.float32], self.cls_name)
x_dtype = x.dtype
if hx is not None:
if not self.is_lstm:
_check_is_tensor("h", hx, self.cls_name)
_check_input_dtype(hx.dtype, "hx", [mstype.float32], self.cls_name)
_check_input_dtype_same_and_valid(['x', 'hx'], [x_dtype, hx.dtype], \
[mstype.float32, mstype.float16], self.cls_name)
else:
_check_is_tuple('hx', hx, self.cls_name)
_check_tuple_length('hx', hx, 2, self.cls_name)
_check_is_tensor('hx[0]', hx[0], self.cls_name)
_check_is_tensor('hx[1]', hx[1], self.cls_name)
_check_input_dtype(hx[0].dtype, "hx[0]", [mstype.float32], self.cls_name)
_check_input_dtype(hx[1].dtype, "hx[1]", [mstype.float32], self.cls_name)
_check_input_dtype_same_and_valid(['x', 'hx[0]', 'hx[1]'], [x_dtype, hx[0].dtype, hx[1].dtype], \
[mstype.float32, mstype.float16], self.cls_name)
else:
hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), \
x_dtype, self.is_lstm)
if seq_length is not None:
_check_input_dtype(seq_length.dtype, "seq_length", [mstype.int32, mstype.int64], self.cls_name)
_check_seq_length_size(max_batch_size, seq_length.shape[0], self.cls_name)
if hx is None:
hx = _init_state((self.num_layers * num_directions, max_batch_size, self.hidden_size), \
x.dtype, self.is_lstm)
if self.batch_first:
x = P.Transpose()(x, (1, 0, 2))
if self.bidirectional:
x, h = self._stacked_bi_dynamic_rnn(x, hx, seq_length)
x_n, hx_n = self._stacked_bi_dynamic_rnn(x, hx, seq_length)
else:
x, h = self._stacked_dynamic_rnn(x, hx, seq_length)
x_n, hx_n = self._stacked_dynamic_rnn(x, hx, seq_length)
if self.batch_first:
x = P.Transpose()(x, (1, 0, 2))
x_n = P.Transpose()(x_n, (1, 0, 2))
if not self.is_lstm:
return x.astype(mstype.float32), h.astype(mstype.float32)
return x.astype(mstype.float32), (h[0].astype(mstype.float32), h[1].astype(mstype.float32))
return x_n.astype(x_dtype), hx_n.astype(x_dtype)
return x_n.astype(x_dtype), (hx_n[0].astype(x_dtype), hx_n[1].astype(x_dtype))
class RNN(_RNNBase):
@ -553,9 +563,9 @@ class RNN(_RNNBase):
num_directions=2 if bidirectional=True otherwise 1. Default: False.
Inputs:
- **x** (Tensor) - Tensor of data type mindspore.float32 and
- **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
- **hx** (Tensor) - Tensor of data type mindspore.float32 and
- **hx** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
shape (num_directions * `num_layers`, batch_size, `hidden_size`). The data type of `hx` must be the same as
`x`.
- **seq_length** (Tensor) - The length of each sequence in an input batch.
@ -649,9 +659,9 @@ class GRU(_RNNBase):
num_directions=2 if bidirectional=True otherwise 1. Default: False.
Inputs:
- **x** (Tensor) - Tensor of data type mindspore.float32 and
- **x** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
- **hx** (Tensor) - Tensor of data type mindspore.float32 and
- **hx** (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
shape (num_directions * `num_layers`, batch_size, `hidden_size`). The data type of `hx` must be the same as
`x`.
- **seq_length** (Tensor) - The length of each sequence in an input batch.
@ -738,10 +748,10 @@ class LSTM(_RNNBase):
num_directions=2 if bidirectional=True otherwise 1. Default: False.
Inputs:
- **x** (Tensor) - (Tensor) - Tensor of data type mindspore.float32 and
- **x** (Tensor) - (Tensor) - Tensor of data type mindspore.float32 or mindspore.float16 and
shape (seq_len, batch_size, `input_size`) or (batch_size, seq_len, `input_size`).
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32
and shape (num_directions * `num_layers`, batch_size, `hidden_size`).
or mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`).
The data type of `hx` must be the same as `x`.
- **seq_length** (Tensor) - The length of each sequence in an input batch.
Tensor of shape :math:`(\text{batch_size})`. Default: None.
@ -759,7 +769,7 @@ class LSTM(_RNNBase):
Raises:
TypeError: If `input_size`, `hidden_size` or `num_layers` is not an int.
TypeError: If `has_bias`, `batch_first` or `bidirectional` is not a bool.
TypeError: If `dropout` is neither a float.
TypeError: If `dropout` is not a float.
ValueError: If `dropout` is not in range [0.0, 1.0].
Supported Platforms: