forked from mindspore-Ecosystem/mindspore
fix some bug for DynamicGRUV2.
This commit is contained in:
parent
3836665f29
commit
4817978d28
|
@ -6491,7 +6491,7 @@ class DynamicRNN(PrimitiveWithInfer):
|
|||
|
||||
class DynamicGRUV2(PrimitiveWithInfer):
|
||||
r"""
|
||||
DynamicGRUV2 Operator.
|
||||
Applies a single-layer gated recurrent unit (GRU) to an input sequence.
|
||||
|
||||
Args:
|
||||
direction (str): A string identifying the direction in the op. Default: 'UNIDIRECTIONAL'.
|
||||
|
@ -6532,19 +6532,19 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
- **y** (Tensor) - A Tensor of shape :math:
|
||||
if num_proj > 0 `(num_step, batch_size, min(hidden_size, num_proj)`,
|
||||
if num_proj == 0 `(num_step, batch_size, hidden_size)`.
|
||||
Has the same data type with input `bais_type`.
|
||||
Has the same data type with input `bias_type`.
|
||||
- **output_h** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
Has the same data type with input `bias_type`.
|
||||
- **update** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
Has the same data type with input `bias_type`.
|
||||
- **reset** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
Has the same data type with input `bias_type`.
|
||||
- **new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
Has the same data type with input `bias_type`.
|
||||
- **hidden_new** (Tensor) - A Tensor of shape :math:`(\text{num_step}, \text{batch_size}, \text{hidden_size})`.
|
||||
Has the same data type with input `bais_type`.
|
||||
Has the same data type with input `bias_type`.
|
||||
|
||||
- If `bias_input` and `bias_hidden` both are `None`, `bias_type` is float32.
|
||||
- If `bias_input` and `bias_hidden` both are `None`, `bias_type` is date type of `init_h`.
|
||||
- If `bias_input` is not `None`, `bias_type` is the date type of `bias_input`.
|
||||
- If `bias_input` is `None` and `bias_hidden` is not `None, `bias_type` is the date type of `bias_hidden`.
|
||||
|
||||
|
@ -6563,6 +6563,15 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
>>> print(output[0].shape)
|
||||
(2, 8, 16)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('x', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('weight_input', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('weight_hidden', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('bias_input', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('bias_hidden', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('seq_length', dtype=sig.sig_dtype.T4),
|
||||
sig.make_sig('init_h', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
|
@ -6631,7 +6640,7 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name)
|
||||
b_dtype = mstype.float32
|
||||
b_dtype = h_dtype
|
||||
if binput_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
|
|
Loading…
Reference in New Issue