From f182edfd448ebe478d6bff58113978abeaf5f5aa Mon Sep 17 00:00:00 2001 From: Ziyan Date: Sat, 18 Apr 2020 12:50:47 +0800 Subject: [PATCH] fix lars base class type --- mindspore/nn/optim/lars.py | 7 +++---- mindspore/nn/optim/optimizer.py | 2 +- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/mindspore/nn/optim/lars.py b/mindspore/nn/optim/lars.py index c0cb71cfa6a..02538aa61a6 100755 --- a/mindspore/nn/optim/lars.py +++ b/mindspore/nn/optim/lars.py @@ -21,8 +21,7 @@ from mindspore.common.parameter import Parameter from mindspore.ops import operations as P from mindspore.ops import composite as C from mindspore.ops import functional as F -from mindspore.nn.cell import Cell -from .optimizer import grad_scale +from .optimizer import grad_scale, Optimizer lars_opt = C.MultitypeFuncGraph("lars_opt") @@ -61,7 +60,7 @@ def _tensor_run_opt_v2(lars, weight_decay, learning_rate, gradient, weight, deca return gradient -class LARS(Cell): +class LARS(Optimizer): """ Implements the LARS algorithm with LARSUpdate Operator. @@ -98,7 +97,7 @@ class LARS(Cell): def __init__(self, optimizer, epsilon=1e-05, hyperpara=0.001, weight_decay=0.0, use_clip=False, decay_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, lars_filter=lambda x: 'LayerNorm' not in x.name and 'bias' not in x.name, loss_scale=1.0): - super(LARS, self).__init__(auto_prefix=False) + super(LARS, self).__init__(0.0, [Parameter(Tensor(0.0), name="trivial")]) self.opt = optimizer self.parameters = optimizer.parameters self.learning_rate = optimizer.learning_rate diff --git a/mindspore/nn/optim/optimizer.py b/mindspore/nn/optim/optimizer.py index 6c6d14ed7ab..00d3fd3b7bd 100755 --- a/mindspore/nn/optim/optimizer.py +++ b/mindspore/nn/optim/optimizer.py @@ -57,7 +57,7 @@ class Optimizer(Cell): def __init__(self, learning_rate, parameters, weight_decay=0.0, loss_scale=1.0, decay_filter=lambda x: 'beta' not in x.name and 'gamma' not in x.name): - super(Optimizer, self).__init__() + super(Optimizer, self).__init__(auto_prefix=False) if isinstance(learning_rate, float): self.dynamic_lr = False self.gather = None