forked from mindspore-Ecosystem/mindspore
Move API inner.DynamicGRUV2 to P.DynamicGRUV2.
This commit is contained in:
parent
97a10bfa7c
commit
712ad98a92
|
@ -892,7 +892,7 @@ def get_bprop_dynamic_rnn(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(inner.DynamicGRUV2)
|
||||
@bprop_getters.register(P.DynamicGRUV2)
|
||||
def get_bprop_dynamic_gru_v2(self):
|
||||
"""Grad definition for `DynamicGRUV2` operation."""
|
||||
dynamic_gru_v2_grad = G.DynamicGRUV2Grad(self.direction, self.cell_depth, self.keep_prob, self.cell_clip,
|
||||
|
|
|
@ -72,7 +72,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
|
|||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||
ResizeBilinear, Sigmoid,
|
||||
SigmoidCrossEntropyWithLogits,
|
||||
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN,
|
||||
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||
TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||
|
@ -119,6 +119,7 @@ __all__ = [
|
|||
'Rsqrt',
|
||||
'Sqrt',
|
||||
'Square',
|
||||
'DynamicGRUV2',
|
||||
'SquaredDifference',
|
||||
'Xdivy',
|
||||
'Xlogy',
|
||||
|
|
|
@ -529,159 +529,6 @@ class MatrixSetDiag(PrimitiveWithInfer):
|
|||
return assist_shape
|
||||
|
||||
|
||||
class DynamicGRUV2(PrimitiveWithInfer):
|
||||
r"""
|
||||
DynamicGRUV2 Operator.
|
||||
|
||||
Args:
|
||||
direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
||||
Only 'UNIDIRECTIONAL' is currently supported.
|
||||
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
|
||||
keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
|
||||
cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
|
||||
num_proj (int): An integer identifying the num proj in the op. Default: 0.
|
||||
time_major (bool): A bool identifying the time major in the op. Default: True.
|
||||
activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'.
|
||||
Only 'tanh' is currently supported.
|
||||
gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh.
|
||||
'zrh' is another option.
|
||||
reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True.
|
||||
is_training (bool): A bool identifying is training in the op. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Current words.
|
||||
Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`.
|
||||
The data type must be float16.
|
||||
- **weight_input** (Tensor) - Input-hidden weight.
|
||||
Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`.
|
||||
The data type must be float16.
|
||||
- **weight_hidden** (Tensor) - Hidden-hidden weight.
|
||||
Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`.
|
||||
The data type must be float16.
|
||||
- **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
|
||||
The data type must be float16 or float32.
|
||||
- **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
|
||||
The data type must be float16 or float32.
|
||||
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
|
||||
Only `None` is currently supported.
|
||||
- **init_h** (Tensor) - Hidden state of initial time.
|
||||
Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`.
|
||||
The data type must be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - A Tensor of shape :math:
|
||||
if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
|
||||
if num_proj == 0 `(num_step, batch_size, hidden_size)`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
|
||||
- If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32.
|
||||
- If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`.
|
||||
- If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`.
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
|
||||
>>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16))
|
||||
>>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16))
|
||||
>>> bias_i = Tensor(np.random.rand(48).astype(np.float16))
|
||||
>>> bias_h = Tensor(np.random.rand(48).astype(np.float16))
|
||||
>>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16))
|
||||
>>> dynamic_gru_v2 = ops.DynamicGRUV2()
|
||||
>>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h)
|
||||
>>> result = output[0].shape
|
||||
>>> print(result)
|
||||
(2, 8, 16)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
direction='UNIDIRECTIONAL',
|
||||
cell_depth=1,
|
||||
keep_prob=1.0,
|
||||
cell_clip=-1.0,
|
||||
num_proj=0,
|
||||
time_major=True,
|
||||
activation="tanh",
|
||||
gate_order="rzh",
|
||||
reset_after=True,
|
||||
is_training=True):
|
||||
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
|
||||
self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
|
||||
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
||||
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
|
||||
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
|
||||
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
||||
self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
|
||||
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
|
||||
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
|
||||
validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
|
||||
validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
|
||||
|
||||
num_step, batch_size, input_size = x_shape
|
||||
hidden_size = winput_shape[-1] // 3
|
||||
if winput_shape[-1] % 3 != 0:
|
||||
raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.")
|
||||
|
||||
self.placeholder_index = [3, 4, 5]
|
||||
if binput_shape is not None:
|
||||
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
|
||||
validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
||||
self.placeholder_index.remove(3)
|
||||
if bhidden_shape is not None:
|
||||
validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name)
|
||||
validator.check("bias_hidden_shape", bhidden_shape,
|
||||
"3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
||||
self.placeholder_index.remove(4)
|
||||
if seq_shape is not None:
|
||||
raise ValueError(f"For {self.name}, seq_shape should be None.")
|
||||
|
||||
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
|
||||
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
|
||||
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
|
||||
whidden_shape[-1], Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
|
||||
validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
if self.num_proj > 0:
|
||||
y_shape = (num_step, batch_size, min(hidden_size, self.num_proj))
|
||||
else:
|
||||
y_shape = (num_step, batch_size, hidden_size)
|
||||
out_shape = (num_step, batch_size, hidden_size)
|
||||
self.add_prim_attr("placeholder_index", self.placeholder_index)
|
||||
return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
|
||||
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = mstype.float32
|
||||
if binput_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = binput_dtype
|
||||
elif bhidden_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = bhidden_dtype
|
||||
|
||||
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
||||
|
||||
|
||||
class ConfusionMulGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
`output0` is the dot product result of input0 and input1.
|
||||
|
|
|
@ -6403,32 +6403,18 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
- **tanhct** (Tensor) - A Tensor of shape (`num_step`, `batch_size`, `hidden_size`).
|
||||
Has the same type with input `b`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> import mindspore
|
||||
>>> import mindspore.nn as nn
|
||||
>>> import numpy as np
|
||||
>>> from mindspore import Parameter
|
||||
>>> from mindspore import Tensor
|
||||
>>> from mindspore.ops import operations as ops
|
||||
>>> import mindspore.context as context
|
||||
>>> context.set_context(mode=context.GRAPH_MODE)
|
||||
>>> class DynamicRNNNet(nn.Cell):
|
||||
>>> def __init__(self):
|
||||
>>> super(DynamicRNNNet, self).__init__()
|
||||
>>> self.dynamic_rnn = ops.DynamicRNN()
|
||||
>>>
|
||||
>>> def construct(self, x, w, b, init_h, init_c):
|
||||
>>> out = self.dynamic_rnn(x, w, b, None, init_h, init_c)
|
||||
>>> return out
|
||||
>>>
|
||||
>>> x = Tensor(np.random.rand(2, 16, 64).astype(np.float16))
|
||||
>>> w = Tensor(np.random.rand(96, 128).astype(np.float16))
|
||||
>>> b = Tensor(np.random.rand(128).astype(np.float16))
|
||||
>>> init_h = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
||||
>>> init_c = Tensor(np.random.rand(1, 16, 32).astype(np.float16))
|
||||
>>> net = DynamicRNNNet()
|
||||
>>> output = net(x, w, b, init_h, init_c)
|
||||
>>> output[0].shape
|
||||
>>> dynamic_rnn = ops.DynamicRNNN()
|
||||
>>> output = dynamic_rnn(x, w, b, None, init_h, init_c)
|
||||
>>> print(output[0].shape)
|
||||
(2, 16, 32)
|
||||
"""
|
||||
|
||||
|
@ -6493,6 +6479,161 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
return b_dtype, x_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
||||
|
||||
|
||||
class DynamicGRUV2(PrimitiveWithInfer):
|
||||
r"""
|
||||
DynamicGRUV2 Operator.
|
||||
|
||||
Args:
|
||||
direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
||||
Only 'UNIDIRECTIONAL' is currently supported.
|
||||
cell_depth (int): An integer identifying the cell depth in the op. Default: 1.
|
||||
keep_prob (float): A float identifying the keep prob in the op. Default: 1.0.
|
||||
cell_clip (float): A float identifying the cell clip in the op. Default: -1.0.
|
||||
num_proj (int): An integer identifying the num proj in the op. Default: 0.
|
||||
time_major (bool): A bool identifying the time major in the op. Default: True.
|
||||
activation (str) : A string identifying the type of activation function in the op. Default: 'tanh'.
|
||||
Only 'tanh' is currently supported.
|
||||
gate_order (str): A string identifying the gate order in weight and bias. Default: 'rzh.
|
||||
'zrh' is another option.
|
||||
reset_after (bool): A bool identifying whether to apply reset gate after matrix multiplication. Default: True.
|
||||
is_training (bool): A bool identifying is training in the op. Default: True.
|
||||
|
||||
Inputs:
|
||||
- **x** (Tensor) - Current words.
|
||||
Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{input_size})`.
|
||||
The data type must be float16.
|
||||
- **weight_input** (Tensor) - Input-hidden weight.
|
||||
Tensor of shape :math:`(\text{input_size}, 3 \times \text{hidden_size})`.
|
||||
The data type must be float16.
|
||||
- **weight_hidden** (Tensor) - Hidden-hidden weight.
|
||||
Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`.
|
||||
The data type must be float16.
|
||||
- **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
|
||||
The data type must be float16 or float32.
|
||||
- **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
|
||||
The data type must be float16 or float32.
|
||||
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
|
||||
Only `None` is currently supported.
|
||||
- **init_h** (Tensor) - Hidden state of initial time.
|
||||
Tensor of shape :math:`(\text{batch_size}, \text{hidden_size})`.
|
||||
The data type must be float16 or float32.
|
||||
|
||||
Outputs:
|
||||
- **y** (Tensor) - A Tensor of shape :math:
|
||||
if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
|
||||
if num_proj == 0 `(num_step, batch_size, hidden_size)`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
- **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
|
||||
- If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32.
|
||||
- If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`.
|
||||
- If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> x = Tensor(np.random.rand(2, 8, 64).astype(np.float16))
|
||||
>>> weight_i = Tensor(np.random.rand(64, 48).astype(np.float16))
|
||||
>>> weight_h = Tensor(np.random.rand(16, 48).astype(np.float16))
|
||||
>>> bias_i = Tensor(np.random.rand(48).astype(np.float16))
|
||||
>>> bias_h = Tensor(np.random.rand(48).astype(np.float16))
|
||||
>>> init_h = Tensor(np.random.rand(8, 16).astype(np.float16))
|
||||
>>> dynamic_gru_v2 = ops.DynamicGRUV2()
|
||||
>>> output = dynamic_gru_v2(x, weight_i, weight_h, bias_i, bias_h, None, init_h)
|
||||
>>> print(output[0].shape)
|
||||
(2, 8, 16)
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
direction='UNIDIRECTIONAL',
|
||||
cell_depth=1,
|
||||
keep_prob=1.0,
|
||||
cell_clip=-1.0,
|
||||
num_proj=0,
|
||||
time_major=True,
|
||||
activation="tanh",
|
||||
gate_order="rzh",
|
||||
reset_after=True,
|
||||
is_training=True):
|
||||
self.cell_depth = validator.check_value_type("cell_depth", cell_depth, [int], self.name)
|
||||
self.keep_prob = validator.check_value_type("keep_prob", keep_prob, [float], self.name)
|
||||
self.cell_clip = validator.check_value_type("cell_clip", cell_clip, [float], self.name)
|
||||
self.num_proj = validator.check_non_negative_int(num_proj, "num_proj", self.name)
|
||||
self.time_major = validator.check_value_type("time_major", time_major, [bool], self.name)
|
||||
self.is_training = validator.check_value_type("is_training", is_training, [bool], self.name)
|
||||
self.direction = validator.check_string(direction, ['UNIDIRECTIONAL'], "direction", self.name)
|
||||
self.activation = validator.check_string(activation, ['tanh'], "activation", self.name)
|
||||
self.gate_order = validator.check_string(gate_order, ['zrh', 'rzh'], "gate_order", self.name)
|
||||
self.reset_after = validator.check_value_type("reset_after", reset_after, [bool], self.name)
|
||||
self.add_prim_attr("io_format", "ND")
|
||||
|
||||
def infer_shape(self, x_shape, winput_shape, whidden_shape, binput_shape, bhidden_shape, seq_shape, h_shape):
|
||||
validator.check_int(len(x_shape), 3, Rel.EQ, "x shape", self.name)
|
||||
validator.check_int(len(winput_shape), 2, Rel.EQ, "weight input shape rank", self.name)
|
||||
validator.check_int(len(whidden_shape), 2, Rel.EQ, "weight hidden shape rank", self.name)
|
||||
|
||||
num_step, batch_size, input_size = x_shape
|
||||
hidden_size = winput_shape[-1] // 3
|
||||
if winput_shape[-1] % 3 != 0:
|
||||
raise ValueError(f"For {self.name}, weight_input_shape[-1] should multiple of 3.")
|
||||
|
||||
self.placeholder_index = [3, 4, 5]
|
||||
if binput_shape is not None:
|
||||
validator.check_int(len(binput_shape), 1, Rel.EQ, "bias input shape rank", self.name)
|
||||
validator.check("bias_input_shape", binput_shape, "3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
||||
self.placeholder_index.remove(3)
|
||||
if bhidden_shape is not None:
|
||||
validator.check_int(len(bhidden_shape), 1, Rel.EQ, "bias hidden shape rank", self.name)
|
||||
validator.check("bias_hidden_shape", bhidden_shape,
|
||||
"3 * hidden_shape", [3 * hidden_size], Rel.EQ, self.name)
|
||||
self.placeholder_index.remove(4)
|
||||
if seq_shape is not None:
|
||||
raise ValueError(f"For {self.name}, seq_shape should be None.")
|
||||
|
||||
validator.check_int(len(h_shape), 2, Rel.EQ, "init_h shape rank", self.name)
|
||||
validator.check("init_h_shape[0]", h_shape[0], "batch_size", batch_size, Rel.EQ, self.name)
|
||||
validator.check("init_h_shape[1]", h_shape[1], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape[-1]", winput_shape[-1], "weight_hidden_shape[-1]",
|
||||
whidden_shape[-1], Rel.EQ, self.name)
|
||||
validator.check("weight_input_shape[0]", winput_shape[0], "input_size", input_size, Rel.EQ, self.name)
|
||||
validator.check("weight_hidden_shape[0]", whidden_shape[0], "hidden_size", hidden_size, Rel.EQ, self.name)
|
||||
if self.num_proj > 0:
|
||||
y_shape = (num_step, batch_size, min(hidden_size, self.num_proj))
|
||||
else:
|
||||
y_shape = (num_step, batch_size, hidden_size)
|
||||
out_shape = (num_step, batch_size, hidden_size)
|
||||
self.add_prim_attr("placeholder_index", self.placeholder_index)
|
||||
return y_shape, out_shape, out_shape, out_shape, out_shape, out_shape
|
||||
|
||||
def infer_dtype(self, x_dtype, winput_dtype, whidden_dtype, binput_dtype, bhidden_dtype, seq_dtype, h_dtype):
|
||||
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = mstype.float32
|
||||
if binput_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = binput_dtype
|
||||
elif bhidden_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = bhidden_dtype
|
||||
|
||||
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
||||
|
||||
|
||||
class InTopK(PrimitiveWithInfer):
|
||||
r"""
|
||||
Determines whether the targets are in the top `k` predictions.
|
||||
|
|
|
@ -822,7 +822,7 @@ class DynamicGRUV2Net(nn.Cell):
|
|||
|
||||
def __init__(self):
|
||||
super(DynamicGRUV2Net, self).__init__()
|
||||
self.dynamic_gru = inner.DynamicGRUV2()
|
||||
self.dynamic_gru = P.DynamicGRUV2()
|
||||
|
||||
def construct(self, x, w_i, w_h, b_i, b_h, init_h):
|
||||
return self.dynamic_gru(x, w_i, w_h, b_i, b_h, None, init_h)
|
||||
|
|
Loading…
Reference in New Issue