mirror of https://github.com/tracel-ai/burn.git
Feat/train/custom optimize method (#689)
* Add the possibility to add a custom optimize function for models * Fix clippy
This commit is contained in:
parent
183620fb20
commit
efee0ac296
|
@ -1,4 +1,5 @@
|
|||
#![warn(missing_docs)]
|
||||
#![allow(clippy::single_range_in_vec_init)]
|
||||
|
||||
//! Burn Tch Backend
|
||||
|
||||
|
|
|
@ -87,9 +87,8 @@ impl<TI> TrainEpoch<TI> {
|
|||
) -> (M, O)
|
||||
where
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
M: TrainStep<TI, TO> + ADModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
M: TrainStep<TI, TO>,
|
||||
LR: LRScheduler,
|
||||
{
|
||||
log::info!("Executing training step for epoch {}", self.epoch,);
|
||||
|
@ -114,11 +113,11 @@ impl<TI> TrainEpoch<TI> {
|
|||
|
||||
if accumulation <= accumulation_current {
|
||||
let grads = accumulator.grads();
|
||||
model = optim.step(lr, model, grads);
|
||||
model = model.optimize(&mut optim, lr, grads);
|
||||
accumulation_current = 0;
|
||||
}
|
||||
}
|
||||
None => model = optim.step(lr, model, item.grads),
|
||||
None => model = model.optimize(&mut optim, lr, item.grads),
|
||||
}
|
||||
|
||||
let item = LearnerItem::new(
|
||||
|
@ -204,7 +203,7 @@ impl<TI> TrainEpoch<TI> {
|
|||
|
||||
if accumulation <= accumulation_current {
|
||||
let grads = accumulator.grads();
|
||||
model = optim.step(lr, model, grads);
|
||||
model = model.optimize(&mut optim, lr, grads);
|
||||
accumulation_current = 0;
|
||||
}
|
||||
|
||||
|
|
|
@ -35,21 +35,52 @@ impl<TO> TrainOutput<TO> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Trait for a training step.
|
||||
/// Trait to be implemented for training models.
|
||||
///
|
||||
/// The [step](TrainStep::step) method needs to be manually implemented for all structs.
|
||||
///
|
||||
/// The [optimize](TrainStep::optimize) method can be overridden if you want to control how the
|
||||
/// optimizer is used to update the model. This can be useful if you want to call custom mutable
|
||||
/// functions on your model (e.g., clipping the weights) before or after the optimizer is used.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// To be used with the [Learner](Learner) struct, the struct which implements this trait must
|
||||
/// also implement the [ADModule](ADModule) trait, which is done automatically with the
|
||||
/// [Module](burn_core::module::Module) derive.
|
||||
pub trait TrainStep<TI, TO> {
|
||||
/// Runs a training step.
|
||||
/// Runs the training step, which executes the forward and backward passes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The item to train on.
|
||||
/// * `item` - The training input for the model.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The training output.
|
||||
/// The training output containing the model output and the gradients.
|
||||
fn step(&self, item: TI) -> TrainOutput<TO>;
|
||||
/// Optimize the current module with the provided gradients and learning rate.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `optim`: Optimizer used for training this model.
|
||||
/// * `lr`: The learning rate used for this step.
|
||||
/// * `grads`: The gradients of each parameter in the current model.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The updated model.
|
||||
fn optimize<B, O>(self, optim: &mut O, lr: f64, grads: GradientsParams) -> Self
|
||||
where
|
||||
B: ADBackend,
|
||||
O: Optimizer<Self, B>,
|
||||
Self: ADModule<B>,
|
||||
{
|
||||
optim.step(lr, self, grads)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for a validation step.
|
||||
/// Trait to be implemented for validating models.
|
||||
pub trait ValidStep<VI, VO> {
|
||||
/// Runs a validation step.
|
||||
///
|
||||
|
|
Loading…
Reference in New Issue