commit
7210751e4a
|
@ -2189,7 +2189,7 @@ class SparseApplyFtrlD(PrimitiveWithInfer):
|
||||||
self.l1 = validator.check_type("l1", l1, [float])
|
self.l1 = validator.check_type("l1", l1, [float])
|
||||||
self.l2 = validator.check_type("l2", l2, [float])
|
self.l2 = validator.check_type("l2", l2, [float])
|
||||||
self.lr_power = validator.check_type("lr_power", lr_power, [float])
|
self.lr_power = validator.check_type("lr_power", lr_power, [float])
|
||||||
self.use_locking = validator.check_type("use_locking", use_locaking, [bool])
|
self.use_locking = validator.check_type("use_locking", use_locking, [bool])
|
||||||
|
|
||||||
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
|
def infer_shape(self, var_shape, accum_shape, linear_shape, grad_shape, indices_shape):
|
||||||
validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape)
|
validator.check_param_equal('var shape', var_shape, 'accum shape', accum_shape)
|
||||||
|
@ -2210,7 +2210,6 @@ class SparseApplyFtrlD(PrimitiveWithInfer):
|
||||||
validator.check_subclass("linear_type", linear_type, mstype.tensor)
|
validator.check_subclass("linear_type", linear_type, mstype.tensor)
|
||||||
validator.check_subclass("grad_type", grad_type, mstype.tensor)
|
validator.check_subclass("grad_type", grad_type, mstype.tensor)
|
||||||
validator.check_subclass("indices_type", indices_type, mstype.tensor)
|
validator.check_subclass("indices_type", indices_type, mstype.tensor)
|
||||||
validator.check_subclass('indices_type', indices_type, [mstype.int32])
|
|
||||||
|
|
||||||
return var_type
|
return var_type
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue