forked from OSSInnovation/mindspore
enable use float type learning rate in lars optimizer
This commit is contained in:
parent
930a1fb0a8
commit
4cbcd8e907
|
@ -13,12 +13,14 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lars optimizer"""
|
||||
from typing import Iterable
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.common import Tensor
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.parameter import Parameter
|
||||
from mindspore.nn.cell import Cell
|
||||
from .optimizer import grad_scale
|
||||
|
||||
|
@ -111,7 +113,8 @@ class LARS(Cell):
|
|||
self.gather = None
|
||||
self.global_step = None
|
||||
self.axis = None
|
||||
if not isinstance(self.learning_rate, float):
|
||||
if isinstance(self.learning_rate.default_input, Iterable) or \
|
||||
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
|
||||
self.dynamic_lr = True
|
||||
self.assignadd = P.AssignAdd()
|
||||
self.gather = P.GatherV2()
|
||||
|
@ -124,7 +127,7 @@ class LARS(Cell):
|
|||
lr = self.gather(self.learning_rate, self.global_step, self.axis)
|
||||
F.control_depend(lr, self.assignadd(self.global_step, 1))
|
||||
else:
|
||||
lr = F.scalar_to_array(self.learning_rate)
|
||||
lr = self.learning_rate
|
||||
if self.reciprocal_scale != 1.0:
|
||||
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ class Net(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
def test_lars():
|
||||
def test_lars_multi_step_lr():
|
||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||
net = Net()
|
||||
|
@ -61,3 +61,20 @@ def test_lars():
|
|||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_lars_float_lr():
|
||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
|
||||
lr = 0.1
|
||||
SGD = Momentum(net.trainable_params(), lr, 0.9)
|
||||
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
|
||||
lars_filter=lambda x: 'bn' not in x.name)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
Loading…
Reference in New Issue