From 7d31deb6fa2b972b96e571f42c38479648525e8d Mon Sep 17 00:00:00 2001 From: zhousiyi Date: Tue, 4 Aug 2020 02:08:23 +0000 Subject: [PATCH] remove loss_scale range check to make FP32Imm(inf) comparison equal --- mindspore/core/abstract/abstract_value.cc | 21 ++++++++++++++++--- .../test_optimizer_with_loss_scale.py | 11 ++++++++-- 2 files changed, 27 insertions(+), 5 deletions(-) diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index fb16cf0161..502c0c9464 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -38,9 +38,24 @@ bool AbstractBase::operator==(const AbstractBase &other) const { << this->ToString() << ", other: " << other.ToString(); } - bool value_equal = *value_ == *other.value_; - bool type_equal = *type_ == *other.type_; - bool shape_equal = *shape_ == *other.shape_; + bool value_equal = false; + if (value_ == other.value_) { + value_equal = true; + } else if (*value_ == *other.value_) { + value_equal = true; + } + bool type_equal = false; + if (type_ == other.type_) { + type_equal = true; + } else if (*type_ == *other.type_) { + type_equal = true; + } + bool shape_equal = false; + if (shape_ == other.shape_) { + shape_equal = true; + } else if (*shape_ == *other.shape_) { + shape_equal = true; + } return value_equal && type_equal && shape_equal; } diff --git a/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py b/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py index 6f77c4a361..cccaf8e3b8 100644 --- a/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py +++ b/tests/ut/python/optimizer/test_optimizer_with_loss_scale.py @@ -276,7 +276,7 @@ def test_compile_fp16_lr_overflow_dynamic_graph(): print("the result is ", output) -def test_adam_compile(): +def adam_compile(loss_scale=1.0): inputs = Tensor(np.ones([15, 1]).astype(np.float32)) label = Tensor(np.zeros([15, 1]).astype(np.float32)) scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32) @@ -284,10 +284,17 @@ def test_adam_compile(): loss = MSELoss() optimizer = Adam(net.trainable_params(), learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, - use_nesterov=False, weight_decay=0.0, loss_scale=1.0) + use_nesterov=False, weight_decay=0.0, loss_scale=loss_scale) net_with_loss = WithLossCell(net, loss) train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer) train_network.set_train() output = train_network(inputs, label, scaling_sens) print("the result is ", output) + +def test_adam_compile(): + adam_compile() + +def test_adam_loss_scale_compile(): + """ test setting loss_scale to 1e-40 """ + adam_compile(loss_scale=1e-40)