mirror of https://github.com/tracel-ai/burn.git
Add Clone trait to the `OptimizerAdaptor` and Clone implementations to the optimizers (#1770)
This commit is contained in:
parent
f8a1356075
commit
e823338750
|
@ -30,6 +30,7 @@ impl GradientClippingConfig {
|
|||
/// Gradient Clipping provides a way to mitigate exploding gradients
|
||||
/// by clipping every component of the gradient by value or by norm during
|
||||
/// backpropagation.
|
||||
#[derive(Clone)]
|
||||
pub enum GradientClipping {
|
||||
/// Clip the gradient by value.
|
||||
Value(f32),
|
||||
|
|
|
@ -26,6 +26,7 @@ pub struct AdaGradConfig {
|
|||
}
|
||||
|
||||
/// AdaGrad optimizer
|
||||
#[derive(Clone)]
|
||||
pub struct AdaGrad<B: Backend> {
|
||||
lr_decay: LrDecay,
|
||||
weight_decay: Option<WeightDecay<B>>,
|
||||
|
@ -105,6 +106,7 @@ pub struct LrDecayState<B: Backend, const D: usize> {
|
|||
sum: Tensor<B, D>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct LrDecay {
|
||||
lr_decay: f64,
|
||||
epsilon: f32,
|
||||
|
|
|
@ -31,6 +31,7 @@ pub struct AdamConfig {
|
|||
}
|
||||
|
||||
/// Adam optimizer as described in the paper [Adam: A Method for Stochastic Optimization](https://arxiv.org/pdf/1412.6980.pdf).
|
||||
#[derive(Clone)]
|
||||
pub struct Adam<B: Backend> {
|
||||
momentum: AdaptiveMomentum,
|
||||
weight_decay: Option<WeightDecay<B>>,
|
||||
|
@ -113,6 +114,7 @@ pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
|
|||
moment_2: Tensor<B, D>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AdaptiveMomentum {
|
||||
beta_1: f32,
|
||||
beta_2: f32,
|
||||
|
|
|
@ -30,6 +30,7 @@ pub struct AdamWConfig {
|
|||
}
|
||||
|
||||
/// AdamW optimizer as described in the paper [Decoupled Weight Decay Regularization, Loshchilov and Hutter, 2019](https://arxiv.org/abs/1711.05101).
|
||||
#[derive(Clone)]
|
||||
pub struct AdamW<B: Backend> {
|
||||
momentum: AdaptiveMomentumW,
|
||||
weight_decay: f32,
|
||||
|
@ -112,6 +113,7 @@ pub struct AdaptiveMomentumWState<B: Backend, const D: usize> {
|
|||
moment_2: Tensor<B, D>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
struct AdaptiveMomentumW {
|
||||
beta_1: f32,
|
||||
beta_2: f32,
|
||||
|
|
|
@ -20,6 +20,7 @@ pub struct WeightDecayState<B: Backend, const D: usize> {
|
|||
}
|
||||
|
||||
/// Weight decay implementation that transforms gradients.
|
||||
#[derive(Clone)]
|
||||
pub struct WeightDecay<B: Backend> {
|
||||
penalty: B::FloatElem,
|
||||
}
|
||||
|
|
|
@ -27,6 +27,7 @@ pub struct MomentumState<B: Backend, const D: usize> {
|
|||
}
|
||||
|
||||
/// Momemtum implementation that transforms gradients.
|
||||
#[derive(Clone)]
|
||||
pub struct Momentum<B: Backend> {
|
||||
momentum: B::FloatElem,
|
||||
dampening: f64,
|
||||
|
|
|
@ -64,6 +64,7 @@ impl RmsPropConfig {
|
|||
|
||||
/// Optimizer that implements stochastic gradient descent with momentum.
|
||||
/// The optimizer can be configured with [RmsPropConfig](RmsPropConfig).
|
||||
#[derive(Clone)]
|
||||
pub struct RmsProp<B: Backend> {
|
||||
alpha: f32,
|
||||
// epsilon: f32,
|
||||
|
@ -251,6 +252,7 @@ impl<B: Backend, const D: usize> CenteredState<B, D> {
|
|||
|
||||
/// [RmsPropMomentum](RmsPropMomentum) is to store config status for optimizer.
|
||||
/// (, which is stored in [optimizer](RmsProp) itself and not passed in during `step()` calculation)
|
||||
#[derive(Clone)]
|
||||
pub struct RmsPropMomentum {
|
||||
momentum: f32,
|
||||
epsilon: f32,
|
||||
|
|
|
@ -25,6 +25,7 @@ pub struct SgdConfig {
|
|||
/// Optimizer that implements stochastic gradient descent with momentum.
|
||||
///
|
||||
/// The optimizer can be configured with [SgdConfig](SgdConfig).
|
||||
#[derive(Clone)]
|
||||
pub struct Sgd<B: Backend> {
|
||||
momentum: Option<Momentum<B>>,
|
||||
weight_decay: Option<WeightDecay<B>>,
|
||||
|
|
|
@ -11,6 +11,7 @@ use hashbrown::HashMap;
|
|||
|
||||
/// Wrapper struct that adapts any [simple optimizer](SimpleOptimizer) into
|
||||
/// an [optimizer](Optimizer).
|
||||
#[derive(Clone)]
|
||||
pub struct OptimizerAdaptor<O, M, B>
|
||||
where
|
||||
O: SimpleOptimizer<B::InnerBackend>,
|
||||
|
|
|
@ -6,7 +6,7 @@ use burn_tensor::{backend::Backend, Tensor};
|
|||
///
|
||||
/// Implementations don't have to handle missing gradients, loading and exporting records, navigate the
|
||||
/// module parameter structure, handle tracked and untracked tensors, and the likes.
|
||||
pub trait SimpleOptimizer<B>: Send + Sync
|
||||
pub trait SimpleOptimizer<B>: Send + Sync + Clone
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
|
|
Loading…
Reference in New Issue