Fix data type bug of ApplyPowerSign inputs.

This commit is contained in:
liuxiao93 2020-07-23 17:24:34 +08:00
parent 684ff4f46b
commit 5a05e5601c
2 changed files with 9 additions and 6 deletions

View File

@ -2934,7 +2934,8 @@ class Tan(PrimitiveWithInfer):
Computes tangent of `input_x` element-wise.
Inputs:
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. Data type should be
float16, float32 or int32.
Outputs:
Tensor, has the same shape as `input_x`.
@ -2953,7 +2954,8 @@ class Tan(PrimitiveWithInfer):
return x_shape
def infer_dtype(self, x_type):
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
valid_types = [mstype.float16, mstype.float32, mstype.int32]
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
return x_type

View File

@ -4281,6 +4281,7 @@ class ApplyPowerSign(PrimitiveWithInfer):
Inputs:
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
If data type of `var` is float16, all inputs must have the same data type as `var`.
- **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
- **lr** (Union[Number, Tensor]) - The learning rate value, should be a scalar.
With float32 or float16 data type.
@ -4323,11 +4324,11 @@ class ApplyPowerSign(PrimitiveWithInfer):
__mindspore_signature__ = (
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
sig_dtype.T3),
('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4),
sig_dtype.T),
('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)