mirror of https://github.com/tracel-ai/burn.git
Feat/RMSProp-optimizer (#607)
This commit is contained in:
parent
3264b1007c
commit
3e4adc4bc7
|
@ -176,8 +176,8 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
||||||
|
|
||||||
/// Map each tensor in the module with a [mapper](ModuleMapper).
|
/// Map each tensor in the module with a [mapper](ModuleMapper).
|
||||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
|
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
|
||||||
/// Load the module state from a record.
|
|
||||||
|
|
||||||
|
/// Load the module state from a record.
|
||||||
fn load_record(self, record: Self::Record) -> Self;
|
fn load_record(self, record: Self::Record) -> Self;
|
||||||
|
|
||||||
/// Convert the module into a record containing the state.
|
/// Convert the module into a record containing the state.
|
||||||
|
|
|
@ -16,7 +16,7 @@ pub struct WeightDecayConfig {
|
||||||
/// State of [WeightDecay](WeightDecay).
|
/// State of [WeightDecay](WeightDecay).
|
||||||
#[derive(Record, Clone, new)]
|
#[derive(Record, Clone, new)]
|
||||||
pub struct WeightDecayState<B: Backend, const D: usize> {
|
pub struct WeightDecayState<B: Backend, const D: usize> {
|
||||||
grad_last_step: Tensor<B, D>,
|
pub(crate) grad_last_step: Tensor<B, D>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Weight decay implementation that transforms gradients.
|
/// Weight decay implementation that transforms gradients.
|
||||||
|
@ -57,6 +57,15 @@ impl<B: Backend> WeightDecay<B> {
|
||||||
|
|
||||||
(grad, WeightDecayState::new(grad_last_step))
|
(grad, WeightDecayState::new(grad_last_step))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// temp fix for Transform.
|
||||||
|
pub fn transform_temp_fix<const D: usize>(
|
||||||
|
&self,
|
||||||
|
grad: Tensor<B, D>,
|
||||||
|
tensor: Tensor<B, D>,
|
||||||
|
) -> Tensor<B, D> {
|
||||||
|
tensor.mul_scalar(self.penalty).add(grad)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend, const D: usize> WeightDecayState<B, D> {
|
impl<B: Backend, const D: usize> WeightDecayState<B, D> {
|
||||||
|
|
|
@ -10,6 +10,7 @@ mod adamw;
|
||||||
mod base;
|
mod base;
|
||||||
mod grad_accum;
|
mod grad_accum;
|
||||||
mod grads;
|
mod grads;
|
||||||
|
mod rmsprop;
|
||||||
mod sgd;
|
mod sgd;
|
||||||
mod simple;
|
mod simple;
|
||||||
mod visitor;
|
mod visitor;
|
||||||
|
@ -20,5 +21,6 @@ pub use adamw::*;
|
||||||
pub use base::*;
|
pub use base::*;
|
||||||
pub use grad_accum::*;
|
pub use grad_accum::*;
|
||||||
pub use grads::*;
|
pub use grads::*;
|
||||||
|
pub use rmsprop::*;
|
||||||
pub use sgd::*;
|
pub use sgd::*;
|
||||||
pub use simple::*;
|
pub use simple::*;
|
||||||
|
|
|
@ -0,0 +1,528 @@
|
||||||
|
use crate::{
|
||||||
|
self as burn, grad_clipping::GradientClippingConfig, module::ADModule, record::Record,
|
||||||
|
LearningRate,
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
decay::{WeightDecay, WeightDecayConfig},
|
||||||
|
SimpleOptimizer,
|
||||||
|
};
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::optim::adaptor::OptimizerAdaptor;
|
||||||
|
use crate::tensor::{backend::ADBackend, Tensor};
|
||||||
|
use burn_tensor::backend::Backend;
|
||||||
|
|
||||||
|
/// Configuration to create the [RMSProp](RMSProp) optimizer.
|
||||||
|
#[derive(Config)]
|
||||||
|
pub struct RMSPropConfig {
|
||||||
|
/// Smoothing constant.
|
||||||
|
#[config(default = 0.99)]
|
||||||
|
alpha: f32,
|
||||||
|
/// momentum for RMSProp.
|
||||||
|
#[config(default = 0.9)]
|
||||||
|
momentum: f32,
|
||||||
|
/// A value required for numerical stability.
|
||||||
|
#[config(default = 1e-5)]
|
||||||
|
epsilon: f32,
|
||||||
|
/// if True, compute the centered RMSProp, the gradient is normalized by an estimation of its variance
|
||||||
|
#[config(default = false)]
|
||||||
|
centered: bool,
|
||||||
|
/// [Weight decay](WeightDecayConfig) config.
|
||||||
|
weight_decay: Option<WeightDecayConfig>,
|
||||||
|
/// [Gradient Clipping](GradientClippingConfig) config.
|
||||||
|
grad_clipping: Option<GradientClippingConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RMSPropConfig {
|
||||||
|
/// Initialize RMSProp optimizer.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// Returns an optimizer that can be used to optimize a module.
|
||||||
|
pub fn init<B: ADBackend, M: ADModule<B>>(
|
||||||
|
&self,
|
||||||
|
) -> OptimizerAdaptor<RMSProp<B::InnerBackend>, M, B> {
|
||||||
|
let weight_decay = self.weight_decay.as_ref().map(WeightDecay::new);
|
||||||
|
|
||||||
|
let mut optim = OptimizerAdaptor::from(RMSProp {
|
||||||
|
alpha: self.alpha,
|
||||||
|
centered: self.centered,
|
||||||
|
weight_decay,
|
||||||
|
momentum: RMSPropMomentum {
|
||||||
|
momentum: self.momentum,
|
||||||
|
epsilon: self.epsilon,
|
||||||
|
},
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(config) = &self.grad_clipping {
|
||||||
|
optim = optim.with_grad_clipping(config.init());
|
||||||
|
}
|
||||||
|
|
||||||
|
optim
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Optimizer that implements stochastic gradient descent with momentum.
|
||||||
|
/// The optimizer can be configured with [RMSPropConfig](RMSPropConfig).
|
||||||
|
pub struct RMSProp<B: Backend> {
|
||||||
|
alpha: f32,
|
||||||
|
// epsilon: f32,
|
||||||
|
centered: bool,
|
||||||
|
// momentum: Option<Momentum<B>>,
|
||||||
|
momentum: RMSPropMomentum,
|
||||||
|
weight_decay: Option<WeightDecay<B>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> SimpleOptimizer<B> for RMSProp<B> {
|
||||||
|
type State<const D: usize> = RMSPropState<B, D>;
|
||||||
|
|
||||||
|
fn step<const D: usize>(
|
||||||
|
&self,
|
||||||
|
lr: LearningRate,
|
||||||
|
tensor: Tensor<B, D>,
|
||||||
|
mut grad: Tensor<B, D>,
|
||||||
|
state: Option<Self::State<D>>,
|
||||||
|
) -> (Tensor<B, D>, Option<Self::State<D>>) {
|
||||||
|
// fetch state for params
|
||||||
|
let mut state_square_avg = None;
|
||||||
|
let mut state_centered = None;
|
||||||
|
let mut state_momentum = None;
|
||||||
|
if let Some(state) = state {
|
||||||
|
state_square_avg = Some(state.square_avg);
|
||||||
|
state_centered = Some(state.centered);
|
||||||
|
state_momentum = state.momentum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// weight_decay transform
|
||||||
|
if let Some(weight_decay) = &self.weight_decay {
|
||||||
|
grad = weight_decay.transform_temp_fix(grad, tensor.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// square_avg transform
|
||||||
|
let (grad, state_square_avg) =
|
||||||
|
SquareAvgState::transform(self.alpha, grad, state_square_avg);
|
||||||
|
|
||||||
|
// centered transform
|
||||||
|
let (grad, state_square_avg, state_centered) = CenteredState::transform(
|
||||||
|
self.alpha,
|
||||||
|
self.centered,
|
||||||
|
grad,
|
||||||
|
state_square_avg,
|
||||||
|
state_centered,
|
||||||
|
);
|
||||||
|
|
||||||
|
// momentum transform
|
||||||
|
let (grad, state_centered, state_momentum) =
|
||||||
|
self.momentum
|
||||||
|
.transform(grad, state_centered, state_momentum);
|
||||||
|
|
||||||
|
// transition state
|
||||||
|
let state = RMSPropState::new(state_square_avg, state_centered, state_momentum);
|
||||||
|
|
||||||
|
// tensor param transform
|
||||||
|
let delta = grad.mul_scalar(lr);
|
||||||
|
(tensor - delta, Some(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn to_device<const D: usize>(
|
||||||
|
mut state: Self::State<D>,
|
||||||
|
device: &<B as Backend>::Device,
|
||||||
|
) -> Self::State<D> {
|
||||||
|
state.square_avg = state.square_avg.to_device(device);
|
||||||
|
state.centered = state.centered.to_device(device);
|
||||||
|
state.momentum = state.momentum.map(|momentum| momentum.to_device(device));
|
||||||
|
state
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// State of [RMSProp](RMSProp)
|
||||||
|
#[derive(Record, Clone, new)]
|
||||||
|
pub struct RMSPropState<B: Backend, const D: usize> {
|
||||||
|
square_avg: SquareAvgState<B, D>,
|
||||||
|
centered: CenteredState<B, D>,
|
||||||
|
momentum: Option<RMSPropMomentumState<B, D>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [SquareAvgState](SquareAvgState) is to store and pass optimizer step params.
|
||||||
|
#[derive(Record, Clone, new)]
|
||||||
|
pub struct SquareAvgState<B: Backend, const D: usize> {
|
||||||
|
square_avg: Tensor<B, D>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend, const D: usize> SquareAvgState<B, D> {
|
||||||
|
/// transform [SquareAvgState] to the next step
|
||||||
|
fn transform(alpha: f32, grad: Tensor<B, D>, state: Option<Self>) -> (Tensor<B, D>, Self) {
|
||||||
|
match state {
|
||||||
|
Some(state) => {
|
||||||
|
let square_avg = state
|
||||||
|
.square_avg
|
||||||
|
.clone()
|
||||||
|
.mul_scalar(alpha)
|
||||||
|
.add(grad.clone().powf(2.).mul_scalar(1. - alpha));
|
||||||
|
(grad, Self { square_avg })
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
let square_avg = grad.clone().powf(2.).mul_scalar(1. - alpha);
|
||||||
|
(grad, Self { square_avg })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Moves the state to a device.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `device` - Device to move the state to.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `self` - Moved state.
|
||||||
|
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||||
|
self.square_avg = self.square_avg.to_device(device);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [CenteredState](CenteredState) is to store and pass optimizer step params.
|
||||||
|
#[derive(Record, Clone, new)]
|
||||||
|
pub struct CenteredState<B: Backend, const D: usize> {
|
||||||
|
grad_avg: Option<Tensor<B, D>>,
|
||||||
|
avg: Tensor<B, D>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend, const D: usize> CenteredState<B, D> {
|
||||||
|
/// transform [CenteredState] to the next step
|
||||||
|
fn transform(
|
||||||
|
alpha: f32,
|
||||||
|
centered: bool,
|
||||||
|
grad: Tensor<B, D>,
|
||||||
|
square_avg_state: SquareAvgState<B, D>,
|
||||||
|
centered_state: Option<Self>,
|
||||||
|
) -> (Tensor<B, D>, SquareAvgState<B, D>, Self) {
|
||||||
|
if centered {
|
||||||
|
let grad_avg_constant = grad.clone().mul_scalar(1. - alpha);
|
||||||
|
let grad_avg = match centered_state {
|
||||||
|
Some(state) => state
|
||||||
|
.grad_avg
|
||||||
|
.map_or(grad_avg_constant.clone(), move |grad_avg| {
|
||||||
|
grad_avg.clone().mul_scalar(alpha).add(grad_avg_constant)
|
||||||
|
}),
|
||||||
|
_ => grad_avg_constant,
|
||||||
|
};
|
||||||
|
let avg = square_avg_state
|
||||||
|
.square_avg
|
||||||
|
.clone()
|
||||||
|
.sub(grad_avg.clone().powf(2.));
|
||||||
|
|
||||||
|
(
|
||||||
|
grad,
|
||||||
|
square_avg_state,
|
||||||
|
Self {
|
||||||
|
grad_avg: Some(grad_avg),
|
||||||
|
avg,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(
|
||||||
|
grad,
|
||||||
|
square_avg_state.clone(),
|
||||||
|
Self {
|
||||||
|
grad_avg: None,
|
||||||
|
avg: square_avg_state.square_avg,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Moves the state to a device.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `device` - Device to move the state to.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `self` - Moved state.
|
||||||
|
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||||
|
self.grad_avg = self.grad_avg.map(|grad_avg| grad_avg.to_device(device));
|
||||||
|
self.avg = self.avg.to_device(device);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [RMSPropMomentum](RMSPropMomentum) is to store config status for optimizer.
|
||||||
|
/// (, which is stored in [optimizer](RMSProp) itself and not passed in during `step()` calculation)
|
||||||
|
pub struct RMSPropMomentum {
|
||||||
|
momentum: f32,
|
||||||
|
epsilon: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RMSPropMomentum {
|
||||||
|
/// transform [grad](Tensor) and [RMSPropMomentumState] to the next step
|
||||||
|
fn transform<B: Backend, const D: usize>(
|
||||||
|
&self,
|
||||||
|
grad: Tensor<B, D>,
|
||||||
|
centered_state: CenteredState<B, D>,
|
||||||
|
momentum_state: Option<RMSPropMomentumState<B, D>>,
|
||||||
|
) -> (
|
||||||
|
Tensor<B, D>,
|
||||||
|
CenteredState<B, D>,
|
||||||
|
Option<RMSPropMomentumState<B, D>>,
|
||||||
|
) {
|
||||||
|
let grad = grad
|
||||||
|
.clone()
|
||||||
|
.div(centered_state.avg.clone().sqrt().add_scalar(self.epsilon));
|
||||||
|
|
||||||
|
if self.momentum > 0. {
|
||||||
|
let buf = match momentum_state {
|
||||||
|
Some(state) => state
|
||||||
|
.buf
|
||||||
|
.clone()
|
||||||
|
.mul_scalar(self.momentum)
|
||||||
|
.add(grad.clone()),
|
||||||
|
_ => grad.clone(),
|
||||||
|
};
|
||||||
|
(
|
||||||
|
buf.clone(),
|
||||||
|
centered_state,
|
||||||
|
Some(RMSPropMomentumState { buf }),
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
(grad.clone(), centered_state, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// [RMSPropMomentumState](RMSPropMomentumState) is to store and pass optimizer step params.
|
||||||
|
#[derive(Record, Clone, new)]
|
||||||
|
pub struct RMSPropMomentumState<B: Backend, const D: usize> {
|
||||||
|
buf: Tensor<B, D>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend, const D: usize> RMSPropMomentumState<B, D> {
|
||||||
|
/// Moves the state to a device.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `device` - Device to move the state to.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// * `self` - Moved state.
|
||||||
|
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||||
|
self.buf = self.buf.to_device(device);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use burn_tensor::Shape;
|
||||||
|
|
||||||
|
use super::*;
|
||||||
|
use crate::module::{Module, Param};
|
||||||
|
use crate::optim::{GradientsParams, Optimizer};
|
||||||
|
use crate::record::{BinFileRecorder, FullPrecisionSettings, Recorder};
|
||||||
|
use crate::tensor::{Data, Distribution, Tensor};
|
||||||
|
use crate::{nn, TestADBackend, TestBackend};
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
const LEARNING_RATE: LearningRate = 0.01;
|
||||||
|
const ASSERT_PRECISION: usize = 6;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rmsprop_optimizer_save_load_state() {
|
||||||
|
let linear = nn::LinearConfig::new(6, 6).init();
|
||||||
|
let x = Tensor::<TestADBackend, 2>::random([2, 6], Distribution::Default);
|
||||||
|
let mut optimizer = create_rmsprop();
|
||||||
|
let grads = linear.forward(x).backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &linear);
|
||||||
|
let _linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||||
|
let temp_dir = TempDir::new().unwrap();
|
||||||
|
BinFileRecorder::<FullPrecisionSettings>::default()
|
||||||
|
.record(optimizer.to_record(), temp_dir.path().join("test_optim"))
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let state_optim_before = optimizer.to_record();
|
||||||
|
let state_optim_before_copy = optimizer.to_record();
|
||||||
|
let optimizer = create_rmsprop();
|
||||||
|
let optimizer = optimizer.load_record(state_optim_before_copy);
|
||||||
|
let state_optim_after = optimizer.to_record();
|
||||||
|
|
||||||
|
assert_eq!(state_optim_before.len(), state_optim_after.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
/// used for test differences and debug
|
||||||
|
#[test]
|
||||||
|
fn test_rmsprop_optimizer_with_numbers_basic() {
|
||||||
|
let linear = given_linear_layer(
|
||||||
|
Data::from([
|
||||||
|
[1., 1., 1., 1., 1., 1.],
|
||||||
|
[1., 1., 1., 1., 1., 1.],
|
||||||
|
[1., 1., 1., 1., 1., 1.],
|
||||||
|
[1., 1., 1., 1., 1., 1.],
|
||||||
|
[1., 1., 1., 1., 1., 1.],
|
||||||
|
[1., 1., 1., 1., 1., 1.],
|
||||||
|
]),
|
||||||
|
Data::from([0.5, 0.5, 0.5, 0.5, 0.5, 0.5]),
|
||||||
|
);
|
||||||
|
let x_1 = Tensor::from_floats([
|
||||||
|
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||||
|
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||||
|
])
|
||||||
|
.require_grad();
|
||||||
|
let x_2 = Tensor::from_floats([
|
||||||
|
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||||
|
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||||
|
])
|
||||||
|
.require_grad();
|
||||||
|
|
||||||
|
let mut optimizer = RMSPropConfig::new()
|
||||||
|
.with_alpha(0.99)
|
||||||
|
.with_epsilon(1e-8)
|
||||||
|
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
||||||
|
.with_momentum(0.9)
|
||||||
|
.with_centered(false)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
// println!("linear is {:?}", linear);
|
||||||
|
let grads = linear.forward(x_1).backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &linear);
|
||||||
|
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||||
|
|
||||||
|
// println!("linear is {:?}", linear);
|
||||||
|
let grads = linear.forward(x_2).backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &linear);
|
||||||
|
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||||
|
|
||||||
|
// println!("linear is {:?}", linear);
|
||||||
|
let state_updated = linear.into_record();
|
||||||
|
|
||||||
|
let (weight_updated, bias_updated) = (
|
||||||
|
state_updated.weight.to_data(),
|
||||||
|
state_updated.bias.unwrap().to_data(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// println!("\nweight_updated\n{:?}", weight_updated);
|
||||||
|
// println!("\nbias_updated\n{:?}", bias_updated);
|
||||||
|
|
||||||
|
let weights_expected = Data::from([
|
||||||
|
[0.743937, 0.743937, 0.743937, 0.743937, 0.743937, 0.743937],
|
||||||
|
[0.783809, 0.783809, 0.783809, 0.783809, 0.783809, 0.783809],
|
||||||
|
[0.742881, 0.742881, 0.742881, 0.742881, 0.742881, 0.742881],
|
||||||
|
[0.740366, 0.740366, 0.740366, 0.740366, 0.740366, 0.740366],
|
||||||
|
[0.748005, 0.748005, 0.748005, 0.748005, 0.748005, 0.748005],
|
||||||
|
[0.743710, 0.743710, 0.743710, 0.743710, 0.743710, 0.743710],
|
||||||
|
]);
|
||||||
|
let bias_expected =
|
||||||
|
Data::from([0.239199, 0.239199, 0.239199, 0.239199, 0.239199, 0.239199]);
|
||||||
|
|
||||||
|
bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION);
|
||||||
|
weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_rmsprop_optimizer_with_numbers() {
|
||||||
|
let linear = given_linear_layer(
|
||||||
|
Data::from([
|
||||||
|
[-0.3206, 0.1374, 0.4043, 0.3200, 0.0859, 0.0671],
|
||||||
|
[0.0777, -0.0185, -0.3667, 0.2550, 0.1955, -0.2922],
|
||||||
|
[-0.0190, 0.0346, -0.2962, 0.2484, -0.2780, 0.3130],
|
||||||
|
[-0.2980, -0.2214, -0.3715, -0.2981, -0.0761, 0.1626],
|
||||||
|
[0.3300, -0.2182, 0.3717, -0.1729, 0.3796, -0.0304],
|
||||||
|
[-0.0159, -0.0120, 0.1258, 0.1921, 0.0293, 0.3833],
|
||||||
|
]),
|
||||||
|
Data::from([-0.3905, 0.0884, -0.0970, 0.1176, 0.1366, 0.0130]),
|
||||||
|
);
|
||||||
|
let x_1 = Tensor::from_floats([
|
||||||
|
[0.6294, 0.0940, 0.8176, 0.8824, 0.5228, 0.4310],
|
||||||
|
[0.7152, 0.9559, 0.7893, 0.5684, 0.5939, 0.8883],
|
||||||
|
])
|
||||||
|
.require_grad();
|
||||||
|
let x_2 = Tensor::from_floats([
|
||||||
|
[0.8491, 0.2108, 0.8939, 0.4433, 0.5527, 0.2528],
|
||||||
|
[0.3270, 0.0412, 0.5538, 0.9605, 0.3195, 0.9085],
|
||||||
|
])
|
||||||
|
.require_grad();
|
||||||
|
|
||||||
|
let mut optimizer = RMSPropConfig::new()
|
||||||
|
.with_alpha(0.99)
|
||||||
|
.with_epsilon(1e-8)
|
||||||
|
.with_weight_decay(WeightDecayConfig::new(0.05).into())
|
||||||
|
.with_momentum(0.9)
|
||||||
|
.with_centered(false)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
let grads = linear.forward(x_1).backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &linear);
|
||||||
|
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||||
|
|
||||||
|
let grads = linear.forward(x_2).backward();
|
||||||
|
let grads = GradientsParams::from_grads(grads, &linear);
|
||||||
|
let linear = optimizer.step(LEARNING_RATE, linear, grads);
|
||||||
|
|
||||||
|
let state_updated = linear.into_record();
|
||||||
|
let weights_expected = Data::from([
|
||||||
|
[
|
||||||
|
-0.576399, -0.118494, 0.148353, 0.064070, -0.169983, -0.188779,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
-0.135571, -0.231448, -0.578445, 0.041143, -0.018162, -0.504207,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
-0.275990, -0.222397, -0.553153, -0.008625, -0.534956, 0.055967,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
-0.557575, -0.480979, -0.631072, -0.557675, -0.335686, -0.096997,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.078313, -0.469618, 0.119993, -0.424341, 0.127890, -0.281912,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
-0.271996, -0.268097, -0.130324, -0.064037, -0.226805, 0.127126,
|
||||||
|
],
|
||||||
|
]);
|
||||||
|
let bias_expected = Data::from([
|
||||||
|
-0.651299, -0.172400, -0.357800, -0.143200, -0.124200, -0.247800,
|
||||||
|
]);
|
||||||
|
|
||||||
|
let (weight_updated, bias_updated) = (
|
||||||
|
state_updated.weight.to_data(),
|
||||||
|
state_updated.bias.unwrap().to_data(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// println!("\nweight_updated\n{:?}", weight_updated);
|
||||||
|
// println!("\nbias_updated\n{:?}", bias_updated);
|
||||||
|
|
||||||
|
bias_updated.assert_approx_eq(&bias_expected, ASSERT_PRECISION);
|
||||||
|
weight_updated.assert_approx_eq(&weights_expected, ASSERT_PRECISION);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn given_linear_layer(weight: Data<f32, 2>, bias: Data<f32, 1>) -> nn::Linear<TestADBackend> {
|
||||||
|
let record = nn::LinearRecord {
|
||||||
|
weight: Param::from(Tensor::from_data(weight)),
|
||||||
|
bias: Some(Param::from(Tensor::from_data(bias))),
|
||||||
|
};
|
||||||
|
|
||||||
|
nn::LinearConfig::new(6, 6).init_with(record)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn create_random_tensor() -> Tensor<TestADBackend, 2> {
|
||||||
|
Tensor::<TestADBackend, 2>::random(Shape::new([2, 20]), Distribution::Default)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn create_rmsprop(
|
||||||
|
) -> OptimizerAdaptor<RMSProp<TestBackend>, nn::Linear<TestADBackend>, TestADBackend> {
|
||||||
|
RMSPropConfig {
|
||||||
|
alpha: 0.99,
|
||||||
|
epsilon: 1e-9,
|
||||||
|
centered: false,
|
||||||
|
weight_decay: Some(WeightDecayConfig { penalty: 0.05 }),
|
||||||
|
momentum: 0.9,
|
||||||
|
grad_clipping: None,
|
||||||
|
..RMSPropConfig::new()
|
||||||
|
}
|
||||||
|
.init()
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue