forked from mindspore-Ecosystem/mindspore
!446 fix lars base class type
Merge pull request !446 from gziyan/fix_lars_base_class
This commit is contained in:
commit
aa54305268
|
@ -21,8 +21,7 @@ from mindspore.common.parameter import Parameter
|
|||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.nn.cell import Cell
|
||||
from .optimizer import grad_scale
|
||||
from .optimizer import grad_scale, Optimizer
|
||||
|
||||
lars_opt = C.MultitypeFuncGraph("lars_opt")
|
||||
|
||||
|
@ -61,7 +60,7 @@ def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, deca
|
|||
return gradient
|
||||
|
||||
|
||||
class LARS(Cell):
|
||||
class LARS(Optimizer):
|
||||
"""
|
||||
Implements the LARS algorithm with LARSUpdate Operator.
|
||||
|
||||
|
@ -98,7 +97,7 @@ class LARS(Cell):
|
|||
def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False,
|
||||
decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name,
|
||||
lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0):
|
||||
super(LARS, self).__init__(auto_prefix=False)
|
||||
super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")])
|
||||
self.opt = optimizer
|
||||
self.parameters = optimizer.parameters
|
||||
self.learning_rate = optimizer.learning_rate
|
||||
|
|
|
@ -57,7 +57,7 @@ class Optimizer(Cell):
|
|||
|
||||
def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0,
|
||||
decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name):
|
||||
super(Optimizer, self).__init__()
|
||||
super(Optimizer, self).__init__(auto_prefix=False)
|
||||
if isinstance(learning_rate, float):
|
||||
self.dynamic_lr = False
|
||||
self.gather = None
|
||||
|
|
Loading…
Reference in New Issue