Fix data type bug of ApplyPowerSign inputs.
This commit is contained in:
parent
684ff4f46b
commit
5a05e5601c
|
@ -2934,7 +2934,8 @@ class Tan(PrimitiveWithInfer):
|
|||
Computes tangent of `input_x` element-wise.
|
||||
|
||||
Inputs:
|
||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`.
|
||||
- **input_x** (Tensor) - The shape of tensor is :math:`(x_1, x_2, ..., x_R)`. Data type should be
|
||||
float16, float32 or int32.
|
||||
|
||||
Outputs:
|
||||
Tensor, has the same shape as `input_x`.
|
||||
|
@ -2953,7 +2954,8 @@ class Tan(PrimitiveWithInfer):
|
|||
return x_shape
|
||||
|
||||
def infer_dtype(self, x_type):
|
||||
validator.check_tensor_type_same({'x': x_type}, mstype.number_type, self.name)
|
||||
valid_types = [mstype.float16, mstype.float32, mstype.int32]
|
||||
validator.check_tensor_type_same({'x': x_type}, valid_types, self.name)
|
||||
return x_type
|
||||
|
||||
|
||||
|
|
|
@ -4281,6 +4281,7 @@ class ApplyPowerSign(PrimitiveWithInfer):
|
|||
|
||||
Inputs:
|
||||
- **var** (Parameter) - Variable tensor to be updated. With float32 or float16 data type.
|
||||
If data type of `var` is float16, all inputs must have the same data type as `var`.
|
||||
- **m** (Parameter) - Variable tensor to be updated. Has the same dtype as `var`.
|
||||
- **lr** (Union[Number, Tensor]) - The learning rate value, should be a scalar.
|
||||
With float32 or float16 data type.
|
||||
|
@ -4323,11 +4324,11 @@ class ApplyPowerSign(PrimitiveWithInfer):
|
|||
__mindspore_signature__ = (
|
||||
('var', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('m', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T1),
|
||||
('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T2),
|
||||
('lr', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('logbase', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('sign_decay', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
|
||||
sig_dtype.T3),
|
||||
('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T4),
|
||||
sig_dtype.T),
|
||||
('beta', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
|
||||
('grad', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue