Add Clone trait to the `OptimizerAdaptor` and Clone implementations to the optimizers (#1770)

This commit is contained in:
getumen 2024-05-15 22:18:09 +09:00 committed by GitHub
parent f8a1356075
commit e823338750
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 14 additions and 1 deletions

View File

@ -30,6 +30,7 @@ impl GradientClippingConfig {
/// Gradient Clipping provides a way to mitigate exploding gradients
/// by clipping every component of the gradient by value or by norm during
/// backpropagation.
#[derive(Clone)]
pub enum GradientClipping {
/// Clip the gradient by value.
Value(f32),

View File

@ -26,6 +26,7 @@ pub struct AdaGradConfig {
}
/// AdaGrad optimizer
#[derive(Clone)]
pub struct AdaGrad<B: Backend> {
lr_decay: LrDecay,
weight_decay: Option<WeightDecay<B>>,
@ -105,6 +106,7 @@ pub struct LrDecayState<B: Backend, const D: usize> {
sum: Tensor<B, D>,
}
#[derive(Clone)]
struct LrDecay {
lr_decay: f64,
epsilon: f32,

View File

@ -31,6 +31,7 @@ pub struct AdamConfig {
}
/// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
#[derive(Clone)]
pub struct Adam<B: Backend> {
momentum: AdaptiveMomentum,
weight_decay: Option<WeightDecay<B>>,
@ -113,6 +114,7 @@ pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
moment_2: Tensor<B, D>,
}
#[derive(Clone)]
struct AdaptiveMomentum {
beta_1: f32,
beta_2: f32,

View File

@ -30,6 +30,7 @@ pub struct AdamWConfig {
}
/// AdamW optimizer as described in the paper [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
#[derive(Clone)]
pub struct AdamW<B: Backend> {
momentum: AdaptiveMomentumW,
weight_decay: f32,
@ -112,6 +113,7 @@ pub struct AdaptiveMomentumWState<B: Backend, const D: usize> {
moment_2: Tensor<B, D>,
}
#[derive(Clone)]
struct AdaptiveMomentumW {
beta_1: f32,
beta_2: f32,

View File

@ -20,6 +20,7 @@ pub struct WeightDecayState<B: Backend, const D: usize> {
}
/// Weight decay implementation that transforms gradients.
#[derive(Clone)]
pub struct WeightDecay<B: Backend> {
penalty: B::FloatElem,
}

View File

@ -27,6 +27,7 @@ pub struct MomentumState<B: Backend, const D: usize> {
}
/// Momemtum implementation that transforms gradients.
#[derive(Clone)]
pub struct Momentum<B: Backend> {
momentum: B::FloatElem,
dampening: f64,

View File

@ -64,6 +64,7 @@ impl RmsPropConfig {
/// Optimizer that implements stochastic gradient descent with momentum.
/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
#[derive(Clone)]
pub struct RmsProp<B: Backend> {
alpha: f32,
// epsilon: f32,
@ -251,6 +252,7 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
#[derive(Clone)]
pub struct RmsPropMomentum {
momentum: f32,
epsilon: f32,

View File

@ -25,6 +25,7 @@ pub struct SgdConfig {
/// Optimizer that implements stochastic gradient descent with momentum.
///
/// The optimizer can be configured with [SgdConfig](SgdConfig).
#[derive(Clone)]
pub struct Sgd<B: Backend> {
momentum: Option<Momentum<B>>,
weight_decay: Option<WeightDecay<B>>,

View File

@ -11,6 +11,7 @@ use hashbrown::HashMap;
/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into
/// an [optimizer](Optimizer).
#[derive(Clone)]
pub struct OptimizerAdaptor<O, M, B>
where
O: SimpleOptimizer<B::InnerBackend>,

View File

@ -6,7 +6,7 @@ use burn_tensor::{backend::Backend, Tensor};
///
/// Implementations don't have to handle missing gradients, loading and exporting records, navigate the
/// module parameter structure, handle tracked and untracked tensors, and the likes.
pub trait SimpleOptimizer<B>: Send + Sync
pub trait SimpleOptimizer<B>: Send + Sync + Clone
where
B: Backend,
{