diff --git a/mindspore/ops/operations/math_ops.py b/mindspore/ops/operations/math_ops.py index 940bf65576c..4fef5349cfc 100644 --- a/mindspore/ops/operations/math_ops.py +++ b/mindspore/ops/operations/math_ops.py @@ -2934,7 +2934,8 @@ class Tan(PrimitiveWithInfer): Computes tangent of `input_x` element-wise. Inputs: - - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. + - **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. Data type should be + float16, float32 or int32. Outputs: Tensor, has the same shape as `input_x`. @@ -2953,7 +2954,8 @@ class Tan(PrimitiveWithInfer): return x_shape def infer_dtype(self, x_type): - validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name) + valid_types = [mstype.float16, mstype.float32, mstype.int32] + validator.check_tensor_type_same({'x': x_type}, valid_types, self.name) return x_type diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index f0ca9f4e374..061c97cc7bd 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -4281,6 +4281,7 @@ class ApplyPowerSign(PrimitiveWithInfer): Inputs: - **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type. + If data type of `var` is float16, all inputs must have the same data type as `var`. - **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`. - **lr** (Union[Number, Tensor]) - The learning rate value, should be a scalar. With float32 or float16 data type. @@ -4323,11 +4324,11 @@ class ApplyPowerSign(PrimitiveWithInfer): __mindspore_signature__ = ( ('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), - ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1), - ('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2), + ('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), + ('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, - sig_dtype.T3), - ('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4), + sig_dtype.T), + ('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T), ('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T) )