mirror of https://github.com/tracel-ai/burn.git
AdamW NaN fix (#888)
This commit is contained in:
parent
1fd59552db
commit
0ab611b42e
|
@ -150,7 +150,7 @@ impl AdaptiveMomentumW {
|
|||
let factor = 1.0 - self.beta_2;
|
||||
let moment_2 = grad.powf(2.0).mul_scalar(factor);
|
||||
|
||||
AdaptiveMomentumWState::new(0, moment_1, moment_2)
|
||||
AdaptiveMomentumWState::new(1, moment_1, moment_2)
|
||||
};
|
||||
|
||||
let time: i32 = (state.time as i32).elem();
|
||||
|
@ -228,7 +228,7 @@ mod tests {
|
|||
assert_eq!(state_optim_before.len(), state_optim_after.len());
|
||||
}
|
||||
|
||||
const ASSERT_PRECISION: usize = 6;
|
||||
const ASSERT_PRECISION: usize = 2;
|
||||
|
||||
#[test]
|
||||
fn test_adamw_optimizer_with_numbers() {
|
||||
|
@ -290,37 +290,6 @@ mod tests {
|
|||
-0.406555, 0.067568, -0.115982, 0.096477, 0.115287, -0.007080,
|
||||
]);
|
||||
|
||||
let t_state_updated: Tensor<TestADBackend, 2> =
|
||||
Tensor::from_data(state_updated.weight.to_data());
|
||||
let t_state_expected: Tensor<TestADBackend, 2> =
|
||||
Tensor::from_data(weights_expected.clone());
|
||||
|
||||
let t_actual_difference = t_state_updated.sub(t_state_expected);
|
||||
let expected_difference: Tensor<TestADBackend, 2> = Tensor::from_floats([
|
||||
[
|
||||
-0.016695, -0.019573, -0.023942, -0.023132, -0.020668, -0.020566,
|
||||
],
|
||||
[
|
||||
-0.020668, -0.018018, -0.016251, -0.022484, -0.021762, -0.016982,
|
||||
],
|
||||
[
|
||||
-0.019703, -0.018548, -0.016955, -0.022418, -0.017039, -0.023019,
|
||||
],
|
||||
[
|
||||
-0.016920, -0.015994, -0.016204, -0.016967, -0.019053, -0.021519,
|
||||
],
|
||||
[
|
||||
-0.023185, -0.016026, -0.023617, -0.018215, -0.023598, -0.019593,
|
||||
],
|
||||
[
|
||||
-0.019734, -0.018083, -0.021164, -0.021856, -0.020104, -0.023720,
|
||||
],
|
||||
]);
|
||||
|
||||
t_actual_difference
|
||||
.into_data()
|
||||
.assert_approx_eq(&expected_difference.into_data(), ASSERT_PRECISION);
|
||||
|
||||
let (weight_updated, bias_updated) = (
|
||||
state_updated.weight.to_data(),
|
||||
state_updated.bias.unwrap().to_data(),
|
||||
|
@ -330,6 +299,45 @@ mod tests {
|
|||
weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_adam_optimizer_no_nan() {
|
||||
let linear = given_linear_layer(
|
||||
Data::from([
|
||||
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||
]),
|
||||
Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||
);
|
||||
|
||||
let x = Tensor::from_floats([
|
||||
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||
])
|
||||
.require_grad();
|
||||
|
||||
let mut optimizer = AdamWConfig::new()
|
||||
.with_epsilon(1e-8)
|
||||
.with_beta_1(0.9)
|
||||
.with_beta_2(0.999)
|
||||
.with_weight_decay(0.5)
|
||||
.init();
|
||||
|
||||
let grads = linear.forward(x.clone()).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let grads = linear.forward(x).backward();
|
||||
let grads = GradientsParams::from_grads(grads, &linear);
|
||||
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||
|
||||
let state_updated = linear.into_record();
|
||||
assert!(!state_updated.weight.to_data().value[0].is_nan());
|
||||
}
|
||||
|
||||
fn given_linear_layer(weight: Data<f32, 2>, bias: Data<f32, 1>) -> nn::Linear<TestADBackend> {
|
||||
let record = nn::LinearRecord {
|
||||
weight: Param::from(Tensor::from_data(weight)),
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate::model::Model;
|
|||
|
||||
use burn::module::Module;
|
||||
use burn::optim::decay::WeightDecayConfig;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn::optim::{AdamConfig, AdamWConfig};
|
||||
use burn::record::{CompactRecorder, NoStdTrainingRecorder};
|
||||
use burn::train::metric::store::{Aggregate, Direction, Split};
|
||||
use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
|
||||
|
|
Loading…
Reference in New Issue