!36872 add type check for Adadelta input

Merge pull request !36872 from fujianzhao/fix_adadelta
This commit is contained in:
i-robot 2022-07-09 08:09:07 +00:00 committed by Gitee
commit f72d2f90b0
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
2 changed files with 6 additions and 3 deletions

View File

@ -47,7 +47,7 @@ mindspore.nn.Adadelta
**输入:**
- **grads** (tuple[Tensor]) - 优化器中 `params` 的梯度形状shape`params` 相同。
- **grads** (tuple[Tensor]) - 优化器中 `params` 的梯度形状shape和数据类型`params` 相同。数据类型为float16或float32。
**输出:**

View File

@ -16,6 +16,7 @@
from mindspore.ops import functional as F, composite as C, operations as P
from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from mindspore.common.tensor import Tensor
from .optimizer import Optimizer
from .optimizer import opt_init_args_register
@ -122,8 +123,8 @@ class Adadelta(Optimizer):
the Cell with step as the input to get the weight decay value of current step.
Inputs:
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, the shape is the same as the `params`
in optimizer.
- **grads** (tuple[Tensor]) - The gradients of `params` in the optimizer, has the same shape and data type as
the `params` in optimizer. With float16 or float32 data type.
Outputs:
Tensor[bool], the value is True.
@ -183,6 +184,8 @@ class Adadelta(Optimizer):
self.epsilon = epsilon
def construct(self, grads):
if not isinstance(grads, tuple) or not isinstance(grads[0], Tensor):
raise TypeError("For 'Adadelta', the 'grads' must be a tuple of Tensor.")
params = self.parameters
grads = self.decay_weight(grads)
grads = self.gradients_centralization(grads)