!1555 fix bug in lamb warmup step check

Merge pull request !1555 from wangnan39/fix_bug_in_check_lamb_warmup_step
This commit is contained in:
mindspore-ci-bot 2020-05-28 16:08:43 +08:00 committed by Gitee
commit 2a6a3e012c
2 changed files with 13 additions and 15 deletions

View File

@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
def _check_param_value(decay_steps, warmup_steps, start_learning_rate, def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs.""" """Check the type of inputs."""
_ = warmup_steps
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name) validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name) validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
validator.check_float_positive('power', power, prim_name) validator.check_float_positive('power', power, prim_name)
validator.check_float_legal_value('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name) validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name)

View File

@ -14,6 +14,7 @@
# ============================================================================ # ============================================================================
""" test lamb """ """ test lamb """
import numpy as np import numpy as np
import pytest
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell):
return x return x
def test_lamb_1(): def test_lamb_compile():
""" test_Lamb_1 """ """ test_Lamb_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32)) inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
net.set_train() net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits() loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5) optimizer = Lamb(net.trainable_params(), decay_steps=10)
net_with_loss = WithLossCell(net, loss) net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer) train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label) _executor.compile(train_network, inputs, label)
def test_lamb_2(): def test_lamb_error():
""" test_Lamb_2 """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net() net = Net()
net.set_train() with pytest.raises(TypeError):
loss = nn.SoftmaxCrossEntropyWithLogits() Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0)
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0)
net_with_loss = WithLossCell(net, loss) with pytest.raises(TypeError):
train_network = TrainOneStepCell(net_with_loss, optimizer) Lamb(net.get_parameters(), decay_steps=1.0)
_executor.compile(train_network, inputs, label)
with pytest.raises(ValueError):
Lamb(net.get_parameters(), decay_steps=0)