forked from mindspore-Ecosystem/mindspore
!10385 Modified data type requirement of bias of DynamicGRUV2.
From: @liu_xiao_93 Reviewed-by: @liangchenghui,@wuxuejian Signed-off-by: @liangchenghui
This commit is contained in:
commit
173265f11f
|
@ -6519,9 +6519,9 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`.
|
||||
The data type must be float16.
|
||||
- **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
|
||||
The data type must be float16 or float32.
|
||||
Has the same data type with input `init_h`.
|
||||
- **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
|
||||
The data type must be float16 or float32.
|
||||
Has the same data type with input `init_h`.
|
||||
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
|
||||
Only `None` is currently supported.
|
||||
- **init_h** (Tensor) - Hidden state of initial time.
|
||||
|
@ -6563,15 +6563,6 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
>>> print(output[0].shape)
|
||||
(2, 8, 16)
|
||||
"""
|
||||
__mindspore_signature__ = (
|
||||
sig.make_sig('x', dtype=sig.sig_dtype.T1),
|
||||
sig.make_sig('weight_input', dtype=sig.sig_dtype.T2),
|
||||
sig.make_sig('weight_hidden', dtype=sig.sig_dtype.T3),
|
||||
sig.make_sig('bias_input', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('bias_hidden', dtype=sig.sig_dtype.T),
|
||||
sig.make_sig('seq_length', dtype=sig.sig_dtype.T4),
|
||||
sig.make_sig('init_h', dtype=sig.sig_dtype.T),
|
||||
)
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self,
|
||||
|
@ -6639,15 +6630,16 @@ class DynamicGRUV2(PrimitiveWithInfer):
|
|||
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
|
||||
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name)
|
||||
valid_dtypes = [mstype.float16, mstype.float32]
|
||||
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, valid_dtypes, self.name)
|
||||
b_dtype = h_dtype
|
||||
if binput_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
args = {'init_h': h_dtype, 'bias_input': binput_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
b_dtype = binput_dtype
|
||||
elif bhidden_dtype is not None:
|
||||
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
|
||||
(mstype.float16, mstype.float32), self.name)
|
||||
args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype}
|
||||
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
|
||||
b_dtype = bhidden_dtype
|
||||
|
||||
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype
|
||||
|
|
Loading…
Reference in New Issue