mirror of https://github.com/tracel-ai/burn.git
Feat/inner module (#26)
This commit is contained in:
parent
68e052513b
commit
674e078a85
|
@ -29,6 +29,7 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
let state_fn = param.gen_state_fn();
|
||||
let load_from_parent_fn = param.gen_load_from_parent_fn();
|
||||
let load_fn = param.gen_load_fn();
|
||||
let inner_fn = param.gen_inner_fn();
|
||||
|
||||
let gen = quote! {
|
||||
impl #generics burn::module::Module for #name #generics_ty #generics_where {
|
||||
|
@ -44,11 +45,17 @@ fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
#load_fn
|
||||
}
|
||||
|
||||
impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::back::ad::Backend, {
|
||||
type ADBackend=B;
|
||||
type InnerModule=#name<B::InnerBackend>;
|
||||
|
||||
#inner_fn
|
||||
}
|
||||
|
||||
impl #generics std::fmt::Display for #name #generics_ty #generics_where {
|
||||
#display_fn
|
||||
}
|
||||
};
|
||||
|
||||
gen.into()
|
||||
let tokens = gen.into();
|
||||
tokens
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ impl Param {
|
|||
}
|
||||
|
||||
quote! {
|
||||
fn update_params<O: burn::optim::Optimizer<B>>(&mut self, grads: &burn::tensor::Gradients, optim: &mut O)
|
||||
fn update_params<O: burn::optim::Optimizer<Backend = B>>(&mut self, grads: &burn::tensor::Gradients, optim: &mut O)
|
||||
where
|
||||
B: burn::tensor::back::ad::Backend {
|
||||
#body
|
||||
|
@ -98,6 +98,30 @@ impl Param {
|
|||
.into()
|
||||
}
|
||||
|
||||
pub fn gen_inner_fn(&self) -> TokenStream {
|
||||
let mut body = quote! {};
|
||||
let mut names = Vec::new();
|
||||
for field in self.fields.iter() {
|
||||
let name = field.ident();
|
||||
names.push(name.clone());
|
||||
|
||||
body.extend(quote! {
|
||||
let #name = self.#name.inner();
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn inner(&self) -> Self::InnerModule {
|
||||
#body
|
||||
|
||||
Self::InnerModule {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
.into()
|
||||
}
|
||||
|
||||
pub fn gen_state_fn(&self) -> TokenStream {
|
||||
let mut body = quote! {
|
||||
let mut state = burn::module::State::new(self.name());
|
||||
|
|
|
@ -28,6 +28,7 @@ half = { version = "1.6", features = ["num-traits"] } # needs to be 1.6 to work
|
|||
|
||||
# Backends
|
||||
tch = { version = "0.8", optional = true }
|
||||
lazy_static = "1.4"
|
||||
ndarray = { version = "0.15", optional = true }
|
||||
|
||||
# Autodiff
|
||||
|
|
|
@ -11,13 +11,13 @@ pub fn softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) ->
|
|||
}
|
||||
|
||||
pub fn log_softmax<const D: usize, B: Backend>(tensor: &Tensor<B, D>, dim: usize) -> Tensor<B, D> {
|
||||
let tensor_tmp = match Precision::Half == B::Elem::precision() {
|
||||
true => {
|
||||
let tensor_tmp = match B::Elem::precision() {
|
||||
Precision::Half => {
|
||||
let tensor_full = tensor.to_full_precision();
|
||||
let tensor_tmp = tensor_full.exp().sum_dim(dim).log();
|
||||
Tensor::from_full_precision(tensor_tmp)
|
||||
}
|
||||
false => tensor.exp().sum_dim(dim).log(),
|
||||
_ => tensor.exp().sum_dim(dim).log(),
|
||||
};
|
||||
|
||||
tensor.sub(&tensor_tmp)
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
use crate::tensor::{ops::TensorOpsUtilities, Data, Element, Shape, TensorTrait};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref NO_GRAD: tch::NoGradGuard = {
|
||||
tch::no_grad_guard()
|
||||
};
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct TchTensor<P: tch::kind::Element, const D: usize> {
|
||||
pub kind: TchKind<P>,
|
||||
|
@ -65,6 +71,8 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
|
|||
let shape_tch = TchShape::from(data.shape);
|
||||
let kind = TchKind::new();
|
||||
let tensor = tensor.reshape(&shape_tch.dims).to_kind(kind.kind());
|
||||
|
||||
lazy_static::initialize(&NO_GRAD);
|
||||
let tensor = tensor.set_requires_grad(false);
|
||||
|
||||
Self {
|
||||
|
@ -81,6 +89,8 @@ impl<P: tch::kind::Element + Default + Copy + std::fmt::Debug, const D: usize> T
|
|||
let device = tch::Device::Cpu;
|
||||
let kind = TchKind::new();
|
||||
let tensor = tch::Tensor::empty(&shape_tch.dims, (kind.kind(), device.clone()));
|
||||
|
||||
lazy_static::initialize(&NO_GRAD);
|
||||
let tensor = tensor.set_requires_grad(false);
|
||||
|
||||
Self {
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn::data::dataloader::batcher::Batcher;
|
||||
use burn::data::dataloader::DataLoaderBuilder;
|
||||
use burn::data::dataloader::{DataLoaderBuilder, Detach};
|
||||
use burn::data::dataset::source::huggingface::{MNISTDataset, MNISTItem};
|
||||
use burn::module::{Forward, Module, Param};
|
||||
use burn::nn;
|
||||
|
@ -8,8 +8,8 @@ use burn::tensor::af::relu;
|
|||
use burn::tensor::back::{ad, Backend};
|
||||
use burn::tensor::losses::cross_entropy_with_logits;
|
||||
use burn::tensor::{Data, ElementConversion, Shape, Tensor};
|
||||
use burn::train::logger::{AsyncLogger, CLILogger};
|
||||
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric, Metric};
|
||||
use burn::train::logger::{AsyncLogger, CLILogger, TextPlot};
|
||||
use burn::train::metric::{AccuracyMetric, CUDAMetric, LossMetric};
|
||||
use burn::train::{ClassificationLearner, ClassificationOutput, SupervisedTrainer};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -118,7 +118,16 @@ struct MNISTBatch<B: Backend> {
|
|||
targets: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||
impl<B: ad::Backend> Detach for MNISTBatch<B> {
|
||||
fn detach(self) -> Self {
|
||||
Self {
|
||||
images: self.images.detach(),
|
||||
targets: self.targets.detach(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
||||
fn batch(&self, items: Vec<MNISTItem>) -> MNISTBatch<B> {
|
||||
let images = items
|
||||
.iter()
|
||||
|
@ -133,8 +142,8 @@ impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
|||
.map(|item| Tensor::<B, 2>::one_hot(item.label, 10))
|
||||
.collect();
|
||||
|
||||
let images = Tensor::cat(images, 0).to_device(self.device).detach();
|
||||
let targets = Tensor::cat(targets, 0).to_device(self.device).detach();
|
||||
let images = Tensor::cat(images, 0).to_device(self.device);
|
||||
let targets = Tensor::cat(targets, 0).to_device(self.device);
|
||||
|
||||
MNISTBatch { images, targets }
|
||||
}
|
||||
|
@ -143,18 +152,11 @@ impl<B: ad::Backend> Batcher<MNISTItem, MNISTBatch<B>> for MNISTBatcher<B> {
|
|||
fn run<B: ad::Backend>(device: B::Device) {
|
||||
let batch_size = 128;
|
||||
let learning_rate = 5.5e-2;
|
||||
let num_epochs = 100;
|
||||
let num_epochs = 10;
|
||||
let num_workers = 8;
|
||||
let num_layers = 4;
|
||||
let hidden_dim = 1024;
|
||||
let hidden_dim = 3024;
|
||||
let seed = 42;
|
||||
let metrics = || -> Vec<Box<dyn Metric<ClassificationOutput<B>>>> {
|
||||
vec![
|
||||
Box::new(LossMetric::new()),
|
||||
Box::new(AccuracyMetric::new()),
|
||||
Box::new(CUDAMetric::new()),
|
||||
]
|
||||
};
|
||||
|
||||
let mut model: Model<B> = Model::new(784, hidden_dim, num_layers, 10);
|
||||
model.to_device(device);
|
||||
|
@ -167,25 +169,34 @@ fn run<B: ad::Backend>(device: B::Device) {
|
|||
);
|
||||
|
||||
let optim: SGDOptimizer<B> = SGDOptimizer::new(learning_rate);
|
||||
let batcher = Arc::new(MNISTBatcher::<B> { device });
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher.clone())
|
||||
let batcher_train = Arc::new(MNISTBatcher::<B> { device });
|
||||
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend> { device });
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(batch_size)
|
||||
.shuffle(seed)
|
||||
.num_workers(num_workers)
|
||||
.build(Arc::new(MNISTDataset::train()));
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher.clone())
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||
.batch_size(batch_size)
|
||||
.num_workers(num_workers)
|
||||
.build(Arc::new(MNISTDataset::test()));
|
||||
|
||||
let learner = ClassificationLearner::new(model);
|
||||
let learner = ClassificationLearner::new(model, optim);
|
||||
|
||||
let logger_train = Box::new(AsyncLogger::new(Box::new(CLILogger::new(
|
||||
metrics(),
|
||||
vec![
|
||||
Box::new(TextPlot::new(LossMetric::new())),
|
||||
Box::new(AccuracyMetric::new()),
|
||||
Box::new(CUDAMetric::new()),
|
||||
],
|
||||
"Train".to_string(),
|
||||
))));
|
||||
let logger_valid = Box::new(AsyncLogger::new(Box::new(CLILogger::new(
|
||||
metrics(),
|
||||
vec![
|
||||
Box::new(TextPlot::new(LossMetric::new())),
|
||||
Box::new(AccuracyMetric::new()),
|
||||
Box::new(CUDAMetric::new()),
|
||||
],
|
||||
"Valid".to_string(),
|
||||
))));
|
||||
|
||||
|
@ -195,7 +206,6 @@ fn run<B: ad::Backend>(device: B::Device) {
|
|||
logger_train,
|
||||
logger_valid,
|
||||
learner,
|
||||
optim,
|
||||
);
|
||||
|
||||
trainer.run(num_epochs);
|
||||
|
|
|
@ -11,3 +11,7 @@ pub use builder::*;
|
|||
pub use dataloader::*;
|
||||
pub use multithread::*;
|
||||
pub use strategy::*;
|
||||
|
||||
pub trait Detach {
|
||||
fn detach(self) -> Self;
|
||||
}
|
||||
|
|
|
@ -72,8 +72,11 @@ where
|
|||
pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
|
||||
type Backend: back::Backend;
|
||||
|
||||
fn update_params<O: Optimizer<Self::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
|
||||
where
|
||||
fn update_params<O: Optimizer<Backend = Self::Backend>>(
|
||||
&mut self,
|
||||
grads: &Gradients,
|
||||
optim: &mut O,
|
||||
) where
|
||||
Self::Backend: back::ad::Backend;
|
||||
fn devices(&self) -> Vec<<Self::Backend as back::Backend>::Device>;
|
||||
fn to_device(&mut self, device: <Self::Backend as back::Backend>::Device);
|
||||
|
@ -93,6 +96,13 @@ pub trait Module: Send + Sync + std::fmt::Debug + std::fmt::Display {
|
|||
fn num_params(&self) -> usize;
|
||||
}
|
||||
|
||||
pub trait ADModule: Module + Send + Sync + std::fmt::Debug + std::fmt::Display {
|
||||
type ADBackend: back::ad::Backend;
|
||||
type InnerModule: Module<Backend = <Self::ADBackend as back::ad::Backend>::InnerBackend>;
|
||||
|
||||
fn inner(&self) -> Self::InnerModule;
|
||||
}
|
||||
|
||||
pub trait Forward<In, Out> {
|
||||
fn forward(&self, input: In) -> Out;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::module::{Module, State};
|
||||
use crate::module::{ADModule, Module, State};
|
||||
use crate::optim::Optimizer;
|
||||
use crate::tensor::{back, Gradients, Tensor};
|
||||
use serde::de::DeserializeOwned;
|
||||
|
@ -28,7 +28,7 @@ impl<const D: usize, B: back::Backend> Param<Tensor<B, D>> {
|
|||
self.value.shape().num_elements()
|
||||
}
|
||||
|
||||
pub fn update_params<O: Optimizer<B>>(&mut self, grads: &Gradients, optim: &mut O)
|
||||
pub fn update_params<O: Optimizer<Backend = B>>(&mut self, grads: &Gradients, optim: &mut O)
|
||||
where
|
||||
B: back::ad::Backend,
|
||||
{
|
||||
|
@ -60,6 +60,13 @@ impl<const D: usize, B: back::Backend> Param<Tensor<B, D>> {
|
|||
let data = state.get(name);
|
||||
self.value = Tensor::from_data_device(data, self.value.device());
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> Param<Tensor<B::InnerBackend, D>>
|
||||
where
|
||||
B: back::ad::Backend,
|
||||
{
|
||||
Param::new(self.value.inner())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: back::Backend> Param<Option<Tensor<B, D>>> {
|
||||
|
@ -71,7 +78,7 @@ impl<const D: usize, B: back::Backend> Param<Option<Tensor<B, D>>> {
|
|||
0
|
||||
}
|
||||
|
||||
pub fn update_params<O: Optimizer<B>>(&mut self, grads: &Gradients, optim: &mut O)
|
||||
pub fn update_params<O: Optimizer<Backend = B>>(&mut self, grads: &Gradients, optim: &mut O)
|
||||
where
|
||||
B: back::ad::Backend,
|
||||
{
|
||||
|
@ -118,6 +125,16 @@ impl<const D: usize, B: back::Backend> Param<Option<Tensor<B, D>>> {
|
|||
|
||||
self.value = value;
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> Param<Option<Tensor<B::InnerBackend, D>>>
|
||||
where
|
||||
B: back::ad::Backend,
|
||||
{
|
||||
match &self.value {
|
||||
Some(tensor) => Param::new(Some(tensor.inner())),
|
||||
None => Param::new(None),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Module> Param<M> {
|
||||
|
@ -125,8 +142,11 @@ impl<M: Module> Param<M> {
|
|||
self.value.num_params()
|
||||
}
|
||||
|
||||
pub fn update_params<O: Optimizer<M::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
|
||||
where
|
||||
pub fn update_params<O: Optimizer<Backend = M::Backend>>(
|
||||
&mut self,
|
||||
grads: &Gradients,
|
||||
optim: &mut O,
|
||||
) where
|
||||
M::Backend: back::ad::Backend,
|
||||
{
|
||||
self.value.update_params(grads, optim);
|
||||
|
@ -157,6 +177,14 @@ impl<M: Module> Param<M> {
|
|||
{
|
||||
self.value.load_from_parent(name, state);
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> Param<M::InnerModule>
|
||||
where
|
||||
M: ADModule,
|
||||
M::Backend: back::ad::Backend,
|
||||
{
|
||||
Param::new(self.value.inner())
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Module> Param<Vec<M>> {
|
||||
|
@ -169,8 +197,11 @@ impl<M: Module> Param<Vec<M>> {
|
|||
num_params
|
||||
}
|
||||
|
||||
pub fn update_params<O: Optimizer<M::Backend>>(&mut self, grads: &Gradients, optim: &mut O)
|
||||
where
|
||||
pub fn update_params<O: Optimizer<Backend = M::Backend>>(
|
||||
&mut self,
|
||||
grads: &Gradients,
|
||||
optim: &mut O,
|
||||
) where
|
||||
M::Backend: back::ad::Backend,
|
||||
{
|
||||
for module in self.value.iter_mut() {
|
||||
|
@ -209,4 +240,12 @@ impl<M: Module> Param<Vec<M>> {
|
|||
{
|
||||
todo!();
|
||||
}
|
||||
|
||||
pub fn inner(&self) -> Param<Vec<M::InnerModule>>
|
||||
where
|
||||
M: ADModule,
|
||||
M::Backend: back::ad::Backend,
|
||||
{
|
||||
Param::new(self.value.iter().map(|v| v.inner()).collect())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
use crate::tensor::back::ad::Backend;
|
||||
use crate::tensor::{Gradients, Tensor};
|
||||
|
||||
pub trait Optimizer<B: Backend>: Send + Sync {
|
||||
fn update<const D: usize>(&mut self, tensor: &mut Tensor<B, D>, grads: &Gradients);
|
||||
pub trait Optimizer: Send + Sync {
|
||||
type Backend: Backend;
|
||||
|
||||
fn update<const D: usize>(&mut self, tensor: &mut Tensor<Self::Backend, D>, grads: &Gradients);
|
||||
}
|
||||
|
|
|
@ -14,7 +14,9 @@ impl<B: back::ad::Backend> SGDOptimizer<B> {
|
|||
Self { learning_rate }
|
||||
}
|
||||
}
|
||||
impl<B: back::ad::Backend> Optimizer<B> for SGDOptimizer<B> {
|
||||
impl<B: back::ad::Backend> Optimizer for SGDOptimizer<B> {
|
||||
type Backend = B;
|
||||
|
||||
fn update<const D: usize>(&mut self, tensor: &mut Tensor<B, D>, grads: &Gradients) {
|
||||
let grad = tensor.grad(&grads).unwrap();
|
||||
let delta = grad.mul_scalar(&self.learning_rate);
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
use super::{Learner, Loss};
|
||||
use crate::module::ADModule;
|
||||
use crate::optim::Optimizer;
|
||||
use crate::tensor::back::{ad, Backend};
|
||||
use crate::train::metric;
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
#[derive(new)]
|
||||
pub struct BasicLearner<L> {
|
||||
pub struct BasicLearner<L, O> {
|
||||
model: L,
|
||||
optim: O,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -23,23 +25,28 @@ impl<B: Backend> metric::Metric<BasicOutput<B>> for metric::LossMetric {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B, T, L, O> Learner<B, T, T, O, BasicOutput<B>, BasicOutput<B>> for BasicLearner<L>
|
||||
impl<B, B2, T, L, L2, O> Learner<T, T, BasicOutput<B>, BasicOutput<B2>> for BasicLearner<L, O>
|
||||
where
|
||||
B: ad::Backend,
|
||||
L: Loss<B, T>,
|
||||
O: Optimizer<B>,
|
||||
B: ad::Backend<InnerBackend = B2>,
|
||||
B2: Backend,
|
||||
O: Optimizer<Backend = B>,
|
||||
L: Loss<Backend = B, Item = T> + ADModule<Backend = B, InnerModule = L2>,
|
||||
L2: Loss<Backend = B::InnerBackend, Item = T>,
|
||||
O: Optimizer<Backend = B>,
|
||||
{
|
||||
fn train(&mut self, item: T, optim: &mut O) -> BasicOutput<B> {
|
||||
type Backend = B;
|
||||
|
||||
fn train(&mut self, item: T) -> BasicOutput<B> {
|
||||
let loss = self.model.loss(item);
|
||||
let grads = loss.backward();
|
||||
|
||||
self.model.update_params(&grads, optim);
|
||||
self.model.update_params(&grads, &mut self.optim);
|
||||
|
||||
BasicOutput::new(loss)
|
||||
}
|
||||
|
||||
fn valid(&self, item: T) -> BasicOutput<B> {
|
||||
let loss = self.model.loss(item);
|
||||
fn valid(&self, item: T) -> BasicOutput<B2> {
|
||||
let loss = self.model.inner().loss(item);
|
||||
BasicOutput::new(loss)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
use super::Learner;
|
||||
use crate::module::{Forward, Module};
|
||||
use crate::module::{ADModule, Forward, Module};
|
||||
use crate::optim::Optimizer;
|
||||
use crate::tensor::back::{ad, Backend};
|
||||
use crate::train::metric;
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
#[derive(new)]
|
||||
pub struct ClassificationLearner<M> {
|
||||
pub struct ClassificationLearner<M, O> {
|
||||
model: M,
|
||||
optim: O,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
|
@ -36,23 +37,27 @@ impl<B: Backend> metric::Metric<ClassificationOutput<B>> for metric::AccuracyMet
|
|||
}
|
||||
}
|
||||
|
||||
impl<B, I, M, O> Learner<B, I, I, O, ClassificationOutput<B>, ClassificationOutput<B>>
|
||||
for ClassificationLearner<M>
|
||||
impl<B, I, IV, M, M2, O>
|
||||
Learner<I, IV, ClassificationOutput<B>, ClassificationOutput<B::InnerBackend>>
|
||||
for ClassificationLearner<M, O>
|
||||
where
|
||||
B: ad::Backend,
|
||||
M: Forward<I, ClassificationOutput<B>> + Module<Backend = B>,
|
||||
O: Optimizer<B>,
|
||||
M: Forward<I, ClassificationOutput<B>> + ADModule<Backend = B, InnerModule = M2>,
|
||||
M2: Forward<IV, ClassificationOutput<B::InnerBackend>> + Module<Backend = B::InnerBackend>,
|
||||
O: Optimizer<Backend = B>,
|
||||
{
|
||||
fn train(&mut self, item: I, optim: &mut O) -> ClassificationOutput<B> {
|
||||
type Backend = B;
|
||||
|
||||
fn train(&mut self, item: I) -> ClassificationOutput<B> {
|
||||
let output = self.model.forward(item);
|
||||
let grads = output.loss.backward();
|
||||
|
||||
self.model.update_params(&grads, optim);
|
||||
self.model.update_params(&grads, &mut self.optim);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
fn valid(&self, item: I) -> ClassificationOutput<B> {
|
||||
self.model.forward(item)
|
||||
fn valid(&self, item: IV) -> ClassificationOutput<B::InnerBackend> {
|
||||
self.model.inner().forward(item)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,11 +2,15 @@ use crate::module::Module;
|
|||
use crate::tensor::back::Backend;
|
||||
use burn_tensor::Tensor;
|
||||
|
||||
pub trait Loss<B: Backend, T>: Module<Backend = B> {
|
||||
fn loss(&self, item: T) -> Tensor<B, 1>;
|
||||
pub trait Loss: Module {
|
||||
type Item;
|
||||
|
||||
fn loss(&self, item: Self::Item) -> Tensor<Self::Backend, 1>;
|
||||
}
|
||||
|
||||
pub trait Learner<B: Backend, T, V, O, TO, VO> {
|
||||
fn train(&mut self, item: T, optim: &mut O) -> TO;
|
||||
pub trait Learner<T, V, TO, VO> {
|
||||
type Backend: Backend;
|
||||
|
||||
fn train(&mut self, item: T) -> TO;
|
||||
fn valid(&self, item: V) -> VO;
|
||||
}
|
||||
|
|
|
@ -1,30 +1,28 @@
|
|||
use super::{Learner, TrainerItem};
|
||||
use crate::data::dataloader::DataLoader;
|
||||
use crate::optim::Optimizer;
|
||||
use crate::tensor::back::ad;
|
||||
use crate::data::dataloader::Detach;
|
||||
use crate::tensor::back;
|
||||
use crate::train::logger::Logger;
|
||||
use std::sync::Arc;
|
||||
|
||||
pub struct SupervisedTrainer<B, T, V, L, O, TO, VO>
|
||||
pub struct SupervisedTrainer<B, T, V, L, TO, VO>
|
||||
where
|
||||
B: ad::Backend,
|
||||
L: Learner<B, T, V, O, TO, VO>,
|
||||
O: Optimizer<B>,
|
||||
B: back::ad::Backend,
|
||||
L: Learner<T, V, TO, VO, Backend = B>,
|
||||
{
|
||||
dataloader_train: Arc<dyn DataLoader<T>>,
|
||||
dataloader_valid: Arc<dyn DataLoader<V>>,
|
||||
logger_train: Box<dyn Logger<TrainerItem<TO>>>,
|
||||
logger_valid: Box<dyn Logger<TrainerItem<VO>>>,
|
||||
learner: L,
|
||||
optimizer: O,
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl<B, T, V, L, O, TO, VO> SupervisedTrainer<B, T, V, L, O, TO, VO>
|
||||
impl<B, T, V, L, TO, VO> SupervisedTrainer<B, T, V, L, TO, VO>
|
||||
where
|
||||
B: ad::Backend,
|
||||
L: Learner<B, T, V, O, TO, VO>,
|
||||
O: Optimizer<B>,
|
||||
B: back::ad::Backend,
|
||||
T: Detach,
|
||||
L: Learner<T, V, TO, VO, Backend = B>,
|
||||
{
|
||||
pub fn new(
|
||||
dataloader_train: Arc<dyn DataLoader<T>>,
|
||||
|
@ -32,13 +30,11 @@ where
|
|||
logger_train: Box<dyn Logger<TrainerItem<TO>>>,
|
||||
logger_valid: Box<dyn Logger<TrainerItem<VO>>>,
|
||||
learner: L,
|
||||
optimizer: O,
|
||||
) -> Self {
|
||||
Self {
|
||||
dataloader_train,
|
||||
dataloader_valid,
|
||||
learner,
|
||||
optimizer,
|
||||
logger_train,
|
||||
logger_valid,
|
||||
_b: B::default(),
|
||||
|
@ -52,7 +48,7 @@ where
|
|||
num_epochs,
|
||||
&self.dataloader_train,
|
||||
&mut self.logger_train,
|
||||
&mut |item| self.learner.train(item, &mut self.optimizer),
|
||||
&mut |item| self.learner.train(item.detach()),
|
||||
);
|
||||
|
||||
run_step(
|
||||
|
|
Loading…
Reference in New Issue