diff --git a/burn-derive/src/lib.rs b/burn-derive/src/lib.rs index 350d9c209..43cc042d2 100644 --- a/burn-derive/src/lib.rs +++ b/burn-derive/src/lib.rs @@ -29,6 +29,7 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream { let state_fn = param.gen_state_fn(); let load_from_parent_fn = param.gen_load_from_parent_fn(); let load_fn = param.gen_load_fn(); + let inner_fn = param.gen_inner_fn(); let gen = quote! { impl #generics burn::module::Module for #name #generics_ty #generics_where { @@ -44,11 +45,17 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream { #load_fn } + impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::back::ad::Backend, { + type ADBackend=B; + type InnerModule=#name; + + #inner_fn + } impl #generics std::fmt::Display for #name #generics_ty #generics_where { #display_fn } }; - - gen.into() + let tokens = gen.into(); + tokens } diff --git a/burn-derive/src/param.rs b/burn-derive/src/param.rs index 20e31bada..5f71aca66 100644 --- a/burn-derive/src/param.rs +++ b/burn-derive/src/param.rs @@ -49,7 +49,7 @@ impl Param { } quote! { - fn update_params>(&mut self, grads: &burn::tensor::Gradients, optim: &mut O) + fn update_params>(&mut self, grads: &burn::tensor::Gradients, optim: &mut O) where B: burn::tensor::back::ad::Backend { #body @@ -98,6 +98,30 @@ impl Param { .into() } + pub fn gen_inner_fn(&self) -> TokenStream { + let mut body = quote! {}; + let mut names = Vec::new(); + for field in self.fields.iter() { + let name = field.ident(); + names.push(name.clone()); + + body.extend(quote! { + let #name = self.#name.inner(); + }); + } + + quote! { + fn inner(&self) -> Self::InnerModule { + #body + + Self::InnerModule { + #(#names),* + } + } + } + .into() + } + pub fn gen_state_fn(&self) -> TokenStream { let mut body = quote! { let mut state = burn::module::State::new(self.name()); diff --git a/burn-tensor/Cargo.toml b/burn-tensor/Cargo.toml index 03ee44af1..deaeb8287 100644 --- a/burn-tensor/Cargo.toml +++ b/burn-tensor/Cargo.toml @@ -28,6 +28,7 @@ half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work # Backends tch = { version = "0.8", optional = true } +lazy_static = "1.4" ndarray = { version = "0.15", optional = true } # Autodiff diff --git a/burn-tensor/src/tensor/api/af.rs b/burn-tensor/src/tensor/api/af.rs index 1c48df9ac..40cfd30d9 100644 --- a/burn-tensor/src/tensor/api/af.rs +++ b/burn-tensor/src/tensor/api/af.rs @@ -11,13 +11,13 @@ pub fn softmax(tensor: &Tensor, dim: usize) -> } pub fn log_softmax(tensor: &Tensor, dim: usize) -> Tensor { - let tensor_tmp = match Precision::Half == B::Elem::precision() { - true => { + let tensor_tmp = match B::Elem::precision() { + Precision::Half => { let tensor_full = tensor.to_full_precision(); let tensor_tmp = tensor_full.exp().sum_dim(dim).log(); Tensor::from_full_precision(tensor_tmp) } - false => tensor.exp().sum_dim(dim).log(), + _ => tensor.exp().sum_dim(dim).log(), }; tensor.sub(&tensor_tmp) diff --git a/burn-tensor/src/tensor/backend/tch/tensor.rs b/burn-tensor/src/tensor/backend/tch/tensor.rs index fb783b548..c821e69f9 100644 --- a/burn-tensor/src/tensor/backend/tch/tensor.rs +++ b/burn-tensor/src/tensor/backend/tch/tensor.rs @@ -1,5 +1,11 @@ use crate::tensor::{ops::TensorOpsUtilities, Data, Element, Shape, TensorTrait}; +lazy_static::lazy_static! { + static ref NO_GRAD: tch::NoGradGuard = { + tch::no_grad_guard() + }; +} + #[derive(Debug, PartialEq)] pub struct TchTensor { pub kind: TchKind

, @@ -65,6 +71,8 @@ impl TchTensor { let shape_tch = TchShape::from(data.shape); let kind = TchKind::new(); let tensor = tensor.reshape(&shape_tch.dims).to_kind(kind.kind()); + + lazy_static::initialize(&NO_GRAD); let tensor = tensor.set_requires_grad(false); Self { @@ -81,6 +89,8 @@ impl T let device = tch::Device::Cpu; let kind = TchKind::new(); let tensor = tch::Tensor::empty(&shape_tch.dims, (kind.kind(), device.clone())); + + lazy_static::initialize(&NO_GRAD); let tensor = tensor.set_requires_grad(false); Self { diff --git a/examples/mnist.rs b/examples/mnist.rs index a1bab6f4d..825944be0 100644 --- a/examples/mnist.rs +++ b/examples/mnist.rs @@ -1,5 +1,5 @@ use burn::data::dataloader::batcher::Batcher; -use burn::data::dataloader::DataLoaderBuilder; +use burn::data::dataloader::{DataLoaderBuilder, Detach}; use burn::data::dataset::source::huggingface::{MNISTDataset, MNISTItem}; use burn::module::{Forward, Module, Param}; use burn::nn; @@ -8,8 +8,8 @@ use burn::tensor::af::relu; use burn::tensor::back::{ad, Backend}; use burn::tensor::losses::cross_entropy_with_logits; use burn::tensor::{Data, ElementConversion, Shape, Tensor}; -use burn::train::logger::{AsyncLogger, CLILogger}; -use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric, Metric}; +use burn::train::logger::{AsyncLogger, CLILogger, TextPlot}; +use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric}; use burn::train::{ClassificationLearner, ClassificationOutput, SupervisedTrainer}; use std::sync::Arc; @@ -118,7 +118,16 @@ struct MNISTBatch { targets: Tensor, } -impl Batcher> for MNISTBatcher { +impl Detach for MNISTBatch { + fn detach(self) -> Self { + Self { + images: self.images.detach(), + targets: self.targets.detach(), + } + } +} + +impl Batcher> for MNISTBatcher { fn batch(&self, items: Vec) -> MNISTBatch { let images = items .iter() @@ -133,8 +142,8 @@ impl Batcher> for MNISTBatcher { .map(|item| Tensor::::one_hot(item.label, 10)) .collect(); - let images = Tensor::cat(images, 0).to_device(self.device).detach(); - let targets = Tensor::cat(targets, 0).to_device(self.device).detach(); + let images = Tensor::cat(images, 0).to_device(self.device); + let targets = Tensor::cat(targets, 0).to_device(self.device); MNISTBatch { images, targets } } @@ -143,18 +152,11 @@ impl Batcher> for MNISTBatcher { fn run(device: B::Device) { let batch_size = 128; let learning_rate = 5.5e-2; - let num_epochs = 100; + let num_epochs = 10; let num_workers = 8; let num_layers = 4; - let hidden_dim = 1024; + let hidden_dim = 3024; let seed = 42; - let metrics = || -> Vec>>> { - vec![ - Box::new(LossMetric::new()), - Box::new(AccuracyMetric::new()), - Box::new(CUDAMetric::new()), - ] - }; let mut model: Model = Model::new(784, hidden_dim, num_layers, 10); model.to_device(device); @@ -167,25 +169,34 @@ fn run(device: B::Device) { ); let optim: SGDOptimizer = SGDOptimizer::new(learning_rate); - let batcher = Arc::new(MNISTBatcher:: { device }); - let dataloader_train = DataLoaderBuilder::new(batcher.clone()) + let batcher_train = Arc::new(MNISTBatcher:: { device }); + let batcher_valid = Arc::new(MNISTBatcher:: { device }); + let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(batch_size) .shuffle(seed) .num_workers(num_workers) .build(Arc::new(MNISTDataset::train())); - let dataloader_test = DataLoaderBuilder::new(batcher.clone()) + let dataloader_test = DataLoaderBuilder::new(batcher_valid) .batch_size(batch_size) .num_workers(num_workers) .build(Arc::new(MNISTDataset::test())); - let learner = ClassificationLearner::new(model); + let learner = ClassificationLearner::new(model, optim); let logger_train = Box::new(AsyncLogger::new(Box::new(CLILogger::new( - metrics(), + vec![ + Box::new(TextPlot::new(LossMetric::new())), + Box::new(AccuracyMetric::new()), + Box::new(CUDAMetric::new()), + ], "Train".to_string(), )))); let logger_valid = Box::new(AsyncLogger::new(Box::new(CLILogger::new( - metrics(), + vec![ + Box::new(TextPlot::new(LossMetric::new())), + Box::new(AccuracyMetric::new()), + Box::new(CUDAMetric::new()), + ], "Valid".to_string(), )))); @@ -195,7 +206,6 @@ fn run(device: B::Device) { logger_train, logger_valid, learner, - optim, ); trainer.run(num_epochs); diff --git a/src/data/dataloader/mod.rs b/src/data/dataloader/mod.rs index 75ff76654..10bd44105 100644 --- a/src/data/dataloader/mod.rs +++ b/src/data/dataloader/mod.rs @@ -11,3 +11,7 @@ pub use builder::*; pub use dataloader::*; pub use multithread::*; pub use strategy::*; + +pub trait Detach { + fn detach(self) -> Self; +} diff --git a/src/module/module.rs b/src/module/module.rs index eed791128..c9140ad03 100644 --- a/src/module/module.rs +++ b/src/module/module.rs @@ -72,8 +72,11 @@ where pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display { type Backend: back::Backend; - fn update_params>(&mut self, grads: &Gradients, optim: &mut O) - where + fn update_params>( + &mut self, + grads: &Gradients, + optim: &mut O, + ) where Self::Backend: back::ad::Backend; fn devices(&self) -> Vec<::Device>; fn to_device(&mut self, device: ::Device); @@ -93,6 +96,13 @@ pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display { fn num_params(&self) -> usize; } +pub trait ADModule: Module + Send + Sync + std::fmt::Debug + std::fmt::Display { + type ADBackend: back::ad::Backend; + type InnerModule: Module::InnerBackend>; + + fn inner(&self) -> Self::InnerModule; +} + pub trait Forward { fn forward(&self, input: In) -> Out; } diff --git a/src/module/param.rs b/src/module/param.rs index e9dacc726..493460b7f 100644 --- a/src/module/param.rs +++ b/src/module/param.rs @@ -1,4 +1,4 @@ -use crate::module::{Module, State}; +use crate::module::{ADModule, Module, State}; use crate::optim::Optimizer; use crate::tensor::{back, Gradients, Tensor}; use serde::de::DeserializeOwned; @@ -28,7 +28,7 @@ impl Param> { self.value.shape().num_elements() } - pub fn update_params>(&mut self, grads: &Gradients, optim: &mut O) + pub fn update_params>(&mut self, grads: &Gradients, optim: &mut O) where B: back::ad::Backend, { @@ -60,6 +60,13 @@ impl Param> { let data = state.get(name); self.value = Tensor::from_data_device(data, self.value.device()); } + + pub fn inner(&self) -> Param> + where + B: back::ad::Backend, + { + Param::new(self.value.inner()) + } } impl Param>> { @@ -71,7 +78,7 @@ impl Param>> { 0 } - pub fn update_params>(&mut self, grads: &Gradients, optim: &mut O) + pub fn update_params>(&mut self, grads: &Gradients, optim: &mut O) where B: back::ad::Backend, { @@ -118,6 +125,16 @@ impl Param>> { self.value = value; } + + pub fn inner(&self) -> Param>> + where + B: back::ad::Backend, + { + match &self.value { + Some(tensor) => Param::new(Some(tensor.inner())), + None => Param::new(None), + } + } } impl Param { @@ -125,8 +142,11 @@ impl Param { self.value.num_params() } - pub fn update_params>(&mut self, grads: &Gradients, optim: &mut O) - where + pub fn update_params>( + &mut self, + grads: &Gradients, + optim: &mut O, + ) where M::Backend: back::ad::Backend, { self.value.update_params(grads, optim); @@ -157,6 +177,14 @@ impl Param { { self.value.load_from_parent(name, state); } + + pub fn inner(&self) -> Param + where + M: ADModule, + M::Backend: back::ad::Backend, + { + Param::new(self.value.inner()) + } } impl Param> { @@ -169,8 +197,11 @@ impl Param> { num_params } - pub fn update_params>(&mut self, grads: &Gradients, optim: &mut O) - where + pub fn update_params>( + &mut self, + grads: &Gradients, + optim: &mut O, + ) where M::Backend: back::ad::Backend, { for module in self.value.iter_mut() { @@ -209,4 +240,12 @@ impl Param> { { todo!(); } + + pub fn inner(&self) -> Param> + where + M: ADModule, + M::Backend: back::ad::Backend, + { + Param::new(self.value.iter().map(|v| v.inner()).collect()) + } } diff --git a/src/optim/optim.rs b/src/optim/optim.rs index a3be54490..a905cfc50 100644 --- a/src/optim/optim.rs +++ b/src/optim/optim.rs @@ -1,6 +1,8 @@ use crate::tensor::back::ad::Backend; use crate::tensor::{Gradients, Tensor}; -pub trait Optimizer: Send + Sync { - fn update(&mut self, tensor: &mut Tensor, grads: &Gradients); +pub trait Optimizer: Send + Sync { + type Backend: Backend; + + fn update(&mut self, tensor: &mut Tensor, grads: &Gradients); } diff --git a/src/optim/sgd.rs b/src/optim/sgd.rs index b16c5293b..b3bedc60f 100644 --- a/src/optim/sgd.rs +++ b/src/optim/sgd.rs @@ -14,7 +14,9 @@ impl SGDOptimizer { Self { learning_rate } } } -impl Optimizer for SGDOptimizer { +impl Optimizer for SGDOptimizer { + type Backend = B; + fn update(&mut self, tensor: &mut Tensor, grads: &Gradients) { let grad = tensor.grad(&grads).unwrap(); let delta = grad.mul_scalar(&self.learning_rate); diff --git a/src/train/trainer/learner/basic.rs b/src/train/trainer/learner/basic.rs index e834f4dcc..84cbe7b84 100644 --- a/src/train/trainer/learner/basic.rs +++ b/src/train/trainer/learner/basic.rs @@ -1,12 +1,14 @@ use super::{Learner, Loss}; +use crate::module::ADModule; use crate::optim::Optimizer; use crate::tensor::back::{ad, Backend}; use crate::train::metric; use burn_tensor::Tensor; #[derive(new)] -pub struct BasicLearner { +pub struct BasicLearner { model: L, + optim: O, } #[derive(new)] @@ -23,23 +25,28 @@ impl metric::Metric> for metric::LossMetric { } } -impl Learner, BasicOutput> for BasicLearner +impl Learner, BasicOutput> for BasicLearner where - B: ad::Backend, - L: Loss, - O: Optimizer, + B: ad::Backend, + B2: Backend, + O: Optimizer, + L: Loss + ADModule, + L2: Loss, + O: Optimizer, { - fn train(&mut self, item: T, optim: &mut O) -> BasicOutput { + type Backend = B; + + fn train(&mut self, item: T) -> BasicOutput { let loss = self.model.loss(item); let grads = loss.backward(); - self.model.update_params(&grads, optim); + self.model.update_params(&grads, &mut self.optim); BasicOutput::new(loss) } - fn valid(&self, item: T) -> BasicOutput { - let loss = self.model.loss(item); + fn valid(&self, item: T) -> BasicOutput { + let loss = self.model.inner().loss(item); BasicOutput::new(loss) } } diff --git a/src/train/trainer/learner/classification.rs b/src/train/trainer/learner/classification.rs index e20d254df..80e7267c9 100644 --- a/src/train/trainer/learner/classification.rs +++ b/src/train/trainer/learner/classification.rs @@ -1,13 +1,14 @@ use super::Learner; -use crate::module::{Forward, Module}; +use crate::module::{ADModule, Forward, Module}; use crate::optim::Optimizer; use crate::tensor::back::{ad, Backend}; use crate::train::metric; use burn_tensor::Tensor; #[derive(new)] -pub struct ClassificationLearner { +pub struct ClassificationLearner { model: M, + optim: O, } #[derive(new)] @@ -36,23 +37,27 @@ impl metric::Metric> for metric::AccuracyMet } } -impl Learner, ClassificationOutput> - for ClassificationLearner +impl + Learner, ClassificationOutput> + for ClassificationLearner where B: ad::Backend, - M: Forward> + Module, - O: Optimizer, + M: Forward> + ADModule, + M2: Forward> + Module, + O: Optimizer, { - fn train(&mut self, item: I, optim: &mut O) -> ClassificationOutput { + type Backend = B; + + fn train(&mut self, item: I) -> ClassificationOutput { let output = self.model.forward(item); let grads = output.loss.backward(); - self.model.update_params(&grads, optim); + self.model.update_params(&grads, &mut self.optim); output } - fn valid(&self, item: I) -> ClassificationOutput { - self.model.forward(item) + fn valid(&self, item: IV) -> ClassificationOutput { + self.model.inner().forward(item) } } diff --git a/src/train/trainer/learner/learner.rs b/src/train/trainer/learner/learner.rs index d0addf618..6ed33a48d 100644 --- a/src/train/trainer/learner/learner.rs +++ b/src/train/trainer/learner/learner.rs @@ -2,11 +2,15 @@ use crate::module::Module; use crate::tensor::back::Backend; use burn_tensor::Tensor; -pub trait Loss: Module { - fn loss(&self, item: T) -> Tensor; +pub trait Loss: Module { + type Item; + + fn loss(&self, item: Self::Item) -> Tensor; } -pub trait Learner { - fn train(&mut self, item: T, optim: &mut O) -> TO; +pub trait Learner { + type Backend: Backend; + + fn train(&mut self, item: T) -> TO; fn valid(&self, item: V) -> VO; } diff --git a/src/train/trainer/trainer.rs b/src/train/trainer/trainer.rs index 05be1709e..def7fe9ba 100644 --- a/src/train/trainer/trainer.rs +++ b/src/train/trainer/trainer.rs @@ -1,30 +1,28 @@ use super::{Learner, TrainerItem}; use crate::data::dataloader::DataLoader; -use crate::optim::Optimizer; -use crate::tensor::back::ad; +use crate::data::dataloader::Detach; +use crate::tensor::back; use crate::train::logger::Logger; use std::sync::Arc; -pub struct SupervisedTrainer +pub struct SupervisedTrainer where - B: ad::Backend, - L: Learner, - O: Optimizer, + B: back::ad::Backend, + L: Learner, { dataloader_train: Arc>, dataloader_valid: Arc>, logger_train: Box>>, logger_valid: Box>>, learner: L, - optimizer: O, _b: B, } -impl SupervisedTrainer +impl SupervisedTrainer where - B: ad::Backend, - L: Learner, - O: Optimizer, + B: back::ad::Backend, + T: Detach, + L: Learner, { pub fn new( dataloader_train: Arc>, @@ -32,13 +30,11 @@ where logger_train: Box>>, logger_valid: Box>>, learner: L, - optimizer: O, ) -> Self { Self { dataloader_train, dataloader_valid, learner, - optimizer, logger_train, logger_valid, _b: B::default(), @@ -52,7 +48,7 @@ where num_epochs, &self.dataloader_train, &mut self.logger_train, - &mut |item| self.learner.train(item, &mut self.optimizer), + &mut |item| self.learner.train(item.detach()), ); run_step(