Update module doc + add sponsors section (#267)

This commit is contained in:
Nathaniel Simard 2023-04-02 17:37:01 -04:00 committed by GitHub
parent d3887bcd3d
commit 2c151a5570
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 73 additions and 33 deletions

View File

@ -28,6 +28,7 @@ __Sections__
* [Config](#config)
* [Learner](#learner)
* [no_std support](#no_std-support)
* [Sponsors](#sponsors)
* [License](#license)
## Features
@ -123,18 +124,17 @@ fn main() {
#### Module
The `Module` derive allows you to create your own neural network modules, similar to PyTorch.
Note that the `Module` derive generates all the necessary methods to make your type essentially a parameter container.
It makes no assumptions about how the forward function is declared.
The derive function only generates the necessary methods to essentially act as a parameter container for your type, it makes no assumptions about how the forward pass is declared.
```rust
use burn::nn;
use burn::module::{Param, Module};
use burn::module::Module;
use burn::tensor::backend::Backend;
#[derive(Module, Debug)]
pub struct PositionWiseFeedForward<B: Backend> {
linear_inner: Param<Linear<B>>,
linear_outer: Param<Linear<B>>,
linear_inner: Linear<B>,
linear_outer: Linear<B>,
dropout: Dropout,
gelu: GELU,
}
@ -150,7 +150,8 @@ impl<B: Backend> PositionWiseFeedForward<B> {
}
```
Note that only the fields wrapped inside `Param` are updated during training, and the other fields should implement the `Clone` trait.
Note that all fields declared in the struct must also implement the `Module` trait.
The `Tensor` struct doesn't implement `Module`, but `Param<Tensor<B, D>>` does.
#### Config
@ -189,6 +190,7 @@ In order to create a learner, you must use the `LearnerBuilder`.
```rust
use burn::train::LearnerBuilder;
use burn::train::metric::{AccuracyMetric, LossMetric};
use burn::record::DefaultRecordSettings;
fn main() {
let dataloader_train = ...;
@ -202,7 +204,7 @@ fn main() {
.metric_valid_plot(AccuracyMetric::new())
.metric_train(LossMetric::new())
.metric_valid(LossMetric::new())
.with_file_checkpointer::<f32>(2)
.with_file_checkpointer::<DefaultRecordSettings>(2)
.num_epochs(10)
.build(model, optim);
@ -222,6 +224,15 @@ Additionally `burn-core` and `burn-tensor` crates support `no_std` with `alloc`
Note, under the `no_std` mode, a random seed is generated during the build time if the seed is not initialized by `Backend::seed` method.
Additionally, [spin::mutex::Mutex](https://docs.rs/spin/latest/spin/mutex/struct.Mutex.html) is used in place of [std::sync::Mutex](https://doc.rust-lang.org/std/sync/struct.Mutex.html) under the `no_std` mode.
## Sponsors
You can sponsor the founder of Burn from his [GitHub Sponsors profile](https://github.com/sponsors/nathanielsimard).
The Burn-rs organization doesn't yet have a fiscal entity, but other sponsor methods might become available as the project grows.
Thanks to all current sponsors 🙏.
<a href="https://github.com/smallstepman"><img src="https://github.com/smallstepman.png" width="60px" style="border-radius: 50%;" alt="nathanielsimard" /></a>
## License
Burn is distributed under the terms of both the MIT license and the Apache License (Version 2.0).

View File

@ -3,7 +3,6 @@ use alloc::vec::Vec;
use crate as burn;
use crate::config::Config;
use crate::constant;
use crate::module::Module;
use crate::module::Param;
use crate::nn::Initializer;
@ -34,7 +33,7 @@ pub struct Conv1dConfig {
}
/// Padding configuration for 1D convolution [config](Conv1dConfig).
#[derive(Config, Debug)]
#[derive(Module, Config, Debug)]
pub enum Conv1dPaddingConfig {
/// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input.
@ -43,8 +42,6 @@ pub enum Conv1dPaddingConfig {
Explicit(usize),
}
constant!(Conv1dPaddingConfig);
/// Applies a 1D convolution over input tensors.
///
/// # Params

View File

@ -3,7 +3,6 @@ use alloc::vec::Vec;
use crate as burn;
use crate::config::Config;
use crate::constant;
use crate::module::Module;
use crate::module::Param;
use crate::nn::Initializer;
@ -33,7 +32,7 @@ pub struct Conv2dConfig {
}
/// Padding configuration for 2D convolution [config](Conv2dConfig).
#[derive(Config, Debug)]
#[derive(Module, Config, Debug)]
pub enum Conv2dPaddingConfig {
/// Dynamicaly calculate the amount of padding necessary to ensure that the output size will be
/// the same as the input.
@ -44,8 +43,6 @@ pub enum Conv2dPaddingConfig {
Explicit(usize, usize),
}
constant!(Conv2dPaddingConfig);
/// Applies a 2D convolution over input tensors.
///
/// # Params

View File

@ -1,7 +1,7 @@
use crate as burn;
use crate::config::Config;
use crate::constant;
use crate::module::Module;
use crate::tensor::backend::Backend;
use crate::tensor::{Distribution, Tensor};
@ -18,13 +18,11 @@ pub struct DropoutConfig {
/// [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580).
///
/// The input is also scaled during training to `1 / (1 - prob_keep)`.
#[derive(Clone, Debug)]
#[derive(Module, Clone, Debug)]
pub struct Dropout {
prob: f64,
}
constant!(Dropout);
impl DropoutConfig {
/// Initialize a new [dropout](Dropout) module.
pub fn init(&self) -> Dropout {

View File

@ -1,15 +1,13 @@
use crate as burn;
use crate::constant;
use crate::module::Module;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
/// Applies the Gaussian Error Linear Units function element-wise.
#[derive(Clone, Debug, Default)]
#[derive(Module, Clone, Debug, Default)]
pub struct GELU {}
constant!(GELU);
impl GELU {
/// Create the module.
pub fn new() -> Self {

View File

@ -1,6 +1,7 @@
use crate::{self as burn, constant};
use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::nn::conv::Conv2dPaddingConfig;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
@ -25,15 +26,13 @@ pub struct MaxPool2dConfig {
pub type MaxPool2dPaddingConfig = Conv2dPaddingConfig;
/// Applies a 2D max pooling over input tensors.
#[derive(Debug, Clone)]
#[derive(Module, Debug, Clone)]
pub struct MaxPool2d {
stride: [usize; 2],
kernel_size: [usize; 2],
padding: MaxPool2dPaddingConfig,
}
constant!(MaxPool2d);
impl MaxPool2dConfig {
/// Initialize a new [max pool 2d](MaxPool2d) module.
pub fn init(&self) -> MaxPool2d {

View File

@ -1,17 +1,15 @@
use crate as burn;
use crate::constant;
use crate::module::Module;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
/// Applies the rectified linear unit function element-wise:
///
/// `y = max(0, x)`
#[derive(Clone, Debug, Default)]
#[derive(Module, Clone, Debug, Default)]
pub struct ReLU {}
constant!(ReLU);
impl ReLU {
/// Create the module.
pub fn new() -> Self {

View File

@ -10,8 +10,6 @@ use module::module_derive_impl;
#[proc_macro_derive(Module)]
pub fn module_derive(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
// panic!("{}", gen);
module_derive_impl(&input)
}

View File

@ -2,9 +2,53 @@ use super::{fn_generator::FnGenerator, record::RecordGenerator};
use crate::module::display;
use proc_macro::TokenStream;
use quote::quote;
use syn::parse_quote;
pub(crate) fn constant_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let (_, generics_ty, generics_where) = ast.generics.split_for_impl();
let backend: syn::Generics = parse_quote! { <B: burn::tensor::backend::Backend >};
let backend_ad: syn::Generics = parse_quote! { <B: burn::tensor::backend::ADBackend >};
let mut generics_module = ast.generics.clone();
let mut generics_module_ad = ast.generics.clone();
for param in backend.params.into_iter() {
generics_module.params.push(param);
}
for param in backend_ad.params.into_iter() {
generics_module_ad.params.push(param);
}
let (generics_module, _, _) = generics_module.split_for_impl();
let (generics_module_ad, _, _) = generics_module_ad.split_for_impl();
let gen = quote! {
impl #generics_module burn::module::Module<B> for #name #generics_ty #generics_where {
burn::constant!(module);
}
impl #generics_module_ad burn::module::ADModule<B> for #name #generics_ty #generics_where {
burn::constant!(ad_module, #name #generics_ty);
}
};
gen.into()
}
pub(crate) fn module_derive_impl(ast: &syn::DeriveInput) -> TokenStream {
let name = &ast.ident;
let has_backend = ast
.generics
.type_params()
.map(|param| param.ident == "B")
.reduce(|accum, is_backend| is_backend || accum)
.unwrap_or(false);
if !has_backend {
return constant_derive_impl(ast);
}
let (generics, generics_ty, generics_where) = ast.generics.split_for_impl();
let display_fn = display::display_fn(name);