Refactor Param wrapping only for Tensor (#259)

This commit is contained in:
Nathaniel Simard 2023-03-31 16:45:10 -04:00 committed by GitHub
parent 7364d09d32
commit 32d38bebc3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
44 changed files with 881 additions and 911 deletions

View File

@ -11,9 +11,6 @@ use burn_tensor::Tensor;
/// This will make your module trainable, savable and loadable via
/// [state](Module::state) and [load](Module::load).
///
/// Module concrete types should define their parameters via the [Param](crate::module::Param)
/// struct.
///
/// # Example
///
/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
@ -26,39 +23,34 @@ use burn_tensor::Tensor;
///
/// use burn::{
/// nn,
/// module::{Param, Module},
/// module::Module,
/// tensor::Tensor,
/// tensor::backend::Backend,
/// };
///
/// #[derive(Module, Debug)]
/// struct MyModule<B: Backend> {
/// my_param: Param<nn::Linear<B>>,
/// my_param: nn::Linear<B>,
/// my_other_field: usize,
/// }
/// ```
pub trait Module: Clone + Send + Sync + core::fmt::Debug + core::fmt::Display {
type Backend: Backend;
pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
/// Get the device list of the module and all of its sub-modules.
fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>;
fn devices(&self) -> Vec<B::Device>;
/// Move the module and all of its sub-modules to the given device.
fn to_device(self, device: &<Self::Backend as Backend>::Device) -> Self;
fn to_device(self, device: &B::Device) -> Self;
/// Load the module state.
fn load(
self,
state: &State<<Self::Backend as Backend>::FloatElem>,
) -> Result<Self, LoadingError>;
fn load(self, state: &State<B::FloatElem>) -> Result<Self, LoadingError>;
/// Get the module state.
fn state(&self) -> State<<Self::Backend as Backend>::FloatElem>;
fn state(&self) -> State<B::FloatElem>;
/// Detach the module from the graph.
fn detach(self) -> Self;
/// Get the number of parameters the module has, including all of its sub-modules.
fn num_params(&self) -> usize;
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V);
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
/// Map each tensor in the module with a [mapper](ModuleMapper).
fn map<M: ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self;
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
}
pub trait ModuleVisitor<B: Backend> {
@ -70,11 +62,8 @@ pub trait ModuleMapper<B: Backend> {
}
/// Module with auto-differentiation backend.
pub trait ADModule:
Module<Backend = Self::ADBackend> + Send + Sync + core::fmt::Debug + core::fmt::Display
{
type ADBackend: ADBackend;
type InnerModule: Module<Backend = <Self::ADBackend as ADBackend>::InnerBackend>;
pub trait ADModule<B: ADBackend>: Module<B> + Send + Sync + core::fmt::Debug {
type InnerModule: Module<B::InnerBackend>;
/// Get the same module, but on the inner backend without auto-differentiation.
fn inner(self) -> Self::InnerModule;

View File

@ -0,0 +1,86 @@
use crate as burn;
#[macro_export]
macro_rules! constant {
(module) => {
fn devices(&self) -> alloc::vec::Vec<<B as burn_tensor::backend::Backend>::Device> {
alloc::vec::Vec::new()
}
fn to_device(self, _device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self
}
fn load(
self,
_state: &burn::module::State<<B as burn_tensor::backend::Backend>::FloatElem>,
) -> Result<Self, burn::module::LoadingError> {
Ok(self)
}
fn state(&self) -> burn::module::State<<B as burn_tensor::backend::Backend>::FloatElem> {
burn::module::State::StateNamed(burn::module::StateNamed::new())
}
fn detach(self) -> Self {
self
}
fn num_params(&self) -> usize {
0
}
fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
// Nothing to do
}
fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
self
}
};
(ad_module, $type:ty) => {
type InnerModule = $type;
fn inner(self) -> Self::InnerModule {
self
}
fn from_inner(module: Self::InnerModule) -> Self {
module
}
};
($type:ty) => {
impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
constant!(module);
}
impl<B: burn::tensor::backend::ADBackend> burn::module::ADModule<B> for $type {
constant!(ad_module, $type);
}
};
}
// General Types
constant!(alloc::string::String);
constant!(bool);
// Float Types
constant!(f64);
constant!(f32);
constant!(half::bf16);
constant!(half::f16);
// Unsigned Integer Types
constant!(usize);
constant!(u64);
constant!(u32);
constant!(u16);
constant!(u8);
// Signed Integer Types
constant!(i64);
constant!(i32);
constant!(i16);
constant!(i8);

View File

@ -1,13 +1,13 @@
mod base;
mod constant;
mod id;
mod module;
mod primitive;
mod running;
mod tensor;
mod visitor;
pub use base::*;
pub use id::*;
pub use module::*;
pub use running::*;
pub use tensor::*;
pub use visitor::*;

View File

@ -1,201 +0,0 @@
use alloc::{format, vec::Vec};
use super::{load_with_id, state_with_id, Param, ParamId};
use crate::module::{
ADModule, LoadingError, Module, ModuleMapper, ModuleVisitor, State, StateNamed,
};
use crate::tensor::backend::Backend;
impl<M: Module> From<M> for Param<M> {
fn from(value: M) -> Self {
Param {
id: ParamId::new(),
value,
}
}
}
impl<M: Module> From<Vec<M>> for Param<Vec<M>> {
fn from(value: Vec<M>) -> Self {
Param {
id: ParamId::new(),
value,
}
}
}
impl<M: Module> Module for Param<M> {
type Backend = M::Backend;
fn num_params(&self) -> usize {
self.value.num_params()
}
fn devices(&self) -> Vec<<M::Backend as Backend>::Device> {
self.value.devices()
}
fn to_device(self, device: &<Self::Backend as Backend>::Device) -> Self {
Param {
id: self.id,
value: self.value.to_device(device),
}
}
fn state(&self) -> State<<M::Backend as Backend>::FloatElem> {
let state = self.value.state();
state_with_id(self.id.clone(), state)
}
fn load(self, state: &State<<M::Backend as Backend>::FloatElem>) -> Result<Self, LoadingError> {
let (id, state) = load_with_id(state)?;
Ok(Self {
id: id.clone(),
value: self.value.load(state)?,
})
}
fn detach(self) -> Self {
Param {
id: self.id,
value: self.value.detach(),
}
}
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
self.value.visit(visitor);
}
fn map<V: ModuleMapper<Self::Backend>>(self, mapper: &mut V) -> Self {
Self {
id: self.id,
value: self.value.map(mapper),
}
}
}
impl<M: Module> Module for Param<Vec<M>> {
type Backend = M::Backend;
fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.value.iter() {
num_params += module.num_params();
}
num_params
}
fn devices(&self) -> Vec<<M::Backend as Backend>::Device> {
let mut devices = Vec::new();
for module in self.value.iter() {
devices.append(&mut module.devices());
}
devices
}
fn to_device(self, device: &<M::Backend as Backend>::Device) -> Self {
Param {
id: self.id,
value: self
.value
.into_iter()
.map(|val| val.to_device(device))
.collect(),
}
}
fn state(&self) -> State<<M::Backend as Backend>::FloatElem> {
let mut state = StateNamed::new();
for (i, module) in self.value.iter().enumerate() {
state.register_state(format!("mod-{i}").as_str(), module.state());
}
let state = State::StateNamed(state);
state_with_id(self.id.clone(), state)
}
fn load(self, state: &State<<M::Backend as Backend>::FloatElem>) -> Result<Self, LoadingError> {
let (id, state) = load_with_id(state)?;
let id = id.clone();
let num = self.value.len();
let mut modules = Vec::with_capacity(num);
for (i, module) in self.value.into_iter().enumerate() {
let module = module
.load(state.get(format!("mod-{i}").as_str()).ok_or_else(|| {
LoadingError::new(format!(
"Invalid number of modules, expected {num} modules missing #{i}"
))
})?)
.map_err(|err| LoadingError::new(format!("Can't load modules mod-{i}: {err}")))?;
modules.push(module);
}
Ok(Self { id, value: modules })
}
fn detach(self) -> Self {
Param {
id: self.id,
value: self.value.into_iter().map(|val| val.detach()).collect(),
}
}
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
for module in self.value.iter() {
module.visit(visitor);
}
}
fn map<V: ModuleMapper<Self::Backend>>(self, mapper: &mut V) -> Self {
Self {
id: self.id,
value: self.value.into_iter().map(|val| val.map(mapper)).collect(),
}
}
}
impl<M: ADModule> ADModule for Param<Vec<M>> {
type ADBackend = M::ADBackend;
type InnerModule = Param<Vec<M::InnerModule>>;
fn inner(self) -> Self::InnerModule {
Param::from(
self.value
.into_iter()
.map(|v| v.inner())
.collect::<Vec<_>>(),
)
}
fn from_inner(module: Self::InnerModule) -> Self {
Param {
id: module.id,
value: module.value.into_iter().map(ADModule::from_inner).collect(),
}
}
}
impl<M: ADModule> ADModule for Param<M> {
type ADBackend = M::ADBackend;
type InnerModule = Param<M::InnerModule>;
fn inner(self) -> Self::InnerModule {
Param::from(self.value.inner())
}
fn from_inner(module: Self::InnerModule) -> Self {
Param {
id: module.id,
value: ADModule::from_inner(module.value),
}
}
}

View File

@ -0,0 +1,248 @@
use crate::module::{
ADModule, LoadingError, Module, ModuleMapper, ModuleVisitor, State, StateNamed,
};
use alloc::format;
use alloc::vec::Vec;
use burn_tensor::backend::{ADBackend, Backend};
use core::fmt::Debug;
impl<T, B> Module<B> for Option<T>
where
T: Module<B> + Debug + Send + Sync + Clone,
B: Backend,
{
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
if let Some(module) = self {
return Module::<B>::devices(module);
}
Vec::new()
}
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self.map(|module| module.to_device(device))
}
fn load(self, state: &State<B::FloatElem>) -> Result<Self, LoadingError> {
self.map(|module| module.load(state).map(|val| Some(val)))
.unwrap_or(Ok(None))
}
fn state(&self) -> State<B::FloatElem> {
if let Some(module) = self {
return module.state();
}
State::StateNamed(StateNamed::new())
}
fn detach(self) -> Self {
self.map(|module| module.detach())
}
fn num_params(&self) -> usize {
match &self {
Some(module) => module.num_params(),
None => 0,
}
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
if let Some(module) = self {
module.visit(visitor)
}
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
self.map(|module| module.map(mapper))
}
}
impl<T, B> ADModule<B> for Option<T>
where
T: ADModule<B> + Debug + Send + Sync + Clone,
B: ADBackend,
{
type InnerModule = Option<T::InnerModule>;
fn inner(self) -> Self::InnerModule {
self.map(|module| module.inner())
}
fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
}
}
impl<T, B> Module<B> for Vec<T>
where
T: Module<B> + Debug + Send + Sync + Clone,
B: Backend,
{
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
let mut devices = Vec::new();
for module in self.iter() {
devices.append(&mut module.devices());
}
devices
}
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self.into_iter().map(|val| val.to_device(device)).collect()
}
fn load(self, state: &State<B::FloatElem>) -> Result<Self, LoadingError> {
let num = self.len();
let mut modules = Vec::with_capacity(num);
for (i, module) in self.into_iter().enumerate() {
let module = module
.load(state.get(format!("mod-{i}").as_str()).ok_or_else(|| {
LoadingError::new(format!(
"Invalid number of modules, expected {num} modules missing #{i}"
))
})?)
.map_err(|err| LoadingError::new(format!("Can't load modules mod-{i}: {err}")))?;
modules.push(module);
}
Ok(modules)
}
fn state(&self) -> State<B::FloatElem> {
let mut state = StateNamed::new();
for (i, module) in self.iter().enumerate() {
state.register_state(format!("mod-{i}").as_str(), module.state());
}
State::StateNamed(state)
}
fn detach(self) -> Self {
self.into_iter().map(|module| module.detach()).collect()
}
fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
num_params += module.num_params();
}
num_params
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
self.iter().for_each(|module| {
module.visit(visitor);
});
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
self.into_iter().map(|module| module.map(mapper)).collect()
}
}
impl<T, B> ADModule<B> for Vec<T>
where
T: ADModule<B> + Debug + Send + Sync + Clone,
B: ADBackend,
{
type InnerModule = Vec<T::InnerModule>;
fn inner(self) -> Self::InnerModule {
self.into_iter().map(|module| module.inner()).collect()
}
fn from_inner(module: Self::InnerModule) -> Self {
module
.into_iter()
.map(|module| T::from_inner(module))
.collect()
}
}
impl<const N: usize, T, B> Module<B> for [T; N]
where
T: Module<B> + Debug + Send + Sync + Clone + Copy,
B: Backend,
{
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
let mut devices = Vec::new();
for module in self.iter() {
devices.append(&mut module.devices());
}
devices
}
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
self.map(|val| val.to_device(device))
}
fn load(mut self, state: &State<B::FloatElem>) -> Result<Self, LoadingError> {
let num = self.len();
for (i, module) in self.into_iter().enumerate().take(N) {
self[i] = module
.load(state.get(format!("mod-{i}").as_str()).ok_or_else(|| {
LoadingError::new(format!(
"Invalid number of modules, expected {num} modules missing #{i}"
))
})?)
.map_err(|err| LoadingError::new(format!("Can't load modules mod-{i}: {err}")))?;
}
Ok(self)
}
fn state(&self) -> State<B::FloatElem> {
let mut state = StateNamed::new();
for (i, module) in self.iter().enumerate() {
state.register_state(format!("mod-{i}").as_str(), module.state());
}
State::StateNamed(state)
}
fn detach(self) -> Self {
self.map(|module| module.detach())
}
fn num_params(&self) -> usize {
let mut num_params = 0;
for module in self.iter() {
num_params += module.num_params();
}
num_params
}
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
self.iter().for_each(|module| {
module.visit(visitor);
});
}
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
self.map(|module| module.map(mapper))
}
}
impl<const N: usize, T, B> ADModule<B> for [T; N]
where
T: ADModule<B> + Debug + Send + Sync + Clone + Copy,
T::InnerModule: Copy,
B: ADBackend,
{
type InnerModule = [T::InnerModule; N];
fn inner(self) -> Self::InnerModule {
self.map(|module| module.inner())
}
fn from_inner(module: Self::InnerModule) -> Self {
module.map(|module| T::from_inner(module))
}
}

View File

@ -55,9 +55,7 @@ impl<B: Backend, const D: usize> From<RunningState<Tensor<B, D>>>
}
}
impl<const D: usize, B: Backend> Module for Param<RunningState<Tensor<B, D>>> {
type Backend = B;
impl<const D: usize, B: Backend> Module<B> for Param<RunningState<Tensor<B, D>>> {
fn num_params(&self) -> usize {
let tensor = self.value.value.read().unwrap();
tensor.shape().num_elements()
@ -113,13 +111,13 @@ impl<const D: usize, B: Backend> Module for Param<RunningState<Tensor<B, D>>> {
self
}
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
let tensor = self.value.value.read().unwrap();
visitor.visit(&self.id, &tensor)
}
fn map<M: ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self {
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let mut tensor = self.value.value.write().unwrap();
let tensor_out = mapper.map(&self.id, tensor.clone());
@ -208,9 +206,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
}
}
impl<const D: usize, B: ADBackend> ADModule for Param<RunningState<Tensor<B, D>>> {
type ADBackend = B;
impl<const D: usize, B: ADBackend> ADModule<B> for Param<RunningState<Tensor<B, D>>> {
type InnerModule = Param<RunningState<Tensor<B::InnerBackend, D>>>;
fn inner(self) -> Self::InnerModule {

View File

@ -1,9 +1,7 @@
use alloc::{string::ToString, vec, vec::Vec};
use super::{load_with_id, state_with_id, Param, ParamId};
use crate::module::{
ADModule, LoadingError, Module, ModuleMapper, ModuleVisitor, State, StateNamed,
};
use crate::module::{ADModule, LoadingError, Module, ModuleMapper, ModuleVisitor, State};
use crate::tensor::{
backend::{ADBackend, Backend},
Data, Tensor,
@ -18,18 +16,7 @@ impl<B: Backend, const D: usize> From<Tensor<B, D>> for Param<Tensor<B, D>> {
}
}
impl<B: Backend, const D: usize> From<Option<Tensor<B, D>>> for Param<Option<Tensor<B, D>>> {
fn from(value: Option<Tensor<B, D>>) -> Self {
Param {
id: ParamId::new(),
value: value.map(|tensor| tensor.require_grad()),
}
}
}
impl<const D: usize, B: Backend> Module for Param<Tensor<B, D>> {
type Backend = B;
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
fn num_params(&self) -> usize {
self.value.shape().num_elements()
}
@ -72,99 +59,17 @@ impl<const D: usize, B: Backend> Module for Param<Tensor<B, D>> {
}
}
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
visitor.visit(&self.id, &self.value)
}
fn map<M: ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self {
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
let value = mapper.map(&self.id, self.value).require_grad();
Self { id: self.id, value }
}
}
impl<const D: usize, B: Backend> Module for Param<Option<Tensor<B, D>>> {
type Backend = B;
fn num_params(&self) -> usize {
if let Some(value) = &self.value {
return value.shape().num_elements();
}
0
}
fn devices(&self) -> Vec<B::Device> {
if let Some(value) = &self.value {
return vec![value.device()];
}
vec![]
}
fn to_device(self, device: &B::Device) -> Self {
Self {
id: self.id,
value: self
.value
.map(|value| value.to_device(device).require_grad()),
}
}
fn state(&self) -> State<B::FloatElem> {
let state = match &self.value {
Some(value) => State::Data(value.to_data().serialize()),
None => State::StateNamed(StateNamed::new()),
};
state_with_id(self.id.clone(), state)
}
fn load(self, state: &State<B::FloatElem>) -> Result<Self, LoadingError> {
let (id, state) = load_with_id(state)?;
let id = id.clone();
let tensor = if let Some(tensor) = self.value {
let data = match state {
State::Data(data) => data,
_ => {
return Err(LoadingError::new(
"Can't load Option<Tensor> from NamedState".to_string(),
))
}
};
Some(Tensor::from_data_device(Data::from(data), &tensor.device()).require_grad())
} else {
None
};
Ok(Self { id, value: tensor })
}
fn detach(self) -> Self {
Self {
id: self.id,
value: self.value.map(|value| value.detach().require_grad()),
}
}
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
if let Some(value) = &self.value {
visitor.visit(&self.id, value)
}
}
fn map<M: ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self {
let value = self
.value
.map(|value| mapper.map(&self.id, value).require_grad());
Self { id: self.id, value }
}
}
impl<const D: usize, B: ADBackend> ADModule for Param<Tensor<B, D>> {
type ADBackend = B;
impl<const D: usize, B: ADBackend> ADModule<B> for Param<Tensor<B, D>> {
type InnerModule = Param<Tensor<B::InnerBackend, D>>;
fn inner(self) -> Self::InnerModule {
@ -181,25 +86,3 @@ impl<const D: usize, B: ADBackend> ADModule for Param<Tensor<B, D>> {
}
}
}
impl<const D: usize, B: ADBackend> ADModule for Param<Option<Tensor<B, D>>> {
type ADBackend = B;
type InnerModule = Param<Option<Tensor<B::InnerBackend, D>>>;
fn inner(self) -> Self::InnerModule {
Param {
id: self.id,
value: self.value.map(|val| val.inner()),
}
}
fn from_inner(module: Self::InnerModule) -> Self {
Param {
id: module.id,
value: module
.value
.map(|val| Tensor::from_inner(val).require_grad()),
}
}
}

View File

@ -1,24 +1,31 @@
use alloc::vec::Vec;
use super::ParamId;
use crate::module::{Module, ModuleVisitor};
use alloc::vec::Vec;
use burn_tensor::{backend::Backend, Tensor};
use core::marker::PhantomData;
#[derive(new)]
struct ParamIdCollector<'a> {
struct ParamIdCollector<'a, M> {
param_ids: &'a mut Vec<ParamId>,
phantom: PhantomData<M>,
}
impl<'a, B: Backend> ModuleVisitor<B> for ParamIdCollector<'a> {
impl<'a, B, M> ModuleVisitor<B> for ParamIdCollector<'a, M>
where
B: Backend,
M: Module<B>,
{
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
self.param_ids.push(id.clone());
}
}
/// List all the parameter ids in a module.
pub fn list_param_ids<M: Module>(module: &M) -> Vec<ParamId> {
pub fn list_param_ids<M: Module<B>, B: Backend>(module: &M) -> Vec<ParamId> {
let mut params_ids = Vec::new();
let mut visitor = ParamIdCollector::new(&mut params_ids);
let mut visitor = ParamIdCollector {
param_ids: &mut params_ids,
phantom: PhantomData::<M>::default(),
};
module.visit(&mut visitor);
params_ids

View File

@ -5,7 +5,7 @@ use crate as burn;
use crate::nn::cache::TensorCache;
use crate::{
config::Config,
module::{Module, Param},
module::Module,
nn,
tensor::{activation, backend::Backend, Bool, Tensor},
};
@ -39,10 +39,10 @@ pub struct MultiHeadAttentionConfig {
/// - output: [Linear](nn::Linear) layer with `d_model` input and output features.
#[derive(Module, Debug)]
pub struct MultiHeadAttention<B: Backend> {
query: Param<nn::Linear<B>>,
key: Param<nn::Linear<B>>,
value: Param<nn::Linear<B>>,
output: Param<nn::Linear<B>>,
query: nn::Linear<B>,
key: nn::Linear<B>,
value: nn::Linear<B>,
output: nn::Linear<B>,
dropout: nn::Dropout,
activation: nn::GELU,
n_heads: usize,
@ -63,9 +63,7 @@ pub struct MhaInput<B: Backend> {
impl MultiHeadAttentionConfig {
/// Initialize a new [multihead attention](MultiHeadAttention) module.
pub fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
let linear = |config: &Self| {
Param::from(nn::LinearConfig::new(config.d_model, config.d_model).init())
};
let linear = |config: &Self| nn::LinearConfig::new(config.d_model, config.d_model).init();
MultiHeadAttention {
query: linear(self),
@ -172,7 +170,7 @@ impl<B: Backend> MultiHeadAttention<B> {
let attention_linear = |cache: &mut TensorCache<B, 4>,
tensor: Tensor<B, 3>,
param: &Param<nn::Linear<B>>| {
param: &nn::Linear<B>| {
cache.forward_autoregressive(tensor, 2, |tensor| self.attention_linear(tensor, param))
};
@ -235,7 +233,7 @@ impl<B: Backend> MultiHeadAttention<B> {
activation::softmax(attn_scores, 3)
}
fn attention_linear(&self, x: Tensor<B, 3>, linear: &Param<nn::Linear<B>>) -> Tensor<B, 4> {
fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
let [batch_size, seq_length, _d_model] = x.dims();
linear
.forward(x)

View File

@ -4,6 +4,7 @@ use alloc::vec::Vec;
use crate as burn;
use crate::config::Config;
use crate::constant;
use crate::module::Module;
use crate::module::Param;
use crate::nn::Initializer;
@ -43,6 +44,8 @@ pub enum Conv1dPaddingConfig {
Explicit(usize),
}
constant!(Conv1dPaddingConfig);
/// Applies a 1D convolution over input tensors.
///
/// # Params
@ -55,7 +58,7 @@ pub enum Conv1dPaddingConfig {
#[derive(Module, Debug)]
pub struct Conv1d<B: Backend> {
weight: Param<Tensor<B, 3>>,
bias: Param<Option<Tensor<B, 1>>>,
bias: Option<Param<Tensor<B, 1>>>,
stride: usize,
kernel_size: usize,
padding: Option<Conv1dPaddingConfig>,
@ -76,14 +79,14 @@ impl Conv1dConfig {
let weight = initializer.init([self.channels_out, self.channels_in, self.kernel_size]);
let bias = if self.bias {
Some(initializer.init([self.channels_out]))
Some(Param::from(initializer.init([self.channels_out])))
} else {
None
};
Conv1d {
weight: Param::from(weight),
bias: Param::from(bias),
bias,
stride: 1, // TODO: Add the stride to the config when properly supported.
kernel_size: self.kernel_size,
padding: self.padding.clone(),
@ -115,7 +118,7 @@ impl<B: Backend> Conv1d<B> {
conv1d(
input,
self.weight.val(),
self.bias.val(),
self.bias.as_ref().map(|bias| bias.val()),
self.stride,
padding,
)

View File

@ -3,6 +3,7 @@ use alloc::{format, vec::Vec};
use crate as burn;
use crate::config::Config;
use crate::constant;
use crate::module::Module;
use crate::module::Param;
use crate::nn::Initializer;
@ -43,6 +44,8 @@ pub enum Conv2dPaddingConfig {
Explicit(usize, usize),
}
constant!(Conv2dPaddingConfig);
/// Applies a 2D convolution over input tensors.
///
/// # Params
@ -55,7 +58,7 @@ pub enum Conv2dPaddingConfig {
#[derive(Module, Debug)]
pub struct Conv2d<B: Backend> {
weight: Param<Tensor<B, 4>>,
bias: Param<Option<Tensor<B, 1>>>,
bias: Option<Param<Tensor<B, 1>>>,
stride: [usize; 2],
kernel_size: [usize; 2],
padding: Conv2dPaddingConfig,
@ -88,7 +91,7 @@ impl Conv2dConfig {
Conv2d {
weight: Param::from(weight),
bias: Param::from(bias),
bias: bias.map(Param::from),
stride: [1, 1], // TODO: Add the stride to the config when properly supported.
kernel_size: self.kernel_size,
padding: self.padding.clone(),
@ -111,7 +114,7 @@ impl<B: Backend> Conv2d<B> {
conv2d(
input,
self.weight.val(),
self.bias.val(),
self.bias.as_ref().map(|bias| bias.val()),
self.stride,
padding,
)

View File

@ -1,5 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::constant;
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor};
@ -21,6 +23,8 @@ pub struct Dropout {
prob: f64,
}
constant!(Dropout);
impl DropoutConfig {
/// Initialize a new [dropout](Dropout) module.
pub fn init(&self) -> Dropout {

View File

@ -1,3 +1,6 @@
use crate as burn;
use crate::constant;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -5,6 +8,8 @@ use crate::tensor::Tensor;
#[derive(Clone, Debug, Default)]
pub struct GELU {}
constant!(GELU);
impl GELU {
/// Create the module.
pub fn new() -> Self {

View File

@ -40,7 +40,7 @@ pub struct LinearConfig {
#[derive(Module, Debug)]
pub struct Linear<B: Backend> {
weight: Param<Tensor<B, 2>>,
bias: Param<Option<Tensor<B, 1>>>,
bias: Option<Param<Tensor<B, 1>>>,
}
impl LinearConfig {
@ -64,7 +64,7 @@ impl LinearConfig {
Linear {
weight: Param::from(weight),
bias: Param::from(bias),
bias: bias.map(Param::from),
}
}
}
@ -79,8 +79,8 @@ impl<B: Backend> Linear<B> {
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
let output = input.matmul(self.weight.val().unsqueeze());
match self.bias.val() {
Some(bias) => output + bias.unsqueeze(),
match &self.bias {
Some(bias) => output + bias.val().unsqueeze(),
None => output,
}
}

View File

@ -1,4 +1,4 @@
use crate as burn;
use crate::{self as burn, constant};
use crate::config::Config;
use crate::nn::conv::Conv2dPaddingConfig;
@ -32,6 +32,8 @@ pub struct MaxPool2d {
padding: MaxPool2dPaddingConfig,
}
constant!(MaxPool2d);
impl MaxPool2dConfig {
/// Initialize a new [max pool 2d](MaxPool2d) module.
pub fn init(&self) -> MaxPool2d {

View File

@ -1,3 +1,6 @@
use crate as burn;
use crate::constant;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -7,6 +10,8 @@ use crate::tensor::Tensor;
#[derive(Clone, Debug, Default)]
pub struct ReLU {}
constant!(ReLU);
impl ReLU {
/// Create the module.
pub fn new() -> Self {

View File

@ -9,7 +9,7 @@ use crate::{
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
use crate::{
config::Config,
module::{Module, Param},
module::Module,
nn::{
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
@ -43,7 +43,7 @@ pub struct TransformerEncoderConfig {
/// - layers: transformer encoder layers with `d_model` input and output features.
#[derive(Module, Debug)]
pub struct TransformerEncoder<B: Backend> {
layers: Param<Vec<TransformerEncoderLayer<B>>>,
layers: Vec<TransformerEncoderLayer<B>>,
}
/// [Transformer Encoder](TransformerEncoder) forward pass input argument.
@ -83,9 +83,7 @@ impl TransformerEncoderConfig {
.map(|_| TransformerEncoderLayer::new(self))
.collect::<Vec<_>>();
TransformerEncoder {
layers: Param::from(layers),
}
TransformerEncoder { layers }
}
}
@ -141,10 +139,10 @@ impl<B: Backend> TransformerEncoder<B> {
#[derive(Module, Debug)]
struct TransformerEncoderLayer<B: Backend> {
mha: Param<MultiHeadAttention<B>>,
pwff: Param<PositionWiseFeedForward<B>>,
norm_1: Param<LayerNorm<B>>,
norm_2: Param<LayerNorm<B>>,
mha: MultiHeadAttention<B>,
pwff: PositionWiseFeedForward<B>,
norm_1: LayerNorm<B>,
norm_2: LayerNorm<B>,
dropout: Dropout,
norm_first: bool,
}
@ -162,10 +160,10 @@ impl<B: Backend> TransformerEncoderLayer<B> {
.init();
Self {
mha: Param::from(mha),
norm_1: Param::from(norm_1),
norm_2: Param::from(norm_2),
pwff: Param::from(pwff),
mha,
norm_1,
norm_2,
pwff,
dropout,
norm_first: config.norm_first,
}

View File

@ -4,7 +4,7 @@ use crate as burn;
use crate::{
config::Config,
module::{Module, Param},
module::Module,
nn::{Dropout, DropoutConfig, Linear, LinearConfig, GELU},
tensor::{backend::Backend, Tensor},
};
@ -29,8 +29,8 @@ pub struct PositionWiseFeedForwardConfig {
/// - linear outer: Linear layer with `d_ff` input features and `d_model` output features.
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Param<Linear<B>>,
linear_outer: Param<Linear<B>>,
linear_inner: Linear<B>,
linear_outer: Linear<B>,
dropout: Dropout,
gelu: GELU,
}
@ -39,8 +39,8 @@ impl PositionWiseFeedForwardConfig {
/// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module.
pub fn init<B: Backend>(&self) -> PositionWiseFeedForward<B> {
PositionWiseFeedForward {
linear_inner: Param::from(LinearConfig::new(self.d_model, self.d_ff).init()),
linear_outer: Param::from(LinearConfig::new(self.d_ff, self.d_model).init()),
linear_inner: LinearConfig::new(self.d_model, self.d_ff).init(),
linear_outer: LinearConfig::new(self.d_ff, self.d_model).init(),
dropout: DropoutConfig::new(self.dropout).init(),
gelu: GELU::new(),
}

View File

@ -1,4 +1,4 @@
use crate as burn;
use crate::{self as burn, module::ADModule};
use super::{
decay::{WeightDecay, WeightDecayConfig},
@ -54,9 +54,7 @@ impl<B: ADBackend> Adam<B> {
}
}
impl<B: ADBackend> Optimizer for Adam<B> {
type Backend = B;
impl<M: ADModule<B>, B: ADBackend> Optimizer<M, B> for Adam<B> {
fn update_tensor<const D: usize>(
&mut self,
id: &ParamId,

View File

@ -2,25 +2,26 @@ use super::mapper::ModuleTensorUpdater;
use super::visitor::{GradientsLoader, GradientsRegister};
use super::GradientsParams;
use crate::module::{ADModule, LoadingError, Module, ParamId, State, StateNamed};
use crate::tensor::backend::{ADBackend, Backend};
use crate::module::{ADModule, LoadingError, ParamId, State, StateNamed};
use crate::tensor::backend::ADBackend;
use crate::tensor::{Data, Tensor};
pub trait Optimizer: Send + Sync {
type Backend: ADBackend;
pub trait Optimizer<M, B>: Send + Sync
where
M: ADModule<B>,
B: ADBackend,
{
/// Update the tensor parameter using the given the gradients.
fn update_tensor<const D: usize>(
&mut self,
id: &ParamId,
tensor: Tensor<Self::Backend, D>,
grad: Tensor<<Self::Backend as ADBackend>::InnerBackend, D>,
) -> Tensor<Self::Backend, D>;
tensor: Tensor<B, D>,
grad: Tensor<B::InnerBackend, D>,
) -> Tensor<B, D>;
/// Update the parameters of the given module using the given the gradients.
fn update_module<M>(&mut self, module: M, grads: GradientsParams) -> M
fn update_module(&mut self, module: M, grads: GradientsParams) -> M
where
M: ADModule<ADBackend = Self::Backend>,
Self: Sized,
{
let mut mapper = ModuleTensorUpdater::new(self, grads);
@ -35,7 +36,7 @@ pub trait Optimizer: Send + Sync {
fn register_param_state<const D: usize>(
&self,
_id: &ParamId,
_state: &mut StateNamed<<Self::Backend as Backend>::FloatElem>,
_state: &mut StateNamed<B::FloatElem>,
) {
// By default there is no state to register
}
@ -48,17 +49,14 @@ pub trait Optimizer: Send + Sync {
fn load_param_state<const D: usize>(
&mut self,
_id: &ParamId,
_state: &StateNamed<<Self::Backend as Backend>::FloatElem>,
_device: &<Self::Backend as Backend>::Device,
_state: &StateNamed<B::FloatElem>,
_device: &B::Device,
) {
// By default there is no state to load
}
/// Get the optimizer state for a given module.
fn state<M: Module<Backend = Self::Backend>>(
&self,
module: &M,
) -> State<<Self::Backend as Backend>::FloatElem>
fn state(&self, module: &M) -> State<B::FloatElem>
where
Self: Sized,
{
@ -70,11 +68,7 @@ pub trait Optimizer: Send + Sync {
}
/// Load the optimizer state for a given module.
fn load<M: Module<Backend = Self::Backend>>(
&mut self,
module: &M,
state: &State<<Self::Backend as Backend>::FloatElem>,
) -> Result<(), LoadingError>
fn load(&mut self, module: &M, state: &State<B::FloatElem>) -> Result<(), LoadingError>
where
Self: Sized,
{

View File

@ -1,36 +1,40 @@
use crate::module::{Module, ModuleVisitor, ParamId};
use core::marker::PhantomData;
use crate::module::{ADModule, ModuleVisitor, ParamId};
use burn_tensor::{backend::ADBackend, Tensor};
use super::GradientsParams;
/// Accumulate gradients into a single [Gradients](ADBackend::Gradients) object.
pub struct GradientsAccumulator {
pub struct GradientsAccumulator<M> {
grads: GradientsParams,
phantom: PhantomData<M>,
}
impl Default for GradientsAccumulator {
impl<M> Default for GradientsAccumulator<M> {
fn default() -> Self {
Self::new()
}
}
impl GradientsAccumulator {
impl<M> GradientsAccumulator<M> {
/// Create a new gradients accumulator.
pub fn new() -> Self {
Self {
grads: GradientsParams::new(),
phantom: PhantomData::default(),
}
}
}
impl GradientsAccumulator {
impl<M> GradientsAccumulator<M> {
/// Accumulate the given gradients for each parameter in the given module.
pub fn accumulate<B: ADBackend, M>(&mut self, module: &M, grads: GradientsParams)
pub fn accumulate<B: ADBackend>(&mut self, module: &M, grads: GradientsParams)
where
M: Module<Backend = B>,
M: ADModule<B>,
{
let mut visitor = ModuleGradsAccumulator::new(&mut self.grads, grads);
let mut visitor = ModuleGradsAccumulator::<M>::new(&mut self.grads, grads);
module.visit(&mut visitor);
}
@ -44,12 +48,13 @@ impl GradientsAccumulator {
}
#[derive(new)]
struct ModuleGradsAccumulator<'a> {
struct ModuleGradsAccumulator<'a, M> {
grads: &'a mut GradientsParams,
grads_new: GradientsParams,
phantom: PhantomData<M>,
}
impl<'a, B: ADBackend> ModuleVisitor<B> for ModuleGradsAccumulator<'a> {
impl<'a, B: ADBackend, M: ADModule<B>> ModuleVisitor<B> for ModuleGradsAccumulator<'a, M> {
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(id) {
Some(new) => match self.grads.remove::<B::InnerBackend, D>(id) {

View File

@ -63,23 +63,20 @@ impl GradientsParams {
}
/// Change the device of each tensor gradients registered for the given [module](ADModule).
pub fn to_device<M: ADModule>(
pub fn to_device<B: ADBackend, M: ADModule<B>>(
mut self,
device: &<M::Backend as Backend>::Device,
device: &B::Device,
module: &M,
) -> Self {
let mut visitor = GradientsParamsChangeDevice::new(device, &mut self);
let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
module.visit(&mut visitor);
self
}
/// Extract each tensor gradients for the given [module](ADModule).
pub fn from_grads<M: ADModule>(
grads: <M::ADBackend as ADBackend>::Gradients,
module: &M,
) -> Self {
pub fn from_grads<B: ADBackend, M: ADModule<B>>(grads: B::Gradients, module: &M) -> Self {
let mut grads_params = GradientsParams::new();
let mut visitor = GradientsParamsConverter::new(grads, &mut grads_params);
let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params);
module.visit(&mut visitor);
grads_params

View File

@ -1,16 +1,24 @@
use core::marker::PhantomData;
use burn_tensor::{backend::ADBackend, Tensor};
use crate::module::{ModuleMapper, ParamId};
use crate::module::{ADModule, ModuleMapper, ParamId};
use super::{GradientsParams, Optimizer};
#[derive(new)]
pub struct ModuleTensorUpdater<'a, O> {
pub struct ModuleTensorUpdater<'a, M, O> {
optimizer: &'a mut O,
grads: GradientsParams,
phatom: PhantomData<M>,
}
impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleMapper<B> for ModuleTensorUpdater<'a, O> {
impl<'a, M, B, O> ModuleMapper<B> for ModuleTensorUpdater<'a, M, O>
where
M: ADModule<B>,
B: ADBackend,
O: Optimizer<M, B>,
{
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
if let Some(grad) = self.grads.remove::<B::InnerBackend, D>(id) {
self.optimizer.update_tensor(id, tensor, grad)

View File

@ -3,7 +3,7 @@ use crate as burn;
use super::decay::{WeightDecay, WeightDecayConfig};
use super::momentum::{Momentum, MomentumConfig};
use crate::config::Config;
use crate::module::{ParamId, StateNamed};
use crate::module::{ADModule, ParamId, StateNamed};
use crate::optim::Optimizer;
use crate::tensor::backend::ADBackend;
use crate::tensor::{ElementConversion, Tensor};
@ -45,9 +45,7 @@ impl<B: ADBackend> Sgd<B> {
}
}
impl<B: ADBackend> Optimizer for Sgd<B> {
type Backend = B;
impl<M: ADModule<B>, B: ADBackend> Optimizer<M, B> for Sgd<B> {
fn update_tensor<const D: usize>(
&mut self,
id: &ParamId,

View File

@ -1,45 +1,65 @@
use core::marker::PhantomData;
use super::{GradientsParams, Optimizer};
use crate::module::{ModuleVisitor, ParamId, StateNamed};
use crate::module::{ADModule, ModuleVisitor, ParamId, StateNamed};
use burn_tensor::{backend::ADBackend, Tensor};
#[derive(new)]
pub struct GradientsRegister<'a, B: ADBackend, O> {
pub struct GradientsRegister<'a, M: ADModule<B>, B: ADBackend, O> {
optimizer: &'a O,
state: &'a mut StateNamed<B::FloatElem>,
phatom: PhantomData<M>,
}
#[derive(new)]
pub struct GradientsLoader<'a, B: ADBackend, O> {
pub struct GradientsLoader<'a, M: ADModule<B>, B: ADBackend, O> {
optimizer: &'a mut O,
state: &'a StateNamed<B::FloatElem>,
phatom: PhantomData<M>,
}
#[derive(new)]
pub struct GradientsParamsConverter<'a, B: ADBackend> {
pub struct GradientsParamsConverter<'a, M: ADModule<B>, B: ADBackend> {
grads: B::Gradients,
grads_params: &'a mut GradientsParams,
phatom: PhantomData<M>,
}
#[derive(new)]
pub struct GradientsParamsChangeDevice<'a, B: ADBackend> {
pub struct GradientsParamsChangeDevice<'a, M: ADModule<B>, B: ADBackend> {
device: &'a B::Device,
grads: &'a mut GradientsParams,
phatom: PhantomData<M>,
}
impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleVisitor<B> for GradientsRegister<'a, B, O> {
impl<'a, B, M, O> ModuleVisitor<B> for GradientsRegister<'a, M, B, O>
where
B: ADBackend,
M: ADModule<B>,
O: Optimizer<M, B>,
{
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
self.optimizer.register_param_state::<D>(id, self.state)
}
}
impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleVisitor<B> for GradientsLoader<'a, B, O> {
impl<'a, B, M, O> ModuleVisitor<B> for GradientsLoader<'a, M, B, O>
where
B: ADBackend,
M: ADModule<B>,
O: Optimizer<M, B>,
{
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
self.optimizer
.load_param_state::<D>(id, self.state, &tensor.device())
}
}
impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsConverter<'a, B> {
impl<'a, B, M> ModuleVisitor<B> for GradientsParamsConverter<'a, M, B>
where
B: ADBackend,
M: ADModule<B>,
{
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
if let Some(grad) = tensor.grad_remove(&mut self.grads) {
self.grads_params
@ -48,7 +68,11 @@ impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsConverter<'a, B> {
}
}
impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsChangeDevice<'a, B> {
impl<'a, B, M> ModuleVisitor<B> for GradientsParamsChangeDevice<'a, M, B>
where
B: ADBackend,
M: ADModule<B>,
{
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
if let Some(grad) = self.grads.remove::<B::InnerBackend, D>(id) {
self.grads

View File

@ -28,7 +28,7 @@ where
B: Backend,
{
weight: Param<Tensor<B, 2>>,
basic: Param<ModuleBasic<B>>,
basic: ModuleBasic<B>,
}
impl<B: Backend> ModuleComposed<B> {
@ -36,7 +36,7 @@ impl<B: Backend> ModuleComposed<B> {
let weight = Tensor::random(Shape::new([20, 20]), Distribution::Standard);
Self {
weight: Param::from(weight),
basic: Param::from(ModuleBasic::new()),
basic: ModuleBasic::new(),
}
}
}

View File

@ -1,4 +1,4 @@
use super::param::Param;
use super::fn_generator::FnGenerator;
use crate::module::display;
use proc_macro::TokenStream;
use quote::quote;
@ -9,24 +9,22 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let display_fn = display::display_fn(name);
let param = Param::from_ast(ast);
let num_params_fn = param.gen_num_params_fn();
let visit = param.gen_visit_fn();
let map_mut = param.gen_map_fn();
let devices_fn = param.gen_devices_fn();
let to_device_fn = param.gen_to_device_fn();
let state_fn = param.gen_state_fn();
let load_fn = param.gen_load_fn();
let inner_fn = param.gen_inner_fn();
let from_inner_fn = param.gen_from_inner_fn();
let detach_fn = param.gen_detach_fn();
let clone_fn = param.gen_clone_fn();
let generator = FnGenerator::from_ast(ast);
let num_params_fn = generator.gen_num_params_fn();
let visit = generator.gen_visit_fn();
let map_mut = generator.gen_map_fn();
let devices_fn = generator.gen_devices_fn();
let to_device_fn = generator.gen_to_device_fn();
let state_fn = generator.gen_state_fn();
let load_fn = generator.gen_load_fn();
let inner_fn = generator.gen_inner_fn();
let from_inner_fn = generator.gen_from_inner_fn();
let detach_fn = generator.gen_detach_fn();
let clone_fn = generator.gen_clone_fn();
let generics_names_except_backend = generics_names_except_backend(&ast.generics);
let gen = quote! {
impl #generics burn::module::Module for #name #generics_ty #generics_where {
type Backend=B;
impl #generics burn::module::Module<B> for #name #generics_ty #generics_where {
#devices_fn
#to_device_fn
@ -40,8 +38,7 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
#map_mut
}
impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
type ADBackend=B;
impl #generics burn::module::ADModule<B> for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
type InnerModule=#name<B::InnerBackend, #generics_names_except_backend>;
#inner_fn

View File

@ -0,0 +1,246 @@
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
pub struct FnGenerator {
fields: Vec<FieldTypeAnalyzer>,
}
impl FnGenerator {
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
Self {
fields: parse_fields(ast)
.into_iter()
.map(FieldTypeAnalyzer::new)
.collect(),
}
}
pub fn gen_num_params_fn(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
num_params += burn::module::Module::<B>::num_params(&self.#name);
}
});
quote! {
fn num_params(&self) -> usize {
let mut num_params = 0;
#body
num_params
}
}
}
pub fn gen_visit_fn(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
burn::module::Module::visit(&self.#name, visitor);
}
});
quote! {
fn visit<V: burn::module::ModuleVisitor<B>>(&self, visitor: &mut V) {
#body
}
}
}
pub fn gen_map_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::map(self.#name, mapper);
}
});
quote! {
fn map<M: burn::module::ModuleMapper<B>>(self, mapper: &mut M) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_devices_fn(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
devices.append(&mut burn::module::Module::<B>::devices(&self.#name));
}
});
quote! {
fn devices(&self) -> Vec<B::Device> {
let mut devices = Vec::new();
#body
devices
}
}
}
pub fn gen_to_device_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::<B>::to_device(self.#name, device);
}
});
quote! {
fn to_device(self, device: &B::Device) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_detach_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::Module::<B>::detach(self.#name);
}
});
quote! {
fn detach(self) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_inner_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::ADModule::<B>::inner(self.#name);
}
});
quote! {
fn inner(self) -> Self::InnerModule {
#body
Self::InnerModule {
#(#names),*
}
}
}
}
pub fn gen_from_inner_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = burn::module::ADModule::<B>::from_inner(module.#name);
}
});
quote! {
fn from_inner(module: Self::InnerModule) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_clone_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let #name = self.#name.clone();
}
});
quote! {
fn clone(&self) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_state_fn(&self) -> TokenStream {
let body = self.gen_fields_fn(|name| {
quote! {
state.register_state(stringify!(#name), burn::module::Module::<B>::state(&self.#name));
}
});
quote! {
fn state(&self) -> burn::module::State<B::FloatElem>
{
let mut state = burn::module::StateNamed::new();
#body
burn::module::State::StateNamed(state)
}
}
}
pub fn gen_load_fn(&self) -> TokenStream {
let (names, body) = self.gen_fields_fn_names(|name| {
quote! {
let state_mod = state.get(stringify!(#name)).ok_or(
burn::module::LoadingError::new(format!(
"Missing module '{}' from state",
stringify!(#name),
)))?;
let #name = burn::module::Module::<B>::load(self.#name, state_mod).map_err(|err| {
burn::module::LoadingError::new(format!("Can't load module {}: {}", stringify!(#name), err))
})?;
}
});
quote! {
fn load(self, state: &burn::module::State<B::FloatElem>) -> Result<Self, burn::module::LoadingError>
{
#body
Ok(Self {
#(#names),*
})
}
}
}
pub fn gen_fields_fn_names<F>(&self, func: F) -> (Vec<Ident>, TokenStream)
where
F: Fn(Ident) -> TokenStream,
{
let mut body = quote! {};
let mut names = Vec::new();
for field in self.fields.iter() {
let name = field.ident();
names.push(name.clone());
body.extend(func(field.ident()));
}
(names, body)
}
pub fn gen_fields_fn<F>(&self, func: F) -> TokenStream
where
F: Fn(Ident) -> TokenStream,
{
let mut body = quote! {};
for field in self.fields.iter() {
body.extend(func(field.ident()));
}
body
}
}

View File

@ -1,5 +1,5 @@
pub(crate) mod display;
pub(crate) mod param;
pub(crate) mod fn_generator;
mod base;

View File

@ -1,319 +0,0 @@
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
use proc_macro2::{Ident, TokenStream};
use quote::quote;
pub struct Param {
fields_param: Vec<FieldTypeAnalyzer>,
fields_other: Vec<FieldTypeAnalyzer>,
}
impl Param {
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
let fields_param = parse_fields(ast)
.into_iter()
.map(FieldTypeAnalyzer::new)
.filter(FieldTypeAnalyzer::is_param)
.collect();
let fields_other = parse_fields(ast)
.into_iter()
.map(FieldTypeAnalyzer::new)
.filter(|val| !FieldTypeAnalyzer::is_param(val))
.collect();
Self {
fields_param,
fields_other,
}
}
pub fn gen_num_params_fn(&self) -> TokenStream {
let mut body = quote! {
let mut num_params = 0;
};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
num_params += self.#name.num_params();
});
}
body.extend(quote! {
num_params
});
quote! {
fn num_params(&self) -> usize {
#body
}
}
}
pub fn gen_visit_fn(&self) -> TokenStream {
let mut body = quote! {};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
self.#name.visit(visitor);
});
}
quote! {
fn visit<V: burn::module::ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
#body
}
}
}
pub fn gen_map_fn(&self) -> TokenStream {
let (names, body) = self.gen_params_others_fn(
|name| {
quote! {
let #name = self.#name.map(mapper);
}
},
|name| {
quote! {
let #name = self.#name;
}
},
);
quote! {
fn map<M: burn::module::ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_devices_fn(&self) -> TokenStream {
let mut body = quote! {
let mut devices = Vec::new();
};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
devices.append(&mut self.#name.devices());
});
}
body.extend(quote! {
devices
});
quote! {
fn devices(&self) -> Vec<B::Device> {
#body
}
}
}
pub fn gen_to_device_fn(&self) -> TokenStream {
let (names, body) = self.gen_params_others_fn(
|name| {
quote! {
let #name = self.#name.to_device(device);
}
},
|name| {
quote! {
let #name = self.#name.clone();
}
},
);
quote! {
fn to_device(self, device: &B::Device) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_detach_fn(&self) -> TokenStream {
let (names, body) = self.gen_params_others_fn(
|name| {
quote! {
let #name = self.#name.detach();
}
},
|name| {
quote! {
let #name = self.#name.clone();
}
},
);
quote! {
fn detach(self) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_inner_fn(&self) -> TokenStream {
let (names, body) = self.gen_params_others_fn(
|name| {
quote! {
let #name = self.#name.inner();
}
},
|name| {
quote! {
let #name = self.#name.clone();
}
},
);
quote! {
fn inner(self) -> Self::InnerModule {
#body
Self::InnerModule {
#(#names),*
}
}
}
}
pub fn gen_from_inner_fn(&self) -> TokenStream {
let (names, body) = self.gen_params_others_fn(
|name| {
quote! {
let #name = burn::module::ADModule::from_inner(module.#name);
}
},
|name| {
quote! {
let #name = module.#name.clone();
}
},
);
quote! {
fn from_inner(module: Self::InnerModule) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_clone_fn(&self) -> TokenStream {
let mut body = quote! {};
let mut names = Vec::new();
let mut fields = Vec::new();
fields.append(&mut self.fields_param.clone());
fields.append(&mut self.fields_other.clone());
for field in fields {
let name = field.ident();
names.push(name.clone());
body.extend(quote! {
let #name = self.#name.clone();
});
}
quote! {
fn clone(&self) -> Self {
#body
Self {
#(#names),*
}
}
}
}
pub fn gen_state_fn(&self) -> TokenStream {
let mut body = quote! {
let mut state = burn::module::StateNamed::new();
};
for field in self.fields_param.iter() {
let name = field.ident();
body.extend(quote! {
state.register_state(stringify!(#name), self.#name.state());
});
}
quote! {
fn state(&self) -> burn::module::State<<Self::Backend as burn::tensor::backend::Backend>::FloatElem>
{
#body
burn::module::State::StateNamed(state)
}
}
}
pub fn gen_load_fn(&self) -> TokenStream {
let (names, body) = self.gen_params_others_fn(|name| {
quote! {
let state_mod = state.get(stringify!(#name)).ok_or(
burn::module::LoadingError::new(format!(
"Missing module '{}' from state",
stringify!(#name),
)))?;
let #name = self.#name.load(state_mod).map_err(|err| {
burn::module::LoadingError::new(format!("Can't load module {}: {}", stringify!(#name), err))
})?;
}
}, |name| {
quote! {
let #name = self.#name.clone();
}
});
quote! {
fn load(self, state: &burn::module::State<<Self::Backend as burn::tensor::backend::Backend>::FloatElem>) -> Result<Self, burn::module::LoadingError>
{
#body
Ok(Self {
#(#names),*
})
}
}
}
pub fn gen_params_others_fn<FP, FO>(
&self,
func_params: FP,
func_others: FO,
) -> (Vec<Ident>, TokenStream)
where
FP: Fn(Ident) -> TokenStream,
FO: Fn(Ident) -> TokenStream,
{
let mut body = quote! {};
let mut names = Vec::new();
for field in self.fields_param.iter() {
let name = field.ident();
names.push(name.clone());
body.extend(func_params(field.ident()));
}
for field in self.fields_other.iter() {
let name = field.ident();
names.push(name.clone());
body.extend(func_others(field.ident()));
}
(names, body)
}
}

View File

@ -73,10 +73,6 @@ impl FieldTypeAnalyzer {
.into_iter()
.map(AttributeAnalyzer::new)
}
pub fn is_param(&self) -> bool {
self.is_of_type(&["Param", "burn::Param"])
}
}
pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {

View File

@ -4,14 +4,14 @@ use alloc::{format, vec::Vec};
use burn::{
config::Config,
module::{Module, Param},
module::Module,
nn,
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Param<nn::conv::Conv2d<B>>,
conv: nn::conv::Conv2d<B>,
pool: nn::pool::MaxPool2d,
activation: nn::GELU,
}
@ -34,7 +34,7 @@ impl<B: Backend> ConvBlock<B> {
let activation = nn::GELU::new();
Self {
conv: Param::from(conv),
conv,
pool,
activation,
}

View File

@ -4,7 +4,7 @@ use alloc::{format, vec::Vec};
use burn::{
config::Config,
module::{Module, Param},
module::Module,
nn,
tensor::{backend::Backend, Tensor},
};
@ -26,7 +26,7 @@ pub struct MlpConfig {
/// Multilayer Perceptron module.
#[derive(Module, Debug)]
pub struct Mlp<B: Backend> {
linears: Param<Vec<nn::Linear<B>>>,
linears: Vec<nn::Linear<B>>,
dropout: nn::Dropout,
activation: nn::ReLU,
}
@ -41,7 +41,7 @@ impl<B: Backend> Mlp<B> {
}
Self {
linears: Param::from(linears),
linears,
dropout: nn::DropoutConfig::new(0.3).init(),
activation: nn::ReLU::new(),
}

View File

@ -9,7 +9,7 @@ use crate::{
use burn::{
config::Config,
module::{Module, Param},
module::Module,
nn,
tensor::{backend::Backend, Tensor},
};
@ -30,26 +30,25 @@ pub struct MnistConfig {
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
mlp: Param<Mlp<B>>,
conv: Param<ConvBlock<B>>,
input: Param<nn::Linear<B>>,
output: Param<nn::Linear<B>>,
mlp: Mlp<B>,
conv: ConvBlock<B>,
input: nn::Linear<B>,
output: nn::Linear<B>,
num_classes: usize,
}
impl<B: Backend> Model<B> {
pub fn new(config: &MnistConfig) -> Self {
let mlp = Mlp::new(&config.mlp);
let input = nn::LinearConfig::new(config.input_size, config.mlp.d_model).init();
let output = nn::LinearConfig::new(config.mlp.d_model, config.output_size).init();
let conv = ConvBlock::new(&ConvBlockConfig::new([1, 1]));
Self {
mlp: Param::from(mlp),
conv: Param::from(conv),
output: Param::from(output),
input: Param::from(input),
mlp,
conv,
output,
input,
num_classes: config.output_size,
}
}

View File

@ -1,44 +1,45 @@
use crate::checkpoint::Checkpointer;
use crate::LearnerCallback;
use burn_core::module::{ADModule, Module};
use burn_core::module::ADModule;
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::Backend;
use burn_core::tensor::backend::{ADBackend, Backend};
/// Learner struct encapsulating all components necessary to train a Neural Network model.
///
/// To create a learner, use the [builder](crate::train::LearnerBuilder) struct.
pub struct Learner<M, O, TO, VO>
pub struct Learner<B, M, O, TO, VO>
where
M: ADModule,
B: ADBackend,
M: ADModule<B>,
O: Optimizer<M, B>,
{
pub(super) model: M,
pub(super) optim: O,
pub(super) num_epochs: usize,
pub(super) callback: Box<dyn LearnerCallback<TO, VO>>,
pub(super) checkpoint: Option<usize>,
pub(super) checkpointer_model: CheckpointModel<M>,
pub(super) checkpointer_optimizer: CheckpointOptim<M>,
pub(super) checkpointer_model: CheckpointModel<B>,
pub(super) checkpointer_optimizer: CheckpointOptim<B>,
pub(super) grad_accumulation: Option<usize>,
pub(super) devices: Vec<<M::Backend as Backend>::Device>,
pub(super) devices: Vec<B::Device>,
}
type CheckpointModel<M> =
Option<Box<dyn Checkpointer<<<M as Module>::Backend as Backend>::FloatElem>>>;
type CheckpointOptim<M> =
Option<Box<dyn Checkpointer<<<M as Module>::Backend as Backend>::FloatElem>>>;
type CheckpointModel<B> = Option<Box<dyn Checkpointer<<B as Backend>::FloatElem>>>;
type CheckpointOptim<B> = Option<Box<dyn Checkpointer<<B as Backend>::FloatElem>>>;
impl<M, O, TO, VO> Learner<M, O, TO, VO>
impl<B, M, O, TO, VO> Learner<B, M, O, TO, VO>
where
VO: Send + Sync + 'static,
TO: Send + Sync + 'static,
M: ADModule,
O: Optimizer<Backend = M::Backend>,
B: ADBackend,
M: ADModule<B>,
O: Optimizer<M, B>,
{
pub(super) fn checkpoint(
model: &M,
optim: &O,
checkpointer_model: &CheckpointModel<M>,
checkpointer_optimizer: &CheckpointOptim<M>,
checkpointer_model: &CheckpointModel<B>,
checkpointer_optimizer: &CheckpointOptim<B>,
epoch: usize,
) {
if let Some(checkpointer) = &checkpointer_model {

View File

@ -161,10 +161,10 @@ where
}
/// Create the [learner](Learner) from a [module](ADModule) and an
pub fn build<M, O>(self, model: M, optim: O) -> Learner<M, O, T, V>
pub fn build<M, O>(self, model: M, optim: O) -> Learner<B, M, O, T, V>
where
M: ADModule<ADBackend = B>,
O: Optimizer<Backend = B>,
M: ADModule<B>,
O: Optimizer<M, B>,
{
self.init_logger();
let callack = Box::new(self.dashboard);

View File

@ -2,7 +2,7 @@ use burn_core::{
data::dataloader::DataLoader,
module::ADModule,
optim::{GradientsAccumulator, Optimizer},
tensor::backend::Backend,
tensor::backend::ADBackend,
};
use std::sync::Arc;
@ -24,9 +24,10 @@ pub struct TrainEpoch<TI> {
}
impl<I> ValidEpoch<I> {
pub fn run<M, TO, VO>(&self, model: M, callback: &mut Box<dyn LearnerCallback<TO, VO>>) -> M
pub fn run<B, M, TO, VO>(&self, model: M, callback: &mut Box<dyn LearnerCallback<TO, VO>>) -> M
where
M: ADModule,
B: ADBackend,
M: ADModule<B>,
M::InnerModule: ValidStep<I, VO>,
{
log::info!("Executing validation step for epoch {}", self.epoch);
@ -55,15 +56,16 @@ impl<I> ValidEpoch<I> {
}
impl<TI> TrainEpoch<TI> {
pub fn run<M, O, TO, VO>(
pub fn run<B, M, O, TO, VO>(
&self,
mut model: M,
mut optim: O,
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
) -> (M, O)
where
M: ADModule,
O: Optimizer<Backend = M::ADBackend>,
B: ADBackend,
M: ADModule<B>,
O: Optimizer<M, B>,
M: TrainStep<TI, TO>,
{
log::info!("Executing training step for epoch {}", self.epoch,);
@ -108,17 +110,18 @@ impl<TI> TrainEpoch<TI> {
}
impl<TI> TrainEpoch<TI> {
pub fn run_multi_device<M, O, TO, VO>(
pub fn run_multi_device<B, M, O, TO, VO>(
&self,
mut model: M,
mut optim: O,
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
devices: Vec<<M::Backend as Backend>::Device>,
devices: Vec<B::Device>,
) -> (M, O)
where
O: Optimizer<Backend = M::ADBackend>,
B: ADBackend,
M: ADModule<B> + 'static,
O: Optimizer<M, B>,
M: TrainStep<TI, TO>,
M: ADModule + 'static,
TI: Send + 'static,
TO: Send + 'static,
{

View File

@ -20,9 +20,10 @@ struct Worker<B: ADBackend, M, TI> {
device: B::Device,
}
impl<B: ADBackend, M, TI> Worker<B, M, TI>
impl<B, M, TI> Worker<B, M, TI>
where
M: ADModule<ADBackend = B> + Clone,
B: ADBackend,
M: ADModule<B>,
{
fn register(&self, item: TI, model: &M) {
let message = Message {
@ -63,9 +64,9 @@ where
impl<B, M, TI, TO> MultiDevicesTrainStep<B, M, TI, TO>
where
B: ADBackend,
M: ADModule<B> + TrainStep<TI, TO> + Send + Clone + 'static,
TI: Send + 'static,
TO: Send + 'static,
M: ADModule<ADBackend = B> + TrainStep<TI, TO> + Send + Clone + 'static,
{
pub fn new(devices: &[B::Device]) -> Self
where

View File

@ -13,13 +13,8 @@ pub struct TrainOutput<TO> {
}
impl<TO> TrainOutput<TO> {
pub fn new<M: ADModule>(
module: &M,
grads: <M::ADBackend as ADBackend>::Gradients,
item: TO,
) -> Self {
pub fn new<B: ADBackend, M: ADModule<B>>(module: &M, grads: B::Gradients, item: TO) -> Self {
let grads = GradientsParams::from_grads(grads, module);
Self { grads, item }
}
}
@ -32,12 +27,13 @@ pub trait ValidStep<VI, VO> {
fn step(&self, item: VI) -> VO;
}
impl<M, O, TO, VO> Learner<M, O, TO, VO>
impl<B, M, O, TO, VO> Learner<B, M, O, TO, VO>
where
VO: Send + Sync + 'static,
TO: Send + Sync + 'static,
M: ADModule,
O: Optimizer<Backend = M::Backend>,
B: ADBackend,
M: ADModule<B> + core::fmt::Display,
O: Optimizer<M, B>,
{
pub fn fit<TI, VI>(
mut self,

View File

@ -5,18 +5,18 @@
use alloc::{format, vec::Vec};
use burn::{
module::{Module, Param},
module::Module,
nn::{self, conv::Conv2dPaddingConfig, BatchNorm},
tensor::{backend::Backend, Tensor},
};
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
conv1: Param<ConvBlock<B>>,
conv2: Param<ConvBlock<B>>,
conv3: Param<ConvBlock<B>>,
fc1: Param<nn::Linear<B>>,
fc2: Param<nn::Linear<B>>,
conv1: ConvBlock<B>,
conv2: ConvBlock<B>,
conv3: ConvBlock<B>,
fc1: nn::Linear<B>,
fc2: nn::Linear<B>,
activation: nn::GELU,
}
@ -37,11 +37,11 @@ impl<B: Backend> Model<B> {
.init();
Self {
conv1: Param::from(conv1),
conv2: Param::from(conv2),
conv3: Param::from(conv3),
fc1: Param::from(fc1),
fc2: Param::from(fc2),
conv1,
conv2,
conv3,
fc1,
fc2,
activation: nn::GELU::new(),
}
}
@ -66,8 +66,8 @@ impl<B: Backend> Model<B> {
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Param<nn::conv::Conv2d<B>>,
norm: Param<BatchNorm<B, 2>>,
conv: nn::conv::Conv2d<B>,
norm: BatchNorm<B, 2>,
activation: nn::GELU,
}
@ -79,8 +79,8 @@ impl<B: Backend> ConvBlock<B> {
let norm = nn::BatchNormConfig::new(channels[1]).init();
Self {
conv: Param::from(conv),
norm: Param::from(norm),
conv,
norm,
activation: nn::GELU::new(),
}
}

View File

@ -1,7 +1,7 @@
use crate::data::MNISTBatch;
use burn::{
module::{Module, Param},
module::Module,
nn::{self, conv::Conv2dPaddingConfig, loss::CrossEntropyLoss, BatchNorm},
tensor::{
backend::{ADBackend, Backend},
@ -12,12 +12,12 @@ use burn::{
#[derive(Module, Debug)]
pub struct Model<B: Backend> {
conv1: Param<ConvBlock<B>>,
conv2: Param<ConvBlock<B>>,
conv3: Param<ConvBlock<B>>,
conv1: ConvBlock<B>,
conv2: ConvBlock<B>,
conv3: ConvBlock<B>,
dropout: nn::Dropout,
fc1: Param<nn::Linear<B>>,
fc2: Param<nn::Linear<B>>,
fc1: nn::Linear<B>,
fc2: nn::Linear<B>,
activation: nn::GELU,
}
@ -39,11 +39,11 @@ impl<B: Backend> Model<B> {
let dropout = nn::DropoutConfig::new(0.3).init();
Self {
conv1: Param::from(conv1),
conv2: Param::from(conv2),
conv3: Param::from(conv3),
fc1: Param::from(fc1),
fc2: Param::from(fc2),
conv1,
conv2,
conv3,
fc1,
fc2,
dropout,
activation: nn::GELU::new(),
}
@ -83,8 +83,8 @@ impl<B: Backend> Model<B> {
#[derive(Module, Debug)]
pub struct ConvBlock<B: Backend> {
conv: Param<nn::conv::Conv2d<B>>,
norm: Param<BatchNorm<B, 2>>,
conv: nn::conv::Conv2d<B>,
norm: BatchNorm<B, 2>,
activation: nn::GELU,
}
@ -96,8 +96,8 @@ impl<B: Backend> ConvBlock<B> {
let norm = nn::BatchNormConfig::new(channels[1]).init();
Self {
conv: Param::from(conv),
norm: Param::from(norm),
conv,
norm,
activation: nn::GELU::new(),
}
}

View File

@ -1,7 +1,7 @@
use crate::data::{TextClassificationInferenceBatch, TextClassificationTrainingBatch};
use burn::{
config::Config,
module::{Module, Param},
module::Module,
nn::{
loss::CrossEntropyLoss,
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
@ -22,10 +22,10 @@ pub struct TextClassificationModelConfig {
#[derive(Module, Debug)]
pub struct TextClassificationModel<B: Backend> {
transformer: Param<TransformerEncoder<B>>,
embedding_token: Param<Embedding<B>>,
embedding_pos: Param<Embedding<B>>,
output: Param<Linear<B>>,
transformer: TransformerEncoder<B>,
embedding_token: Embedding<B>,
embedding_pos: Embedding<B>,
output: Linear<B>,
n_classes: usize,
max_seq_length: usize,
}
@ -40,10 +40,10 @@ impl TextClassificationModelConfig {
EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init();
TextClassificationModel {
transformer: Param::from(transformer),
embedding_token: Param::from(embedding_token),
embedding_pos: Param::from(embedding_pos),
output: Param::from(output),
transformer,
embedding_token,
embedding_pos,
output,
n_classes: self.n_classes,
max_seq_length: self.max_seq_length,
}

View File

@ -1,7 +1,7 @@
use crate::data::TrainingTextGenerationBatch;
use burn::{
config::Config,
module::{Module, Param},
module::Module,
nn::{
attention::generate_autoregressive_mask,
loss::CrossEntropyLoss,
@ -23,10 +23,10 @@ pub struct TextGenerationModelConfig {
#[derive(Module, Debug)]
pub struct TextGenerationModel<B: Backend> {
transformer: Param<TransformerEncoder<B>>,
embedding_token: Param<Embedding<B>>,
embedding_pos: Param<Embedding<B>>,
output: Param<Linear<B>>,
transformer: TransformerEncoder<B>,
embedding_token: Embedding<B>,
embedding_pos: Embedding<B>,
output: Linear<B>,
vocab_size: usize,
pad_token: usize,
max_seq_length: usize,
@ -42,10 +42,10 @@ impl TextGenerationModelConfig {
EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init();
TextGenerationModel {
transformer: Param::from(transformer),
embedding_token: Param::from(embedding_token),
embedding_pos: Param::from(embedding_pos),
output: Param::from(output),
transformer,
embedding_token,
embedding_pos,
output,
vocab_size: self.vocab_size,
pad_token: self.pad_token,
max_seq_length: self.max_seq_length,