Feat/inner module (#26)

This commit is contained in:
Nathaniel Simard 2022-08-30 18:05:42 -04:00 committed by GitHub
parent 68e052513b
commit 674e078a85
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 198 additions and 77 deletions

View File

@ -29,6 +29,7 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let state_fn = param.gen_state_fn();
let load_from_parent_fn = param.gen_load_from_parent_fn();
let load_fn = param.gen_load_fn();
let inner_fn = param.gen_inner_fn();
let gen = quote! {
impl #generics burn::module::Module for #name #generics_ty #generics_where {
@ -44,11 +45,17 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
#load_fn
}
impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::back::ad::Backend, {
type ADBackend=B;
type InnerModule=#name<B::InnerBackend>;
#inner_fn
}
impl #generics std::fmt::Display for #name #generics_ty #generics_where {
#display_fn
}
};
gen.into()
let tokens = gen.into();
tokens
}

View File

@ -49,7 +49,7 @@ impl Param {
}
quote! {
fn update_params<O: burn::optim::Optimizer<B>>(&mut self, grads: &burn::tensor::Gradients, optim: &mut O)
fn update_params<O: burn::optim::Optimizer<Backend = B>>(&mut self, grads: &burn::tensor::Gradients, optim: &mut O)
where
B: burn::tensor::back::ad::Backend {
#body
@ -98,6 +98,30 @@ impl Param {
.into()
}
pub fn gen_inner_fn(&self) -> TokenStream {
let mut body = quote! {};
let mut names = Vec::new();
for field in self.fields.iter() {
let name = field.ident();
names.push(name.clone());
body.extend(quote! {
let #name = self.#name.inner();
});
}
quote! {
fn inner(&self) -> Self::InnerModule {
#body
Self::InnerModule {
#(#names),*
}
}
}
.into()
}
pub fn gen_state_fn(&self) -> TokenStream {
let mut body = quote! {
let mut state = burn::module::State::new(self.name());

View File

@ -28,6 +28,7 @@ half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work
# Backends
tch = { version = "0.8", optional = true }
lazy_static = "1.4"
ndarray = { version = "0.15", optional = true }
# Autodiff

View File

@ -11,13 +11,13 @@ pub fn softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) ->
}
pub fn log_softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
let tensor_tmp = match Precision::Half == B::Elem::precision() {
true => {
let tensor_tmp = match B::Elem::precision() {
Precision::Half => {
let tensor_full = tensor.to_full_precision();
let tensor_tmp = tensor_full.exp().sum_dim(dim).log();
Tensor::from_full_precision(tensor_tmp)
}
false => tensor.exp().sum_dim(dim).log(),
_ => tensor.exp().sum_dim(dim).log(),
};
tensor.sub(&tensor_tmp)

View File

@ -1,5 +1,11 @@
use crate::tensor::{ops::TensorOpsUtilities, Data, Element, Shape, TensorTrait};
lazy_static::lazy_static! {
static ref NO_GRAD: tch::NoGradGuard = {
tch::no_grad_guard()
};
}
#[derive(Debug, PartialEq)]
pub struct TchTensor<P: tch::kind::Element, const D: usize> {
pub kind: TchKind<P>,
@ -65,6 +71,8 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
let shape_tch = TchShape::from(data.shape);
let kind = TchKind::new();
let tensor = tensor.reshape(&shape_tch.dims).to_kind(kind.kind());
lazy_static::initialize(&NO_GRAD);
let tensor = tensor.set_requires_grad(false);
Self {
@ -81,6 +89,8 @@ impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
let device = tch::Device::Cpu;
let kind = TchKind::new();
let tensor = tch::Tensor::empty(&shape_tch.dims, (kind.kind(), device.clone()));
lazy_static::initialize(&NO_GRAD);
let tensor = tensor.set_requires_grad(false);
Self {

View File

@ -1,5 +1,5 @@
use burn::data::dataloader::batcher::Batcher;
use burn::data::dataloader::DataLoaderBuilder;
use burn::data::dataloader::{DataLoaderBuilder, Detach};
use burn::data::dataset::source::huggingface::{MNISTDataset, MNISTItem};
use burn::module::{Forward, Module, Param};
use burn::nn;
@ -8,8 +8,8 @@ use burn::tensor::af::relu;
use burn::tensor::back::{ad, Backend};
use burn::tensor::losses::cross_entropy_with_logits;
use burn::tensor::{Data, ElementConversion, Shape, Tensor};
use burn::train::logger::{AsyncLogger, CLILogger};
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric, Metric};
use burn::train::logger::{AsyncLogger, CLILogger, TextPlot};
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric};
use burn::train::{ClassificationLearner, ClassificationOutput, SupervisedTrainer};
use std::sync::Arc;
@ -118,7 +118,16 @@ struct MNISTBatch<B: Backend> {
targets: Tensor<B, 2>,
}
impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
impl<B: ad::Backend> Detach for MNISTBatch<B> {
fn detach(self) -> Self {
Self {
images: self.images.detach(),
targets: self.targets.detach(),
}
}
}
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
let images = items
.iter()
@ -133,8 +142,8 @@ impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
.map(|item| Tensor::<B, 2>::one_hot(item.label, 10))
.collect();
let images = Tensor::cat(images, 0).to_device(self.device).detach();
let targets = Tensor::cat(targets, 0).to_device(self.device).detach();
let images = Tensor::cat(images, 0).to_device(self.device);
let targets = Tensor::cat(targets, 0).to_device(self.device);
MNISTBatch { images, targets }
}
@ -143,18 +152,11 @@ impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
fn run<B: ad::Backend>(device: B::Device) {
let batch_size = 128;
let learning_rate = 5.5e-2;
let num_epochs = 100;
let num_epochs = 10;
let num_workers = 8;
let num_layers = 4;
let hidden_dim = 1024;
let hidden_dim = 3024;
let seed = 42;
let metrics = || -> Vec<Box<dyn Metric<ClassificationOutput<B>>>> {
vec![
Box::new(LossMetric::new()),
Box::new(AccuracyMetric::new()),
Box::new(CUDAMetric::new()),
]
};
let mut model: Model<B> = Model::new(784, hidden_dim, num_layers, 10);
model.to_device(device);
@ -167,25 +169,34 @@ fn run<B: ad::Backend>(device: B::Device) {
);
let optim: SGDOptimizer<B> = SGDOptimizer::new(learning_rate);
let batcher = Arc::new(MNISTBatcher::<B> { device });
let dataloader_train = DataLoaderBuilder::new(batcher.clone())
let batcher_train = Arc::new(MNISTBatcher::<B> { device });
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend> { device });
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(batch_size)
.shuffle(seed)
.num_workers(num_workers)
.build(Arc::new(MNISTDataset::train()));
let dataloader_test = DataLoaderBuilder::new(batcher.clone())
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(batch_size)
.num_workers(num_workers)
.build(Arc::new(MNISTDataset::test()));
let learner = ClassificationLearner::new(model);
let learner = ClassificationLearner::new(model, optim);
let logger_train = Box::new(AsyncLogger::new(Box::new(CLILogger::new(
metrics(),
vec![
Box::new(TextPlot::new(LossMetric::new())),
Box::new(AccuracyMetric::new()),
Box::new(CUDAMetric::new()),
],
"Train".to_string(),
))));
let logger_valid = Box::new(AsyncLogger::new(Box::new(CLILogger::new(
metrics(),
vec![
Box::new(TextPlot::new(LossMetric::new())),
Box::new(AccuracyMetric::new()),
Box::new(CUDAMetric::new()),
],
"Valid".to_string(),
))));
@ -195,7 +206,6 @@ fn run<B: ad::Backend>(device: B::Device) {
logger_train,
logger_valid,
learner,
optim,
);
trainer.run(num_epochs);

View File

@ -11,3 +11,7 @@ pub use builder::*;
pub use dataloader::*;
pub use multithread::*;
pub use strategy::*;
pub trait Detach {
fn detach(self) -> Self;
}

View File

@ -72,8 +72,11 @@ where
pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
type Backend: back::Backend;
fn update_params<O: Optimizer<Self::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
where
fn update_params<O: Optimizer<Backend = Self::Backend>>(
&mut self,
grads: &Gradients,
optim: &mut O,
) where
Self::Backend: back::ad::Backend;
fn devices(&self) -> Vec<<Self::Backend as back::Backend>::Device>;
fn to_device(&mut self, device: <Self::Backend as back::Backend>::Device);
@ -93,6 +96,13 @@ pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
fn num_params(&self) -> usize;
}
pub trait ADModule: Module + Send + Sync + std::fmt::Debug + std::fmt::Display {
type ADBackend: back::ad::Backend;
type InnerModule: Module<Backend = <Self::ADBackend as back::ad::Backend>::InnerBackend>;
fn inner(&self) -> Self::InnerModule;
}
pub trait Forward<In, Out> {
fn forward(&self, input: In) -> Out;
}

View File

@ -1,4 +1,4 @@
use crate::module::{Module, State};
use crate::module::{ADModule, Module, State};
use crate::optim::Optimizer;
use crate::tensor::{back, Gradients, Tensor};
use serde::de::DeserializeOwned;
@ -28,7 +28,7 @@ impl<const D: usize, B: back::Backend> Param<Tensor<B, D>> {
self.value.shape().num_elements()
}
pub fn update_params<O: Optimizer<B>>(&mut self, grads: &Gradients, optim: &mut O)
pub fn update_params<O: Optimizer<Backend = B>>(&mut self, grads: &Gradients, optim: &mut O)
where
B: back::ad::Backend,
{
@ -60,6 +60,13 @@ impl<const D: usize, B: back::Backend> Param<Tensor<B, D>> {
let data = state.get(name);
self.value = Tensor::from_data_device(data, self.value.device());
}
pub fn inner(&self) -> Param<Tensor<B::InnerBackend, D>>
where
B: back::ad::Backend,
{
Param::new(self.value.inner())
}
}
impl<const D: usize, B: back::Backend> Param<Option<Tensor<B, D>>> {
@ -71,7 +78,7 @@ impl<const D: usize, B: back::Backend> Param<Option<Tensor<B, D>>> {
0
}
pub fn update_params<O: Optimizer<B>>(&mut self, grads: &Gradients, optim: &mut O)
pub fn update_params<O: Optimizer<Backend = B>>(&mut self, grads: &Gradients, optim: &mut O)
where
B: back::ad::Backend,
{
@ -118,6 +125,16 @@ impl<const D: usize, B: back::Backend> Param<Option<Tensor<B, D>>> {
self.value = value;
}
pub fn inner(&self) -> Param<Option<Tensor<B::InnerBackend, D>>>
where
B: back::ad::Backend,
{
match &self.value {
Some(tensor) => Param::new(Some(tensor.inner())),
None => Param::new(None),
}
}
}
impl<M: Module> Param<M> {
@ -125,8 +142,11 @@ impl<M: Module> Param<M> {
self.value.num_params()
}
pub fn update_params<O: Optimizer<M::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
where
pub fn update_params<O: Optimizer<Backend = M::Backend>>(
&mut self,
grads: &Gradients,
optim: &mut O,
) where
M::Backend: back::ad::Backend,
{
self.value.update_params(grads, optim);
@ -157,6 +177,14 @@ impl<M: Module> Param<M> {
{
self.value.load_from_parent(name, state);
}
pub fn inner(&self) -> Param<M::InnerModule>
where
M: ADModule,
M::Backend: back::ad::Backend,
{
Param::new(self.value.inner())
}
}
impl<M: Module> Param<Vec<M>> {
@ -169,8 +197,11 @@ impl<M: Module> Param<Vec<M>> {
num_params
}
pub fn update_params<O: Optimizer<M::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
where
pub fn update_params<O: Optimizer<Backend = M::Backend>>(
&mut self,
grads: &Gradients,
optim: &mut O,
) where
M::Backend: back::ad::Backend,
{
for module in self.value.iter_mut() {
@ -209,4 +240,12 @@ impl<M: Module> Param<Vec<M>> {
{
todo!();
}
pub fn inner(&self) -> Param<Vec<M::InnerModule>>
where
M: ADModule,
M::Backend: back::ad::Backend,
{
Param::new(self.value.iter().map(|v| v.inner()).collect())
}
}

View File

@ -1,6 +1,8 @@
use crate::tensor::back::ad::Backend;
use crate::tensor::{Gradients, Tensor};
pub trait Optimizer<B: Backend>: Send + Sync {
fn update<const D: usize>(&mut self, tensor: &mut Tensor<B, D>, grads: &Gradients);
pub trait Optimizer: Send + Sync {
type Backend: Backend;
fn update<const D: usize>(&mut self, tensor: &mut Tensor<Self::Backend, D>, grads: &Gradients);
}

View File

@ -14,7 +14,9 @@ impl<B: back::ad::Backend> SGDOptimizer<B> {
Self { learning_rate }
}
}
impl<B: back::ad::Backend> Optimizer<B> for SGDOptimizer<B> {
impl<B: back::ad::Backend> Optimizer for SGDOptimizer<B> {
type Backend = B;
fn update<const D: usize>(&mut self, tensor: &mut Tensor<B, D>, grads: &Gradients) {
let grad = tensor.grad(&grads).unwrap();
let delta = grad.mul_scalar(&self.learning_rate);

View File

@ -1,12 +1,14 @@
use super::{Learner, Loss};
use crate::module::ADModule;
use crate::optim::Optimizer;
use crate::tensor::back::{ad, Backend};
use crate::train::metric;
use burn_tensor::Tensor;
#[derive(new)]
pub struct BasicLearner<L> {
pub struct BasicLearner<L, O> {
model: L,
optim: O,
}
#[derive(new)]
@ -23,23 +25,28 @@ impl<B: Backend> metric::Metric<BasicOutput<B>> for metric::LossMetric {
}
}
impl<B, T, L, O> Learner<B, T, T, O, BasicOutput<B>, BasicOutput<B>> for BasicLearner<L>
impl<B, B2, T, L, L2, O> Learner<T, T, BasicOutput<B>, BasicOutput<B2>> for BasicLearner<L, O>
where
B: ad::Backend,
L: Loss<B, T>,
O: Optimizer<B>,
B: ad::Backend<InnerBackend = B2>,
B2: Backend,
O: Optimizer<Backend = B>,
L: Loss<Backend = B, Item = T> + ADModule<Backend = B, InnerModule = L2>,
L2: Loss<Backend = B::InnerBackend, Item = T>,
O: Optimizer<Backend = B>,
{
fn train(&mut self, item: T, optim: &mut O) -> BasicOutput<B> {
type Backend = B;
fn train(&mut self, item: T) -> BasicOutput<B> {
let loss = self.model.loss(item);
let grads = loss.backward();
self.model.update_params(&grads, optim);
self.model.update_params(&grads, &mut self.optim);
BasicOutput::new(loss)
}
fn valid(&self, item: T) -> BasicOutput<B> {
let loss = self.model.loss(item);
fn valid(&self, item: T) -> BasicOutput<B2> {
let loss = self.model.inner().loss(item);
BasicOutput::new(loss)
}
}

View File

@ -1,13 +1,14 @@
use super::Learner;
use crate::module::{Forward, Module};
use crate::module::{ADModule, Forward, Module};
use crate::optim::Optimizer;
use crate::tensor::back::{ad, Backend};
use crate::train::metric;
use burn_tensor::Tensor;
#[derive(new)]
pub struct ClassificationLearner<M> {
pub struct ClassificationLearner<M, O> {
model: M,
optim: O,
}
#[derive(new)]
@ -36,23 +37,27 @@ impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMet
}
}
impl<B, I, M, O> Learner<B, I, I, O, ClassificationOutput<B>, ClassificationOutput<B>>
for ClassificationLearner<M>
impl<B, I, IV, M, M2, O>
Learner<I, IV, ClassificationOutput<B>, ClassificationOutput<B::InnerBackend>>
for ClassificationLearner<M, O>
where
B: ad::Backend,
M: Forward<I, ClassificationOutput<B>> + Module<Backend = B>,
O: Optimizer<B>,
M: Forward<I, ClassificationOutput<B>> + ADModule<Backend = B, InnerModule = M2>,
M2: Forward<IV, ClassificationOutput<B::InnerBackend>> + Module<Backend = B::InnerBackend>,
O: Optimizer<Backend = B>,
{
fn train(&mut self, item: I, optim: &mut O) -> ClassificationOutput<B> {
type Backend = B;
fn train(&mut self, item: I) -> ClassificationOutput<B> {
let output = self.model.forward(item);
let grads = output.loss.backward();
self.model.update_params(&grads, optim);
self.model.update_params(&grads, &mut self.optim);
output
}
fn valid(&self, item: I) -> ClassificationOutput<B> {
self.model.forward(item)
fn valid(&self, item: IV) -> ClassificationOutput<B::InnerBackend> {
self.model.inner().forward(item)
}
}

View File

@ -2,11 +2,15 @@ use crate::module::Module;
use crate::tensor::back::Backend;
use burn_tensor::Tensor;
pub trait Loss<B: Backend, T>: Module<Backend = B> {
fn loss(&self, item: T) -> Tensor<B, 1>;
pub trait Loss: Module {
type Item;
fn loss(&self, item: Self::Item) -> Tensor<Self::Backend, 1>;
}
pub trait Learner<B: Backend, T, V, O, TO, VO> {
fn train(&mut self, item: T, optim: &mut O) -> TO;
pub trait Learner<T, V, TO, VO> {
type Backend: Backend;
fn train(&mut self, item: T) -> TO;
fn valid(&self, item: V) -> VO;
}

View File

@ -1,30 +1,28 @@
use super::{Learner, TrainerItem};
use crate::data::dataloader::DataLoader;
use crate::optim::Optimizer;
use crate::tensor::back::ad;
use crate::data::dataloader::Detach;
use crate::tensor::back;
use crate::train::logger::Logger;
use std::sync::Arc;
pub struct SupervisedTrainer<B, T, V, L, O, TO, VO>
pub struct SupervisedTrainer<B, T, V, L, TO, VO>
where
B: ad::Backend,
L: Learner<B, T, V, O, TO, VO>,
O: Optimizer<B>,
B: back::ad::Backend,
L: Learner<T, V, TO, VO, Backend = B>,
{
dataloader_train: Arc<dyn DataLoader<T>>,
dataloader_valid: Arc<dyn DataLoader<V>>,
logger_train: Box<dyn Logger<TrainerItem<TO>>>,
logger_valid: Box<dyn Logger<TrainerItem<VO>>>,
learner: L,
optimizer: O,
_b: B,
}
impl<B, T, V, L, O, TO, VO> SupervisedTrainer<B, T, V, L, O, TO, VO>
impl<B, T, V, L, TO, VO> SupervisedTrainer<B, T, V, L, TO, VO>
where
B: ad::Backend,
L: Learner<B, T, V, O, TO, VO>,
O: Optimizer<B>,
B: back::ad::Backend,
T: Detach,
L: Learner<T, V, TO, VO, Backend = B>,
{
pub fn new(
dataloader_train: Arc<dyn DataLoader<T>>,
@ -32,13 +30,11 @@ where
logger_train: Box<dyn Logger<TrainerItem<TO>>>,
logger_valid: Box<dyn Logger<TrainerItem<VO>>>,
learner: L,
optimizer: O,
) -> Self {
Self {
dataloader_train,
dataloader_valid,
learner,
optimizer,
logger_train,
logger_valid,
_b: B::default(),
@ -52,7 +48,7 @@ where
num_epochs,
&self.dataloader_train,
&mut self.logger_train,
&mut |item| self.learner.train(item, &mut self.optimizer),
&mut |item| self.learner.train(item.detach()),
);
run_step(