Feat/train/custom optimize method (#689)

* Add the possibility to add a custom optimize function for models

* Fix clippy
This commit is contained in:
Nathaniel Simard 2023-08-25 07:14:36 -04:00 committed by GitHub
parent 183620fb20
commit efee0ac296
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 41 additions and 10 deletions

View File

@ -1,4 +1,5 @@
#![warn(missing_docs)]
#![allow(clippy::single_range_in_vec_init)]
//! Burn Tch Backend

View File

@ -87,9 +87,8 @@ impl<TI> TrainEpoch<TI> {
) -> (M, O)
where
B: ADBackend,
M: ADModule<B>,
M: TrainStep<TI, TO> + ADModule<B>,
O: Optimizer<M, B>,
M: TrainStep<TI, TO>,
LR: LRScheduler,
{
log::info!("Executing training step for epoch {}", self.epoch,);
@ -114,11 +113,11 @@ impl<TI> TrainEpoch<TI> {
if accumulation <= accumulation_current {
let grads = accumulator.grads();
model = optim.step(lr, model, grads);
model = model.optimize(&mut optim, lr, grads);
accumulation_current = 0;
}
}
None => model = optim.step(lr, model, item.grads),
None => model = model.optimize(&mut optim, lr, item.grads),
}
let item = LearnerItem::new(
@ -204,7 +203,7 @@ impl<TI> TrainEpoch<TI> {
if accumulation <= accumulation_current {
let grads = accumulator.grads();
model = optim.step(lr, model, grads);
model = model.optimize(&mut optim, lr, grads);
accumulation_current = 0;
}

View File

@ -35,21 +35,52 @@ impl<TO> TrainOutput<TO> {
}
}
/// Trait for a training step.
/// Trait to be implemented for training models.
///
/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
///
/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
/// optimizer is used to update the model. This can be useful if you want to call custom mutable
/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
///
/// # Notes
///
/// To be used with the [Learner](Learner) struct, the struct which implements this trait must
/// also implement the [ADModule](ADModule) trait, which is done automatically with the
/// [Module](burn_core::module::Module) derive.
pub trait TrainStep<TI, TO> {
/// Runs a training step.
/// Runs the training step, which executes the forward and backward passes.
///
/// # Arguments
///
/// * `item` - The item to train on.
/// * `item` - The training input for the model.
///
/// # Returns
///
/// The training output.
/// The training output containing the model output and the gradients.
fn step(&self, item: TI) -> TrainOutput<TO>;
/// Optimize the current module with the provided gradients and learning rate.
///
/// # Arguments
///
/// * `optim`: Optimizer used for training this model.
/// * `lr`: The learning rate used for this step.
/// * `grads`: The gradients of each parameter in the current model.
///
/// # Returns
///
/// The updated model.
fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
where
B: ADBackend,
O: Optimizer<Self, B>,
Self: ADModule<B>,
{
optim.step(lr, self, grads)
}
}
/// Trait for a validation step.
/// Trait to be implemented for validating models.
pub trait ValidStep<VI, VO> {
/// Runs a validation step.
///