!49166 SGD优化器支持更新标量权重,合入r1.8
Merge pull request !49166 from tangdezhi_123/r1.8
This commit is contained in:
commit
0579d6625d
|
@ -3036,10 +3036,8 @@ class SGD(PrimitiveWithCheck):
|
|||
|
||||
def check_shape(self, parameters_shape, gradient_shape, learning_rate_shape,
|
||||
accum_shape, momentum_shape, stat_shape):
|
||||
validator.check_positive_int(len(parameters_shape), "parameters rank", self.name)
|
||||
validator.check_int(len(gradient_shape), 0, Rel.GE, f'gradient rank', self.name)
|
||||
validator.check_int(len(learning_rate_shape), 0, Rel.GE, f'learning rate rank', self.name)
|
||||
validator.check_positive_int(len(accum_shape), "accumulation rank", self.name)
|
||||
validator.check_int(len(momentum_shape), 0, Rel.GE, f'momentum rank', self.name)
|
||||
validator.check_int(len(stat_shape), 0, Rel.GE, f'stat rank', self.name)
|
||||
validator.check("gradient shape", gradient_shape, "stat shape", stat_shape, Rel.EQ, self.name)
|
||||
|
|
Loading…
Reference in New Issue