forked from mindspore-Ecosystem/mindspore
dynamic_rnn for new backend opensource.
This commit is contained in:
parent
ce0c7a66cd
commit
dd5e4c9f9a
|
@ -805,6 +805,22 @@ def get_bprop_lstm(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.DynamicRNN)
|
||||
def get_bprop_dynamic_rnn(self):
|
||||
"""Grad definition for `DynamicRNN` operation."""
|
||||
dynamic_rnn_grad = G.DynamicRNNGrad(forget_bias=self.forget_bias)
|
||||
|
||||
def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
|
||||
dy, dh, dc, _, _, _, _, _, = dout
|
||||
dh = dh[-1]
|
||||
dc = dc[-1]
|
||||
y, h, c, i, j, f, o, tanhct = out
|
||||
dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
|
||||
c, dy, dh, dc, i, j, f, o, tanhct)
|
||||
return dx, dw, db, (0), dh_prev, dc_prev
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
|
||||
def get_bprop_sigmoid_crossentropy_with_logits(self):
|
||||
"""Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
|
||||
|
|
|
@ -274,6 +274,8 @@ from .basic_lstm_cell import _basic_lstm_cell_tbe
|
|||
from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
|
||||
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
|
||||
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
|
||||
from .dynamic_rnn import _dynamic_rnn_tbe
|
||||
from .lstm_input_grad import _lstm_input_grad_tbe
|
||||
from .confusion_matrix import _confusion_matrix_tbe
|
||||
from .broadcast_to import _broadcast_to_tbe
|
||||
from .strided_read import _strided_read_tbe
|
||||
|
|
|
@ -0,0 +1,70 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""DynamicRNN op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
dynamic_rnn_op_info = TBERegOp("DynamicRNN") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("dynamic_rnn.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("dynamic_rnn") \
|
||||
.attr("cell_type", "optional", "str", "all", "LSTM") \
|
||||
.attr("direction", "optional", "str", "all", "UNIDIRECTIONAL") \
|
||||
.attr("cell_depth", "optional", "int", "all", "1") \
|
||||
.attr("use_peephole", "optional", "bool", "all", "false") \
|
||||
.attr("keep_prob", "optional", "float", "all", "1") \
|
||||
.attr("cell_clip", "optional", "float", "all", "-1") \
|
||||
.attr("num_proj", "optional", "int", "all", "0") \
|
||||
.attr("time_major", "optional", "bool", "all", "false") \
|
||||
.attr("forget_bias", "optional", "float", "all", "0") \
|
||||
.attr("is_training", "optional", "bool", "all", "true") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "w", False, "required", "all", reshape_type="CN") \
|
||||
.input(2, "b", False, "required", "all") \
|
||||
.input(3, "seq_length", False, "optional", "all") \
|
||||
.input(4, "init_h", False, "optional", "all") \
|
||||
.input(5, "init_c", False, "optional", "all") \
|
||||
.input(6, "wci", False, "optional", "all") \
|
||||
.input(7, "wcf", False, "optional", "all") \
|
||||
.input(8, "wco", False, "optional", "all") \
|
||||
.input(9, "mask", False, "optional", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "output_h", False, "required", "all") \
|
||||
.output(2, "output_c", False, "required", "all") \
|
||||
.output(3, "i", False, "required", "all") \
|
||||
.output(4, "j", False, "required", "all") \
|
||||
.output(5, "f", False, "required", "all") \
|
||||
.output(6, "o", False, "required", "all") \
|
||||
.output(7, "tanhc", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default,
|
||||
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||
DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.I32_Default,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(dynamic_rnn_op_info)
|
||||
def _dynamic_rnn_tbe():
|
||||
"""DynamicRNN TBE register"""
|
||||
return
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""LSTMInputGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
lstm_input_grad_op_info = TBERegOp("LSTMInputGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("lstm_input_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("lstm_input_grad") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "w", False, "required", "all") \
|
||||
.input(1, "init_c", False, "required", "all") \
|
||||
.input(2, "c", False, "required", "all") \
|
||||
.input(3, "dy", False, "required", "all") \
|
||||
.input(4, "dh", False, "required", "all") \
|
||||
.input(5, "dc", False, "required", "all") \
|
||||
.input(6, "i", False, "required", "all") \
|
||||
.input(7, "j", False, "required", "all") \
|
||||
.input(8, "f", False, "required", "all") \
|
||||
.input(9, "o", False, "required", "all") \
|
||||
.input(10, "tanhct", False, "optional", "all") \
|
||||
.output(0, "dx", False, "required", "all") \
|
||||
.output(1, "dh_prev", False, "required", "all") \
|
||||
.output(2, "dc_prev", False, "required", "all") \
|
||||
.output(3, "dgate", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(lstm_input_grad_op_info)
|
||||
def _lstm_input_grad_tbe():
|
||||
"""LSTMInputGrad TBE register"""
|
||||
return
|
|
@ -1014,6 +1014,53 @@ class LSTMGrad(PrimitiveWithInfer):
|
|||
return (dy_dtype, dy_dtype, dy_dtype, hx_dtype)
|
||||
|
||||
|
||||
class DynamicRNNGrad(PrimitiveWithInfer):
|
||||
"""Computes the input gradients of DynamicRNN."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
cell_type='LSTM',
|
||||
direction='UNIDIRECTIONAL',
|
||||
cell_depth=0,
|
||||
use_peephole=False,
|
||||
keep_prob=-1.0,
|
||||
cell_clip=-1.0,
|
||||
num_proj=0,
|
||||
time_major=False,
|
||||
forget_bias=0.0):
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
|
||||
c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
|
||||
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
|
||||
num_step, batch_size, input_size = x_shape
|
||||
hidden_size = w_shape[-1] // 4
|
||||
if w_shape[-1] % 4 != 0:
|
||||
raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
|
||||
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
|
||||
input_size + hidden_size, Rel.EQ, self.name)
|
||||
valid_shape = [num_step, batch_size, hidden_size]
|
||||
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||
validator.check("y_shape", y_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("c_shape", c_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("i_shape", i_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("j_shape", j_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("f_shape", f_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("o_shape", o_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("tanhc_shape", tanhc_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("dh_shape", dh_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
|
||||
validator.check("dc_shape", dc_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
|
||||
|
||||
return w_shape, (w_shape[1],), x_shape, dh_shape, dc_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, w_dtype, b_dtype, y_dtype, init_h_dtype, init_c_dtype, h_dtype,
|
||||
c_dtype, dy_dtype, dh_dtype, dc_dtype, i_dtype, j_dtype, f_dtype, o_dtype, tanhc_dtype):
|
||||
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
|
||||
|
||||
|
||||
class PReLUGrad(PrimitiveWithInfer):
|
||||
r"""
|
||||
Gradients of PReLU operation.
|
||||
|
|
|
@ -573,3 +573,111 @@ class MatrixSetDiag(PrimitiveWithInfer):
|
|||
x_shape[:-2] + x_shape[-1:], Rel.EQ, self.name)
|
||||
|
||||
return assist_shape
|
||||
|
||||
|
||||
class DynamicRNN(PrimitiveWithInfer):
|
||||
r"""
|
||||
DynamicRNN Operator.
|
||||
|
||||
Args:
|
||||
cell_type (str): An string identifying the cell type in the op. Default: 'LSTM'.
|
||||
Only 'LSTM' is currently supported.
|
||||
direction (str): An string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
||||
Only 'UNIDIRECTIONAL' is currently supported.
|
||||
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
|
||||
use_peephole (bool): An bool identifying if use peephole in the op. Default: False.
|
||||
keep_prob (float): An float identifying the keep prob in the op. Default: 1.0.
|
||||
cell_clip (float): An float identifying the cell clip in the op. Default: -1.0.
|
||||
num_proj (int): An integer identifying the num proj in the op. Default: 0.
|
||||
time_major (bool): An bool identifying the time major in the op. Default: False.
|
||||
forget_bias (float): An float identifying the forget bias in the op. Default: 0.0.
|
||||
is_training (bool): An bool identifying is training in the op. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **w** (Tensor) - Weight. Tensor of shape :math:`(input_size + hidden_size, 4 x hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **b** (Tensor) - Bias. Tensor of shape :math:`(4 x hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **seq_length (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`.
|
||||
Only `None` is currently supported.
|
||||
- **init_h (Tensor) - Hidden state of initial time. Tensor of shape :math:`(1, batch_size, hidden_size)`.
|
||||
- **init_c (Tensor) - Cell state of initial time. Tensor of shape :math:`(1, batch_size, hidden_size)`.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
Has the same type with input `b`.
|
||||
- **output_h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
With data type of float16.
|
||||
- **output_c** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
Has the same type with input `b`.
|
||||
- **i** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
Has the same type with input `b`.
|
||||
- **j** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
Has the same type with input `b`.
|
||||
- **f** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
Has the same type with input `b`.
|
||||
- **o** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
Has the same type with input `b`.
|
||||
- **tanhct** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
Has the same type with input `b`.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
|
||||
>>> w = Tensor(np.random.rand(96, 128).astype(np.float16))
|
||||
>>> b = Tensor(np.random.rand(128).astype(np.float16))
|
||||
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
||||
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
||||
>>> dynamic_rnn = P.DynamicRNN()
|
||||
>>> output = lstm(x, w, b, None, init_h, init_c)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
cell_type='LSTM',
|
||||
direction='UNIDIRECTIONAL',
|
||||
cell_depth=1,
|
||||
use_peephole=False,
|
||||
keep_prob=1.0,
|
||||
cell_clip=-1.0,
|
||||
num_proj=0,
|
||||
time_major=False,
|
||||
forget_bias=0.0,
|
||||
is_training=True):
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
|
||||
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
|
||||
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
|
||||
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
|
||||
validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name)
|
||||
validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name)
|
||||
if seq_shape is not None:
|
||||
raise ValueError(f"For {self.name}, seq_shape should be None.")
|
||||
|
||||
num_step, batch_size, input_size = x_shape
|
||||
hidden_size = w_shape[-1] // 4
|
||||
|
||||
validator.check("b_shape[-1]", b_shape[-1], "w_shape[-1]", w_shape[-1], Rel.EQ, self.name)
|
||||
if w_shape[-1] % 4 != 0:
|
||||
raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
|
||||
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
|
||||
input_size + hidden_size, Rel.EQ, self.name)
|
||||
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
|
||||
validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name)
|
||||
validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
|
||||
validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
|
||||
|
||||
y_shape = (num_step, batch_size, hidden_size)
|
||||
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
|
||||
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
||||
|
|
Loading…
Reference in New Issue