!10385 Modified data type requirement of bias of DynamicGRUV2.

From: @liu_xiao_93
Reviewed-by: @liangchenghui,@wuxuejian
Signed-off-by: @liangchenghui
This commit is contained in:
mindspore-ci-bot 2020-12-23 20:47:05 +08:00 committed by Gitee
commit 173265f11f
1 changed files with 8 additions and 16 deletions

View File

@ -6519,9 +6519,9 @@ class DynamicGRUV2(PrimitiveWithInfer):
Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`.
The data type must be float16.
- **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
The data type must be float16 or float32.
Has the same data type with input `init_h`.
- **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None.
The data type must be float16 or float32.
Has the same data type with input `init_h`.
- **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`.
Only `None` is currently supported.
- **init_h** (Tensor) - Hidden state of initial time.
@ -6563,15 +6563,6 @@ class DynamicGRUV2(PrimitiveWithInfer):
>>> print(output[0].shape)
(2, 8, 16)
"""
__mindspore_signature__ = (
sig.make_sig('x', dtype=sig.sig_dtype.T1),
sig.make_sig('weight_input', dtype=sig.sig_dtype.T2),
sig.make_sig('weight_hidden', dtype=sig.sig_dtype.T3),
sig.make_sig('bias_input', dtype=sig.sig_dtype.T),
sig.make_sig('bias_hidden', dtype=sig.sig_dtype.T),
sig.make_sig('seq_length', dtype=sig.sig_dtype.T4),
sig.make_sig('init_h', dtype=sig.sig_dtype.T),
)
@prim_attr_register
def __init__(self,
@ -6639,15 +6630,16 @@ class DynamicGRUV2(PrimitiveWithInfer):
validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name)
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name)
valid_dtypes = [mstype.float16, mstype.float32]
validator.check_tensor_dtype_valid("init_h dtype", h_dtype, valid_dtypes, self.name)
b_dtype = h_dtype
if binput_dtype is not None:
validator.check_tensor_dtype_valid("bias input dtype", binput_dtype,
(mstype.float16, mstype.float32), self.name)
args = {'init_h': h_dtype, 'bias_input': binput_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
b_dtype = binput_dtype
elif bhidden_dtype is not None:
validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype,
(mstype.float16, mstype.float32), self.name)
args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype}
validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name)
b_dtype = bhidden_dtype
return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype