dynamic_rnn for new backend opensource.

This commit is contained in:
liuxiao93 2020-09-12 18:38:00 +08:00
parent ce0c7a66cd
commit dd5e4c9f9a
6 changed files with 294 additions and 0 deletions

View File

@ -805,6 +805,22 @@ def get_bprop_lstm(self):
return bprop
@bprop_getters.register(inner.DynamicRNN)
def get_bprop_dynamic_rnn(self):
"""Grad definition for `DynamicRNN` operation."""
dynamic_rnn_grad = G.DynamicRNNGrad(forget_bias=self.forget_bias)
def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
dy, dh, dc, _, _, _, _, _, = dout
dh = dh[-1]
dc = dc[-1]
y, h, c, i, j, f, o, tanhct = out
dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
c, dy, dh, dc, i, j, f, o, tanhct)
return dx, dw, db, (0), dh_prev, dc_prev
return bprop
@bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
def get_bprop_sigmoid_crossentropy_with_logits(self):
"""Grad definition for `SigmoidCrossEntropyWithLogits` operation."""

View File

@ -274,6 +274,8 @@ from .basic_lstm_cell import _basic_lstm_cell_tbe
from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
from .dynamic_rnn import _dynamic_rnn_tbe
from .lstm_input_grad import _lstm_input_grad_tbe
from .confusion_matrix import _confusion_matrix_tbe
from .broadcast_to import _broadcast_to_tbe
from .strided_read import _strided_read_tbe

View File

@ -0,0 +1,70 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""DynamicRNN op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
dynamic_rnn_op_info = TBERegOp("DynamicRNN") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("dynamic_rnn.so") \
.compute_cost(10) \
.kernel_name("dynamic_rnn") \
.attr("cell_type", "optional", "str", "all", "LSTM") \
.attr("direction", "optional", "str", "all", "UNIDIRECTIONAL") \
.attr("cell_depth", "optional", "int", "all", "1") \
.attr("use_peephole", "optional", "bool", "all", "false") \
.attr("keep_prob", "optional", "float", "all", "1") \
.attr("cell_clip", "optional", "float", "all", "-1") \
.attr("num_proj", "optional", "int", "all", "0") \
.attr("time_major", "optional", "bool", "all", "false") \
.attr("forget_bias", "optional", "float", "all", "0") \
.attr("is_training", "optional", "bool", "all", "true") \
.partial_flag(True) \
.input(0, "x", False, "required", "all") \
.input(1, "w", False, "required", "all", reshape_type="CN") \
.input(2, "b", False, "required", "all") \
.input(3, "seq_length", False, "optional", "all") \
.input(4, "init_h", False, "optional", "all") \
.input(5, "init_c", False, "optional", "all") \
.input(6, "wci", False, "optional", "all") \
.input(7, "wcf", False, "optional", "all") \
.input(8, "wco", False, "optional", "all") \
.input(9, "mask", False, "optional", "all") \
.output(0, "y", False, "required", "all") \
.output(1, "output_h", False, "required", "all") \
.output(2, "output_c", False, "required", "all") \
.output(3, "i", False, "required", "all") \
.output(4, "j", False, "required", "all") \
.output(5, "f", False, "required", "all") \
.output(6, "o", False, "required", "all") \
.output(7, "tanhc", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F32_Default, DataType.I32_Default,
DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.U8_Default, DataType.F32_FracNZ, DataType.F16_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
DataType.F32_FracNZ, DataType.F32_FracNZ) \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZNLSTM, DataType.F16_Default, DataType.I32_Default,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.U8_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ) \
.get_op_info()
@op_info_register(dynamic_rnn_op_info)
def _dynamic_rnn_tbe():
"""DynamicRNN TBE register"""
return

View File

@ -0,0 +1,51 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""LSTMInputGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
lstm_input_grad_op_info = TBERegOp("LSTMInputGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("lstm_input_grad.so") \
.compute_cost(10) \
.kernel_name("lstm_input_grad") \
.partial_flag(True) \
.input(0, "w", False, "required", "all") \
.input(1, "init_c", False, "required", "all") \
.input(2, "c", False, "required", "all") \
.input(3, "dy", False, "required", "all") \
.input(4, "dh", False, "required", "all") \
.input(5, "dc", False, "required", "all") \
.input(6, "i", False, "required", "all") \
.input(7, "j", False, "required", "all") \
.input(8, "f", False, "required", "all") \
.input(9, "o", False, "required", "all") \
.input(10, "tanhct", False, "optional", "all") \
.output(0, "dx", False, "required", "all") \
.output(1, "dh_prev", False, "required", "all") \
.output(2, "dc_prev", False, "required", "all") \
.output(3, "dgate", False, "required", "all") \
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
.get_op_info()
@op_info_register(lstm_input_grad_op_info)
def _lstm_input_grad_tbe():
"""LSTMInputGrad TBE register"""
return

View File

@ -1014,6 +1014,53 @@ class LSTMGrad(PrimitiveWithInfer):
return (dy_dtype, dy_dtype, dy_dtype, hx_dtype)
class DynamicRNNGrad(PrimitiveWithInfer):
"""Computes the input gradients of DynamicRNN."""
@prim_attr_register
def __init__(self,
cell_type='LSTM',
direction='UNIDIRECTIONAL',
cell_depth=0,
use_peephole=False,
keep_prob=-1.0,
cell_clip=-1.0,
num_proj=0,
time_major=False,
forget_bias=0.0):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, w_shape, b_shape, y_shape, init_h_shape, init_c_shape, h_shape,
c_shape, dy_shape, dh_shape, dc_shape, i_shape, j_shape, f_shape, o_shape, tanhc_shape):
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
num_step, batch_size, input_size = x_shape
hidden_size = w_shape[-1] // 4
if w_shape[-1] % 4 != 0:
raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
input_size + hidden_size, Rel.EQ, self.name)
valid_shape = [num_step, batch_size, hidden_size]
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check("y_shape", y_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
validator.check("c_shape", c_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
validator.check("i_shape", i_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
validator.check("j_shape", j_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
validator.check("f_shape", f_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
validator.check("o_shape", o_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
validator.check("tanhc_shape", tanhc_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
validator.check("dh_shape", dh_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
validator.check("dc_shape", dc_shape, "excepted shape", [batch_size, hidden_size], Rel.EQ, self.name)
return w_shape, (w_shape[1],), x_shape, dh_shape, dc_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype, y_dtype, init_h_dtype, init_c_dtype, h_dtype,
c_dtype, dy_dtype, dh_dtype, dc_dtype, i_dtype, j_dtype, f_dtype, o_dtype, tanhc_dtype):
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
class PReLUGrad(PrimitiveWithInfer):
r"""
Gradients of PReLU operation.

View File

@ -573,3 +573,111 @@ class MatrixSetDiag(PrimitiveWithInfer):
x_shape[:-2] + x_shape[-1:], Rel.EQ, self.name)
return assist_shape
class DynamicRNN(PrimitiveWithInfer):
r"""
DynamicRNN Operator.
Args:
cell_type (str): An string identifying the cell type in the op. Default: 'LSTM'.
Only 'LSTM' is currently supported.
direction (str): An string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
Only 'UNIDIRECTIONAL' is currently supported.
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
use_peephole (bool): An bool identifying if use peephole in the op. Default: False.
keep_prob (float): An float identifying the keep prob in the op. Default: 1.0.
cell_clip (float): An float identifying the cell clip in the op. Default: -1.0.
num_proj (int): An integer identifying the num proj in the op. Default: 0.
time_major (bool): An bool identifying the time major in the op. Default: False.
forget_bias (float): An float identifying the forget bias in the op. Default: 0.0.
is_training (bool): An bool identifying is training in the op. Default: True.
Inputs:
- **x** (Tensor) - Current words. Tensor of shape :math:`(num_step, batch_size, input_size)`.
The data type must be float16 or float32.
- **w** (Tensor) - Weight. Tensor of shape :math:`(input_size + hidden_size, 4 x hidden_size)`.
The data type must be float16 or float32.
- **b** (Tensor) - Bias. Tensor of shape :math:`(4 x hidden_size)`.
The data type must be float16 or float32.
- **seq_length (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`.
Only `None` is currently supported.
- **init_h (Tensor) - Hidden state of initial time. Tensor of shape :math:`(1, batch_size, hidden_size)`.
- **init_c (Tensor) - Cell state of initial time. Tensor of shape :math:`(1, batch_size, hidden_size)`.
Outputs:
- **y** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **output_h** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
With data type of float16.
- **output_c** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **i** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **j** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **f** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **o** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
- **tanhct** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
Has the same type with input `b`.
Examples:
>>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
>>> w = Tensor(np.random.rand(96, 128).astype(np.float16))
>>> b = Tensor(np.random.rand(128).astype(np.float16))
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
>>> dynamic_rnn = P.DynamicRNN()
>>> output = lstm(x, w, b, None, init_h, init_c)
"""
@prim_attr_register
def __init__(self,
cell_type='LSTM',
direction='UNIDIRECTIONAL',
cell_depth=1,
use_peephole=False,
keep_prob=1.0,
cell_clip=-1.0,
num_proj=0,
time_major=False,
forget_bias=0.0,
is_training=True):
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
self.add_prim_attr("io_format", "ND")
def infer_shape(self, x_shape, w_shape, b_shape, seq_shape, h_shape, c_shape):
validator.check_integer("x_shape", len(x_shape), 3, Rel.EQ, self.name)
validator.check_integer("w rank", len(w_shape), 2, Rel.EQ, self.name)
validator.check_integer("b rank", len(b_shape), 1, Rel.EQ, self.name)
validator.check_integer("h_shape", len(h_shape), 3, Rel.EQ, self.name)
validator.check_integer("c_shape", len(c_shape), 3, Rel.EQ, self.name)
if seq_shape is not None:
raise ValueError(f"For {self.name}, seq_shape should be None.")
num_step, batch_size, input_size = x_shape
hidden_size = w_shape[-1] // 4
validator.check("b_shape[-1]", b_shape[-1], "w_shape[-1]", w_shape[-1], Rel.EQ, self.name)
if w_shape[-1] % 4 != 0:
raise ValueError(f"For {self.name}, w_shape[-1] should multiple of 4.")
validator.check("w_shape[0]", w_shape[0], "input_size + hidden_size",
input_size + hidden_size, Rel.EQ, self.name)
validator.check("b_shape[0]", b_shape[0], "w_shape[1]", w_shape[1], Rel.EQ, self.name)
validator.check_integer("h_shape[0]", h_shape[0], 1, Rel.EQ, self.name)
validator.check("h_shape[1]", h_shape[1], "batch_size", batch_size, Rel.EQ, self.name)
validator.check("h_shape[2]", h_shape[2], "hidden_size", hidden_size, Rel.EQ, self.name)
validator.check("c_shape", c_shape, "h_shape", h_shape, Rel.EQ, self.name)
y_shape = (num_step, batch_size, hidden_size)
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name)
validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name)
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype