!49166 SGD优化器支持更新标量权重,合入r1.8

Merge pull request !49166 from tangdezhi_123/r1.8
This commit is contained in:
i-robot 2023-03-06 11:00:23 +00:00 committed by Gitee
commit 0579d6625d
No known key found for this signature in database
GPG Key ID: 173E9B9CA92EEF8F
1 changed files with 0 additions and 2 deletions

View File

@ -3036,10 +3036,8 @@ class SGD(PrimitiveWithCheck):
def check_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
accum_shape, momentum_shape, stat_shape):
validator.check_positive_int(len(parameters_shape), "parameters rank", self.name)
validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name)
validator.check_int(len(learning_rate_shape), 0, Rel.GE, f'learning rate rank', self.name)
validator.check_positive_int(len(accum_shape), "accumulation rank", self.name)
validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name)
validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name)
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)