diff --git a/burn-tch/src/lib.rs b/burn-tch/src/lib.rs index 18d75df71..6d01154a7 100644 --- a/burn-tch/src/lib.rs +++ b/burn-tch/src/lib.rs @@ -1,4 +1,5 @@ #![warn(missing_docs)] +#![allow(clippy::single_range_in_vec_init)] //! Burn Tch Backend diff --git a/burn-train/src/learner/epoch.rs b/burn-train/src/learner/epoch.rs index 2761cf7aa..3e7e31e31 100644 --- a/burn-train/src/learner/epoch.rs +++ b/burn-train/src/learner/epoch.rs @@ -87,9 +87,8 @@ impl TrainEpoch { ) -> (M, O) where B: ADBackend, - M: ADModule, + M: TrainStep + ADModule, O: Optimizer, - M: TrainStep, LR: LRScheduler, { log::info!("Executing training step for epoch {}", self.epoch,); @@ -114,11 +113,11 @@ impl TrainEpoch { if accumulation <= accumulation_current { let grads = accumulator.grads(); - model = optim.step(lr, model, grads); + model = model.optimize(&mut optim, lr, grads); accumulation_current = 0; } } - None => model = optim.step(lr, model, item.grads), + None => model = model.optimize(&mut optim, lr, item.grads), } let item = LearnerItem::new( @@ -204,7 +203,7 @@ impl TrainEpoch { if accumulation <= accumulation_current { let grads = accumulator.grads(); - model = optim.step(lr, model, grads); + model = model.optimize(&mut optim, lr, grads); accumulation_current = 0; } diff --git a/burn-train/src/learner/train_val.rs b/burn-train/src/learner/train_val.rs index 4908a2f71..cc3789b4b 100644 --- a/burn-train/src/learner/train_val.rs +++ b/burn-train/src/learner/train_val.rs @@ -35,21 +35,52 @@ impl TrainOutput { } } -/// Trait for a training step. +/// Trait to be implemented for training models. +/// +/// The [step](TrainStep::step) method needs to be manually implemented for all structs. +/// +/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the +/// optimizer is used to update the model. This can be useful if you want to call custom mutable +/// functions on your model (e.g., clipping the weights) before or after the optimizer is used. +/// +/// # Notes +/// +/// To be used with the [Learner](Learner) struct, the struct which implements this trait must +/// also implement the [ADModule](ADModule) trait, which is done automatically with the +/// [Module](burn_core::module::Module) derive. pub trait TrainStep { - /// Runs a training step. + /// Runs the training step, which executes the forward and backward passes. /// /// # Arguments /// - /// * `item` - The item to train on. + /// * `item` - The training input for the model. /// /// # Returns /// - /// The training output. + /// The training output containing the model output and the gradients. fn step(&self, item: TI) -> TrainOutput; + /// Optimize the current module with the provided gradients and learning rate. + /// + /// # Arguments + /// + /// * `optim`: Optimizer used for training this model. + /// * `lr`: The learning rate used for this step. + /// * `grads`: The gradients of each parameter in the current model. + /// + /// # Returns + /// + /// The updated model. + fn optimize(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self + where + B: ADBackend, + O: Optimizer, + Self: ADModule, + { + optim.step(lr, self, grads) + } } -/// Trait for a validation step. +/// Trait to be implemented for validating models. pub trait ValidStep { /// Runs a validation step. ///