diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 57b8f84be12..f9db8917f94 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -6519,9 +6519,9 @@ class DynamicGRUV2(PrimitiveWithInfer): Tensor of shape :math:`(\text{hidden_size}, 3 \times \text{hidden_size})`. The data type must be float16. - **bias_input** (Tensor) - Input-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. - The data type must be float16 or float32. + Has the same data type with input `init_h`. - **bias_hidden** (Tensor) - Hidden-hidden bias. Tensor of shape :math:`(3 \times \text{hidden_size})`, or None. - The data type must be float16 or float32. + Has the same data type with input `init_h`. - **seq_length** (Tensor) - The length of each batch. Tensor of shape :math:`(\text{batch_size})`. Only `None` is currently supported. - **init_h** (Tensor) - Hidden state of initial time. @@ -6563,15 +6563,6 @@ class DynamicGRUV2(PrimitiveWithInfer): >>> print(output[0].shape) (2, 8, 16) """ - __mindspore_signature__ = ( - sig.make_sig('x', dtype=sig.sig_dtype.T1), - sig.make_sig('weight_input', dtype=sig.sig_dtype.T2), - sig.make_sig('weight_hidden', dtype=sig.sig_dtype.T3), - sig.make_sig('bias_input', dtype=sig.sig_dtype.T), - sig.make_sig('bias_hidden', dtype=sig.sig_dtype.T), - sig.make_sig('seq_length', dtype=sig.sig_dtype.T4), - sig.make_sig('init_h', dtype=sig.sig_dtype.T), - ) @prim_attr_register def __init__(self, @@ -6639,15 +6630,16 @@ class DynamicGRUV2(PrimitiveWithInfer): validator.check_tensor_dtype_valid("x dtype", x_dtype, [mstype.float16], self.name) validator.check_tensor_dtype_valid("weight input dtype", winput_dtype, [mstype.float16], self.name) validator.check_tensor_dtype_valid("weight hidden dtype", whidden_dtype, [mstype.float16], self.name) - validator.check_tensor_dtype_valid("init_h dtype", h_dtype, (mstype.float16, mstype.float32), self.name) + valid_dtypes = [mstype.float16, mstype.float32] + validator.check_tensor_dtype_valid("init_h dtype", h_dtype, valid_dtypes, self.name) b_dtype = h_dtype if binput_dtype is not None: - validator.check_tensor_dtype_valid("bias input dtype", binput_dtype, - (mstype.float16, mstype.float32), self.name) + args = {'init_h': h_dtype, 'bias_input': binput_dtype} + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) b_dtype = binput_dtype elif bhidden_dtype is not None: - validator.check_tensor_dtype_valid("bias hidden dtype", bhidden_dtype, - (mstype.float16, mstype.float32), self.name) + args = {'init_h': h_dtype, 'bias_hidden': bhidden_dtype} + validator.check_tensors_dtypes_same_and_valid(args, valid_dtypes, self.name) b_dtype = bhidden_dtype return b_dtype, b_dtype, b_dtype, b_dtype, b_dtype, b_dtype