mirror of https://github.com/tracel-ai/burn.git
Refactor Param wrapping only for Tensor (#259)
This commit is contained in:
parent
7364d09d32
commit
32d38bebc3
|
@ -11,9 +11,6 @@ use burn_tensor::Tensor;
|
|||
/// This will make your module trainable, savable and loadable via
|
||||
/// [state](Module::state) and [load](Module::load).
|
||||
///
|
||||
/// Module concrete types should define their parameters via the [Param](crate::module::Param)
|
||||
/// struct.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
|
||||
|
@ -26,39 +23,34 @@ use burn_tensor::Tensor;
|
|||
///
|
||||
/// use burn::{
|
||||
/// nn,
|
||||
/// module::{Param, Module},
|
||||
/// module::Module,
|
||||
/// tensor::Tensor,
|
||||
/// tensor::backend::Backend,
|
||||
/// };
|
||||
///
|
||||
/// #[derive(Module, Debug)]
|
||||
/// struct MyModule<B: Backend> {
|
||||
/// my_param: Param<nn::Linear<B>>,
|
||||
/// my_param: nn::Linear<B>,
|
||||
/// my_other_field: usize,
|
||||
/// }
|
||||
/// ```
|
||||
pub trait Module: Clone + Send + Sync + core::fmt::Debug + core::fmt::Display {
|
||||
type Backend: Backend;
|
||||
|
||||
pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
||||
/// Get the device list of the module and all of its sub-modules.
|
||||
fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>;
|
||||
fn devices(&self) -> Vec<B::Device>;
|
||||
/// Move the module and all of its sub-modules to the given device.
|
||||
fn to_device(self, device: &<Self::Backend as Backend>::Device) -> Self;
|
||||
fn to_device(self, device: &B::Device) -> Self;
|
||||
/// Load the module state.
|
||||
fn load(
|
||||
self,
|
||||
state: &State<<Self::Backend as Backend>::FloatElem>,
|
||||
) -> Result<Self, LoadingError>;
|
||||
fn load(self, state: &State<B::FloatElem>) -> Result<Self, LoadingError>;
|
||||
/// Get the module state.
|
||||
fn state(&self) -> State<<Self::Backend as Backend>::FloatElem>;
|
||||
fn state(&self) -> State<B::FloatElem>;
|
||||
/// Detach the module from the graph.
|
||||
fn detach(self) -> Self;
|
||||
/// Get the number of parameters the module has, including all of its sub-modules.
|
||||
fn num_params(&self) -> usize;
|
||||
/// Visit each tensor in the module with a [visitor](ModuleVisitor).
|
||||
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V);
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
|
||||
/// Map each tensor in the module with a [mapper](ModuleMapper).
|
||||
fn map<M: ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self;
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
|
||||
}
|
||||
|
||||
pub trait ModuleVisitor<B: Backend> {
|
||||
|
@ -70,11 +62,8 @@ pub trait ModuleMapper<B: Backend> {
|
|||
}
|
||||
|
||||
/// Module with auto-differentiation backend.
|
||||
pub trait ADModule:
|
||||
Module<Backend = Self::ADBackend> + Send + Sync + core::fmt::Debug + core::fmt::Display
|
||||
{
|
||||
type ADBackend: ADBackend;
|
||||
type InnerModule: Module<Backend = <Self::ADBackend as ADBackend>::InnerBackend>;
|
||||
pub trait ADModule<B: ADBackend>: Module<B> + Send + Sync + core::fmt::Debug {
|
||||
type InnerModule: Module<B::InnerBackend>;
|
||||
|
||||
/// Get the same module, but on the inner backend without auto-differentiation.
|
||||
fn inner(self) -> Self::InnerModule;
|
||||
|
|
|
@ -0,0 +1,86 @@
|
|||
use crate as burn;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! constant {
|
||||
(module) => {
|
||||
fn devices(&self) -> alloc::vec::Vec<<B as burn_tensor::backend::Backend>::Device> {
|
||||
alloc::vec::Vec::new()
|
||||
}
|
||||
|
||||
fn to_device(self, _device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn load(
|
||||
self,
|
||||
_state: &burn::module::State<<B as burn_tensor::backend::Backend>::FloatElem>,
|
||||
) -> Result<Self, burn::module::LoadingError> {
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn state(&self) -> burn::module::State<<B as burn_tensor::backend::Backend>::FloatElem> {
|
||||
burn::module::State::StateNamed(burn::module::StateNamed::new())
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
self
|
||||
}
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
0
|
||||
}
|
||||
|
||||
fn visit<V: burn::module::ModuleVisitor<B>>(&self, _visitor: &mut V) {
|
||||
// Nothing to do
|
||||
}
|
||||
|
||||
fn map<M: burn::module::ModuleMapper<B>>(self, _mapper: &mut M) -> Self {
|
||||
self
|
||||
}
|
||||
};
|
||||
|
||||
(ad_module, $type:ty) => {
|
||||
type InnerModule = $type;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
self
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module
|
||||
}
|
||||
};
|
||||
|
||||
($type:ty) => {
|
||||
impl<B: burn::tensor::backend::Backend> burn::module::Module<B> for $type {
|
||||
constant!(module);
|
||||
}
|
||||
|
||||
impl<B: burn::tensor::backend::ADBackend> burn::module::ADModule<B> for $type {
|
||||
constant!(ad_module, $type);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// General Types
|
||||
constant!(alloc::string::String);
|
||||
constant!(bool);
|
||||
|
||||
// Float Types
|
||||
constant!(f64);
|
||||
constant!(f32);
|
||||
constant!(half::bf16);
|
||||
constant!(half::f16);
|
||||
|
||||
// Unsigned Integer Types
|
||||
constant!(usize);
|
||||
constant!(u64);
|
||||
constant!(u32);
|
||||
constant!(u16);
|
||||
constant!(u8);
|
||||
|
||||
// Signed Integer Types
|
||||
constant!(i64);
|
||||
constant!(i32);
|
||||
constant!(i16);
|
||||
constant!(i8);
|
|
@ -1,13 +1,13 @@
|
|||
mod base;
|
||||
mod constant;
|
||||
mod id;
|
||||
mod module;
|
||||
mod primitive;
|
||||
mod running;
|
||||
mod tensor;
|
||||
mod visitor;
|
||||
|
||||
pub use base::*;
|
||||
pub use id::*;
|
||||
pub use module::*;
|
||||
pub use running::*;
|
||||
pub use tensor::*;
|
||||
pub use visitor::*;
|
||||
|
|
|
@ -1,201 +0,0 @@
|
|||
use alloc::{format, vec::Vec};
|
||||
|
||||
use super::{load_with_id, state_with_id, Param, ParamId};
|
||||
use crate::module::{
|
||||
ADModule, LoadingError, Module, ModuleMapper, ModuleVisitor, State, StateNamed,
|
||||
};
|
||||
use crate::tensor::backend::Backend;
|
||||
|
||||
impl<M: Module> From<M> for Param<M> {
|
||||
fn from(value: M) -> Self {
|
||||
Param {
|
||||
id: ParamId::new(),
|
||||
value,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Module> From<Vec<M>> for Param<Vec<M>> {
|
||||
fn from(value: Vec<M>) -> Self {
|
||||
Param {
|
||||
id: ParamId::new(),
|
||||
value,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl<M: Module> Module for Param<M> {
|
||||
type Backend = M::Backend;
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
self.value.num_params()
|
||||
}
|
||||
|
||||
fn devices(&self) -> Vec<<M::Backend as Backend>::Device> {
|
||||
self.value.devices()
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<Self::Backend as Backend>::Device) -> Self {
|
||||
Param {
|
||||
id: self.id,
|
||||
value: self.value.to_device(device),
|
||||
}
|
||||
}
|
||||
|
||||
fn state(&self) -> State<<M::Backend as Backend>::FloatElem> {
|
||||
let state = self.value.state();
|
||||
|
||||
state_with_id(self.id.clone(), state)
|
||||
}
|
||||
|
||||
fn load(self, state: &State<<M::Backend as Backend>::FloatElem>) -> Result<Self, LoadingError> {
|
||||
let (id, state) = load_with_id(state)?;
|
||||
|
||||
Ok(Self {
|
||||
id: id.clone(),
|
||||
value: self.value.load(state)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
Param {
|
||||
id: self.id,
|
||||
value: self.value.detach(),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
|
||||
self.value.visit(visitor);
|
||||
}
|
||||
|
||||
fn map<V: ModuleMapper<Self::Backend>>(self, mapper: &mut V) -> Self {
|
||||
Self {
|
||||
id: self.id,
|
||||
value: self.value.map(mapper),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: Module> Module for Param<Vec<M>> {
|
||||
type Backend = M::Backend;
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
let mut num_params = 0;
|
||||
for module in self.value.iter() {
|
||||
num_params += module.num_params();
|
||||
}
|
||||
|
||||
num_params
|
||||
}
|
||||
|
||||
fn devices(&self) -> Vec<<M::Backend as Backend>::Device> {
|
||||
let mut devices = Vec::new();
|
||||
for module in self.value.iter() {
|
||||
devices.append(&mut module.devices());
|
||||
}
|
||||
devices
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<M::Backend as Backend>::Device) -> Self {
|
||||
Param {
|
||||
id: self.id,
|
||||
value: self
|
||||
.value
|
||||
.into_iter()
|
||||
.map(|val| val.to_device(device))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn state(&self) -> State<<M::Backend as Backend>::FloatElem> {
|
||||
let mut state = StateNamed::new();
|
||||
|
||||
for (i, module) in self.value.iter().enumerate() {
|
||||
state.register_state(format!("mod-{i}").as_str(), module.state());
|
||||
}
|
||||
|
||||
let state = State::StateNamed(state);
|
||||
|
||||
state_with_id(self.id.clone(), state)
|
||||
}
|
||||
|
||||
fn load(self, state: &State<<M::Backend as Backend>::FloatElem>) -> Result<Self, LoadingError> {
|
||||
let (id, state) = load_with_id(state)?;
|
||||
let id = id.clone();
|
||||
|
||||
let num = self.value.len();
|
||||
let mut modules = Vec::with_capacity(num);
|
||||
|
||||
for (i, module) in self.value.into_iter().enumerate() {
|
||||
let module = module
|
||||
.load(state.get(format!("mod-{i}").as_str()).ok_or_else(|| {
|
||||
LoadingError::new(format!(
|
||||
"Invalid number of modules, expected {num} modules missing #{i}"
|
||||
))
|
||||
})?)
|
||||
.map_err(|err| LoadingError::new(format!("Can't load modules mod-{i}: {err}")))?;
|
||||
|
||||
modules.push(module);
|
||||
}
|
||||
|
||||
Ok(Self { id, value: modules })
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
Param {
|
||||
id: self.id,
|
||||
value: self.value.into_iter().map(|val| val.detach()).collect(),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
|
||||
for module in self.value.iter() {
|
||||
module.visit(visitor);
|
||||
}
|
||||
}
|
||||
|
||||
fn map<V: ModuleMapper<Self::Backend>>(self, mapper: &mut V) -> Self {
|
||||
Self {
|
||||
id: self.id,
|
||||
value: self.value.into_iter().map(|val| val.map(mapper)).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: ADModule> ADModule for Param<Vec<M>> {
|
||||
type ADBackend = M::ADBackend;
|
||||
|
||||
type InnerModule = Param<Vec<M::InnerModule>>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
Param::from(
|
||||
self.value
|
||||
.into_iter()
|
||||
.map(|v| v.inner())
|
||||
.collect::<Vec<_>>(),
|
||||
)
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
Param {
|
||||
id: module.id,
|
||||
value: module.value.into_iter().map(ADModule::from_inner).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<M: ADModule> ADModule for Param<M> {
|
||||
type ADBackend = M::ADBackend;
|
||||
|
||||
type InnerModule = Param<M::InnerModule>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
Param::from(self.value.inner())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
Param {
|
||||
id: module.id,
|
||||
value: ADModule::from_inner(module.value),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,248 @@
|
|||
use crate::module::{
|
||||
ADModule, LoadingError, Module, ModuleMapper, ModuleVisitor, State, StateNamed,
|
||||
};
|
||||
use alloc::format;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::backend::{ADBackend, Backend};
|
||||
use core::fmt::Debug;
|
||||
|
||||
impl<T, B> Module<B> for Option<T>
|
||||
where
|
||||
T: Module<B> + Debug + Send + Sync + Clone,
|
||||
B: Backend,
|
||||
{
|
||||
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
|
||||
if let Some(module) = self {
|
||||
return Module::<B>::devices(module);
|
||||
}
|
||||
|
||||
Vec::new()
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
|
||||
self.map(|module| module.to_device(device))
|
||||
}
|
||||
|
||||
fn load(self, state: &State<B::FloatElem>) -> Result<Self, LoadingError> {
|
||||
self.map(|module| module.load(state).map(|val| Some(val)))
|
||||
.unwrap_or(Ok(None))
|
||||
}
|
||||
|
||||
fn state(&self) -> State<B::FloatElem> {
|
||||
if let Some(module) = self {
|
||||
return module.state();
|
||||
}
|
||||
|
||||
State::StateNamed(StateNamed::new())
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
self.map(|module| module.detach())
|
||||
}
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
match &self {
|
||||
Some(module) => module.num_params(),
|
||||
None => 0,
|
||||
}
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
if let Some(module) = self {
|
||||
module.visit(visitor)
|
||||
}
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
self.map(|module| module.map(mapper))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B> ADModule<B> for Option<T>
|
||||
where
|
||||
T: ADModule<B> + Debug + Send + Sync + Clone,
|
||||
B: ADBackend,
|
||||
{
|
||||
type InnerModule = Option<T::InnerModule>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
self.map(|module| module.inner())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module.map(|module| T::from_inner(module))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B> Module<B> for Vec<T>
|
||||
where
|
||||
T: Module<B> + Debug + Send + Sync + Clone,
|
||||
B: Backend,
|
||||
{
|
||||
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
|
||||
let mut devices = Vec::new();
|
||||
for module in self.iter() {
|
||||
devices.append(&mut module.devices());
|
||||
}
|
||||
devices
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
|
||||
self.into_iter().map(|val| val.to_device(device)).collect()
|
||||
}
|
||||
|
||||
fn load(self, state: &State<B::FloatElem>) -> Result<Self, LoadingError> {
|
||||
let num = self.len();
|
||||
let mut modules = Vec::with_capacity(num);
|
||||
|
||||
for (i, module) in self.into_iter().enumerate() {
|
||||
let module = module
|
||||
.load(state.get(format!("mod-{i}").as_str()).ok_or_else(|| {
|
||||
LoadingError::new(format!(
|
||||
"Invalid number of modules, expected {num} modules missing #{i}"
|
||||
))
|
||||
})?)
|
||||
.map_err(|err| LoadingError::new(format!("Can't load modules mod-{i}: {err}")))?;
|
||||
|
||||
modules.push(module);
|
||||
}
|
||||
|
||||
Ok(modules)
|
||||
}
|
||||
|
||||
fn state(&self) -> State<B::FloatElem> {
|
||||
let mut state = StateNamed::new();
|
||||
|
||||
for (i, module) in self.iter().enumerate() {
|
||||
state.register_state(format!("mod-{i}").as_str(), module.state());
|
||||
}
|
||||
|
||||
State::StateNamed(state)
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
self.into_iter().map(|module| module.detach()).collect()
|
||||
}
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
let mut num_params = 0;
|
||||
for module in self.iter() {
|
||||
num_params += module.num_params();
|
||||
}
|
||||
|
||||
num_params
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
self.iter().for_each(|module| {
|
||||
module.visit(visitor);
|
||||
});
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
self.into_iter().map(|module| module.map(mapper)).collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, B> ADModule<B> for Vec<T>
|
||||
where
|
||||
T: ADModule<B> + Debug + Send + Sync + Clone,
|
||||
B: ADBackend,
|
||||
{
|
||||
type InnerModule = Vec<T::InnerModule>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
self.into_iter().map(|module| module.inner()).collect()
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module
|
||||
.into_iter()
|
||||
.map(|module| T::from_inner(module))
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T, B> Module<B> for [T; N]
|
||||
where
|
||||
T: Module<B> + Debug + Send + Sync + Clone + Copy,
|
||||
B: Backend,
|
||||
{
|
||||
fn devices(&self) -> Vec<<B as burn_tensor::backend::Backend>::Device> {
|
||||
let mut devices = Vec::new();
|
||||
for module in self.iter() {
|
||||
devices.append(&mut module.devices());
|
||||
}
|
||||
devices
|
||||
}
|
||||
|
||||
fn to_device(self, device: &<B as burn_tensor::backend::Backend>::Device) -> Self {
|
||||
self.map(|val| val.to_device(device))
|
||||
}
|
||||
|
||||
fn load(mut self, state: &State<B::FloatElem>) -> Result<Self, LoadingError> {
|
||||
let num = self.len();
|
||||
|
||||
for (i, module) in self.into_iter().enumerate().take(N) {
|
||||
self[i] = module
|
||||
.load(state.get(format!("mod-{i}").as_str()).ok_or_else(|| {
|
||||
LoadingError::new(format!(
|
||||
"Invalid number of modules, expected {num} modules missing #{i}"
|
||||
))
|
||||
})?)
|
||||
.map_err(|err| LoadingError::new(format!("Can't load modules mod-{i}: {err}")))?;
|
||||
}
|
||||
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
fn state(&self) -> State<B::FloatElem> {
|
||||
let mut state = StateNamed::new();
|
||||
|
||||
for (i, module) in self.iter().enumerate() {
|
||||
state.register_state(format!("mod-{i}").as_str(), module.state());
|
||||
}
|
||||
|
||||
State::StateNamed(state)
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
self.map(|module| module.detach())
|
||||
}
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
let mut num_params = 0;
|
||||
for module in self.iter() {
|
||||
num_params += module.num_params();
|
||||
}
|
||||
|
||||
num_params
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
self.iter().for_each(|module| {
|
||||
module.visit(visitor);
|
||||
});
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
self.map(|module| module.map(mapper))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const N: usize, T, B> ADModule<B> for [T; N]
|
||||
where
|
||||
T: ADModule<B> + Debug + Send + Sync + Clone + Copy,
|
||||
T::InnerModule: Copy,
|
||||
B: ADBackend,
|
||||
{
|
||||
type InnerModule = [T::InnerModule; N];
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
self.map(|module| module.inner())
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
module.map(|module| T::from_inner(module))
|
||||
}
|
||||
}
|
|
@ -55,9 +55,7 @@ impl<B: Backend, const D: usize> From<RunningState<Tensor<B, D>>>
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module for Param<RunningState<Tensor<B, D>>> {
|
||||
type Backend = B;
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<RunningState<Tensor<B, D>>> {
|
||||
fn num_params(&self) -> usize {
|
||||
let tensor = self.value.value.read().unwrap();
|
||||
tensor.shape().num_elements()
|
||||
|
@ -113,13 +111,13 @@ impl<const D: usize, B: Backend> Module for Param<RunningState<Tensor<B, D>>> {
|
|||
self
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
let tensor = self.value.value.read().unwrap();
|
||||
|
||||
visitor.visit(&self.id, &tensor)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self {
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let mut tensor = self.value.value.write().unwrap();
|
||||
let tensor_out = mapper.map(&self.id, tensor.clone());
|
||||
|
||||
|
@ -208,9 +206,7 @@ impl<const D: usize, B: Backend> RunningState<Tensor<B, D>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: ADBackend> ADModule for Param<RunningState<Tensor<B, D>>> {
|
||||
type ADBackend = B;
|
||||
|
||||
impl<const D: usize, B: ADBackend> ADModule<B> for Param<RunningState<Tensor<B, D>>> {
|
||||
type InnerModule = Param<RunningState<Tensor<B::InnerBackend, D>>>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
|
|
|
@ -1,9 +1,7 @@
|
|||
use alloc::{string::ToString, vec, vec::Vec};
|
||||
|
||||
use super::{load_with_id, state_with_id, Param, ParamId};
|
||||
use crate::module::{
|
||||
ADModule, LoadingError, Module, ModuleMapper, ModuleVisitor, State, StateNamed,
|
||||
};
|
||||
use crate::module::{ADModule, LoadingError, Module, ModuleMapper, ModuleVisitor, State};
|
||||
use crate::tensor::{
|
||||
backend::{ADBackend, Backend},
|
||||
Data, Tensor,
|
||||
|
@ -18,18 +16,7 @@ impl<B: Backend, const D: usize> From<Tensor<B, D>> for Param<Tensor<B, D>> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> From<Option<Tensor<B, D>>> for Param<Option<Tensor<B, D>>> {
|
||||
fn from(value: Option<Tensor<B, D>>) -> Self {
|
||||
Param {
|
||||
id: ParamId::new(),
|
||||
value: value.map(|tensor| tensor.require_grad()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module for Param<Tensor<B, D>> {
|
||||
type Backend = B;
|
||||
|
||||
impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>> {
|
||||
fn num_params(&self) -> usize {
|
||||
self.value.shape().num_elements()
|
||||
}
|
||||
|
@ -72,99 +59,17 @@ impl<const D: usize, B: Backend> Module for Param<Tensor<B, D>> {
|
|||
}
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
|
||||
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
visitor.visit(&self.id, &self.value)
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self {
|
||||
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
let value = mapper.map(&self.id, self.value).require_grad();
|
||||
Self { id: self.id, value }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: Backend> Module for Param<Option<Tensor<B, D>>> {
|
||||
type Backend = B;
|
||||
|
||||
fn num_params(&self) -> usize {
|
||||
if let Some(value) = &self.value {
|
||||
return value.shape().num_elements();
|
||||
}
|
||||
|
||||
0
|
||||
}
|
||||
|
||||
fn devices(&self) -> Vec<B::Device> {
|
||||
if let Some(value) = &self.value {
|
||||
return vec![value.device()];
|
||||
}
|
||||
|
||||
vec![]
|
||||
}
|
||||
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
Self {
|
||||
id: self.id,
|
||||
value: self
|
||||
.value
|
||||
.map(|value| value.to_device(device).require_grad()),
|
||||
}
|
||||
}
|
||||
|
||||
fn state(&self) -> State<B::FloatElem> {
|
||||
let state = match &self.value {
|
||||
Some(value) => State::Data(value.to_data().serialize()),
|
||||
None => State::StateNamed(StateNamed::new()),
|
||||
};
|
||||
|
||||
state_with_id(self.id.clone(), state)
|
||||
}
|
||||
|
||||
fn load(self, state: &State<B::FloatElem>) -> Result<Self, LoadingError> {
|
||||
let (id, state) = load_with_id(state)?;
|
||||
let id = id.clone();
|
||||
|
||||
let tensor = if let Some(tensor) = self.value {
|
||||
let data = match state {
|
||||
State::Data(data) => data,
|
||||
_ => {
|
||||
return Err(LoadingError::new(
|
||||
"Can't load Option<Tensor> from NamedState".to_string(),
|
||||
))
|
||||
}
|
||||
};
|
||||
|
||||
Some(Tensor::from_data_device(Data::from(data), &tensor.device()).require_grad())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self { id, value: tensor })
|
||||
}
|
||||
|
||||
fn detach(self) -> Self {
|
||||
Self {
|
||||
id: self.id,
|
||||
value: self.value.map(|value| value.detach().require_grad()),
|
||||
}
|
||||
}
|
||||
|
||||
fn visit<V: ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
|
||||
if let Some(value) = &self.value {
|
||||
visitor.visit(&self.id, value)
|
||||
}
|
||||
}
|
||||
|
||||
fn map<M: ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self {
|
||||
let value = self
|
||||
.value
|
||||
.map(|value| mapper.map(&self.id, value).require_grad());
|
||||
Self { id: self.id, value }
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: ADBackend> ADModule for Param<Tensor<B, D>> {
|
||||
type ADBackend = B;
|
||||
|
||||
impl<const D: usize, B: ADBackend> ADModule<B> for Param<Tensor<B, D>> {
|
||||
type InnerModule = Param<Tensor<B::InnerBackend, D>>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
|
@ -181,25 +86,3 @@ impl<const D: usize, B: ADBackend> ADModule for Param<Tensor<B, D>> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B: ADBackend> ADModule for Param<Option<Tensor<B, D>>> {
|
||||
type ADBackend = B;
|
||||
|
||||
type InnerModule = Param<Option<Tensor<B::InnerBackend, D>>>;
|
||||
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
Param {
|
||||
id: self.id,
|
||||
value: self.value.map(|val| val.inner()),
|
||||
}
|
||||
}
|
||||
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
Param {
|
||||
id: module.id,
|
||||
value: module
|
||||
.value
|
||||
.map(|val| Tensor::from_inner(val).require_grad()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,24 +1,31 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
use super::ParamId;
|
||||
use crate::module::{Module, ModuleVisitor};
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::{backend::Backend, Tensor};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
#[derive(new)]
|
||||
struct ParamIdCollector<'a> {
|
||||
struct ParamIdCollector<'a, M> {
|
||||
param_ids: &'a mut Vec<ParamId>,
|
||||
phantom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<'a, B: Backend> ModuleVisitor<B> for ParamIdCollector<'a> {
|
||||
impl<'a, B, M> ModuleVisitor<B> for ParamIdCollector<'a, M>
|
||||
where
|
||||
B: Backend,
|
||||
M: Module<B>,
|
||||
{
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
self.param_ids.push(id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
/// List all the parameter ids in a module.
|
||||
pub fn list_param_ids<M: Module>(module: &M) -> Vec<ParamId> {
|
||||
pub fn list_param_ids<M: Module<B>, B: Backend>(module: &M) -> Vec<ParamId> {
|
||||
let mut params_ids = Vec::new();
|
||||
let mut visitor = ParamIdCollector::new(&mut params_ids);
|
||||
let mut visitor = ParamIdCollector {
|
||||
param_ids: &mut params_ids,
|
||||
phantom: PhantomData::<M>::default(),
|
||||
};
|
||||
module.visit(&mut visitor);
|
||||
|
||||
params_ids
|
||||
|
|
|
@ -5,7 +5,7 @@ use crate as burn;
|
|||
use crate::nn::cache::TensorCache;
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn,
|
||||
tensor::{activation, backend::Backend, Bool, Tensor},
|
||||
};
|
||||
|
@ -39,10 +39,10 @@ pub struct MultiHeadAttentionConfig {
|
|||
/// - output: [Linear](nn::Linear) layer with `d_model` input and output features.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct MultiHeadAttention<B: Backend> {
|
||||
query: Param<nn::Linear<B>>,
|
||||
key: Param<nn::Linear<B>>,
|
||||
value: Param<nn::Linear<B>>,
|
||||
output: Param<nn::Linear<B>>,
|
||||
query: nn::Linear<B>,
|
||||
key: nn::Linear<B>,
|
||||
value: nn::Linear<B>,
|
||||
output: nn::Linear<B>,
|
||||
dropout: nn::Dropout,
|
||||
activation: nn::GELU,
|
||||
n_heads: usize,
|
||||
|
@ -63,9 +63,7 @@ pub struct MhaInput<B: Backend> {
|
|||
impl MultiHeadAttentionConfig {
|
||||
/// Initialize a new [multihead attention](MultiHeadAttention) module.
|
||||
pub fn init<B: Backend>(&self) -> MultiHeadAttention<B> {
|
||||
let linear = |config: &Self| {
|
||||
Param::from(nn::LinearConfig::new(config.d_model, config.d_model).init())
|
||||
};
|
||||
let linear = |config: &Self| nn::LinearConfig::new(config.d_model, config.d_model).init();
|
||||
|
||||
MultiHeadAttention {
|
||||
query: linear(self),
|
||||
|
@ -172,7 +170,7 @@ impl<B: Backend> MultiHeadAttention<B> {
|
|||
|
||||
let attention_linear = |cache: &mut TensorCache<B, 4>,
|
||||
tensor: Tensor<B, 3>,
|
||||
param: &Param<nn::Linear<B>>| {
|
||||
param: &nn::Linear<B>| {
|
||||
cache.forward_autoregressive(tensor, 2, |tensor| self.attention_linear(tensor, param))
|
||||
};
|
||||
|
||||
|
@ -235,7 +233,7 @@ impl<B: Backend> MultiHeadAttention<B> {
|
|||
activation::softmax(attn_scores, 3)
|
||||
}
|
||||
|
||||
fn attention_linear(&self, x: Tensor<B, 3>, linear: &Param<nn::Linear<B>>) -> Tensor<B, 4> {
|
||||
fn attention_linear(&self, x: Tensor<B, 3>, linear: &nn::Linear<B>) -> Tensor<B, 4> {
|
||||
let [batch_size, seq_length, _d_model] = x.dims();
|
||||
linear
|
||||
.forward(x)
|
||||
|
|
|
@ -4,6 +4,7 @@ use alloc::vec::Vec;
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::constant;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::Initializer;
|
||||
|
@ -43,6 +44,8 @@ pub enum Conv1dPaddingConfig {
|
|||
Explicit(usize),
|
||||
}
|
||||
|
||||
constant!(Conv1dPaddingConfig);
|
||||
|
||||
/// Applies a 1D convolution over input tensors.
|
||||
///
|
||||
/// # Params
|
||||
|
@ -55,7 +58,7 @@ pub enum Conv1dPaddingConfig {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Conv1d<B: Backend> {
|
||||
weight: Param<Tensor<B, 3>>,
|
||||
bias: Param<Option<Tensor<B, 1>>>,
|
||||
bias: Option<Param<Tensor<B, 1>>>,
|
||||
stride: usize,
|
||||
kernel_size: usize,
|
||||
padding: Option<Conv1dPaddingConfig>,
|
||||
|
@ -76,14 +79,14 @@ impl Conv1dConfig {
|
|||
let weight = initializer.init([self.channels_out, self.channels_in, self.kernel_size]);
|
||||
|
||||
let bias = if self.bias {
|
||||
Some(initializer.init([self.channels_out]))
|
||||
Some(Param::from(initializer.init([self.channels_out])))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Conv1d {
|
||||
weight: Param::from(weight),
|
||||
bias: Param::from(bias),
|
||||
bias,
|
||||
stride: 1, // TODO: Add the stride to the config when properly supported.
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
|
@ -115,7 +118,7 @@ impl<B: Backend> Conv1d<B> {
|
|||
conv1d(
|
||||
input,
|
||||
self.weight.val(),
|
||||
self.bias.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
self.stride,
|
||||
padding,
|
||||
)
|
||||
|
|
|
@ -3,6 +3,7 @@ use alloc::{format, vec::Vec};
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::constant;
|
||||
use crate::module::Module;
|
||||
use crate::module::Param;
|
||||
use crate::nn::Initializer;
|
||||
|
@ -43,6 +44,8 @@ pub enum Conv2dPaddingConfig {
|
|||
Explicit(usize, usize),
|
||||
}
|
||||
|
||||
constant!(Conv2dPaddingConfig);
|
||||
|
||||
/// Applies a 2D convolution over input tensors.
|
||||
///
|
||||
/// # Params
|
||||
|
@ -55,7 +58,7 @@ pub enum Conv2dPaddingConfig {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Conv2d<B: Backend> {
|
||||
weight: Param<Tensor<B, 4>>,
|
||||
bias: Param<Option<Tensor<B, 1>>>,
|
||||
bias: Option<Param<Tensor<B, 1>>>,
|
||||
stride: [usize; 2],
|
||||
kernel_size: [usize; 2],
|
||||
padding: Conv2dPaddingConfig,
|
||||
|
@ -88,7 +91,7 @@ impl Conv2dConfig {
|
|||
|
||||
Conv2d {
|
||||
weight: Param::from(weight),
|
||||
bias: Param::from(bias),
|
||||
bias: bias.map(Param::from),
|
||||
stride: [1, 1], // TODO: Add the stride to the config when properly supported.
|
||||
kernel_size: self.kernel_size,
|
||||
padding: self.padding.clone(),
|
||||
|
@ -111,7 +114,7 @@ impl<B: Backend> Conv2d<B> {
|
|||
conv2d(
|
||||
input,
|
||||
self.weight.val(),
|
||||
self.bias.val(),
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
self.stride,
|
||||
padding,
|
||||
)
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::constant;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::{Distribution, Tensor};
|
||||
|
||||
|
@ -21,6 +23,8 @@ pub struct Dropout {
|
|||
prob: f64,
|
||||
}
|
||||
|
||||
constant!(Dropout);
|
||||
|
||||
impl DropoutConfig {
|
||||
/// Initialize a new [dropout](Dropout) module.
|
||||
pub fn init(&self) -> Dropout {
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::constant;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
|
@ -5,6 +8,8 @@ use crate::tensor::Tensor;
|
|||
#[derive(Clone, Debug, Default)]
|
||||
pub struct GELU {}
|
||||
|
||||
constant!(GELU);
|
||||
|
||||
impl GELU {
|
||||
/// Create the module.
|
||||
pub fn new() -> Self {
|
||||
|
|
|
@ -40,7 +40,7 @@ pub struct LinearConfig {
|
|||
#[derive(Module, Debug)]
|
||||
pub struct Linear<B: Backend> {
|
||||
weight: Param<Tensor<B, 2>>,
|
||||
bias: Param<Option<Tensor<B, 1>>>,
|
||||
bias: Option<Param<Tensor<B, 1>>>,
|
||||
}
|
||||
|
||||
impl LinearConfig {
|
||||
|
@ -64,7 +64,7 @@ impl LinearConfig {
|
|||
|
||||
Linear {
|
||||
weight: Param::from(weight),
|
||||
bias: Param::from(bias),
|
||||
bias: bias.map(Param::from),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -79,8 +79,8 @@ impl<B: Backend> Linear<B> {
|
|||
pub fn forward<const D: usize>(&self, input: Tensor<B, D>) -> Tensor<B, D> {
|
||||
let output = input.matmul(self.weight.val().unsqueeze());
|
||||
|
||||
match self.bias.val() {
|
||||
Some(bias) => output + bias.unsqueeze(),
|
||||
match &self.bias {
|
||||
Some(bias) => output + bias.val().unsqueeze(),
|
||||
None => output,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate as burn;
|
||||
use crate::{self as burn, constant};
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::nn::conv::Conv2dPaddingConfig;
|
||||
|
@ -32,6 +32,8 @@ pub struct MaxPool2d {
|
|||
padding: MaxPool2dPaddingConfig,
|
||||
}
|
||||
|
||||
constant!(MaxPool2d);
|
||||
|
||||
impl MaxPool2dConfig {
|
||||
/// Initialize a new [max pool 2d](MaxPool2d) module.
|
||||
pub fn init(&self) -> MaxPool2d {
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
use crate as burn;
|
||||
|
||||
use crate::constant;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
|
@ -7,6 +10,8 @@ use crate::tensor::Tensor;
|
|||
#[derive(Clone, Debug, Default)]
|
||||
pub struct ReLU {}
|
||||
|
||||
constant!(ReLU);
|
||||
|
||||
impl ReLU {
|
||||
/// Create the module.
|
||||
pub fn new() -> Self {
|
||||
|
|
|
@ -9,7 +9,7 @@ use crate::{
|
|||
use super::{PositionWiseFeedForward, PositionWiseFeedForwardConfig};
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn::{
|
||||
attention::{MhaInput, MultiHeadAttention, MultiHeadAttentionConfig},
|
||||
Dropout, DropoutConfig, LayerNorm, LayerNormConfig,
|
||||
|
@ -43,7 +43,7 @@ pub struct TransformerEncoderConfig {
|
|||
/// - layers: transformer encoder layers with `d_model` input and output features.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerEncoder<B: Backend> {
|
||||
layers: Param<Vec<TransformerEncoderLayer<B>>>,
|
||||
layers: Vec<TransformerEncoderLayer<B>>,
|
||||
}
|
||||
|
||||
/// [Transformer Encoder](TransformerEncoder) forward pass input argument.
|
||||
|
@ -83,9 +83,7 @@ impl TransformerEncoderConfig {
|
|||
.map(|_| TransformerEncoderLayer::new(self))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
TransformerEncoder {
|
||||
layers: Param::from(layers),
|
||||
}
|
||||
TransformerEncoder { layers }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -141,10 +139,10 @@ impl<B: Backend> TransformerEncoder<B> {
|
|||
|
||||
#[derive(Module, Debug)]
|
||||
struct TransformerEncoderLayer<B: Backend> {
|
||||
mha: Param<MultiHeadAttention<B>>,
|
||||
pwff: Param<PositionWiseFeedForward<B>>,
|
||||
norm_1: Param<LayerNorm<B>>,
|
||||
norm_2: Param<LayerNorm<B>>,
|
||||
mha: MultiHeadAttention<B>,
|
||||
pwff: PositionWiseFeedForward<B>,
|
||||
norm_1: LayerNorm<B>,
|
||||
norm_2: LayerNorm<B>,
|
||||
dropout: Dropout,
|
||||
norm_first: bool,
|
||||
}
|
||||
|
@ -162,10 +160,10 @@ impl<B: Backend> TransformerEncoderLayer<B> {
|
|||
.init();
|
||||
|
||||
Self {
|
||||
mha: Param::from(mha),
|
||||
norm_1: Param::from(norm_1),
|
||||
norm_2: Param::from(norm_2),
|
||||
pwff: Param::from(pwff),
|
||||
mha,
|
||||
norm_1,
|
||||
norm_2,
|
||||
pwff,
|
||||
dropout,
|
||||
norm_first: config.norm_first,
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate as burn;
|
|||
|
||||
use crate::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn::{Dropout, DropoutConfig, Linear, LinearConfig, GELU},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
@ -29,8 +29,8 @@ pub struct PositionWiseFeedForwardConfig {
|
|||
/// - linear outer: Linear layer with `d_ff` input features and `d_model` output features.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct PositionWiseFeedForward<B: Backend> {
|
||||
linear_inner: Param<Linear<B>>,
|
||||
linear_outer: Param<Linear<B>>,
|
||||
linear_inner: Linear<B>,
|
||||
linear_outer: Linear<B>,
|
||||
dropout: Dropout,
|
||||
gelu: GELU,
|
||||
}
|
||||
|
@ -39,8 +39,8 @@ impl PositionWiseFeedForwardConfig {
|
|||
/// Initialize a new [position-wise feed-forward](PositionWiseFeedForward) module.
|
||||
pub fn init<B: Backend>(&self) -> PositionWiseFeedForward<B> {
|
||||
PositionWiseFeedForward {
|
||||
linear_inner: Param::from(LinearConfig::new(self.d_model, self.d_ff).init()),
|
||||
linear_outer: Param::from(LinearConfig::new(self.d_ff, self.d_model).init()),
|
||||
linear_inner: LinearConfig::new(self.d_model, self.d_ff).init(),
|
||||
linear_outer: LinearConfig::new(self.d_ff, self.d_model).init(),
|
||||
dropout: DropoutConfig::new(self.dropout).init(),
|
||||
gelu: GELU::new(),
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate as burn;
|
||||
use crate::{self as burn, module::ADModule};
|
||||
|
||||
use super::{
|
||||
decay::{WeightDecay, WeightDecayConfig},
|
||||
|
@ -54,9 +54,7 @@ impl<B: ADBackend> Adam<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: ADBackend> Optimizer for Adam<B> {
|
||||
type Backend = B;
|
||||
|
||||
impl<M: ADModule<B>, B: ADBackend> Optimizer<M, B> for Adam<B> {
|
||||
fn update_tensor<const D: usize>(
|
||||
&mut self,
|
||||
id: &ParamId,
|
||||
|
|
|
@ -2,25 +2,26 @@ use super::mapper::ModuleTensorUpdater;
|
|||
use super::visitor::{GradientsLoader, GradientsRegister};
|
||||
use super::GradientsParams;
|
||||
|
||||
use crate::module::{ADModule, LoadingError, Module, ParamId, State, StateNamed};
|
||||
use crate::tensor::backend::{ADBackend, Backend};
|
||||
use crate::module::{ADModule, LoadingError, ParamId, State, StateNamed};
|
||||
use crate::tensor::backend::ADBackend;
|
||||
use crate::tensor::{Data, Tensor};
|
||||
|
||||
pub trait Optimizer: Send + Sync {
|
||||
type Backend: ADBackend;
|
||||
|
||||
pub trait Optimizer<M, B>: Send + Sync
|
||||
where
|
||||
M: ADModule<B>,
|
||||
B: ADBackend,
|
||||
{
|
||||
/// Update the tensor parameter using the given the gradients.
|
||||
fn update_tensor<const D: usize>(
|
||||
&mut self,
|
||||
id: &ParamId,
|
||||
tensor: Tensor<Self::Backend, D>,
|
||||
grad: Tensor<<Self::Backend as ADBackend>::InnerBackend, D>,
|
||||
) -> Tensor<Self::Backend, D>;
|
||||
tensor: Tensor<B, D>,
|
||||
grad: Tensor<B::InnerBackend, D>,
|
||||
) -> Tensor<B, D>;
|
||||
|
||||
/// Update the parameters of the given module using the given the gradients.
|
||||
fn update_module<M>(&mut self, module: M, grads: GradientsParams) -> M
|
||||
fn update_module(&mut self, module: M, grads: GradientsParams) -> M
|
||||
where
|
||||
M: ADModule<ADBackend = Self::Backend>,
|
||||
Self: Sized,
|
||||
{
|
||||
let mut mapper = ModuleTensorUpdater::new(self, grads);
|
||||
|
@ -35,7 +36,7 @@ pub trait Optimizer: Send + Sync {
|
|||
fn register_param_state<const D: usize>(
|
||||
&self,
|
||||
_id: &ParamId,
|
||||
_state: &mut StateNamed<<Self::Backend as Backend>::FloatElem>,
|
||||
_state: &mut StateNamed<B::FloatElem>,
|
||||
) {
|
||||
// By default there is no state to register
|
||||
}
|
||||
|
@ -48,17 +49,14 @@ pub trait Optimizer: Send + Sync {
|
|||
fn load_param_state<const D: usize>(
|
||||
&mut self,
|
||||
_id: &ParamId,
|
||||
_state: &StateNamed<<Self::Backend as Backend>::FloatElem>,
|
||||
_device: &<Self::Backend as Backend>::Device,
|
||||
_state: &StateNamed<B::FloatElem>,
|
||||
_device: &B::Device,
|
||||
) {
|
||||
// By default there is no state to load
|
||||
}
|
||||
|
||||
/// Get the optimizer state for a given module.
|
||||
fn state<M: Module<Backend = Self::Backend>>(
|
||||
&self,
|
||||
module: &M,
|
||||
) -> State<<Self::Backend as Backend>::FloatElem>
|
||||
fn state(&self, module: &M) -> State<B::FloatElem>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
|
@ -70,11 +68,7 @@ pub trait Optimizer: Send + Sync {
|
|||
}
|
||||
|
||||
/// Load the optimizer state for a given module.
|
||||
fn load<M: Module<Backend = Self::Backend>>(
|
||||
&mut self,
|
||||
module: &M,
|
||||
state: &State<<Self::Backend as Backend>::FloatElem>,
|
||||
) -> Result<(), LoadingError>
|
||||
fn load(&mut self, module: &M, state: &State<B::FloatElem>) -> Result<(), LoadingError>
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
|
|
|
@ -1,36 +1,40 @@
|
|||
use crate::module::{Module, ModuleVisitor, ParamId};
|
||||
use core::marker::PhantomData;
|
||||
|
||||
use crate::module::{ADModule, ModuleVisitor, ParamId};
|
||||
|
||||
use burn_tensor::{backend::ADBackend, Tensor};
|
||||
|
||||
use super::GradientsParams;
|
||||
|
||||
/// Accumulate gradients into a single [Gradients](ADBackend::Gradients) object.
|
||||
pub struct GradientsAccumulator {
|
||||
pub struct GradientsAccumulator<M> {
|
||||
grads: GradientsParams,
|
||||
phantom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl Default for GradientsAccumulator {
|
||||
impl<M> Default for GradientsAccumulator<M> {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl GradientsAccumulator {
|
||||
impl<M> GradientsAccumulator<M> {
|
||||
/// Create a new gradients accumulator.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
grads: GradientsParams::new(),
|
||||
phantom: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl GradientsAccumulator {
|
||||
impl<M> GradientsAccumulator<M> {
|
||||
/// Accumulate the given gradients for each parameter in the given module.
|
||||
pub fn accumulate<B: ADBackend, M>(&mut self, module: &M, grads: GradientsParams)
|
||||
pub fn accumulate<B: ADBackend>(&mut self, module: &M, grads: GradientsParams)
|
||||
where
|
||||
M: Module<Backend = B>,
|
||||
M: ADModule<B>,
|
||||
{
|
||||
let mut visitor = ModuleGradsAccumulator::new(&mut self.grads, grads);
|
||||
let mut visitor = ModuleGradsAccumulator::<M>::new(&mut self.grads, grads);
|
||||
module.visit(&mut visitor);
|
||||
}
|
||||
|
||||
|
@ -44,12 +48,13 @@ impl GradientsAccumulator {
|
|||
}
|
||||
|
||||
#[derive(new)]
|
||||
struct ModuleGradsAccumulator<'a> {
|
||||
struct ModuleGradsAccumulator<'a, M> {
|
||||
grads: &'a mut GradientsParams,
|
||||
grads_new: GradientsParams,
|
||||
phantom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<'a, B: ADBackend> ModuleVisitor<B> for ModuleGradsAccumulator<'a> {
|
||||
impl<'a, B: ADBackend, M: ADModule<B>> ModuleVisitor<B> for ModuleGradsAccumulator<'a, M> {
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
let grad_updated = match self.grads_new.remove::<B::InnerBackend, D>(id) {
|
||||
Some(new) => match self.grads.remove::<B::InnerBackend, D>(id) {
|
||||
|
|
|
@ -63,23 +63,20 @@ impl GradientsParams {
|
|||
}
|
||||
|
||||
/// Change the device of each tensor gradients registered for the given [module](ADModule).
|
||||
pub fn to_device<M: ADModule>(
|
||||
pub fn to_device<B: ADBackend, M: ADModule<B>>(
|
||||
mut self,
|
||||
device: &<M::Backend as Backend>::Device,
|
||||
device: &B::Device,
|
||||
module: &M,
|
||||
) -> Self {
|
||||
let mut visitor = GradientsParamsChangeDevice::new(device, &mut self);
|
||||
let mut visitor = GradientsParamsChangeDevice::<M, B>::new(device, &mut self);
|
||||
module.visit(&mut visitor);
|
||||
self
|
||||
}
|
||||
|
||||
/// Extract each tensor gradients for the given [module](ADModule).
|
||||
pub fn from_grads<M: ADModule>(
|
||||
grads: <M::ADBackend as ADBackend>::Gradients,
|
||||
module: &M,
|
||||
) -> Self {
|
||||
pub fn from_grads<B: ADBackend, M: ADModule<B>>(grads: B::Gradients, module: &M) -> Self {
|
||||
let mut grads_params = GradientsParams::new();
|
||||
let mut visitor = GradientsParamsConverter::new(grads, &mut grads_params);
|
||||
let mut visitor = GradientsParamsConverter::<M, B>::new(grads, &mut grads_params);
|
||||
|
||||
module.visit(&mut visitor);
|
||||
grads_params
|
||||
|
|
|
@ -1,16 +1,24 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{backend::ADBackend, Tensor};
|
||||
|
||||
use crate::module::{ModuleMapper, ParamId};
|
||||
use crate::module::{ADModule, ModuleMapper, ParamId};
|
||||
|
||||
use super::{GradientsParams, Optimizer};
|
||||
|
||||
#[derive(new)]
|
||||
pub struct ModuleTensorUpdater<'a, O> {
|
||||
pub struct ModuleTensorUpdater<'a, M, O> {
|
||||
optimizer: &'a mut O,
|
||||
grads: GradientsParams,
|
||||
phatom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleMapper<B> for ModuleTensorUpdater<'a, O> {
|
||||
impl<'a, M, B, O> ModuleMapper<B> for ModuleTensorUpdater<'a, M, O>
|
||||
where
|
||||
M: ADModule<B>,
|
||||
B: ADBackend,
|
||||
O: Optimizer<M, B>,
|
||||
{
|
||||
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
|
||||
if let Some(grad) = self.grads.remove::<B::InnerBackend, D>(id) {
|
||||
self.optimizer.update_tensor(id, tensor, grad)
|
||||
|
|
|
@ -3,7 +3,7 @@ use crate as burn;
|
|||
use super::decay::{WeightDecay, WeightDecayConfig};
|
||||
use super::momentum::{Momentum, MomentumConfig};
|
||||
use crate::config::Config;
|
||||
use crate::module::{ParamId, StateNamed};
|
||||
use crate::module::{ADModule, ParamId, StateNamed};
|
||||
use crate::optim::Optimizer;
|
||||
use crate::tensor::backend::ADBackend;
|
||||
use crate::tensor::{ElementConversion, Tensor};
|
||||
|
@ -45,9 +45,7 @@ impl<B: ADBackend> Sgd<B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<B: ADBackend> Optimizer for Sgd<B> {
|
||||
type Backend = B;
|
||||
|
||||
impl<M: ADModule<B>, B: ADBackend> Optimizer<M, B> for Sgd<B> {
|
||||
fn update_tensor<const D: usize>(
|
||||
&mut self,
|
||||
id: &ParamId,
|
||||
|
|
|
@ -1,45 +1,65 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use super::{GradientsParams, Optimizer};
|
||||
use crate::module::{ModuleVisitor, ParamId, StateNamed};
|
||||
use crate::module::{ADModule, ModuleVisitor, ParamId, StateNamed};
|
||||
use burn_tensor::{backend::ADBackend, Tensor};
|
||||
|
||||
#[derive(new)]
|
||||
pub struct GradientsRegister<'a, B: ADBackend, O> {
|
||||
pub struct GradientsRegister<'a, M: ADModule<B>, B: ADBackend, O> {
|
||||
optimizer: &'a O,
|
||||
state: &'a mut StateNamed<B::FloatElem>,
|
||||
phatom: PhantomData<M>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct GradientsLoader<'a, B: ADBackend, O> {
|
||||
pub struct GradientsLoader<'a, M: ADModule<B>, B: ADBackend, O> {
|
||||
optimizer: &'a mut O,
|
||||
state: &'a StateNamed<B::FloatElem>,
|
||||
phatom: PhantomData<M>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct GradientsParamsConverter<'a, B: ADBackend> {
|
||||
pub struct GradientsParamsConverter<'a, M: ADModule<B>, B: ADBackend> {
|
||||
grads: B::Gradients,
|
||||
grads_params: &'a mut GradientsParams,
|
||||
phatom: PhantomData<M>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct GradientsParamsChangeDevice<'a, B: ADBackend> {
|
||||
pub struct GradientsParamsChangeDevice<'a, M: ADModule<B>, B: ADBackend> {
|
||||
device: &'a B::Device,
|
||||
grads: &'a mut GradientsParams,
|
||||
phatom: PhantomData<M>,
|
||||
}
|
||||
|
||||
impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleVisitor<B> for GradientsRegister<'a, B, O> {
|
||||
impl<'a, B, M, O> ModuleVisitor<B> for GradientsRegister<'a, M, B, O>
|
||||
where
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
{
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
self.optimizer.register_param_state::<D>(id, self.state)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: ADBackend, O: Optimizer<Backend = B>> ModuleVisitor<B> for GradientsLoader<'a, B, O> {
|
||||
impl<'a, B, M, O> ModuleVisitor<B> for GradientsLoader<'a, M, B, O>
|
||||
where
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
{
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
|
||||
self.optimizer
|
||||
.load_param_state::<D>(id, self.state, &tensor.device())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsConverter<'a, B> {
|
||||
impl<'a, B, M> ModuleVisitor<B> for GradientsParamsConverter<'a, M, B>
|
||||
where
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
{
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>) {
|
||||
if let Some(grad) = tensor.grad_remove(&mut self.grads) {
|
||||
self.grads_params
|
||||
|
@ -48,7 +68,11 @@ impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsConverter<'a, B> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<'a, B: ADBackend> ModuleVisitor<B> for GradientsParamsChangeDevice<'a, B> {
|
||||
impl<'a, B, M> ModuleVisitor<B> for GradientsParamsChangeDevice<'a, M, B>
|
||||
where
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
{
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, _tensor: &Tensor<B, D>) {
|
||||
if let Some(grad) = self.grads.remove::<B::InnerBackend, D>(id) {
|
||||
self.grads
|
||||
|
|
|
@ -28,7 +28,7 @@ where
|
|||
B: Backend,
|
||||
{
|
||||
weight: Param<Tensor<B, 2>>,
|
||||
basic: Param<ModuleBasic<B>>,
|
||||
basic: ModuleBasic<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleComposed<B> {
|
||||
|
@ -36,7 +36,7 @@ impl<B: Backend> ModuleComposed<B> {
|
|||
let weight = Tensor::random(Shape::new([20, 20]), Distribution::Standard);
|
||||
Self {
|
||||
weight: Param::from(weight),
|
||||
basic: Param::from(ModuleBasic::new()),
|
||||
basic: ModuleBasic::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::param::Param;
|
||||
use super::fn_generator::FnGenerator;
|
||||
use crate::module::display;
|
||||
use proc_macro::TokenStream;
|
||||
use quote::quote;
|
||||
|
@ -9,24 +9,22 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
|
||||
let display_fn = display::display_fn(name);
|
||||
|
||||
let param = Param::from_ast(ast);
|
||||
let num_params_fn = param.gen_num_params_fn();
|
||||
let visit = param.gen_visit_fn();
|
||||
let map_mut = param.gen_map_fn();
|
||||
let devices_fn = param.gen_devices_fn();
|
||||
let to_device_fn = param.gen_to_device_fn();
|
||||
let state_fn = param.gen_state_fn();
|
||||
let load_fn = param.gen_load_fn();
|
||||
let inner_fn = param.gen_inner_fn();
|
||||
let from_inner_fn = param.gen_from_inner_fn();
|
||||
let detach_fn = param.gen_detach_fn();
|
||||
let clone_fn = param.gen_clone_fn();
|
||||
let generator = FnGenerator::from_ast(ast);
|
||||
let num_params_fn = generator.gen_num_params_fn();
|
||||
let visit = generator.gen_visit_fn();
|
||||
let map_mut = generator.gen_map_fn();
|
||||
let devices_fn = generator.gen_devices_fn();
|
||||
let to_device_fn = generator.gen_to_device_fn();
|
||||
let state_fn = generator.gen_state_fn();
|
||||
let load_fn = generator.gen_load_fn();
|
||||
let inner_fn = generator.gen_inner_fn();
|
||||
let from_inner_fn = generator.gen_from_inner_fn();
|
||||
let detach_fn = generator.gen_detach_fn();
|
||||
let clone_fn = generator.gen_clone_fn();
|
||||
let generics_names_except_backend = generics_names_except_backend(&ast.generics);
|
||||
|
||||
let gen = quote! {
|
||||
impl #generics burn::module::Module for #name #generics_ty #generics_where {
|
||||
type Backend=B;
|
||||
|
||||
impl #generics burn::module::Module<B> for #name #generics_ty #generics_where {
|
||||
#devices_fn
|
||||
#to_device_fn
|
||||
|
||||
|
@ -40,8 +38,7 @@ pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
|||
#map_mut
|
||||
}
|
||||
|
||||
impl #generics burn::module::ADModule for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
|
||||
type ADBackend=B;
|
||||
impl #generics burn::module::ADModule<B> for #name #generics_ty where B: burn::tensor::backend::ADBackend, {
|
||||
type InnerModule=#name<B::InnerBackend, #generics_names_except_backend>;
|
||||
|
||||
#inner_fn
|
||||
|
|
|
@ -0,0 +1,246 @@
|
|||
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
|
||||
pub struct FnGenerator {
|
||||
fields: Vec<FieldTypeAnalyzer>,
|
||||
}
|
||||
|
||||
impl FnGenerator {
|
||||
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
Self {
|
||||
fields: parse_fields(ast)
|
||||
.into_iter()
|
||||
.map(FieldTypeAnalyzer::new)
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_num_params_fn(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
num_params += burn::module::Module::<B>::num_params(&self.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn num_params(&self) -> usize {
|
||||
let mut num_params = 0;
|
||||
#body
|
||||
num_params
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_visit_fn(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
burn::module::Module::visit(&self.#name, visitor);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn visit<V: burn::module::ModuleVisitor<B>>(&self, visitor: &mut V) {
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_map_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::map(self.#name, mapper);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn map<M: burn::module::ModuleMapper<B>>(self, mapper: &mut M) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_devices_fn(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
devices.append(&mut burn::module::Module::<B>::devices(&self.#name));
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn devices(&self) -> Vec<B::Device> {
|
||||
let mut devices = Vec::new();
|
||||
#body
|
||||
devices
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_to_device_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::<B>::to_device(self.#name, device);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_detach_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::Module::<B>::detach(self.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn detach(self) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_inner_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::ADModule::<B>::inner(self.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
#body
|
||||
|
||||
Self::InnerModule {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_from_inner_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = burn::module::ADModule::<B>::from_inner(module.#name);
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_clone_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let #name = self.#name.clone();
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn clone(&self) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_state_fn(&self) -> TokenStream {
|
||||
let body = self.gen_fields_fn(|name| {
|
||||
quote! {
|
||||
state.register_state(stringify!(#name), burn::module::Module::<B>::state(&self.#name));
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn state(&self) -> burn::module::State<B::FloatElem>
|
||||
{
|
||||
let mut state = burn::module::StateNamed::new();
|
||||
#body
|
||||
burn::module::State::StateNamed(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_load_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_fields_fn_names(|name| {
|
||||
quote! {
|
||||
let state_mod = state.get(stringify!(#name)).ok_or(
|
||||
burn::module::LoadingError::new(format!(
|
||||
"Missing module '{}' from state",
|
||||
stringify!(#name),
|
||||
)))?;
|
||||
let #name = burn::module::Module::<B>::load(self.#name, state_mod).map_err(|err| {
|
||||
burn::module::LoadingError::new(format!("Can't load module {}: {}", stringify!(#name), err))
|
||||
})?;
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn load(self, state: &burn::module::State<B::FloatElem>) -> Result<Self, burn::module::LoadingError>
|
||||
{
|
||||
#body
|
||||
|
||||
Ok(Self {
|
||||
#(#names),*
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_fields_fn_names<F>(&self, func: F) -> (Vec<Ident>, TokenStream)
|
||||
where
|
||||
F: Fn(Ident) -> TokenStream,
|
||||
{
|
||||
let mut body = quote! {};
|
||||
let mut names = Vec::new();
|
||||
|
||||
for field in self.fields.iter() {
|
||||
let name = field.ident();
|
||||
|
||||
names.push(name.clone());
|
||||
body.extend(func(field.ident()));
|
||||
}
|
||||
|
||||
(names, body)
|
||||
}
|
||||
|
||||
pub fn gen_fields_fn<F>(&self, func: F) -> TokenStream
|
||||
where
|
||||
F: Fn(Ident) -> TokenStream,
|
||||
{
|
||||
let mut body = quote! {};
|
||||
|
||||
for field in self.fields.iter() {
|
||||
body.extend(func(field.ident()));
|
||||
}
|
||||
|
||||
body
|
||||
}
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
pub(crate) mod display;
|
||||
pub(crate) mod param;
|
||||
pub(crate) mod fn_generator;
|
||||
|
||||
mod base;
|
||||
|
||||
|
|
|
@ -1,319 +0,0 @@
|
|||
use crate::shared::field::{parse_fields, FieldTypeAnalyzer};
|
||||
use proc_macro2::{Ident, TokenStream};
|
||||
use quote::quote;
|
||||
|
||||
pub struct Param {
|
||||
fields_param: Vec<FieldTypeAnalyzer>,
|
||||
fields_other: Vec<FieldTypeAnalyzer>,
|
||||
}
|
||||
|
||||
impl Param {
|
||||
pub fn from_ast(ast: &syn::DeriveInput) -> Self {
|
||||
let fields_param = parse_fields(ast)
|
||||
.into_iter()
|
||||
.map(FieldTypeAnalyzer::new)
|
||||
.filter(FieldTypeAnalyzer::is_param)
|
||||
.collect();
|
||||
let fields_other = parse_fields(ast)
|
||||
.into_iter()
|
||||
.map(FieldTypeAnalyzer::new)
|
||||
.filter(|val| !FieldTypeAnalyzer::is_param(val))
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
fields_param,
|
||||
fields_other,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_num_params_fn(&self) -> TokenStream {
|
||||
let mut body = quote! {
|
||||
let mut num_params = 0;
|
||||
};
|
||||
for field in self.fields_param.iter() {
|
||||
let name = field.ident();
|
||||
body.extend(quote! {
|
||||
num_params += self.#name.num_params();
|
||||
});
|
||||
}
|
||||
body.extend(quote! {
|
||||
num_params
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn num_params(&self) -> usize {
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_visit_fn(&self) -> TokenStream {
|
||||
let mut body = quote! {};
|
||||
for field in self.fields_param.iter() {
|
||||
let name = field.ident();
|
||||
body.extend(quote! {
|
||||
self.#name.visit(visitor);
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn visit<V: burn::module::ModuleVisitor<Self::Backend>>(&self, visitor: &mut V) {
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_map_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_params_others_fn(
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = self.#name.map(mapper);
|
||||
}
|
||||
},
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = self.#name;
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
quote! {
|
||||
fn map<M: burn::module::ModuleMapper<Self::Backend>>(self, mapper: &mut M) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_devices_fn(&self) -> TokenStream {
|
||||
let mut body = quote! {
|
||||
let mut devices = Vec::new();
|
||||
};
|
||||
for field in self.fields_param.iter() {
|
||||
let name = field.ident();
|
||||
body.extend(quote! {
|
||||
devices.append(&mut self.#name.devices());
|
||||
});
|
||||
}
|
||||
|
||||
body.extend(quote! {
|
||||
devices
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn devices(&self) -> Vec<B::Device> {
|
||||
#body
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_to_device_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_params_others_fn(
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = self.#name.to_device(device);
|
||||
}
|
||||
},
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = self.#name.clone();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
quote! {
|
||||
fn to_device(self, device: &B::Device) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_detach_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_params_others_fn(
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = self.#name.detach();
|
||||
}
|
||||
},
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = self.#name.clone();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
quote! {
|
||||
fn detach(self) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_inner_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_params_others_fn(
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = self.#name.inner();
|
||||
}
|
||||
},
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = self.#name.clone();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
quote! {
|
||||
fn inner(self) -> Self::InnerModule {
|
||||
#body
|
||||
|
||||
Self::InnerModule {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_from_inner_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_params_others_fn(
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = burn::module::ADModule::from_inner(module.#name);
|
||||
}
|
||||
},
|
||||
|name| {
|
||||
quote! {
|
||||
let #name = module.#name.clone();
|
||||
}
|
||||
},
|
||||
);
|
||||
|
||||
quote! {
|
||||
fn from_inner(module: Self::InnerModule) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_clone_fn(&self) -> TokenStream {
|
||||
let mut body = quote! {};
|
||||
let mut names = Vec::new();
|
||||
let mut fields = Vec::new();
|
||||
|
||||
fields.append(&mut self.fields_param.clone());
|
||||
fields.append(&mut self.fields_other.clone());
|
||||
for field in fields {
|
||||
let name = field.ident();
|
||||
names.push(name.clone());
|
||||
|
||||
body.extend(quote! {
|
||||
let #name = self.#name.clone();
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn clone(&self) -> Self {
|
||||
#body
|
||||
|
||||
Self {
|
||||
#(#names),*
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_state_fn(&self) -> TokenStream {
|
||||
let mut body = quote! {
|
||||
let mut state = burn::module::StateNamed::new();
|
||||
};
|
||||
for field in self.fields_param.iter() {
|
||||
let name = field.ident();
|
||||
body.extend(quote! {
|
||||
state.register_state(stringify!(#name), self.#name.state());
|
||||
});
|
||||
}
|
||||
|
||||
quote! {
|
||||
fn state(&self) -> burn::module::State<<Self::Backend as burn::tensor::backend::Backend>::FloatElem>
|
||||
{
|
||||
#body
|
||||
burn::module::State::StateNamed(state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_load_fn(&self) -> TokenStream {
|
||||
let (names, body) = self.gen_params_others_fn(|name| {
|
||||
quote! {
|
||||
let state_mod = state.get(stringify!(#name)).ok_or(
|
||||
burn::module::LoadingError::new(format!(
|
||||
"Missing module '{}' from state",
|
||||
stringify!(#name),
|
||||
)))?;
|
||||
let #name = self.#name.load(state_mod).map_err(|err| {
|
||||
burn::module::LoadingError::new(format!("Can't load module {}: {}", stringify!(#name), err))
|
||||
})?;
|
||||
}
|
||||
}, |name| {
|
||||
quote! {
|
||||
let #name = self.#name.clone();
|
||||
}
|
||||
});
|
||||
|
||||
quote! {
|
||||
fn load(self, state: &burn::module::State<<Self::Backend as burn::tensor::backend::Backend>::FloatElem>) -> Result<Self, burn::module::LoadingError>
|
||||
{
|
||||
#body
|
||||
|
||||
Ok(Self {
|
||||
#(#names),*
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn gen_params_others_fn<FP, FO>(
|
||||
&self,
|
||||
func_params: FP,
|
||||
func_others: FO,
|
||||
) -> (Vec<Ident>, TokenStream)
|
||||
where
|
||||
FP: Fn(Ident) -> TokenStream,
|
||||
FO: Fn(Ident) -> TokenStream,
|
||||
{
|
||||
let mut body = quote! {};
|
||||
let mut names = Vec::new();
|
||||
|
||||
for field in self.fields_param.iter() {
|
||||
let name = field.ident();
|
||||
|
||||
names.push(name.clone());
|
||||
body.extend(func_params(field.ident()));
|
||||
}
|
||||
|
||||
for field in self.fields_other.iter() {
|
||||
let name = field.ident();
|
||||
|
||||
names.push(name.clone());
|
||||
body.extend(func_others(field.ident()));
|
||||
}
|
||||
|
||||
(names, body)
|
||||
}
|
||||
}
|
|
@ -73,10 +73,6 @@ impl FieldTypeAnalyzer {
|
|||
.into_iter()
|
||||
.map(AttributeAnalyzer::new)
|
||||
}
|
||||
|
||||
pub fn is_param(&self) -> bool {
|
||||
self.is_of_type(&["Param", "burn::Param"])
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn parse_fields(ast: &syn::DeriveInput) -> Vec<Field> {
|
||||
|
|
|
@ -4,14 +4,14 @@ use alloc::{format, vec::Vec};
|
|||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvBlock<B: Backend> {
|
||||
conv: Param<nn::conv::Conv2d<B>>,
|
||||
conv: nn::conv::Conv2d<B>,
|
||||
pool: nn::pool::MaxPool2d,
|
||||
activation: nn::GELU,
|
||||
}
|
||||
|
@ -34,7 +34,7 @@ impl<B: Backend> ConvBlock<B> {
|
|||
let activation = nn::GELU::new();
|
||||
|
||||
Self {
|
||||
conv: Param::from(conv),
|
||||
conv,
|
||||
pool,
|
||||
activation,
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ use alloc::{format, vec::Vec};
|
|||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
@ -26,7 +26,7 @@ pub struct MlpConfig {
|
|||
/// Multilayer Perceptron module.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Mlp<B: Backend> {
|
||||
linears: Param<Vec<nn::Linear<B>>>,
|
||||
linears: Vec<nn::Linear<B>>,
|
||||
dropout: nn::Dropout,
|
||||
activation: nn::ReLU,
|
||||
}
|
||||
|
@ -41,7 +41,7 @@ impl<B: Backend> Mlp<B> {
|
|||
}
|
||||
|
||||
Self {
|
||||
linears: Param::from(linears),
|
||||
linears,
|
||||
dropout: nn::DropoutConfig::new(0.3).init(),
|
||||
activation: nn::ReLU::new(),
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ use crate::{
|
|||
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn,
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
@ -30,26 +30,25 @@ pub struct MnistConfig {
|
|||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
mlp: Param<Mlp<B>>,
|
||||
conv: Param<ConvBlock<B>>,
|
||||
input: Param<nn::Linear<B>>,
|
||||
output: Param<nn::Linear<B>>,
|
||||
mlp: Mlp<B>,
|
||||
conv: ConvBlock<B>,
|
||||
input: nn::Linear<B>,
|
||||
output: nn::Linear<B>,
|
||||
num_classes: usize,
|
||||
}
|
||||
|
||||
impl<B: Backend> Model<B> {
|
||||
pub fn new(config: &MnistConfig) -> Self {
|
||||
let mlp = Mlp::new(&config.mlp);
|
||||
|
||||
let input = nn::LinearConfig::new(config.input_size, config.mlp.d_model).init();
|
||||
let output = nn::LinearConfig::new(config.mlp.d_model, config.output_size).init();
|
||||
let conv = ConvBlock::new(&ConvBlockConfig::new([1, 1]));
|
||||
|
||||
Self {
|
||||
mlp: Param::from(mlp),
|
||||
conv: Param::from(conv),
|
||||
output: Param::from(output),
|
||||
input: Param::from(input),
|
||||
mlp,
|
||||
conv,
|
||||
output,
|
||||
input,
|
||||
num_classes: config.output_size,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,44 +1,45 @@
|
|||
use crate::checkpoint::Checkpointer;
|
||||
use crate::LearnerCallback;
|
||||
use burn_core::module::{ADModule, Module};
|
||||
use burn_core::module::ADModule;
|
||||
use burn_core::optim::Optimizer;
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::backend::{ADBackend, Backend};
|
||||
|
||||
/// Learner struct encapsulating all components necessary to train a Neural Network model.
|
||||
///
|
||||
/// To create a learner, use the [builder](crate::train::LearnerBuilder) struct.
|
||||
pub struct Learner<M, O, TO, VO>
|
||||
pub struct Learner<B, M, O, TO, VO>
|
||||
where
|
||||
M: ADModule,
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
{
|
||||
pub(super) model: M,
|
||||
pub(super) optim: O,
|
||||
pub(super) num_epochs: usize,
|
||||
pub(super) callback: Box<dyn LearnerCallback<TO, VO>>,
|
||||
pub(super) checkpoint: Option<usize>,
|
||||
pub(super) checkpointer_model: CheckpointModel<M>,
|
||||
pub(super) checkpointer_optimizer: CheckpointOptim<M>,
|
||||
pub(super) checkpointer_model: CheckpointModel<B>,
|
||||
pub(super) checkpointer_optimizer: CheckpointOptim<B>,
|
||||
pub(super) grad_accumulation: Option<usize>,
|
||||
pub(super) devices: Vec<<M::Backend as Backend>::Device>,
|
||||
pub(super) devices: Vec<B::Device>,
|
||||
}
|
||||
|
||||
type CheckpointModel<M> =
|
||||
Option<Box<dyn Checkpointer<<<M as Module>::Backend as Backend>::FloatElem>>>;
|
||||
type CheckpointOptim<M> =
|
||||
Option<Box<dyn Checkpointer<<<M as Module>::Backend as Backend>::FloatElem>>>;
|
||||
type CheckpointModel<B> = Option<Box<dyn Checkpointer<<B as Backend>::FloatElem>>>;
|
||||
type CheckpointOptim<B> = Option<Box<dyn Checkpointer<<B as Backend>::FloatElem>>>;
|
||||
|
||||
impl<M, O, TO, VO> Learner<M, O, TO, VO>
|
||||
impl<B, M, O, TO, VO> Learner<B, M, O, TO, VO>
|
||||
where
|
||||
VO: Send + Sync + 'static,
|
||||
TO: Send + Sync + 'static,
|
||||
M: ADModule,
|
||||
O: Optimizer<Backend = M::Backend>,
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
{
|
||||
pub(super) fn checkpoint(
|
||||
model: &M,
|
||||
optim: &O,
|
||||
checkpointer_model: &CheckpointModel<M>,
|
||||
checkpointer_optimizer: &CheckpointOptim<M>,
|
||||
checkpointer_model: &CheckpointModel<B>,
|
||||
checkpointer_optimizer: &CheckpointOptim<B>,
|
||||
epoch: usize,
|
||||
) {
|
||||
if let Some(checkpointer) = &checkpointer_model {
|
||||
|
|
|
@ -161,10 +161,10 @@ where
|
|||
}
|
||||
|
||||
/// Create the [learner](Learner) from a [module](ADModule) and an
|
||||
pub fn build<M, O>(self, model: M, optim: O) -> Learner<M, O, T, V>
|
||||
pub fn build<M, O>(self, model: M, optim: O) -> Learner<B, M, O, T, V>
|
||||
where
|
||||
M: ADModule<ADBackend = B>,
|
||||
O: Optimizer<Backend = B>,
|
||||
M: ADModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
{
|
||||
self.init_logger();
|
||||
let callack = Box::new(self.dashboard);
|
||||
|
|
|
@ -2,7 +2,7 @@ use burn_core::{
|
|||
data::dataloader::DataLoader,
|
||||
module::ADModule,
|
||||
optim::{GradientsAccumulator, Optimizer},
|
||||
tensor::backend::Backend,
|
||||
tensor::backend::ADBackend,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -24,9 +24,10 @@ pub struct TrainEpoch<TI> {
|
|||
}
|
||||
|
||||
impl<I> ValidEpoch<I> {
|
||||
pub fn run<M, TO, VO>(&self, model: M, callback: &mut Box<dyn LearnerCallback<TO, VO>>) -> M
|
||||
pub fn run<B, M, TO, VO>(&self, model: M, callback: &mut Box<dyn LearnerCallback<TO, VO>>) -> M
|
||||
where
|
||||
M: ADModule,
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
M::InnerModule: ValidStep<I, VO>,
|
||||
{
|
||||
log::info!("Executing validation step for epoch {}", self.epoch);
|
||||
|
@ -55,15 +56,16 @@ impl<I> ValidEpoch<I> {
|
|||
}
|
||||
|
||||
impl<TI> TrainEpoch<TI> {
|
||||
pub fn run<M, O, TO, VO>(
|
||||
pub fn run<B, M, O, TO, VO>(
|
||||
&self,
|
||||
mut model: M,
|
||||
mut optim: O,
|
||||
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
|
||||
) -> (M, O)
|
||||
where
|
||||
M: ADModule,
|
||||
O: Optimizer<Backend = M::ADBackend>,
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
O: Optimizer<M, B>,
|
||||
M: TrainStep<TI, TO>,
|
||||
{
|
||||
log::info!("Executing training step for epoch {}", self.epoch,);
|
||||
|
@ -108,17 +110,18 @@ impl<TI> TrainEpoch<TI> {
|
|||
}
|
||||
|
||||
impl<TI> TrainEpoch<TI> {
|
||||
pub fn run_multi_device<M, O, TO, VO>(
|
||||
pub fn run_multi_device<B, M, O, TO, VO>(
|
||||
&self,
|
||||
mut model: M,
|
||||
mut optim: O,
|
||||
callback: &mut Box<dyn LearnerCallback<TO, VO>>,
|
||||
devices: Vec<<M::Backend as Backend>::Device>,
|
||||
devices: Vec<B::Device>,
|
||||
) -> (M, O)
|
||||
where
|
||||
O: Optimizer<Backend = M::ADBackend>,
|
||||
B: ADBackend,
|
||||
M: ADModule<B> + 'static,
|
||||
O: Optimizer<M, B>,
|
||||
M: TrainStep<TI, TO>,
|
||||
M: ADModule + 'static,
|
||||
TI: Send + 'static,
|
||||
TO: Send + 'static,
|
||||
{
|
||||
|
|
|
@ -20,9 +20,10 @@ struct Worker<B: ADBackend, M, TI> {
|
|||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B: ADBackend, M, TI> Worker<B, M, TI>
|
||||
impl<B, M, TI> Worker<B, M, TI>
|
||||
where
|
||||
M: ADModule<ADBackend = B> + Clone,
|
||||
B: ADBackend,
|
||||
M: ADModule<B>,
|
||||
{
|
||||
fn register(&self, item: TI, model: &M) {
|
||||
let message = Message {
|
||||
|
@ -63,9 +64,9 @@ where
|
|||
impl<B, M, TI, TO> MultiDevicesTrainStep<B, M, TI, TO>
|
||||
where
|
||||
B: ADBackend,
|
||||
M: ADModule<B> + TrainStep<TI, TO> + Send + Clone + 'static,
|
||||
TI: Send + 'static,
|
||||
TO: Send + 'static,
|
||||
M: ADModule<ADBackend = B> + TrainStep<TI, TO> + Send + Clone + 'static,
|
||||
{
|
||||
pub fn new(devices: &[B::Device]) -> Self
|
||||
where
|
||||
|
|
|
@ -13,13 +13,8 @@ pub struct TrainOutput<TO> {
|
|||
}
|
||||
|
||||
impl<TO> TrainOutput<TO> {
|
||||
pub fn new<M: ADModule>(
|
||||
module: &M,
|
||||
grads: <M::ADBackend as ADBackend>::Gradients,
|
||||
item: TO,
|
||||
) -> Self {
|
||||
pub fn new<B: ADBackend, M: ADModule<B>>(module: &M, grads: B::Gradients, item: TO) -> Self {
|
||||
let grads = GradientsParams::from_grads(grads, module);
|
||||
|
||||
Self { grads, item }
|
||||
}
|
||||
}
|
||||
|
@ -32,12 +27,13 @@ pub trait ValidStep<VI, VO> {
|
|||
fn step(&self, item: VI) -> VO;
|
||||
}
|
||||
|
||||
impl<M, O, TO, VO> Learner<M, O, TO, VO>
|
||||
impl<B, M, O, TO, VO> Learner<B, M, O, TO, VO>
|
||||
where
|
||||
VO: Send + Sync + 'static,
|
||||
TO: Send + Sync + 'static,
|
||||
M: ADModule,
|
||||
O: Optimizer<Backend = M::Backend>,
|
||||
B: ADBackend,
|
||||
M: ADModule<B> + core::fmt::Display,
|
||||
O: Optimizer<M, B>,
|
||||
{
|
||||
pub fn fit<TI, VI>(
|
||||
mut self,
|
||||
|
|
Binary file not shown.
|
@ -5,18 +5,18 @@
|
|||
use alloc::{format, vec::Vec};
|
||||
|
||||
use burn::{
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn::{self, conv::Conv2dPaddingConfig, BatchNorm},
|
||||
tensor::{backend::Backend, Tensor},
|
||||
};
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
conv1: Param<ConvBlock<B>>,
|
||||
conv2: Param<ConvBlock<B>>,
|
||||
conv3: Param<ConvBlock<B>>,
|
||||
fc1: Param<nn::Linear<B>>,
|
||||
fc2: Param<nn::Linear<B>>,
|
||||
conv1: ConvBlock<B>,
|
||||
conv2: ConvBlock<B>,
|
||||
conv3: ConvBlock<B>,
|
||||
fc1: nn::Linear<B>,
|
||||
fc2: nn::Linear<B>,
|
||||
activation: nn::GELU,
|
||||
}
|
||||
|
||||
|
@ -37,11 +37,11 @@ impl<B: Backend> Model<B> {
|
|||
.init();
|
||||
|
||||
Self {
|
||||
conv1: Param::from(conv1),
|
||||
conv2: Param::from(conv2),
|
||||
conv3: Param::from(conv3),
|
||||
fc1: Param::from(fc1),
|
||||
fc2: Param::from(fc2),
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
fc1,
|
||||
fc2,
|
||||
activation: nn::GELU::new(),
|
||||
}
|
||||
}
|
||||
|
@ -66,8 +66,8 @@ impl<B: Backend> Model<B> {
|
|||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvBlock<B: Backend> {
|
||||
conv: Param<nn::conv::Conv2d<B>>,
|
||||
norm: Param<BatchNorm<B, 2>>,
|
||||
conv: nn::conv::Conv2d<B>,
|
||||
norm: BatchNorm<B, 2>,
|
||||
activation: nn::GELU,
|
||||
}
|
||||
|
||||
|
@ -79,8 +79,8 @@ impl<B: Backend> ConvBlock<B> {
|
|||
let norm = nn::BatchNormConfig::new(channels[1]).init();
|
||||
|
||||
Self {
|
||||
conv: Param::from(conv),
|
||||
norm: Param::from(norm),
|
||||
conv,
|
||||
norm,
|
||||
activation: nn::GELU::new(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::data::MNISTBatch;
|
||||
|
||||
use burn::{
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn::{self, conv::Conv2dPaddingConfig, loss::CrossEntropyLoss, BatchNorm},
|
||||
tensor::{
|
||||
backend::{ADBackend, Backend},
|
||||
|
@ -12,12 +12,12 @@ use burn::{
|
|||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct Model<B: Backend> {
|
||||
conv1: Param<ConvBlock<B>>,
|
||||
conv2: Param<ConvBlock<B>>,
|
||||
conv3: Param<ConvBlock<B>>,
|
||||
conv1: ConvBlock<B>,
|
||||
conv2: ConvBlock<B>,
|
||||
conv3: ConvBlock<B>,
|
||||
dropout: nn::Dropout,
|
||||
fc1: Param<nn::Linear<B>>,
|
||||
fc2: Param<nn::Linear<B>>,
|
||||
fc1: nn::Linear<B>,
|
||||
fc2: nn::Linear<B>,
|
||||
activation: nn::GELU,
|
||||
}
|
||||
|
||||
|
@ -39,11 +39,11 @@ impl<B: Backend> Model<B> {
|
|||
let dropout = nn::DropoutConfig::new(0.3).init();
|
||||
|
||||
Self {
|
||||
conv1: Param::from(conv1),
|
||||
conv2: Param::from(conv2),
|
||||
conv3: Param::from(conv3),
|
||||
fc1: Param::from(fc1),
|
||||
fc2: Param::from(fc2),
|
||||
conv1,
|
||||
conv2,
|
||||
conv3,
|
||||
fc1,
|
||||
fc2,
|
||||
dropout,
|
||||
activation: nn::GELU::new(),
|
||||
}
|
||||
|
@ -83,8 +83,8 @@ impl<B: Backend> Model<B> {
|
|||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct ConvBlock<B: Backend> {
|
||||
conv: Param<nn::conv::Conv2d<B>>,
|
||||
norm: Param<BatchNorm<B, 2>>,
|
||||
conv: nn::conv::Conv2d<B>,
|
||||
norm: BatchNorm<B, 2>,
|
||||
activation: nn::GELU,
|
||||
}
|
||||
|
||||
|
@ -96,8 +96,8 @@ impl<B: Backend> ConvBlock<B> {
|
|||
let norm = nn::BatchNormConfig::new(channels[1]).init();
|
||||
|
||||
Self {
|
||||
conv: Param::from(conv),
|
||||
norm: Param::from(norm),
|
||||
conv,
|
||||
norm,
|
||||
activation: nn::GELU::new(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::data::{TextClassificationInferenceBatch, TextClassificationTrainingBatch};
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn::{
|
||||
loss::CrossEntropyLoss,
|
||||
transformer::{TransformerEncoder, TransformerEncoderConfig, TransformerEncoderInput},
|
||||
|
@ -22,10 +22,10 @@ pub struct TextClassificationModelConfig {
|
|||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TextClassificationModel<B: Backend> {
|
||||
transformer: Param<TransformerEncoder<B>>,
|
||||
embedding_token: Param<Embedding<B>>,
|
||||
embedding_pos: Param<Embedding<B>>,
|
||||
output: Param<Linear<B>>,
|
||||
transformer: TransformerEncoder<B>,
|
||||
embedding_token: Embedding<B>,
|
||||
embedding_pos: Embedding<B>,
|
||||
output: Linear<B>,
|
||||
n_classes: usize,
|
||||
max_seq_length: usize,
|
||||
}
|
||||
|
@ -40,10 +40,10 @@ impl TextClassificationModelConfig {
|
|||
EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init();
|
||||
|
||||
TextClassificationModel {
|
||||
transformer: Param::from(transformer),
|
||||
embedding_token: Param::from(embedding_token),
|
||||
embedding_pos: Param::from(embedding_pos),
|
||||
output: Param::from(output),
|
||||
transformer,
|
||||
embedding_token,
|
||||
embedding_pos,
|
||||
output,
|
||||
n_classes: self.n_classes,
|
||||
max_seq_length: self.max_seq_length,
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::data::TrainingTextGenerationBatch;
|
||||
use burn::{
|
||||
config::Config,
|
||||
module::{Module, Param},
|
||||
module::Module,
|
||||
nn::{
|
||||
attention::generate_autoregressive_mask,
|
||||
loss::CrossEntropyLoss,
|
||||
|
@ -23,10 +23,10 @@ pub struct TextGenerationModelConfig {
|
|||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TextGenerationModel<B: Backend> {
|
||||
transformer: Param<TransformerEncoder<B>>,
|
||||
embedding_token: Param<Embedding<B>>,
|
||||
embedding_pos: Param<Embedding<B>>,
|
||||
output: Param<Linear<B>>,
|
||||
transformer: TransformerEncoder<B>,
|
||||
embedding_token: Embedding<B>,
|
||||
embedding_pos: Embedding<B>,
|
||||
output: Linear<B>,
|
||||
vocab_size: usize,
|
||||
pad_token: usize,
|
||||
max_seq_length: usize,
|
||||
|
@ -42,10 +42,10 @@ impl TextGenerationModelConfig {
|
|||
EmbeddingConfig::new(self.max_seq_length, self.transformer.d_model).init();
|
||||
|
||||
TextGenerationModel {
|
||||
transformer: Param::from(transformer),
|
||||
embedding_token: Param::from(embedding_token),
|
||||
embedding_pos: Param::from(embedding_pos),
|
||||
output: Param::from(output),
|
||||
transformer,
|
||||
embedding_token,
|
||||
embedding_pos,
|
||||
output,
|
||||
vocab_size: self.vocab_size,
|
||||
pad_token: self.pad_token,
|
||||
max_seq_length: self.max_seq_length,
|
||||
|
|
Loading…
Reference in New Issue