From 810ccf80d82b2e70a6f059185944642fdb674ebc Mon Sep 17 00:00:00 2001 From: "wangnan39@huawei.com" Date: Thu, 28 May 2020 10:19:13 +0800 Subject: [PATCH] fix_bug_in_check_lamb_warmup_step --- mindspore/nn/optim/lamb.py | 3 +-- tests/ut/python/nn/optim/test_lamb.py | 25 ++++++++++++------------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index b4d478f52ab..59b66869480 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -111,7 +111,6 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para def _check_param_value(decay_steps, warmup_steps, start_learning_rate, end_learning_rate, power, beta1, beta2, eps, weight_decay, prim_name): """Check the type of inputs.""" - _ = warmup_steps validator.check_float_positive('start_learning_rate', start_learning_rate, prim_name) validator.check_float_legal_value('start_learning_rate', start_learning_rate, prim_name) validator.check_value_type("end_learning_rate", end_learning_rate, [float], prim_name) @@ -119,7 +118,7 @@ def _check_param_value(decay_steps, warmup_steps, start_learning_rate, validator.check_float_positive('power', power, prim_name) validator.check_float_legal_value('power', power, prim_name) validator.check_integer('decay_steps', decay_steps, 0, Rel.GT, prim_name) - validator.check_integer('warmup_steps', decay_steps, 0, Rel.GT, prim_name) + validator.check_integer('warmup_steps', warmup_steps, 0, Rel.GE, prim_name) validator.check_value_type("beta1", beta1, [float], prim_name) validator.check_value_type("beta2", beta2, [float], prim_name) validator.check_value_type("eps", eps, [float], prim_name) diff --git a/tests/ut/python/nn/optim/test_lamb.py b/tests/ut/python/nn/optim/test_lamb.py index 2d18207e0ec..4d229f0837d 100644 --- a/tests/ut/python/nn/optim/test_lamb.py +++ b/tests/ut/python/nn/optim/test_lamb.py @@ -14,6 +14,7 @@ # ============================================================================ """ test lamb """ import numpy as np +import pytest import mindspore.nn as nn from mindspore import Tensor, Parameter @@ -50,29 +51,27 @@ class NetWithoutWeight(nn.Cell): return x -def test_lamb_1(): - """ test_Lamb_1 """ +def test_lamb_compile(): + """ test_Lamb_compile """ inputs = Tensor(np.ones([1, 64]).astype(np.float32)) label = Tensor(np.zeros([1, 10]).astype(np.float32)) net = Net() net.set_train() loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=5) + optimizer = Lamb(net.trainable_params(), decay_steps=10) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepCell(net_with_loss, optimizer) _executor.compile(train_network, inputs, label) -def test_lamb_2(): - """ test_Lamb_2 """ - inputs = Tensor(np.ones([1, 64]).astype(np.float32)) - label = Tensor(np.zeros([1, 10]).astype(np.float32)) +def test_lamb_error(): net = Net() - net.set_train() - loss = nn.SoftmaxCrossEntropyWithLogits() - optimizer = Lamb(net.trainable_params(), decay_steps=10, warmup_steps=0) + with pytest.raises(TypeError): + Lamb(net.get_parameters(), decay_steps=6, warmup_steps=5.0) - net_with_loss = WithLossCell(net, loss) - train_network = TrainOneStepCell(net_with_loss, optimizer) - _executor.compile(train_network, inputs, label) + with pytest.raises(TypeError): + Lamb(net.get_parameters(), decay_steps=1.0) + + with pytest.raises(ValueError): + Lamb(net.get_parameters(), decay_steps=0)