forked from mindspore-Ecosystem/mindspore
!1555 fix bug in lamb warmup step check
Merge pull request !1555 from wangnan39/fix_bug_in_check_lamb_warmup_step
This commit is contained in:
commit
2a6a3e012c
|
@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
||||||
def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
||||||
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
|
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
|
||||||
"""Check the type of inputs."""
|
"""Check the type of inputs."""
|
||||||
_ = warmup_steps
|
|
||||||
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
|
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
|
||||||
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
|
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
|
||||||
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
|
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
|
||||||
|
@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
||||||
validator.check_float_positive('power', power, prim_name)
|
validator.check_float_positive('power', power, prim_name)
|
||||||
validator.check_float_legal_value('power', power, prim_name)
|
validator.check_float_legal_value('power', power, prim_name)
|
||||||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
||||||
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name)
|
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
|
||||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||||
validator.check_value_type("eps", eps, [float], prim_name)
|
validator.check_value_type("eps", eps, [float], prim_name)
|
||||||
|
|
|
@ -14,6 +14,7 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
""" test lamb """
|
""" test lamb """
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import Tensor, Parameter
|
from mindspore import Tensor, Parameter
|
||||||
|
@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def test_lamb_1():
|
def test_lamb_compile():
|
||||||
""" test_Lamb_1 """
|
""" test_Lamb_compile """
|
||||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||||
net = Net()
|
net = Net()
|
||||||
net.set_train()
|
net.set_train()
|
||||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||||
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5)
|
optimizer = Lamb(net.trainable_params(), decay_steps=10)
|
||||||
|
|
||||||
net_with_loss = WithLossCell(net, loss)
|
net_with_loss = WithLossCell(net, loss)
|
||||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||||
_executor.compile(train_network, inputs, label)
|
_executor.compile(train_network, inputs, label)
|
||||||
|
|
||||||
|
|
||||||
def test_lamb_2():
|
def test_lamb_error():
|
||||||
""" test_Lamb_2 """
|
|
||||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
|
||||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
|
||||||
net = Net()
|
net = Net()
|
||||||
net.set_train()
|
with pytest.raises(TypeError):
|
||||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0)
|
||||||
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0)
|
|
||||||
|
|
||||||
net_with_loss = WithLossCell(net, loss)
|
with pytest.raises(TypeError):
|
||||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
Lamb(net.get_parameters(), decay_steps=1.0)
|
||||||
_executor.compile(train_network, inputs, label)
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
Lamb(net.get_parameters(), decay_steps=0)
|
||||||
|
|
Loading…
Reference in New Issue