forked from mindspore-Ecosystem/mindspore
fix_bug_in_check_lamb_warmup_step
This commit is contained in:
parent
fb7e4eac76
commit
810ccf80d8
|
@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
|
|||
def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
||||
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
|
||||
"""Check the type of inputs."""
|
||||
_ = warmup_steps
|
||||
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
|
||||
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
|
||||
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
|
||||
|
@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
|
|||
validator.check_float_positive('power', power, prim_name)
|
||||
validator.check_float_legal_value('power', power, prim_name)
|
||||
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
|
||||
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name)
|
||||
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
|
||||
validator.check_value_type("beta1", beta1, [float], prim_name)
|
||||
validator.check_value_type("beta2", beta2, [float], prim_name)
|
||||
validator.check_value_type("eps", eps, [float], prim_name)
|
||||
|
|
|
@ -14,6 +14,7 @@
|
|||
# ============================================================================
|
||||
""" test lamb """
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
|
@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell):
|
|||
return x
|
||||
|
||||
|
||||
def test_lamb_1():
|
||||
""" test_Lamb_1 """
|
||||
def test_lamb_compile():
|
||||
""" test_Lamb_compile """
|
||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||
net = Net()
|
||||
net.set_train()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5)
|
||||
optimizer = Lamb(net.trainable_params(), decay_steps=10)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
|
||||
|
||||
def test_lamb_2():
|
||||
""" test_Lamb_2 """
|
||||
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
|
||||
label = Tensor(np.zeros([1, 10]).astype(np.float32))
|
||||
def test_lamb_error():
|
||||
net = Net()
|
||||
net.set_train()
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits()
|
||||
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0)
|
||||
with pytest.raises(TypeError):
|
||||
Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepCell(net_with_loss, optimizer)
|
||||
_executor.compile(train_network, inputs, label)
|
||||
with pytest.raises(TypeError):
|
||||
Lamb(net.get_parameters(), decay_steps=1.0)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Lamb(net.get_parameters(), decay_steps=0)
|
||||
|
|
Loading…
Reference in New Issue