!36872 add type check for Adadelta input
Merge pull request !36872 from fujianzhao/fix_adadelta
This commit is contained in:
commit
f72d2f90b0
|
@ -47,7 +47,7 @@ mindspore.nn.Adadelta
|
|||
|
||||
**输入:**
|
||||
|
||||
- **grads** (tuple[Tensor]) - 优化器中 `params` 的梯度,形状(shape)与 `params` 相同。
|
||||
- **grads** (tuple[Tensor]) - 优化器中 `params` 的梯度,形状(shape)和数据类型与 `params` 相同。数据类型为float16或float32。
|
||||
|
||||
**输出:**
|
||||
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
from mindspore.ops import functional as F, composite as C, operations as P
|
||||
from mindspore._checkparam import Validator as validator
|
||||
from mindspore._checkparam import Rel
|
||||
from mindspore.common.tensor import Tensor
|
||||
from .optimizer import Optimizer
|
||||
from .optimizer import opt_init_args_register
|
||||
|
||||
|
@ -122,8 +123,8 @@ class Adadelta(Optimizer):
|
|||
the Cell with step as the input to get the weight decay value of current step.
|
||||
|
||||
Inputs:
|
||||
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params`
|
||||
in optimizer.
|
||||
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, has the same shape and data type as
|
||||
the `params` in optimizer. With float16 or float32 data type.
|
||||
|
||||
Outputs:
|
||||
Tensor[bool], the value is True.
|
||||
|
@ -183,6 +184,8 @@ class Adadelta(Optimizer):
|
|||
self.epsilon = epsilon
|
||||
|
||||
def construct(self, grads):
|
||||
if not isinstance(grads, tuple) or not isinstance(grads[0], Tensor):
|
||||
raise TypeError("For 'Adadelta', the 'grads' must be a tuple of Tensor.")
|
||||
params = self.parameters
|
||||
grads = self.decay_weight(grads)
|
||||
grads = self.gradients_centralization(grads)
|
||||
|
|
Loading…
Reference in New Issue