!8134 Add TBE op DynamicGRU for new and old backend.
Merge pull request !8134 from liuxiao93/DynamicGRU-new-old-backend
This commit is contained in:
commit
bbfef6233a
|
@ -187,6 +187,8 @@ constexpr const char kNameBasicLSTMCellWeightGrad[] = "BasicLSTMCellWeightGrad";
|
|||
constexpr const char kNameBasicLSTMCellCStateGrad[] = "BasicLSTMCellCStateGrad";
|
||||
constexpr const char kNameDynamicRNN[] = "DynamicRNN";
|
||||
constexpr const char kNameDynamicRNNGrad[] = "DynamicRNNGrad";
|
||||
constexpr const char kNameDynamicGRUV2[] = "DynamicGRUV2";
|
||||
constexpr const char kNameDynamicGRUV2Grad[] = "DynamicGRUV2Grad";
|
||||
constexpr const char kNameL2Loss[] = "L2Loss";
|
||||
constexpr const char kNameCTCLoss[] = "CTCLoss";
|
||||
constexpr const char kNameRange[] = "Range";
|
||||
|
|
|
@ -92,4 +92,42 @@ OUTPUT_MAP(DynamicRNNGrad) = {{0, OUTPUT_DESC(dw)},
|
|||
{3, OUTPUT_DESC(dh_prev)},
|
||||
{4, OUTPUT_DESC(dc_prev)}};
|
||||
REG_ADPT_DESC(DynamicRNNGrad, kNameDynamicRNNGrad, ADPT_DESC(DynamicRNNGrad))
|
||||
|
||||
// DynamicGRUV2
|
||||
INPUT_MAP(DynamicGRUV2) = {{1, INPUT_DESC(x)}, {2, INPUT_DESC(weight_input)}, {3, INPUT_DESC(weight_hidden)},
|
||||
{4, INPUT_DESC(bias_input)}, {5, INPUT_DESC(bias_hidden)}, {6, INPUT_DESC(seq_length)},
|
||||
{7, INPUT_DESC(init_h)}};
|
||||
ATTR_MAP(DynamicGRUV2) = {{"direction", ATTR_DESC(direction, AnyTraits<std::string>())},
|
||||
{"cell_depth", ATTR_DESC(cell_depth, AnyTraits<int64_t>())},
|
||||
{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())},
|
||||
{"cell_clip", ATTR_DESC(cell_clip, AnyTraits<float>())},
|
||||
{"num_proj", ATTR_DESC(num_proj, AnyTraits<int64_t>())},
|
||||
{"time_major", ATTR_DESC(time_major, AnyTraits<bool>())},
|
||||
{"activation", ATTR_DESC(direction, AnyTraits<std::string>())},
|
||||
{"gate_order", ATTR_DESC(gate_order, AnyTraits<std::string>())},
|
||||
{"reset_after", ATTR_DESC(reset_after, AnyTraits<bool>())},
|
||||
{"is_training", ATTR_DESC(is_training, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(DynamicGRUV2) = {{0, OUTPUT_DESC(y)}, {1, OUTPUT_DESC(output_h)}, {2, OUTPUT_DESC(update)},
|
||||
{3, OUTPUT_DESC(reset)}, {4, OUTPUT_DESC(new)}, {5, OUTPUT_DESC(hidden_new)}};
|
||||
REG_ADPT_DESC(DynamicGRUV2, kNameDynamicGRUV2, ADPT_DESC(DynamicGRUV2))
|
||||
|
||||
// DynamicGRUV2Grad
|
||||
INPUT_MAP(DynamicGRUV2Grad) = {
|
||||
{1, INPUT_DESC(x)}, {2, INPUT_DESC(weight_input)}, {3, INPUT_DESC(weight_hidden)},
|
||||
{4, INPUT_DESC(y)}, {5, INPUT_DESC(init_h)}, {6, INPUT_DESC(h)},
|
||||
{7, INPUT_DESC(dy)}, {8, INPUT_DESC(dh)}, {9, INPUT_DESC(update)},
|
||||
{10, INPUT_DESC(reset)}, {11, INPUT_DESC(new)}, {12, INPUT_DESC(hidden_new)},
|
||||
{13, INPUT_DESC(seq_length)}, {14, INPUT_DESC(mask)}};
|
||||
ATTR_MAP(DynamicGRUV2Grad) = {{"direction", ATTR_DESC(direction, AnyTraits<std::string>())},
|
||||
{"cell_depth", ATTR_DESC(cell_depth, AnyTraits<int64_t>())},
|
||||
{"keep_prob", ATTR_DESC(keep_prob, AnyTraits<float>())},
|
||||
{"cell_clip", ATTR_DESC(cell_clip, AnyTraits<float>())},
|
||||
{"num_proj", ATTR_DESC(num_proj, AnyTraits<int64_t>())},
|
||||
{"time_major", ATTR_DESC(time_major, AnyTraits<bool>())},
|
||||
{"bias_type", ATTR_DESC(bias_type, AnyTraits<std::string>())},
|
||||
{"gate_order", ATTR_DESC(gate_order, AnyTraits<std::string>())},
|
||||
{"reset_after", ATTR_DESC(reset_after, AnyTraits<bool>())}};
|
||||
OUTPUT_MAP(DynamicGRUV2Grad) = {{0, OUTPUT_DESC(dw_input)}, {1, OUTPUT_DESC(dw_hidden)}, {2, OUTPUT_DESC(db_input)},
|
||||
{3, OUTPUT_DESC(db_hidden)}, {4, OUTPUT_DESC(dx)}, {5, OUTPUT_DESC(dh_prev)}};
|
||||
REG_ADPT_DESC(DynamicGRUV2Grad, kNameDynamicGRUV2Grad, ADPT_DESC(DynamicGRUV2Grad))
|
||||
} // namespace mindspore::transform
|
||||
|
|
|
@ -40,5 +40,11 @@ DECLARE_OP_USE_OUTPUT(DynamicRNN)
|
|||
|
||||
DECLARE_OP_ADAPTER(DynamicRNNGrad)
|
||||
DECLARE_OP_USE_OUTPUT(DynamicRNNGrad)
|
||||
|
||||
DECLARE_OP_ADAPTER(DynamicGRUV2)
|
||||
DECLARE_OP_USE_OUTPUT(DynamicGRUV2)
|
||||
|
||||
DECLARE_OP_ADAPTER(DynamicGRUV2Grad)
|
||||
DECLARE_OP_USE_OUTPUT(DynamicGRUV2Grad)
|
||||
} // namespace mindspore::transform
|
||||
#endif // MINDSPORE_CCSRC_TRANSFORM_GRAPH_IR_OP_DECLARE_RNN_DECLARE_H_
|
||||
|
|
|
@ -225,6 +225,8 @@ constexpr auto kBasicLSTMCellInputGradOpName = "BasicLSTMCellInputGrad";
|
|||
constexpr auto kBasicLSTMCellOpName = "BasicLSTMCell";
|
||||
constexpr auto kDynamicRNNOpName = "DynamicRNN";
|
||||
constexpr auto kLSTMInputGradOpName = "LSTMInputGrad";
|
||||
constexpr auto kDynamicGRUOpName = "DynamicGRU";
|
||||
constexpr auto kGRUV2HiddenGrad = "GRUV2HiddenGrad";
|
||||
constexpr auto kFusedSparseFtrlName = "FusedSparseFtrl";
|
||||
constexpr auto kFusedSparseProximalAdagradName = "FusedSparseProximalAdagrad";
|
||||
constexpr auto kFusedSparseLazyAdamName = "FusedSparseLazyAdam";
|
||||
|
|
|
@ -110,6 +110,8 @@ inline const PrimitivePtr kPrimUniqueGrad = std::make_shared<Primitive>("UniqueG
|
|||
inline const PrimitivePtr kPrimExtractImagePatches = std::make_shared<Primitive>("ExtractImagePatches");
|
||||
inline const PrimitivePtr kPrimDynamicRNN = std::make_shared<Primitive>("DynamicRNN");
|
||||
inline const PrimitivePtr kPrimDynamicRNNGrad = std::make_shared<Primitive>("DynamicRNNGrad");
|
||||
inline const PrimitivePtr kPrimDynamicGRUV2 = std::make_shared<Primitive>("DynamicGRUV2");
|
||||
inline const PrimitivePtr kPrimDynamicGRUV2Grad = std::make_shared<Primitive>("DynamicGRUV2Grad");
|
||||
inline const PrimitivePtr kPrimScatterAdd = std::make_shared<Primitive>("ScatterAdd");
|
||||
inline const PrimitivePtr kPrimScatterUpdate = std::make_shared<Primitive>("ScatterUpdate");
|
||||
inline const PrimitivePtr kPrimDiv = std::make_shared<Primitive>("Div");
|
||||
|
|
|
@ -849,7 +849,16 @@ def get_bprop_lstm(self):
|
|||
@bprop_getters.register(P.DynamicRNN)
|
||||
def get_bprop_dynamic_rnn(self):
|
||||
"""Grad definition for `DynamicRNN` operation."""
|
||||
dynamic_rnn_grad = G.DynamicRNNGrad(forget_bias=self.forget_bias)
|
||||
dynamic_rnn_grad = G.DynamicRNNGrad(cell_type=self.cell_type,
|
||||
direction=self.direction,
|
||||
cell_depth=self.cell_depth,
|
||||
use_peephole=self.use_peephole,
|
||||
keep_prob=self.keep_prob,
|
||||
cell_clip=self.cell_clip,
|
||||
num_proj=self.num_proj,
|
||||
time_major=self.time_major,
|
||||
forget_bias=self.forget_bias)
|
||||
expand_dims = P.ExpandDims()
|
||||
|
||||
def bprop(x, w, b, seq_length, init_h, init_c, out, dout):
|
||||
dy, dh, dc, _, _, _, _, _, = dout
|
||||
|
@ -858,10 +867,30 @@ def get_bprop_dynamic_rnn(self):
|
|||
y, h, c, i, j, f, o, tanhct = out
|
||||
dw, db, dx, dh_prev, dc_prev = dynamic_rnn_grad(x, w, b, y, init_h[0], init_c[0], h,
|
||||
c, dy, dh, dc, i, j, f, o, tanhct)
|
||||
dh_prev = expand_dims(dh_prev, 0)
|
||||
dc_prev = expand_dims(dc_prev, 0)
|
||||
return dx, dw, db, (0), dh_prev, dc_prev
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.DynamicGRUV2)
|
||||
def get_bprop_dynamic_gru_v2(self):
|
||||
"""Grad definition for `DynamicGRUV2` operation."""
|
||||
dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip,
|
||||
self.num_proj, self.time_major, 'double_bias', self.gate_order,
|
||||
self.reset_after)
|
||||
|
||||
def bprop(x, winput, whidden, binput, bhidden, seq, init_h, out, dout):
|
||||
y, out_h, update, reset, new, hidden_new = out
|
||||
dy, dout_h, _, _, _, _ = dout
|
||||
|
||||
dw_input, dw_hidden, db_input, db_hidden, dx, dh_prev = dynamic_gru_v2_grad(x, winput, whidden, y, init_h,
|
||||
out_h, dy, dout_h[-1], update,
|
||||
reset, new, hidden_new, None, None)
|
||||
return dx, dw_input, dw_hidden, db_input, db_hidden, (0), dh_prev
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.SigmoidCrossEntropyWithLogits)
|
||||
def get_bprop_sigmoid_crossentropy_with_logits(self):
|
||||
"""Grad definition for `SigmoidCrossEntropyWithLogits` operation."""
|
||||
|
|
|
@ -286,6 +286,8 @@ from .basic_lstm_cell_c_state_grad import _basic_lstm_cell_c_state_grad_tbe
|
|||
from .basic_lstm_cell_weight_grad import _basic_lstm_cell_weight_grad_tbe
|
||||
from .basic_lstm_cell_input_grad import _basic_lstm_cell_input_grad_tbe
|
||||
from .dynamic_rnn import _dynamic_rnn_tbe
|
||||
from .dynamic_gru_v2 import _dynamic_gru_v2_tbe
|
||||
from .gru_v2_hidden_grad import _gru_v2_hidden_grad_tbe
|
||||
from .lstm_input_grad import _lstm_input_grad_tbe
|
||||
from .confusion_matrix import _confusion_matrix_tbe
|
||||
from .broadcast_to import _broadcast_to_tbe
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""DynamicGRUV2 op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
dynamic_gru_v2_op_info = TBERegOp("DynamicGRUV2") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("dynamic_gru_v2.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("dynamic_gru_v2") \
|
||||
.attr("direction", "optional", "str", "all", "UNIDIRECTIONAL") \
|
||||
.attr("cell_depth", "optional", "int", "all", "1") \
|
||||
.attr("keep_prob", "optional", "float", "all", "1") \
|
||||
.attr("cell_clip", "optional", "float", "all", "-1") \
|
||||
.attr("num_proj", "optional", "int", "all", "0") \
|
||||
.attr("time_major", "optional", "bool", "all", "true") \
|
||||
.attr("activation", "optional", "str", "all", "tanh") \
|
||||
.attr("gate_order", "optional", "str", "all", "rzh") \
|
||||
.attr("reset_after", "optional", "bool", "all", "true") \
|
||||
.attr("is_training", "optional", "bool", "all", "true") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "x", False, "required", "all") \
|
||||
.input(1, "weight_input", False, "required", "all") \
|
||||
.input(2, "weight_hidden", False, "required", "all") \
|
||||
.input(3, "bias_input", False, "optional", "all") \
|
||||
.input(4, "bias_hidden", False, "optional", "all") \
|
||||
.input(5, "seq_length", False, "optional", "all") \
|
||||
.input(6, "init_h", False, "optional", "all") \
|
||||
.output(0, "y", False, "required", "all") \
|
||||
.output(1, "output_h", False, "required", "all") \
|
||||
.output(2, "update", False, "optional", "all") \
|
||||
.output(3, "reset", False, "optional", "all") \
|
||||
.output(4, "new", False, "optional", "all") \
|
||||
.output(5, "hidden_new", False, "optional", "all") \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F32_Default,
|
||||
DataType.F32_Default, DataType.I32_Default, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||
DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_Default,
|
||||
DataType.F16_Default, DataType.I32_Default, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(dynamic_gru_v2_op_info)
|
||||
def _dynamic_gru_v2_tbe():
|
||||
"""DynamicGRUV2 TBE register"""
|
||||
return
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""GRUV2HiddenGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
gru_v2_hidden_grad_op_info = TBERegOp("GRUV2HiddenGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("gru_v2_hidden_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("gru_v2_hidden_grad") \
|
||||
.attr("gate_order", "optional", "str", "all", "zrh") \
|
||||
.partial_flag(True) \
|
||||
.input(0, "weight_input", False, "required", "all") \
|
||||
.input(1, "init_h", False, "required", "all") \
|
||||
.input(2, "h", False, "required", "all") \
|
||||
.input(3, "dy", False, "optional", "all") \
|
||||
.input(4, "dh", False, "optional", "all") \
|
||||
.input(5, "update", False, "optional", "all") \
|
||||
.input(6, "reset", False, "optional", "all") \
|
||||
.input(7, "new", False, "optional", "all") \
|
||||
.input(8, "hidden_new", False, "optional", "all") \
|
||||
.output(0, "dh_preh", False, "required", "all") \
|
||||
.output(1, "dgate_h", False, "required", "all") \
|
||||
.output(2, "dnt_x", False, "optional", "all") \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ,
|
||||
DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ, DataType.F32_FracNZ) \
|
||||
.dtype_format(DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ,
|
||||
DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ, DataType.F16_FracNZ) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(gru_v2_hidden_grad_op_info)
|
||||
def _gru_v2_hidden_grad_tbe():
|
||||
"""DynamicGRUV2 TBE register"""
|
||||
return
|
|
@ -1095,9 +1095,9 @@ class DynamicRNNGrad(PrimitiveWithInfer):
|
|||
def __init__(self,
|
||||
cell_type='LSTM',
|
||||
direction='UNIDIRECTIONAL',
|
||||
cell_depth=0,
|
||||
cell_depth=1,
|
||||
use_peephole=False,
|
||||
keep_prob=-1.0,
|
||||
keep_prob=1.0,
|
||||
cell_clip=-1.0,
|
||||
num_proj=0,
|
||||
time_major=True,
|
||||
|
@ -1135,6 +1135,147 @@ class DynamicRNNGrad(PrimitiveWithInfer):
|
|||
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
|
||||
|
||||
|
||||
class DynamicGRUV2Grad(PrimitiveWithInfer):
|
||||
r"""
|
||||
Computes the input gradients of DynamicGRUV2.
|
||||
|
||||
Args:
|
||||
direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
||||
Only 'UNIDIRECTIONAL' is currently supported.
|
||||
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
|
||||
keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
|
||||
cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
|
||||
num_proj (int): An integer identifying the num proj in the op. Default: 0.
|
||||
time_major (bool): A bool identifying the time major in the op. Default: True.
|
||||
bias_type (str): An string identifying the type of bias_type function in the op. Default to "double_bias".
|
||||
gate_order (str): An string identifying the gate order in weight and bias. Default: 'rzh.
|
||||
'zrh' is another option.
|
||||
reset_after (bool): An bool identifying whether to apply reset gate after matrix multiplication. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Current words. Tensor of shape :math:`({num_step, batch_size, input_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **weight_input** (Tensor) - Weight. Tensor of shape :math:`(input_size, 3 x hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **weight_hidden** (Tensor) - Bias. Tensor of shape :math:`(hidden_size, 3 x hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **y** (Tensor) - A Tensor of shape :math:
|
||||
if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
|
||||
if num_proj == 0 `(num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **init_h** (Tensor) - Hidden state of initial time.
|
||||
Tensor of shape :math:`(batch_size, hidden_size)`, or None.
|
||||
The data type must be float16 or float32.
|
||||
- **h** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **dy** (Tensor) - Gradient of `y`, has the same shape and data type as `y`.
|
||||
- **dh** (Tensor) - Gradient of `h`, has the same shape and data type as `h`.
|
||||
- **update** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **reset** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **new** (Tensor) - A Tensor of shape :math:`({num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **hidden_new** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
The data type must be float16 or float32.
|
||||
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(batch_size)`.
|
||||
Only `None` is currently supported.
|
||||
- **mask** (Tensor) - A 4-D Tensor. The data type must be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
- **dw_input** (Tensor) - A Tensor has the same shape as `weight_input`.
|
||||
Has the same type with input `x`.
|
||||
- **dw_hidden** (Tensor) - A Tensor has the same shape as `weight_hidden`.
|
||||
Has the same type with input `x`.
|
||||
- **db_input** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`.
|
||||
Has the same type with input `x`.
|
||||
- **db_hidden** (Tensor) - A Tensor of shape :math:`(3 x hidden_size)`.
|
||||
Has the same type with input `x`.
|
||||
- **dx** (Tensor) - A Tensor of shape :math:`(num_step, batch_size, hidden_size)`.
|
||||
Has the same type with input `x`.
|
||||
- **dh_prev** (Tensor) - A Tensor of shape :math:`(batch_size, hidden_size)`.
|
||||
Has the same type with input `x`.
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
direction='UNIDIRECTIONAL',
|
||||
cell_depth=1,
|
||||
keep_prob=1.0,
|
||||
cell_clip=-1.0,
|
||||
num_proj=0,
|
||||
time_major=True,
|
||||
bias_type="double_bias",
|
||||
gate_order="zrh",
|
||||
reset_after=True):
|
||||
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
|
||||
self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
|
||||
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
||||
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
|
||||
self.bias_type = validator.check_string(bias_type,
|
||||
['no_bias', 'single_bias', 'double_bias'], "bias_type", self.name)
|
||||
self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
|
||||
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, winput_shape, whidden_shape, y_shape, init_h_shape, h_shape,
|
||||
dy_shape, dh_shape, update_shape, reset_shape, new_shape, hnew_shape, seq_shape, mask_shape):
|
||||
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
|
||||
validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
|
||||
validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
|
||||
validator.check_int(len(y_shape), 3, Rel.EQ, "y shape rank", self.name)
|
||||
num_step, batch_size, input_size = x_shape
|
||||
hidden_size = whidden_shape[0]
|
||||
validator.check("weight_hidden_shape[-1]", whidden_shape[-1], "3 * hidden_size",
|
||||
3 * hidden_size, Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape", winput_shape, "excepted shape",
|
||||
[input_size, 3 * hidden_size], Rel.EQ, self.name)
|
||||
if self.num_proj > 0:
|
||||
valid_y_shape = [num_step, batch_size, min(hidden_size, self.num_proj)]
|
||||
else:
|
||||
valid_y_shape = [num_step, batch_size, hidden_size]
|
||||
validator.check("y_shape", y_shape, "excepted shape", valid_y_shape, Rel.EQ, self.name)
|
||||
|
||||
validator.check("init_h_shape", init_h_shape, "excepted shape",
|
||||
[batch_size, hidden_size], Rel.EQ, self.name)
|
||||
valid_shape = [num_step, batch_size, hidden_size]
|
||||
validator.check("h_shape", h_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("dy_shape", dy_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("dh_shape", dh_shape, "excepted shape",
|
||||
[batch_size, hidden_size], Rel.EQ, self.name)
|
||||
validator.check("update_shape", update_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("reset_shape", reset_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("new_shape", new_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
validator.check("hnew_shape", hnew_shape, "excepted shape", valid_shape, Rel.EQ, self.name)
|
||||
if seq_shape is not None:
|
||||
validator.check("seq_shape", seq_shape, "batch_size", batch_size, Rel.EQ, self.name)
|
||||
|
||||
dx_shape = (num_step, batch_size, input_size)
|
||||
dh_shape = (batch_size, hidden_size)
|
||||
dwinput_shape = (input_size, 3 * hidden_size)
|
||||
dwhidden_shape = (hidden_size, 3 * hidden_size)
|
||||
db_shape = (3 * hidden_size,)
|
||||
return dwinput_shape, dwhidden_shape, db_shape, db_shape, dx_shape, dh_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, y_dtype, init_h_dtype, h_dtype,
|
||||
dy_dtype, dh_dtype, update_dtype, reset_dtype, new_dtype, hnew_dtype, seq_dtype, mask_dtype):
|
||||
valid_types = (mstype.float16, mstype.float32)
|
||||
args = {"y_dtype": y_dtype, "init_h_dtype": init_h_dtype, "h_dtype": h_dtype,
|
||||
"dy_dtype": dy_dtype, "dh_dtype": dh_dtype, "update_dtype": update_dtype,
|
||||
"reset_dtype": reset_dtype, "new_dtype": new_dtype, "hnew_dtype": hnew_dtype}
|
||||
validator.check_tensor_type_same({"x_dtype": x_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"winput_dtype": winput_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same({"whidden_dtype": whidden_dtype}, valid_types, self.name)
|
||||
validator.check_tensor_type_same(args, valid_types, self.name)
|
||||
if seq_dtype is not None:
|
||||
validator.check_tensor_type_same({"seq_dtype": seq_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
if mask_dtype is not None:
|
||||
validator.check_tensor_type_same({"mask_dtype": mask_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
return x_dtype, x_dtype, x_dtype, x_dtype, x_dtype, x_dtype
|
||||
|
||||
|
||||
class PReLUGrad(PrimitiveWithInfer):
|
||||
r"""
|
||||
Gradients of PReLU operation.
|
||||
|
|
|
@ -451,6 +451,157 @@ class MatrixSetDiag(PrimitiveWithInfer):
|
|||
return assist_shape
|
||||
|
||||
|
||||
class DynamicGRUV2(PrimitiveWithInfer):
|
||||
r"""
|
||||
DynamicGRUV2 Operator.
|
||||
|
||||
Args:
|
||||
direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
||||
Only 'UNIDIRECTIONAL' is currently supported.
|
||||
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
|
||||
keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
|
||||
cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
|
||||
num_proj (int): An integer identifying the num proj in the op. Default: 0.
|
||||
time_major (bool): A bool identifying the time major in the op. Default: True.
|
||||
activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'.
|
||||
Only 'tanh' is currently supported.
|
||||
gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh.
|
||||
'zrh' is another option.
|
||||
reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True.
|
||||
is_training (bool): A bool identifying is training in the op. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Current words.
|
||||
Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`.
|
||||
The data type must be float16.
|
||||
- **weight_input** (Tensor) - Input-hidden weight.
|
||||
Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`.
|
||||
The data type must be float16.
|
||||
- **weight_hidden** (Tensor) - Hidden-hidden weight.
|
||||
Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`.
|
||||
The data type must be float16.
|
||||
- **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
|
||||
The data type must be float16 or float32.
|
||||
- **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
|
||||
The data type must be float16 or float32.
|
||||
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
|
||||
Only `None` is currently supported.
|
||||
- **init_h** (Tensor) - Hidden state of initial time.
|
||||
Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`, or None.
|
||||
The data type must be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - A Tensor of shape :math:
|
||||
if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
|
||||
if num_proj == 0 `(num_step, batch_size, hidden_size)`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
|
||||
- If `bias_input`, `bias_hidden` and `init_h` all are `None`, `bias_type` is float32.
|
||||
- If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`.
|
||||
- If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`.
|
||||
- Otherwise, `bias_type` is the date type of `init_h`.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
|
||||
>>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16))
|
||||
>>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16))
|
||||
>>> bias_i = Tensor(np.random.rand(48).astype(np.float16))
|
||||
>>> bias_h = Tensor(np.random.rand(48).astype(np.float16))
|
||||
>>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16))
|
||||
>>> dynamic_gru_v2 = P.DynamicGRUV2()
|
||||
>>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h)
|
||||
>>> output[0].shape
|
||||
(2, 8, 16)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
direction='UNIDIRECTIONAL',
|
||||
cell_depth=1,
|
||||
keep_prob=1.0,
|
||||
cell_clip=-1.0,
|
||||
num_proj=0,
|
||||
time_major=True,
|
||||
activation="tanh",
|
||||
gate_order="rzh",
|
||||
reset_after=True,
|
||||
is_training=True):
|
||||
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
|
||||
self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
|
||||
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
||||
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
|
||||
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
|
||||
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
||||
self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
|
||||
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
|
||||
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
|
||||
validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
|
||||
validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
|
||||
if binput_shape is not None:
|
||||
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
|
||||
if bhidden_shape is not None:
|
||||
validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name)
|
||||
if h_shape is not None:
|
||||
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
|
||||
if seq_shape is not None:
|
||||
raise ValueError(f"For {self.name}, seq_shape should be None.")
|
||||
|
||||
num_step, batch_size, input_size = x_shape
|
||||
hidden_size = winput_shape[-1] // 3
|
||||
|
||||
if winput_shape[-1] % 3 != 0:
|
||||
raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.")
|
||||
|
||||
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
|
||||
whidden_shape[-1], Rel.EQ, self.name)
|
||||
validator.check("bias_input_shape", binput_shape, "bias_hidden_shape", bhidden_shape, Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
|
||||
validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
if h_shape is not None:
|
||||
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
|
||||
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
if self.num_proj > 0:
|
||||
y_shape = (num_step, batch_size, min(hidden_size, self.num_proj))
|
||||
else:
|
||||
y_shape = (num_step, batch_size, hidden_size)
|
||||
outh_shape = (num_step, batch_size, hidden_size)
|
||||
return y_shape, outh_shape, outh_shape, outh_shape, outh_shape, outh_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
|
||||
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name)
|
||||
validator.check_tensor_type_same({"weight input dtype": winput_dtype}, (mstype.float16,), self.name)
|
||||
validator.check_tensor_type_same({"weight hidden dtype": whidden_dtype}, (mstype.float16,), self.name)
|
||||
b_dtype = mstype.float32
|
||||
if binput_dtype is not None:
|
||||
validator.check_tensor_type_same({"bias input dtype": binput_dtype},
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = binput_dtype
|
||||
elif bhidden_dtype is not None:
|
||||
validator.check_tensor_type_same({"bias hidden dtype": bhidden_dtype},
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = bhidden_dtype
|
||||
elif h_dtype is not None:
|
||||
validator.check_tensor_type_same({"init_h dtype": h_dtype},
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = h_dtype
|
||||
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
||||
|
||||
|
||||
class ConfusionMulGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
`output0` is the dot product result of input0 and input1.
|
||||
|
|
|
@ -5611,33 +5611,35 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
DynamicRNN Operator.
|
||||
|
||||
Args:
|
||||
cell_type (str): An string identifying the cell type in the op. Default: 'LSTM'.
|
||||
cell_type (str): A string identifying the cell type in the op. Default: 'LSTM'.
|
||||
Only 'LSTM' is currently supported.
|
||||
direction (str): An string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
||||
direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
||||
Only 'UNIDIRECTIONAL' is currently supported.
|
||||
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
|
||||
use_peephole (bool): An bool identifying if use peephole in the op. Default: False.
|
||||
keep_prob (float): An float identifying the keep prob in the op. Default: 1.0.
|
||||
cell_clip (float): An float identifying the cell clip in the op. Default: -1.0.
|
||||
use_peephole (bool): A bool identifying if use peephole in the op. Default: False.
|
||||
keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
|
||||
cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
|
||||
num_proj (int): An integer identifying the num proj in the op. Default: 0.
|
||||
time_major (bool): An bool identifying the time major in the op. Default: True.
|
||||
time_major (bool): A bool identifying the time major in the op. Default: True.
|
||||
Only `True` is currently supported.
|
||||
activation (str): An string identifying the type of activation function in the op. Default: 'tanh'.
|
||||
activation (str): A string identifying the type of activation function in the op. Default: 'tanh'.
|
||||
Only 'tanh' is currently supported.
|
||||
forget_bias (float): An float identifying the forget bias in the op. Default: 0.0.
|
||||
is_training (bool): An bool identifying is training in the op. Default: True.
|
||||
forget_bias (float): A float identifying the forget bias in the op. Default: 0.0.
|
||||
is_training (bool): A bool identifying is training in the op. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Current words. Tensor of shape (`num_step`, `batch_size`, `input_size`).
|
||||
The data type must be float16 or float32.
|
||||
The data type must be float16.
|
||||
- **w** (Tensor) - Weight. Tensor of shape (`input_size + hidden_size`, `4 x hidden_size`).
|
||||
The data type must be float16 or float32.
|
||||
The data type must be float16.
|
||||
- **b** (Tensor) - Bias. Tensor of shape (`4 x hidden_size`).
|
||||
The data type must be float16 or float32.
|
||||
- **seq_length** (Tensor) - The length of each batch. Tensor of shape (`batch_size`).
|
||||
Only `None` is currently supported.
|
||||
- **init_h** (Tensor) - Hidden state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`).
|
||||
The data type must be float16.
|
||||
- **init_c** (Tensor) - Cell state of initial time. Tensor of shape (1, `batch_size`, `hidden_size`).
|
||||
The data type must be float16.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
||||
|
@ -5664,7 +5666,9 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
||||
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
||||
>>> dynamic_rnn = P.DynamicRNN()
|
||||
>>> output = lstm(x, w, b, None, init_h, init_c)
|
||||
>>> output = dynamic_rnn(x, w, b, None, init_h, init_c)
|
||||
>>> output[0].shape
|
||||
(2, 16, 32)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
|
@ -5684,7 +5688,7 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
|
||||
self.num_proj = validator.check_value_type("num_proj", num_proj, [int], self.name)
|
||||
self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
|
||||
self.forget_bias = validator.check_value_type("forget_bias", forget_bias, [float], self.name)
|
||||
self.use_peephole = validator.check_value_type("use_peephole", use_peephole, [bool], self.name)
|
||||
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
||||
|
@ -5721,11 +5725,11 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
return y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape, y_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, w_dtype, b_dtype, seq_dtype, h_dtype, c_dtype):
|
||||
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
validator.check_tensor_type_same({"x dtype": x_dtype}, (mstype.float16,), self.name)
|
||||
validator.check_tensor_type_same({"w dtype": w_dtype}, (mstype.float16,), self.name)
|
||||
validator.check_tensor_type_same({"b dtype": b_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float32, mstype.float16), self.name)
|
||||
validator.check_tensor_type_same({"h dtype": h_dtype}, (mstype.float16,), self.name)
|
||||
validator.check_tensor_type_same({"c dtype": c_dtype}, (mstype.float16,), self.name)
|
||||
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
||||
|
||||
|
||||
|
|
|
@ -817,6 +817,17 @@ class BasicLSTMCellNet(nn.Cell):
|
|||
return self.lstm(x, h, c, w, b)
|
||||
|
||||
|
||||
class DynamicGRUV2Net(nn.Cell):
|
||||
""" DynamicGRUV2Net definition """
|
||||
|
||||
def __init__(self):
|
||||
super(DynamicGRUV2Net, self).__init__()
|
||||
self.dynamic_gru = inner.DynamicGRUV2()
|
||||
|
||||
def construct(self, x, w_i, w_h, b_i, b_h, init_h):
|
||||
return self.dynamic_gru(x, w_i, w_h, b_i, b_h, None, init_h)
|
||||
|
||||
|
||||
class EditDistance(nn.Cell):
|
||||
def __init__(self, hypothesis_shape, truth_shape, normalize=True):
|
||||
super(EditDistance, self).__init__()
|
||||
|
@ -2508,6 +2519,19 @@ test_case_other_ops = [
|
|||
Tensor(np.random.rand(1, 64).astype(np.float16)),
|
||||
Tensor(np.random.rand(1, 64).astype(np.float16)),
|
||||
Tensor(np.random.rand(1, 64).astype(np.float16))]}),
|
||||
('DynamicGRUV2Net', {
|
||||
'block': DynamicGRUV2Net(),
|
||||
'desc_inputs': [Tensor(np.random.rand(2, 8, 64).astype(np.float16)),
|
||||
Tensor(np.random.rand(64, 48).astype(np.float16)),
|
||||
Tensor(np.random.rand(16, 48).astype(np.float16)),
|
||||
Tensor(np.random.rand(48).astype(np.float16)),
|
||||
Tensor(np.random.rand(48).astype(np.float16)),
|
||||
Tensor(np.random.rand(8, 16).astype(np.float16))],
|
||||
'desc_bprop': [Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16)),
|
||||
Tensor(np.random.rand(2, 8, 16).astype(np.float16))]}),
|
||||
]
|
||||
|
||||
test_case_quant_ops = [
|
||||
|
|
Loading…
Reference in New Issue