enable use float type learning rate in lars optimizer

This commit is contained in:
Ziyan 2020-03-30 15:17:05 +08:00
parent 930a1fb0a8
commit 4cbcd8e907
2 changed files with 24 additions and 4 deletions

View File

@ -13,12 +13,14 @@
# limitations under the License.
# ============================================================================
"""lars optimizer"""
from typing import Iterable
from mindspore.common import dtype as mstype
from mindspore.common import Tensor
from mindspore.common.initializer import initializer
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.common.parameter import Parameter
from mindspore.nn.cell import Cell
from .optimizer import grad_scale
@ -111,7 +113,8 @@ class LARS(Cell):
self.gather = None
self.global_step = None
self.axis = None
if not isinstance(self.learning_rate, float):
if isinstance(self.learning_rate.default_input, Iterable) or \
(isinstance(self.learning_rate.default_input, Tensor) and self.learning_rate.default_input.dim() == 1):
self.dynamic_lr = True
self.assignadd = P.AssignAdd()
self.gather = P.GatherV2()
@ -124,7 +127,7 @@ class LARS(Cell):
lr = self.gather(self.learning_rate, self.global_step, self.axis)
F.control_depend(lr, self.assignadd(self.global_step, 1))
else:
lr = F.scalar_to_array(self.learning_rate)
lr = self.learning_rate
if self.reciprocal_scale != 1.0:
gradients = self.hyper_map(F.partial(grad_scale, self.reciprocal_scale), gradients)

View File

@ -46,7 +46,7 @@ class Net(nn.Cell):
return x
def test_lars():
def test_lars_multi_step_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
@ -61,3 +61,20 @@ def test_lars():
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_lars_float_lr():
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
lr = 0.1
SGD = Momentum(net.trainable_params(), lr, 0.9)
optimizer = LARS(SGD, epsilon=1e-08, hyperpara=0.02, decay_filter=lambda x: 'bn' not in x.name,
lars_filter=lambda x: 'bn' not in x.name)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)