remove loss_scale range check to make FP32Imm(inf) comparison equal
This commit is contained in:
parent
7c03073143
commit
7d31deb6fa
|
@ -38,9 +38,24 @@ bool AbstractBase::operator==(const AbstractBase &other) const {
|
|||
<< this->ToString() << ", other: " << other.ToString();
|
||||
}
|
||||
|
||||
bool value_equal = *value_ == *other.value_;
|
||||
bool type_equal = *type_ == *other.type_;
|
||||
bool shape_equal = *shape_ == *other.shape_;
|
||||
bool value_equal = false;
|
||||
if (value_ == other.value_) {
|
||||
value_equal = true;
|
||||
} else if (*value_ == *other.value_) {
|
||||
value_equal = true;
|
||||
}
|
||||
bool type_equal = false;
|
||||
if (type_ == other.type_) {
|
||||
type_equal = true;
|
||||
} else if (*type_ == *other.type_) {
|
||||
type_equal = true;
|
||||
}
|
||||
bool shape_equal = false;
|
||||
if (shape_ == other.shape_) {
|
||||
shape_equal = true;
|
||||
} else if (*shape_ == *other.shape_) {
|
||||
shape_equal = true;
|
||||
}
|
||||
return value_equal && type_equal && shape_equal;
|
||||
}
|
||||
|
||||
|
|
|
@ -276,7 +276,7 @@ def test_compile_fp16_lr_overflow_dynamic_graph():
|
|||
print("the result is ", output)
|
||||
|
||||
|
||||
def test_adam_compile():
|
||||
def adam_compile(loss_scale=1.0):
|
||||
inputs = Tensor(np.ones([15, 1]).astype(np.float32))
|
||||
label = Tensor(np.zeros([15, 1]).astype(np.float32))
|
||||
scaling_sens = Tensor(np.full((1), 1.0), dtype=mstype.float32)
|
||||
|
@ -284,10 +284,17 @@ def test_adam_compile():
|
|||
|
||||
loss = MSELoss()
|
||||
optimizer = Adam(net.trainable_params(), learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False,
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=1.0)
|
||||
use_nesterov=False, weight_decay=0.0, loss_scale=loss_scale)
|
||||
|
||||
net_with_loss = WithLossCell(net, loss)
|
||||
train_network = TrainOneStepWithLossScaleCell(net_with_loss, optimizer)
|
||||
train_network.set_train()
|
||||
output = train_network(inputs, label, scaling_sens)
|
||||
print("the result is ", output)
|
||||
|
||||
def test_adam_compile():
|
||||
adam_compile()
|
||||
|
||||
def test_adam_loss_scale_compile():
|
||||
""" test setting loss_scale to 1e-40 """
|
||||
adam_compile(loss_scale=1e-40)
|
||||
|
|
Loading…
Reference in New Issue