forked from mindspore-Ecosystem/mindspore
Modified indices supported dtype of SparseApplyProximalAdagrad.
This commit is contained in:
parent
61c83bef88
commit
c122e4bda7
|
@ -5187,7 +5187,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
|
||||||
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
|
- **grad** (Tensor) - A tensor of the same type as `var`, for the gradient.
|
||||||
- **indices** (Tensor) - A tensor of indices in the first dimension of `var` and `accum`.
|
- **indices** (Tensor) - A tensor of indices in the first dimension of `var` and `accum`.
|
||||||
If there are duplicates in `indices`, the behavior is undefined. Must be one of the
|
If there are duplicates in `indices`, the behavior is undefined. Must be one of the
|
||||||
following types: int16, int32, int64, uint16, uint32, uint64.
|
following types: int32, int64.
|
||||||
|
|
||||||
Outputs:
|
Outputs:
|
||||||
Tuple of 2 tensors, the updated parameters.
|
Tuple of 2 tensors, the updated parameters.
|
||||||
|
@ -5253,8 +5253,7 @@ class SparseApplyProximalAdagrad(PrimitiveWithCheck):
|
||||||
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name)
|
validator.check_scalar_or_tensor_types_same({"lr": lr_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||||
validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float16, mstype.float32], self.name)
|
validator.check_scalar_or_tensor_types_same({"l1": l1_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||||
validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float16, mstype.float32], self.name)
|
validator.check_scalar_or_tensor_types_same({"l2": l2_dtype}, [mstype.float16, mstype.float32], self.name)
|
||||||
valid_dtypes = [mstype.int16, mstype.int32, mstype.int64,
|
valid_dtypes = [mstype.int32, mstype.int64]
|
||||||
mstype.uint16, mstype.uint32, mstype.uint64]
|
|
||||||
validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name)
|
validator.check_tensor_dtype_valid('indices', indices_dtype, valid_dtypes, self.name)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue