mirror of https://github.com/tracel-ai/burn.git
Update module doc + add sponsors section (#267)
This commit is contained in:
parent
d3887bcd3d
commit
2c151a5570
25
README.md
25
README.md
|
@ -28,6 +28,7 @@ __Sections__
|
||||||
* [Config](#config)
|
* [Config](#config)
|
||||||
* [Learner](#learner)
|
* [Learner](#learner)
|
||||||
* [no_std support](#no_std-support)
|
* [no_std support](#no_std-support)
|
||||||
|
* [Sponsors](#sponsors)
|
||||||
* [License](#license)
|
* [License](#license)
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
@ -123,18 +124,17 @@ fn main() {
|
||||||
#### Module
|
#### Module
|
||||||
|
|
||||||
The `Module` derive allows you to create your own neural network modules, similar to PyTorch.
|
The `Module` derive allows you to create your own neural network modules, similar to PyTorch.
|
||||||
Note that the `Module` derive generates all the necessary methods to make your type essentially a parameter container.
|
The derive function only generates the necessary methods to essentially act as a parameter container for your type, it makes no assumptions about how the forward pass is declared.
|
||||||
It makes no assumptions about how the forward function is declared.
|
|
||||||
|
|
||||||
```rust
|
```rust
|
||||||
use burn::nn;
|
use burn::nn;
|
||||||
use burn::module::{Param, Module};
|
use burn::module::Module;
|
||||||
use burn::tensor::backend::Backend;
|
use burn::tensor::backend::Backend;
|
||||||
|
|
||||||
#[derive(Module, Debug)]
|
#[derive(Module, Debug)]
|
||||||
pub struct PositionWiseFeedForward<B: Backend> {
|
pub struct PositionWiseFeedForward<B: Backend> {
|
||||||
linear_inner: Param<Linear<B>>,
|
linear_inner: Linear<B>,
|
||||||
linear_outer: Param<Linear<B>>,
|
linear_outer: Linear<B>,
|
||||||
dropout: Dropout,
|
dropout: Dropout,
|
||||||
gelu: GELU,
|
gelu: GELU,
|
||||||
}
|
}
|
||||||
|
@ -150,7 +150,8 @@ impl<B: Backend> PositionWiseFeedForward<B> {
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
Note that only the fields wrapped inside `Param` are updated during training, and the other fields should implement the `Clone` trait.
|
Note that all fields declared in the struct must also implement the `Module` trait.
|
||||||
|
The `Tensor` struct doesn't implement `Module`, but `Param<Tensor<B, D>>` does.
|
||||||
|
|
||||||
#### Config
|
#### Config
|
||||||
|
|
||||||
|
@ -189,6 +190,7 @@ In order to create a learner, you must use the `LearnerBuilder`.
|
||||||
```rust
|
```rust
|
||||||
use burn::train::LearnerBuilder;
|
use burn::train::LearnerBuilder;
|
||||||
use burn::train::metric::{AccuracyMetric, LossMetric};
|
use burn::train::metric::{AccuracyMetric, LossMetric};
|
||||||
|
use burn::record::DefaultRecordSettings;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
let dataloader_train = ...;
|
let dataloader_train = ...;
|
||||||
|
@ -202,7 +204,7 @@ fn main() {
|
||||||
.metric_valid_plot(AccuracyMetric::new())
|
.metric_valid_plot(AccuracyMetric::new())
|
||||||
.metric_train(LossMetric::new())
|
.metric_train(LossMetric::new())
|
||||||
.metric_valid(LossMetric::new())
|
.metric_valid(LossMetric::new())
|
||||||
.with_file_checkpointer::<f32>(2)
|
.with_file_checkpointer::<DefaultRecordSettings>(2)
|
||||||
.num_epochs(10)
|
.num_epochs(10)
|
||||||
.build(model, optim);
|
.build(model, optim);
|
||||||
|
|
||||||
|
@ -222,6 +224,15 @@ Additionally `burn-core` and `burn-tensor` crates support `no_std` with `alloc`
|
||||||
Note, under the `no_std` mode, a random seed is generated during the build time if the seed is not initialized by `Backend::seed` method.
|
Note, under the `no_std` mode, a random seed is generated during the build time if the seed is not initialized by `Backend::seed` method.
|
||||||
Additionally, [spin::mutex::Mutex](https://docs.rs/spin/latest/spin/mutex/struct.Mutex.html) is used in place of [std::sync::Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html) under the `no_std` mode.
|
Additionally, [spin::mutex::Mutex](https://docs.rs/spin/latest/spin/mutex/struct.Mutex.html) is used in place of [std::sync::Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html) under the `no_std` mode.
|
||||||
|
|
||||||
|
## Sponsors
|
||||||
|
|
||||||
|
You can sponsor the founder of Burn from his [GitHub Sponsors profile](https://github.com/sponsors/nathanielsimard).
|
||||||
|
The Burn-rs organization doesn't yet have a fiscal entity, but other sponsor methods might become available as the project grows.
|
||||||
|
|
||||||
|
Thanks to all current sponsors 🙏.
|
||||||
|
|
||||||
|
<a href="https://github.com/smallstepman"><img src="https://github.com/smallstepman.png" width="60px" style="border-radius: 50%;" alt="nathanielsimard" /></a>
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).
|
Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).
|
||||||
|
|
|
@ -3,7 +3,6 @@ use alloc::vec::Vec;
|
||||||
use crate as burn;
|
use crate as burn;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::constant;
|
|
||||||
use crate::module::Module;
|
use crate::module::Module;
|
||||||
use crate::module::Param;
|
use crate::module::Param;
|
||||||
use crate::nn::Initializer;
|
use crate::nn::Initializer;
|
||||||
|
@ -34,7 +33,7 @@ pub struct Conv1dConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Padding configuration for 1D convolution [config](Conv1dConfig).
|
/// Padding configuration for 1D convolution [config](Conv1dConfig).
|
||||||
#[derive(Config, Debug)]
|
#[derive(Module, Config, Debug)]
|
||||||
pub enum Conv1dPaddingConfig {
|
pub enum Conv1dPaddingConfig {
|
||||||
/// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be
|
/// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be
|
||||||
/// the same as the input.
|
/// the same as the input.
|
||||||
|
@ -43,8 +42,6 @@ pub enum Conv1dPaddingConfig {
|
||||||
Explicit(usize),
|
Explicit(usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
constant!(Conv1dPaddingConfig);
|
|
||||||
|
|
||||||
/// Applies a 1D convolution over input tensors.
|
/// Applies a 1D convolution over input tensors.
|
||||||
///
|
///
|
||||||
/// # Params
|
/// # Params
|
||||||
|
|
|
@ -3,7 +3,6 @@ use alloc::vec::Vec;
|
||||||
use crate as burn;
|
use crate as burn;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::constant;
|
|
||||||
use crate::module::Module;
|
use crate::module::Module;
|
||||||
use crate::module::Param;
|
use crate::module::Param;
|
||||||
use crate::nn::Initializer;
|
use crate::nn::Initializer;
|
||||||
|
@ -33,7 +32,7 @@ pub struct Conv2dConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Padding configuration for 2D convolution [config](Conv2dConfig).
|
/// Padding configuration for 2D convolution [config](Conv2dConfig).
|
||||||
#[derive(Config, Debug)]
|
#[derive(Module, Config, Debug)]
|
||||||
pub enum Conv2dPaddingConfig {
|
pub enum Conv2dPaddingConfig {
|
||||||
/// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be
|
/// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be
|
||||||
/// the same as the input.
|
/// the same as the input.
|
||||||
|
@ -44,8 +43,6 @@ pub enum Conv2dPaddingConfig {
|
||||||
Explicit(usize, usize),
|
Explicit(usize, usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
constant!(Conv2dPaddingConfig);
|
|
||||||
|
|
||||||
/// Applies a 2D convolution over input tensors.
|
/// Applies a 2D convolution over input tensors.
|
||||||
///
|
///
|
||||||
/// # Params
|
/// # Params
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use crate as burn;
|
use crate as burn;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::constant;
|
use crate::module::Module;
|
||||||
use crate::tensor::backend::Backend;
|
use crate::tensor::backend::Backend;
|
||||||
use crate::tensor::{Distribution, Tensor};
|
use crate::tensor::{Distribution, Tensor};
|
||||||
|
|
||||||
|
@ -18,13 +18,11 @@ pub struct DropoutConfig {
|
||||||
/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).
|
/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).
|
||||||
///
|
///
|
||||||
/// The input is also scaled during training to `1 / (1 - prob_keep)`.
|
/// The input is also scaled during training to `1 / (1 - prob_keep)`.
|
||||||
#[derive(Clone, Debug)]
|
#[derive(Module, Clone, Debug)]
|
||||||
pub struct Dropout {
|
pub struct Dropout {
|
||||||
prob: f64,
|
prob: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
constant!(Dropout);
|
|
||||||
|
|
||||||
impl DropoutConfig {
|
impl DropoutConfig {
|
||||||
/// Initialize a new [dropout](Dropout) module.
|
/// Initialize a new [dropout](Dropout) module.
|
||||||
pub fn init(&self) -> Dropout {
|
pub fn init(&self) -> Dropout {
|
||||||
|
|
|
@ -1,15 +1,13 @@
|
||||||
use crate as burn;
|
use crate as burn;
|
||||||
|
|
||||||
use crate::constant;
|
use crate::module::Module;
|
||||||
use crate::tensor::backend::Backend;
|
use crate::tensor::backend::Backend;
|
||||||
use crate::tensor::Tensor;
|
use crate::tensor::Tensor;
|
||||||
|
|
||||||
/// Applies the Gaussian Error Linear Units function element-wise.
|
/// Applies the Gaussian Error Linear Units function element-wise.
|
||||||
#[derive(Clone, Debug, Default)]
|
#[derive(Module, Clone, Debug, Default)]
|
||||||
pub struct GELU {}
|
pub struct GELU {}
|
||||||
|
|
||||||
constant!(GELU);
|
|
||||||
|
|
||||||
impl GELU {
|
impl GELU {
|
||||||
/// Create the module.
|
/// Create the module.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::{self as burn, constant};
|
use crate as burn;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
use crate::module::Module;
|
||||||
use crate::nn::conv::Conv2dPaddingConfig;
|
use crate::nn::conv::Conv2dPaddingConfig;
|
||||||
use crate::tensor::backend::Backend;
|
use crate::tensor::backend::Backend;
|
||||||
use crate::tensor::Tensor;
|
use crate::tensor::Tensor;
|
||||||
|
@ -25,15 +26,13 @@ pub struct MaxPool2dConfig {
|
||||||
pub type MaxPool2dPaddingConfig = Conv2dPaddingConfig;
|
pub type MaxPool2dPaddingConfig = Conv2dPaddingConfig;
|
||||||
|
|
||||||
/// Applies a 2D max pooling over input tensors.
|
/// Applies a 2D max pooling over input tensors.
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Module, Debug, Clone)]
|
||||||
pub struct MaxPool2d {
|
pub struct MaxPool2d {
|
||||||
stride: [usize; 2],
|
stride: [usize; 2],
|
||||||
kernel_size: [usize; 2],
|
kernel_size: [usize; 2],
|
||||||
padding: MaxPool2dPaddingConfig,
|
padding: MaxPool2dPaddingConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
constant!(MaxPool2d);
|
|
||||||
|
|
||||||
impl MaxPool2dConfig {
|
impl MaxPool2dConfig {
|
||||||
/// Initialize a new [max pool 2d](MaxPool2d) module.
|
/// Initialize a new [max pool 2d](MaxPool2d) module.
|
||||||
pub fn init(&self) -> MaxPool2d {
|
pub fn init(&self) -> MaxPool2d {
|
||||||
|
|
|
@ -1,17 +1,15 @@
|
||||||
use crate as burn;
|
use crate as burn;
|
||||||
|
|
||||||
use crate::constant;
|
use crate::module::Module;
|
||||||
use crate::tensor::backend::Backend;
|
use crate::tensor::backend::Backend;
|
||||||
use crate::tensor::Tensor;
|
use crate::tensor::Tensor;
|
||||||
|
|
||||||
/// Applies the rectified linear unit function element-wise:
|
/// Applies the rectified linear unit function element-wise:
|
||||||
///
|
///
|
||||||
/// `y = max(0, x)`
|
/// `y = max(0, x)`
|
||||||
#[derive(Clone, Debug, Default)]
|
#[derive(Module, Clone, Debug, Default)]
|
||||||
pub struct ReLU {}
|
pub struct ReLU {}
|
||||||
|
|
||||||
constant!(ReLU);
|
|
||||||
|
|
||||||
impl ReLU {
|
impl ReLU {
|
||||||
/// Create the module.
|
/// Create the module.
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
|
|
|
@ -10,8 +10,6 @@ use module::module_derive_impl;
|
||||||
#[proc_macro_derive(Module)]
|
#[proc_macro_derive(Module)]
|
||||||
pub fn module_derive(input: TokenStream) -> TokenStream {
|
pub fn module_derive(input: TokenStream) -> TokenStream {
|
||||||
let input = syn::parse(input).unwrap();
|
let input = syn::parse(input).unwrap();
|
||||||
|
|
||||||
// panic!("{}", gen);
|
|
||||||
module_derive_impl(&input)
|
module_derive_impl(&input)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,53 @@ use super::{fn_generator::FnGenerator, record::RecordGenerator};
|
||||||
use crate::module::display;
|
use crate::module::display;
|
||||||
use proc_macro::TokenStream;
|
use proc_macro::TokenStream;
|
||||||
use quote::quote;
|
use quote::quote;
|
||||||
|
use syn::parse_quote;
|
||||||
|
|
||||||
|
pub(crate) fn constant_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
||||||
|
let name = &ast.ident;
|
||||||
|
let (_, generics_ty, generics_where) = ast.generics.split_for_impl();
|
||||||
|
|
||||||
|
let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};
|
||||||
|
let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::ADBackend >};
|
||||||
|
|
||||||
|
let mut generics_module = ast.generics.clone();
|
||||||
|
let mut generics_module_ad = ast.generics.clone();
|
||||||
|
|
||||||
|
for param in backend.params.into_iter() {
|
||||||
|
generics_module.params.push(param);
|
||||||
|
}
|
||||||
|
for param in backend_ad.params.into_iter() {
|
||||||
|
generics_module_ad.params.push(param);
|
||||||
|
}
|
||||||
|
let (generics_module, _, _) = generics_module.split_for_impl();
|
||||||
|
let (generics_module_ad, _, _) = generics_module_ad.split_for_impl();
|
||||||
|
|
||||||
|
let gen = quote! {
|
||||||
|
impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {
|
||||||
|
burn::constant!(module);
|
||||||
|
}
|
||||||
|
|
||||||
|
impl #generics_module_ad burn::module::ADModule<B> for #name #generics_ty #generics_where {
|
||||||
|
burn::constant!(ad_module, #name #generics_ty);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
gen.into()
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
|
||||||
let name = &ast.ident;
|
let name = &ast.ident;
|
||||||
|
let has_backend = ast
|
||||||
|
.generics
|
||||||
|
.type_params()
|
||||||
|
.map(|param| param.ident == "B")
|
||||||
|
.reduce(|accum, is_backend| is_backend || accum)
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
if !has_backend {
|
||||||
|
return constant_derive_impl(ast);
|
||||||
|
}
|
||||||
|
|
||||||
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
|
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
|
||||||
|
|
||||||
let display_fn = display::display_fn(name);
|
let display_fn = display::display_fn(name);
|
||||||
|
|
Loading…
Reference in New Issue