fix_bug_in_check_lamb_warmup_step

This commit is contained in:
wangnan39@huawei.com 2020-05-28 10:19:13 +08:00
parent fb7e4eac76
commit 810ccf80d8
2 changed files with 13 additions and 15 deletions

View File

@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name):
"""Check the type of inputs."""
_ = warmup_steps
validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name)
validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name)
validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name)
@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate,
validator.check_float_positive('power', power, prim_name)
validator.check_float_legal_value('power', power, prim_name)
validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name)
validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name)
validator.check_value_type("beta1", beta1, [float], prim_name)
validator.check_value_type("beta2", beta2, [float], prim_name)
validator.check_value_type("eps", eps, [float], prim_name)

View File

@ -14,6 +14,7 @@
# ============================================================================
""" test lamb """
import numpy as np
import pytest
import mindspore.nn as nn
from mindspore import Tensor, Parameter
@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell):
return x
def test_lamb_1():
""" test_Lamb_1 """
def test_lamb_compile():
""" test_Lamb_compile """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5)
optimizer = Lamb(net.trainable_params(), decay_steps=10)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
def test_lamb_2():
""" test_Lamb_2 """
inputs = Tensor(np.ones([1, 64]).astype(np.float32))
label = Tensor(np.zeros([1, 10]).astype(np.float32))
def test_lamb_error():
net = Net()
net.set_train()
loss = nn.SoftmaxCrossEntropyWithLogits()
optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0)
with pytest.raises(TypeError):
Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0)
net_with_loss = WithLossCell(net, loss)
train_network = TrainOneStepCell(net_with_loss, optimizer)
_executor.compile(train_network, inputs, label)
with pytest.raises(TypeError):
Lamb(net.get_parameters(), decay_steps=1.0)
with pytest.raises(ValueError):
Lamb(net.get_parameters(), decay_steps=0)