SGD优化器支持更新标量权重,合入r1.8

This commit is contained in:
tangdezhi_123 2023-02-21 14:32:23 +08:00
parent 17548c80e9
commit ae40f357cd
1 changed files with 0 additions and 2 deletions

View File

@ -3036,10 +3036,8 @@ class SGD(PrimitiveWithCheck):
def check_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
accum_shape, momentum_shape, stat_shape):
validator.check_positive_int(len(parameters_shape), "parameters rank", self.name)
validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name)
validator.check_int(len(learning_rate_shape), 0, Rel.GE, f'learning rate rank', self.name)
validator.check_positive_int(len(accum_shape), "accumulation rank", self.name)
validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name)
validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name)
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)