Update module doc + add sponsors section (#267)

This commit is contained in:
Nathaniel Simard 2023-04-02 17:37:01 -04:00 committed by GitHub
parent d3887bcd3d
commit 2c151a5570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 73 additions and 33 deletions

View File

@ -28,6 +28,7 @@ __Sections__
* [Config](#config) * [Config](#config)
* [Learner](#learner) * [Learner](#learner)
* [no_std support](#no_std-support) * [no_std support](#no_std-support)
* [Sponsors](#sponsors)
* [License](#license) * [License](#license)
## Features ## Features
@ -123,18 +124,17 @@ fn main() {
#### Module #### Module
The `Module` derive allows you to create your own neural network modules, similar to PyTorch. The `Module` derive allows you to create your own neural network modules, similar to PyTorch.
Note that the `Module` derive generates all the necessary methods to make your type essentially a parameter container. The derive function only generates the necessary methods to essentially act as a parameter container for your type, it makes no assumptions about how the forward pass is declared.
It makes no assumptions about how the forward function is declared.
```rust ```rust
use burn::nn; use burn::nn;
use burn::module::{Param, Module}; use burn::module::Module;
use burn::tensor::backend::Backend; use burn::tensor::backend::Backend;
#[derive(Module, Debug)] #[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> { pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Param<Linear<B>>, linear_inner: Linear<B>,
linear_outer: Param<Linear<B>>, linear_outer: Linear<B>,
dropout: Dropout, dropout: Dropout,
gelu: GELU, gelu: GELU,
} }
@ -150,7 +150,8 @@ impl<B: Backend> PositionWiseFeedForward<B> {
} }
``` ```
Note that only the fields wrapped inside `Param` are updated during training, and the other fields should implement the `Clone` trait. Note that all fields declared in the struct must also implement the `Module` trait.
The `Tensor` struct doesn't implement `Module`, but `Param<Tensor<B, D>>` does.
#### Config #### Config
@ -189,6 +190,7 @@ In order to create a learner, you must use the `LearnerBuilder`.
```rust ```rust
use burn::train::LearnerBuilder; use burn::train::LearnerBuilder;
use burn::train::metric::{AccuracyMetric, LossMetric}; use burn::train::metric::{AccuracyMetric, LossMetric};
use burn::record::DefaultRecordSettings;
fn main() { fn main() {
let dataloader_train = ...; let dataloader_train = ...;
@ -202,7 +204,7 @@ fn main() {
.metric_valid_plot(AccuracyMetric::new()) .metric_valid_plot(AccuracyMetric::new())
.metric_train(LossMetric::new()) .metric_train(LossMetric::new())
.metric_valid(LossMetric::new()) .metric_valid(LossMetric::new())
.with_file_checkpointer::<f32>(2) .with_file_checkpointer::<DefaultRecordSettings>(2)
.num_epochs(10) .num_epochs(10)
.build(model, optim); .build(model, optim);
@ -222,6 +224,15 @@ Additionally `burn-core` and `burn-tensor` crates support `no_std` with `alloc`
Note, under the `no_std` mode, a random seed is generated during the build time if the seed is not initialized by `Backend::seed` method. Note, under the `no_std` mode, a random seed is generated during the build time if the seed is not initialized by `Backend::seed` method.
Additionally, [spin::mutex::Mutex](https://docs.rs/spin/latest/spin/mutex/struct.Mutex.html) is used in place of [std::sync::Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html) under the `no_std` mode. Additionally, [spin::mutex::Mutex](https://docs.rs/spin/latest/spin/mutex/struct.Mutex.html) is used in place of [std::sync::Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html) under the `no_std` mode.
## Sponsors
You can sponsor the founder of Burn from his [GitHub Sponsors profile](https://github.com/sponsors/nathanielsimard).
The Burn-rs organization doesn't yet have a fiscal entity, but other sponsor methods might become available as the project grows.
Thanks to all current sponsors 🙏.
<a href="https://github.com/smallstepman"><img src="https://github.com/smallstepman.png" width="60px" style="border-radius: 50%;" alt="nathanielsimard" /></a>
## License ## License
Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0). Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).

View File

@ -3,7 +3,6 @@ use alloc::vec::Vec;
use crate as burn; use crate as burn;
use crate::config::Config; use crate::config::Config;
use crate::constant;
use crate::module::Module; use crate::module::Module;
use crate::module::Param; use crate::module::Param;
use crate::nn::Initializer; use crate::nn::Initializer;
@ -34,7 +33,7 @@ pub struct Conv1dConfig {
} }
/// Padding configuration for 1D convolution [config](Conv1dConfig). /// Padding configuration for 1D convolution [config](Conv1dConfig).
#[derive(Config, Debug)] #[derive(Module, Config, Debug)]
pub enum Conv1dPaddingConfig { pub enum Conv1dPaddingConfig {
/// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be /// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input. /// the same as the input.
@ -43,8 +42,6 @@ pub enum Conv1dPaddingConfig {
Explicit(usize), Explicit(usize),
} }
constant!(Conv1dPaddingConfig);
/// Applies a 1D convolution over input tensors. /// Applies a 1D convolution over input tensors.
/// ///
/// # Params /// # Params

View File

@ -3,7 +3,6 @@ use alloc::vec::Vec;
use crate as burn; use crate as burn;
use crate::config::Config; use crate::config::Config;
use crate::constant;
use crate::module::Module; use crate::module::Module;
use crate::module::Param; use crate::module::Param;
use crate::nn::Initializer; use crate::nn::Initializer;
@ -33,7 +32,7 @@ pub struct Conv2dConfig {
} }
/// Padding configuration for 2D convolution [config](Conv2dConfig). /// Padding configuration for 2D convolution [config](Conv2dConfig).
#[derive(Config, Debug)] #[derive(Module, Config, Debug)]
pub enum Conv2dPaddingConfig { pub enum Conv2dPaddingConfig {
/// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be /// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input. /// the same as the input.
@ -44,8 +43,6 @@ pub enum Conv2dPaddingConfig {
Explicit(usize, usize), Explicit(usize, usize),
} }
constant!(Conv2dPaddingConfig);
/// Applies a 2D convolution over input tensors. /// Applies a 2D convolution over input tensors.
/// ///
/// # Params /// # Params

View File

@ -1,7 +1,7 @@
use crate as burn; use crate as burn;
use crate::config::Config; use crate::config::Config;
use crate::constant; use crate::module::Module;
use crate::tensor::backend::Backend; use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor}; use crate::tensor::{Distribution, Tensor};
@ -18,13 +18,11 @@ pub struct DropoutConfig {
/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580). /// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).
/// ///
/// The input is also scaled during training to `1 / (1 - prob_keep)`. /// The input is also scaled during training to `1 / (1 - prob_keep)`.
#[derive(Clone, Debug)] #[derive(Module, Clone, Debug)]
pub struct Dropout { pub struct Dropout {
prob: f64, prob: f64,
} }
constant!(Dropout);
impl DropoutConfig { impl DropoutConfig {
/// Initialize a new [dropout](Dropout) module. /// Initialize a new [dropout](Dropout) module.
pub fn init(&self) -> Dropout { pub fn init(&self) -> Dropout {

View File

@ -1,15 +1,13 @@
use crate as burn; use crate as burn;
use crate::constant; use crate::module::Module;
use crate::tensor::backend::Backend; use crate::tensor::backend::Backend;
use crate::tensor::Tensor; use crate::tensor::Tensor;
/// Applies the Gaussian Error Linear Units function element-wise. /// Applies the Gaussian Error Linear Units function element-wise.
#[derive(Clone, Debug, Default)] #[derive(Module, Clone, Debug, Default)]
pub struct GELU {} pub struct GELU {}
constant!(GELU);
impl GELU { impl GELU {
/// Create the module. /// Create the module.
pub fn new() -> Self { pub fn new() -> Self {

View File

@ -1,6 +1,7 @@
use crate::{self as burn, constant}; use crate as burn;
use crate::config::Config; use crate::config::Config;
use crate::module::Module;
use crate::nn::conv::Conv2dPaddingConfig; use crate::nn::conv::Conv2dPaddingConfig;
use crate::tensor::backend::Backend; use crate::tensor::backend::Backend;
use crate::tensor::Tensor; use crate::tensor::Tensor;
@ -25,15 +26,13 @@ pub struct MaxPool2dConfig {
pub type MaxPool2dPaddingConfig = Conv2dPaddingConfig; pub type MaxPool2dPaddingConfig = Conv2dPaddingConfig;
/// Applies a 2D max pooling over input tensors. /// Applies a 2D max pooling over input tensors.
#[derive(Debug, Clone)] #[derive(Module, Debug, Clone)]
pub struct MaxPool2d { pub struct MaxPool2d {
stride: [usize; 2], stride: [usize; 2],
kernel_size: [usize; 2], kernel_size: [usize; 2],
padding: MaxPool2dPaddingConfig, padding: MaxPool2dPaddingConfig,
} }
constant!(MaxPool2d);
impl MaxPool2dConfig { impl MaxPool2dConfig {
/// Initialize a new [max pool 2d](MaxPool2d) module. /// Initialize a new [max pool 2d](MaxPool2d) module.
pub fn init(&self) -> MaxPool2d { pub fn init(&self) -> MaxPool2d {

View File

@ -1,17 +1,15 @@
use crate as burn; use crate as burn;
use crate::constant; use crate::module::Module;
use crate::tensor::backend::Backend; use crate::tensor::backend::Backend;
use crate::tensor::Tensor; use crate::tensor::Tensor;
/// Applies the rectified linear unit function element-wise: /// Applies the rectified linear unit function element-wise:
/// ///
/// `y = max(0, x)` /// `y = max(0, x)`
#[derive(Clone, Debug, Default)] #[derive(Module, Clone, Debug, Default)]
pub struct ReLU {} pub struct ReLU {}
constant!(ReLU);
impl ReLU { impl ReLU {
/// Create the module. /// Create the module.
pub fn new() -> Self { pub fn new() -> Self {

View File

@ -10,8 +10,6 @@ use module::module_derive_impl;
#[proc_macro_derive(Module)] #[proc_macro_derive(Module)]
pub fn module_derive(input: TokenStream) -> TokenStream { pub fn module_derive(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap(); let input = syn::parse(input).unwrap();
// panic!("{}", gen);
module_derive_impl(&input) module_derive_impl(&input)
} }

View File

@ -2,9 +2,53 @@ use super::{fn_generator::FnGenerator, record::RecordGenerator};
use crate::module::display; use crate::module::display;
use proc_macro::TokenStream; use proc_macro::TokenStream;
use quote::quote; use quote::quote;
use syn::parse_quote;
pub(crate) fn constant_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let (_, generics_ty, generics_where) = ast.generics.split_for_impl();
let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};
let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::ADBackend >};
let mut generics_module = ast.generics.clone();
let mut generics_module_ad = ast.generics.clone();
for param in backend.params.into_iter() {
generics_module.params.push(param);
}
for param in backend_ad.params.into_iter() {
generics_module_ad.params.push(param);
}
let (generics_module, _, _) = generics_module.split_for_impl();
let (generics_module_ad, _, _) = generics_module_ad.split_for_impl();
let gen = quote! {
impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {
burn::constant!(module);
}
impl #generics_module_ad burn::module::ADModule<B> for #name #generics_ty #generics_where {
burn::constant!(ad_module, #name #generics_ty);
}
};
gen.into()
}
pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream { pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident; let name = &ast.ident;
let has_backend = ast
.generics
.type_params()
.map(|param| param.ident == "B")
.reduce(|accum, is_backend| is_backend || accum)
.unwrap_or(false);
if !has_backend {
return constant_derive_impl(ast);
}
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl(); let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
let display_fn = display::display_fn(name); let display_fn = display::display_fn(name);