diff --git a/burn-autodiff/src/backend.rs b/burn-autodiff/src/backend.rs index 678116087..ab38a9c8b 100644 --- a/burn-autodiff/src/backend.rs +++ b/burn-autodiff/src/backend.rs @@ -1,6 +1,7 @@ use crate::{grads::Gradients, graph::backward::backward, tensor::ADTensor}; use burn_tensor::backend::{ADBackend, Backend}; +/// A decorator for a backend that enables automatic differentiation. #[derive(Clone, Copy, Debug, Default)] pub struct ADBackendDecorator { _b: B, diff --git a/burn-autodiff/src/grads.rs b/burn-autodiff/src/grads.rs index 5112a047e..37c7dd39b 100644 --- a/burn-autodiff/src/grads.rs +++ b/burn-autodiff/src/grads.rs @@ -5,6 +5,7 @@ use crate::{ tensor::ADTensor, }; +/// Gradient identifier. pub type GradID = String; /// Gradients container used during the backward pass. @@ -15,6 +16,7 @@ pub struct Gradients { type TensorPrimitive = ::TensorPrimitive; impl Gradients { + /// Creates a new gradients container. pub fn new( root_node: NodeRef, root_tensor: TensorPrimitive, @@ -28,7 +30,8 @@ impl Gradients { ); gradients } - /// Consume the gradients for a given tensor. + + /// Consumes the gradients for a given tensor. /// /// Each tensor should be consumed exactly 1 time if its gradients are only required during the /// backward pass, otherwise, it may be consume multiple times. @@ -48,7 +51,7 @@ impl Gradients { } } - /// Remove a grad tensor from the container. + /// Removes a grad tensor from the container. pub fn remove( &mut self, tensor: &ADTensor, @@ -58,6 +61,7 @@ impl Gradients { .map(|tensor| tensor.into_primitive()) } + /// Gets a grad tensor from the container. pub fn get( &self, tensor: &ADTensor, @@ -67,6 +71,7 @@ impl Gradients { .map(|tensor| tensor.into_primitive()) } + /// Registers a grad tensor in the container. pub fn register( &mut self, node: NodeRef, diff --git a/burn-autodiff/src/lib.rs b/burn-autodiff/src/lib.rs index 8bad2b1b7..d529ee3bd 100644 --- a/burn-autodiff/src/lib.rs +++ b/burn-autodiff/src/lib.rs @@ -1,3 +1,12 @@ +#![warn(missing_docs)] + +//! # Burn Autodiff +//! +//! This autodiff library is a part of the Burn project. It is a standalone crate +//! that can be used to perform automatic differentiation on tensors. It is +//! designed to be used with the Burn Tensor crate, but it can be used with any +//! tensor library that implements the `Backend` trait. + #[macro_use] extern crate derive_new; diff --git a/burn-autodiff/src/tests/mod.rs b/burn-autodiff/src/tests/mod.rs index 903b5c343..a8e2f4046 100644 --- a/burn-autodiff/src/tests/mod.rs +++ b/burn-autodiff/src/tests/mod.rs @@ -1,3 +1,5 @@ +#![allow(missing_docs)] + mod add; mod aggregation; mod avgpool1d; diff --git a/burn-common/README.md b/burn-common/README.md index c33c9dac9..955c579b4 100644 --- a/burn-common/README.md +++ b/burn-common/README.md @@ -1,3 +1,6 @@ -The `burn-common` package hosts code that _must_ be shared between burn packages (with `std` or `no_std` enabled). No other code should be placed in this package unless unavoidable. +# Burn Common + +The `burn-common` package hosts code that _must_ be shared between burn packages (with `std` or +`no_std` enabled). No other code should be placed in this package unless unavoidable. The package must build with `cargo build --no-default-features` as well. diff --git a/burn-common/src/id.rs b/burn-common/src/id.rs index 2ed52aa9e..6ac50e082 100644 --- a/burn-common/src/id.rs +++ b/burn-common/src/id.rs @@ -4,9 +4,11 @@ use crate::rand::{get_seeded_rng, Rng, SEED}; use uuid::{Builder, Bytes}; +/// Simple ID generator. pub struct IdGenerator {} impl IdGenerator { + /// Generates a new ID in the form of a UUID. pub fn generate() -> String { let mut seed = SEED.lock().unwrap(); let mut rng = if let Some(rng_seeded) = seed.as_ref() { diff --git a/burn-common/src/lib.rs b/burn-common/src/lib.rs index c2e90eb44..39d062b97 100644 --- a/burn-common/src/lib.rs +++ b/burn-common/src/lib.rs @@ -1,7 +1,18 @@ #![cfg_attr(not(feature = "std"), no_std)] +#![warn(missing_docs)] +//! # Burn Common Library +//! +//! This library contains common types used by other Burn crates that must be shared. + +/// Id module contains types for unique identifiers. pub mod id; + +/// Rand module contains types for random number generation for non-std environments and for +/// std environments. pub mod rand; + +/// Stub module contains types for stubs for non-std environments and for std environments. pub mod stub; extern crate alloc; diff --git a/burn-common/src/rand.rs b/burn-common/src/rand.rs index 2559c42bb..d71b37063 100644 --- a/burn-common/src/rand.rs +++ b/burn-common/src/rand.rs @@ -9,12 +9,14 @@ use crate::stub::Mutex; #[cfg(not(feature = "std"))] use const_random::const_random; +/// Returns a seeded random number generator using entropy. #[cfg(feature = "std")] #[inline(always)] pub fn get_seeded_rng() -> StdRng { StdRng::from_entropy() } +/// Returns a seeded random number generator using a pre-generated seed. #[cfg(not(feature = "std"))] #[inline(always)] pub fn get_seeded_rng() -> StdRng { diff --git a/burn-common/src/stub.rs b/burn-common/src/stub.rs index f05499802..93d0715f5 100644 --- a/burn-common/src/stub.rs +++ b/burn-common/src/stub.rs @@ -2,13 +2,19 @@ use spin::{ Mutex as MutexImported, MutexGuard, RwLock as RwLockImported, RwLockReadGuard, RwLockWriteGuard, }; -// Mutex wrapper to make spin::Mutex API compatible with std::sync::Mutex to swap +/// A mutual exclusion primitive useful for protecting shared data +/// +/// This mutex will block threads waiting for the lock to become available. The +/// mutex can also be statically initialized or created via a [Mutex::new] +/// +/// [Mutex] wrapper to make `spin::Mutex` API compatible with `std::sync::Mutex` to swap #[derive(Debug)] pub struct Mutex { inner: MutexImported, } impl Mutex { + /// Creates a new mutex in an unlocked state ready for use. #[inline(always)] pub const fn new(value: T) -> Self { Self { @@ -16,19 +22,24 @@ impl Mutex { } } + /// Locks the mutex blocking the current thread until it is able to do so. #[inline(always)] pub fn lock(&self) -> Result, alloc::string::String> { Ok(self.inner.lock()) } } -// Mutex wrapper to make spin::Mutex API compatible with std::sync::Mutex to swap +/// A reader-writer lock which is exclusively locked for writing or shared for reading. +/// This reader-writer lock will block threads waiting for the lock to become available. +/// The lock can also be statically initialized or created via a [RwLock::new] +/// [RwLock] wrapper to make `spin::RwLock` API compatible with `std::sync::RwLock` to swap #[derive(Debug)] pub struct RwLock { inner: RwLockImported, } impl RwLock { + /// Creates a new reader-writer lock in an unlocked state ready for use. #[inline(always)] pub const fn new(value: T) -> Self { Self { @@ -36,17 +47,23 @@ impl RwLock { } } + /// Locks this rwlock with shared read access, blocking the current thread + /// until it can be acquired. #[inline(always)] pub fn read(&self) -> Result, alloc::string::String> { Ok(self.inner.read()) } + /// Locks this rwlock with exclusive write access, blocking the current thread + /// until it can be acquired. #[inline(always)] pub fn write(&self) -> Result, alloc::string::String> { Ok(self.inner.write()) } } -// ThreadId stub when no std is available +/// A unique identifier for a running thread. +/// +/// This module is a stub when no std is available to swap with std::thread::ThreadId. #[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)] pub struct ThreadId(core::num::NonZeroU64); diff --git a/burn-core/src/config.rs b/burn-core/src/config.rs index f23b6f3f5..7389b865b 100644 --- a/burn-core/src/config.rs +++ b/burn-core/src/config.rs @@ -1,9 +1,13 @@ use alloc::{format, string::String, string::ToString}; pub use burn_derive::Config; +/// Configuration IO error. #[derive(Debug)] pub enum ConfigError { + /// Invalid format. InvalidFormat(String), + + /// File not found. FileNotFound(String), } @@ -28,12 +32,31 @@ impl core::fmt::Display for ConfigError { #[cfg(feature = "std")] impl std::error::Error for ConfigError {} +/// Configuration trait. pub trait Config: serde::Serialize + serde::de::DeserializeOwned { + /// Saves the configuration to a file. + /// + /// # Arguments + /// + /// * `file` - File to save the configuration to. + /// + /// # Returns + /// + /// The output of the save operation. #[cfg(feature = "std")] fn save(&self, file: &str) -> std::io::Result<()> { std::fs::write(file, config_to_json(self)) } + /// Loads the configuration from a file. + /// + /// # Arguments + /// + /// * `file` - File to load the configuration from. + /// + /// # Returns + /// + /// The loaded configuration. #[cfg(feature = "std")] fn load(file: &str) -> Result { let content = std::fs::read_to_string(file) @@ -41,6 +64,15 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned { config_from_str(&content) } + /// Loads the configuration from a binary buffer. + /// + /// # Arguments + /// + /// * `data` - Binary buffer to load the configuration from. + /// + /// # Returns + /// + /// The loaded configuration. fn load_binary(data: &[u8]) -> Result { let content = core::str::from_utf8(data).map_err(|_| { ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string()) @@ -49,6 +81,15 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned { } } +/// Converts a configuration to a JSON string. +/// +/// # Arguments +/// +/// * `config` - Configuration to convert. +/// +/// # Returns +/// +/// The JSON string. pub fn config_to_json(config: &C) -> String { serde_json::to_string_pretty(config).unwrap() } diff --git a/burn-core/src/data/dataloader/base.rs b/burn-core/src/data/dataloader/base.rs index 799389ac7..0222248ad 100644 --- a/burn-core/src/data/dataloader/base.rs +++ b/burn-core/src/data/dataloader/base.rs @@ -1,16 +1,24 @@ pub use crate::data::dataset::{Dataset, DatasetIterator}; use core::iter::Iterator; +/// A progress struct that can be used to track the progress of a data loader. #[derive(Clone, Debug)] pub struct Progress { + /// The number of items that have been processed. pub items_processed: usize, + + /// The total number of items that need to be processed. pub items_total: usize, } +/// A data loader iterator that can be used to iterate over a data loader. pub trait DataLoaderIterator: Iterator { + /// Returns the progress of the data loader. fn progress(&self) -> Progress; } +/// A data loader that can be used to iterate over a dataset. pub trait DataLoader { + /// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader. fn iter<'a>(&'a self) -> Box + 'a>; } diff --git a/burn-core/src/data/dataloader/batch.rs b/burn-core/src/data/dataloader/batch.rs index 07d2dc5d3..3a7cb1aa2 100644 --- a/burn-core/src/data/dataloader/batch.rs +++ b/burn-core/src/data/dataloader/batch.rs @@ -5,12 +5,14 @@ use super::{ use burn_dataset::{transform::PartialDataset, Dataset}; use std::sync::Arc; +/// A data loader that can be used to iterate over a dataset in batches. pub struct BatchDataLoader { strategy: Box>, dataset: Arc>, batcher: Arc>, } +/// A data loader iterator that can be used to iterate over a data loader. struct BatchDataloaderIterator { current_index: usize, strategy: Box>, @@ -19,6 +21,17 @@ struct BatchDataloaderIterator { } impl BatchDataLoader { + /// Creates a new batch data loader. + /// + /// # Arguments + /// + /// * `strategy` - The batch strategy. + /// * `dataset` - The dataset. + /// * `batcher` - The batcher. + /// + /// # Returns + /// + /// The batch data loader. pub fn new( strategy: Box>, dataset: Arc>, @@ -31,11 +44,24 @@ impl BatchDataLoader { } } } + impl BatchDataLoader where I: Send + Sync + Clone + 'static, O: Send + Sync + Clone + 'static, { + /// Creates a new multi-threaded batch data loader. + /// + /// # Arguments + /// + /// * `strategy` - The batch strategy. + /// * `dataset` - The dataset. + /// * `batcher` - The batcher. + /// * `num_threads` - The number of threads. + /// + /// # Returns + /// + /// The multi-threaded batch data loader. pub fn multi_thread( strategy: Box>, dataset: Arc>, @@ -65,6 +91,17 @@ impl DataLoader for BatchDataLoader { } impl BatchDataloaderIterator { + /// Creates a new batch data loader iterator. + /// + /// # Arguments + /// + /// * `strategy` - The batch strategy. + /// * `dataset` - The dataset. + /// * `batcher` - The batcher. + /// + /// # Returns + /// + /// The batch data loader iterator. pub fn new( strategy: Box>, dataset: Arc>, diff --git a/burn-core/src/data/dataloader/batcher.rs b/burn-core/src/data/dataloader/batcher.rs index bd52a84a5..724a2e3a5 100644 --- a/burn-core/src/data/dataloader/batcher.rs +++ b/burn-core/src/data/dataloader/batcher.rs @@ -1,4 +1,14 @@ +/// A trait for batching items of type `I` into items of type `O`. pub trait Batcher: Send + Sync { + /// Batches the given items. + /// + /// # Arguments + /// + /// * `items` - The items to batch. + /// + /// # Returns + /// + /// The batched items. fn batch(&self, items: Vec) -> O; } diff --git a/burn-core/src/data/dataloader/builder.rs b/burn-core/src/data/dataloader/builder.rs index 0adfd2ed7..859385b68 100644 --- a/burn-core/src/data/dataloader/builder.rs +++ b/burn-core/src/data/dataloader/builder.rs @@ -2,6 +2,7 @@ use super::{batcher::Batcher, BatchDataLoader, BatchStrategy, DataLoader, FixBat use burn_dataset::{transform::ShuffledDataset, Dataset}; use std::sync::Arc; +/// A builder for data loaders. pub struct DataLoaderBuilder { strategy: Option>>, batcher: Arc>, @@ -14,6 +15,15 @@ where I: Send + Sync + Clone + std::fmt::Debug + 'static, O: Send + Sync + Clone + std::fmt::Debug + 'static, { + /// Creates a new data loader builder. + /// + /// # Arguments + /// + /// * `batcher` - The batcher. + /// + /// # Returns + /// + /// The data loader builder. pub fn new(batcher: B) -> Self where B: Batcher + 'static, @@ -26,21 +36,58 @@ where } } + /// Sets the batch size to a fix number.The [fix batch strategy](FixBatchStrategy) + /// will be used. + /// + /// # Arguments + /// + /// * `batch_size` - The batch size. + /// + /// # Returns + /// + /// The data loader builder. pub fn batch_size(mut self, batch_size: usize) -> Self { self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size))); self } + /// Sets the seed for shuffling. + /// + /// # Arguments + /// + /// * `seed` - The seed. + /// + /// # Returns + /// + /// The data loader builder. pub fn shuffle(mut self, seed: u64) -> Self { self.shuffle = Some(seed); self } + /// Sets the number of workers. + /// + /// # Arguments + /// + /// * `num_workers` - The number of workers. + /// + /// # Returns + /// + /// The data loader builder. pub fn num_workers(mut self, num_workers: usize) -> Self { self.num_threads = Some(num_workers); self } + /// Builds the data loader. + /// + /// # Arguments + /// + /// * `dataset` - The dataset. + /// + /// # Returns + /// + /// The data loader. pub fn build(self, dataset: D) -> Arc> where D: Dataset + 'static, diff --git a/burn-core/src/data/dataloader/mod.rs b/burn-core/src/data/dataloader/mod.rs index 5d82f0f24..8b49d824e 100644 --- a/burn-core/src/data/dataloader/mod.rs +++ b/burn-core/src/data/dataloader/mod.rs @@ -4,6 +4,7 @@ mod builder; mod multithread; mod strategy; +/// Module for batching items. pub mod batcher; pub use base::*; diff --git a/burn-core/src/data/dataloader/multithread.rs b/burn-core/src/data/dataloader/multithread.rs index 1d4abcd5d..c9e029022 100644 --- a/burn-core/src/data/dataloader/multithread.rs +++ b/burn-core/src/data/dataloader/multithread.rs @@ -3,15 +3,20 @@ use std::collections::HashMap; use std::sync::{mpsc, Arc}; use std::thread; -static MAX_QUEUED_ITEMS: usize = 100; +const MAX_QUEUED_ITEMS: usize = 100; +/// A multi-threaded data loader that can be used to iterate over a dataset. pub struct MultiThreadDataLoader { dataloaders: Vec + Send + Sync>>, } +/// A message that can be sent between threads. #[derive(Debug)] pub enum Message { + /// A batch of items. Batch(usize, O, Progress), + + /// The thread is done. Done, } @@ -23,6 +28,15 @@ struct MultiThreadsDataloaderIterator { } impl MultiThreadDataLoader { + /// Creates a new multi-threaded data loader. + /// + /// # Arguments + /// + /// * `dataloaders` - The data loaders. + /// + /// # Returns + /// + /// The multi-threaded data loader. pub fn new(dataloaders: Vec + Send + Sync>>) -> Self { Self { dataloaders } } diff --git a/burn-core/src/data/dataloader/strategy.rs b/burn-core/src/data/dataloader/strategy.rs index 01594817d..9e09207ed 100644 --- a/burn-core/src/data/dataloader/strategy.rs +++ b/burn-core/src/data/dataloader/strategy.rs @@ -1,15 +1,47 @@ +/// A strategy to batch items. pub trait BatchStrategy: Send + Sync { + /// Adds an item to the strategy. + /// + /// # Arguments + /// + /// * `item` - The item to add. fn add(&mut self, item: I); + + /// Batches the items. + /// + /// # Arguments + /// + /// * `force` - Whether to force batching. + /// + /// # Returns + /// + /// The batched items. fn batch(&mut self, force: bool) -> Option>; + + /// Creates a new strategy of the same type. + /// + /// # Returns + /// + /// The new strategy. fn new_like(&self) -> Box>; } +/// A strategy to batch items with a fixed batch size. pub struct FixBatchStrategy { items: Vec, batch_size: usize, } impl FixBatchStrategy { + /// Creates a new strategy to batch items with a fixed batch size. + /// + /// # Arguments + /// + /// * `batch_size` - The batch size. + /// + /// # Returns + /// + /// The strategy. pub fn new(batch_size: usize) -> Self { FixBatchStrategy { items: Vec::with_capacity(batch_size), diff --git a/burn-core/src/data/mod.rs b/burn-core/src/data/mod.rs index 2df32acd8..5489cae64 100644 --- a/burn-core/src/data/mod.rs +++ b/burn-core/src/data/mod.rs @@ -1,4 +1,7 @@ +/// Dataloader module. pub mod dataloader; + +/// Dataset module. pub mod dataset { pub use burn_dataset::*; } diff --git a/burn-core/src/grad_clipping/base.rs b/burn-core/src/grad_clipping/base.rs index 05c5247e1..2018e7d1b 100644 --- a/burn-core/src/grad_clipping/base.rs +++ b/burn-core/src/grad_clipping/base.rs @@ -3,13 +3,22 @@ use crate as burn; use crate::{config::Config, tensor::Tensor}; use burn_tensor::{backend::Backend, ElementConversion}; +/// Gradient Clipping provides a way to mitigate exploding gradients #[derive(Config)] pub enum GradientClippingConfig { + /// Clip the gradient by value. Value(f32), + + /// Clip the gradient by norm. Norm(f32), } impl GradientClippingConfig { + /// Initialize the gradient clipping. + /// + /// # Returns + /// + /// The gradient clipping. pub fn init(&self) -> GradientClipping { match self { GradientClippingConfig::Value(val) => GradientClipping::Value(*val), @@ -22,11 +31,23 @@ impl GradientClippingConfig { /// by clipping every component of the gradient by value or by norm during /// backpropagation. pub enum GradientClipping { + /// Clip the gradient by value. Value(f32), + + /// Clip the gradient by norm. Norm(f32), } impl GradientClipping { + /// Clip the gradient. + /// + /// # Arguments + /// + /// * `grad` - The gradient to clip. + /// + /// # Returns + /// + /// The clipped gradient. pub fn clip_gradient(&self, grad: Tensor) -> Tensor { match self { GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold), diff --git a/burn-core/src/lib.rs b/burn-core/src/lib.rs index b60f94938..07cbe89bd 100644 --- a/burn-core/src/lib.rs +++ b/burn-core/src/lib.rs @@ -1,23 +1,39 @@ #![cfg_attr(not(feature = "std"), no_std)] +#![warn(missing_docs)] + +//! The core crate of Burn. #[macro_use] extern crate derive_new; +/// The configuration module. pub mod config; +/// Data module. #[cfg(feature = "std")] pub mod data; +/// Optimizer module. #[cfg(feature = "std")] pub mod optim; +/// Learning rate scheduler module. #[cfg(feature = "std")] pub mod lr_scheduler; +/// Gradient clipping module. pub mod grad_clipping; + +/// Module for the neural network module. pub mod module; + +/// Neural network module. pub mod nn; + +/// Module for the recorder. pub mod record; + +/// Module for the tensor. pub mod tensor; extern crate alloc; diff --git a/burn-core/src/lr_scheduler/mod.rs b/burn-core/src/lr_scheduler/mod.rs index 74590e28e..026625005 100644 --- a/burn-core/src/lr_scheduler/mod.rs +++ b/burn-core/src/lr_scheduler/mod.rs @@ -1,4 +1,7 @@ +/// Constant learning rate scheduler pub mod constant; + +/// Noam Learning rate schedule pub mod noam; mod base; diff --git a/burn-core/src/module/base.rs b/burn-core/src/module/base.rs index dc89b0816..f601e0ba6 100644 --- a/burn-core/src/module/base.rs +++ b/burn-core/src/module/base.rs @@ -178,16 +178,21 @@ pub trait Module: Clone + Send + Sync + core::fmt::Debug { fn into_record(self) -> Self::Record; } +/// Module visitor trait. pub trait ModuleVisitor { + /// Visit a tensor in the module. fn visit(&mut self, id: &ParamId, tensor: &Tensor); } +/// Module mapper trait. pub trait ModuleMapper { + /// Map a tensor in the module. fn map(&mut self, id: &ParamId, tensor: Tensor) -> Tensor; } /// Module with auto-differentiation backend. pub trait ADModule: Module + Send + Sync + core::fmt::Debug { + /// Inner module without auto-differentiation. type InnerModule: Module; /// Get the same module, but on the inner backend without auto-differentiation. diff --git a/burn-core/src/module/param/base.rs b/burn-core/src/module/param/base.rs index 407019d35..6b405f352 100644 --- a/burn-core/src/module/param/base.rs +++ b/burn-core/src/module/param/base.rs @@ -15,6 +15,11 @@ impl core::fmt::Display for Param { } impl Param { + /// Gets the parameter value. + /// + /// # Returns + /// + /// The parameter value. pub fn val(&self) -> T { self.value.clone() } diff --git a/burn-core/src/module/param/constant.rs b/burn-core/src/module/param/constant.rs index d4b6cbaea..6d3be1209 100644 --- a/burn-core/src/module/param/constant.rs +++ b/burn-core/src/module/param/constant.rs @@ -36,7 +36,7 @@ impl Record for ConstantRecord { item } } - +/// Constant macro. #[macro_export] macro_rules! constant { (module) => { diff --git a/burn-core/src/module/param/id.rs b/burn-core/src/module/param/id.rs index 92f88d442..2828cf38c 100644 --- a/burn-core/src/module/param/id.rs +++ b/burn-core/src/module/param/id.rs @@ -1,6 +1,7 @@ use alloc::string::{String, ToString}; use burn_common::id::IdGenerator; +/// Parameter ID. #[derive(Debug, Hash, PartialEq, Eq, Clone)] pub struct ParamId { value: String, @@ -27,11 +28,14 @@ impl Default for ParamId { } impl ParamId { + /// Create a new parameter ID. pub fn new() -> Self { Self { value: IdGenerator::generate(), } } + + /// Convert the parameter ID into a string. pub fn into_string(self) -> String { self.value } diff --git a/burn-core/src/nn/attention/mask.rs b/burn-core/src/nn/attention/mask.rs index 976d34b98..ddf9b9c4a 100644 --- a/burn-core/src/nn/attention/mask.rs +++ b/burn-core/src/nn/attention/mask.rs @@ -22,8 +22,12 @@ pub fn generate_autoregressive_mask( mask.equal_elem(1_i64.elem::()) } +/// Generate a padding attention mask. pub struct GeneratePaddingMask { + /// The generated tensor. pub tensor: Tensor, + + /// The generated mask. pub mask: Tensor, } diff --git a/burn-core/src/nn/cache/base.rs b/burn-core/src/nn/cache/base.rs index 34a15c3c6..322c65c81 100644 --- a/burn-core/src/nn/cache/base.rs +++ b/burn-core/src/nn/cache/base.rs @@ -6,11 +6,17 @@ pub(crate) enum CacheState { Empty, } +/// A cache for a tensor. pub struct TensorCache { pub(crate) state: CacheState>, } impl TensorCache { + /// Creates a new empty cache. + /// + /// # Returns + /// + /// The empty cache. pub fn empty() -> Self { Self { state: CacheState::Empty, diff --git a/burn-core/src/nn/initializer.rs b/burn-core/src/nn/initializer.rs index 73b3df0d2..688f69c0f 100644 --- a/burn-core/src/nn/initializer.rs +++ b/burn-core/src/nn/initializer.rs @@ -11,27 +11,60 @@ use crate as burn; #[derive(Config, Debug, PartialEq)] pub enum Initializer { /// Fills tensor with specified value everywhere - Constant { value: f64 }, + Constant { + /// The value to fill the tensor with + value: f64, + }, /// Fills tensor with 1s everywhere Ones, /// Fills tensor with 0s everywhere Zeros, /// Fills tensor with values drawn uniformly between specified values - Uniform { min: f64, max: f64 }, + Uniform { + /// The minimum value to draw from + min: f64, + + /// The maximum value to draw from + max: f64, + }, /// Fills tensor with values drawn from normal distribution with specified mean and std - Normal { mean: f64, std: f64 }, + Normal { + /// The mean of the normal distribution + mean: f64, + + /// The standard deviation of the normal distribution + std: f64, + }, /// Fills tensor with values according to the uniform version of Kaiming initialization - KaimingUniform { gain: f64, fan_out_only: bool }, + KaimingUniform { + /// The gain to use in initialization formula + gain: f64, + + /// Whether to use fan out only in initialization formula + fan_out_only: bool, + }, /// Fills tensor with values according to the uniform version of Kaiming initialization - KaimingNormal { gain: f64, fan_out_only: bool }, + KaimingNormal { + /// The gain to use in initialization formula + gain: f64, + + /// Whether to use fan out only in initialization formula + fan_out_only: bool, + }, /// Fills tensor with values according to the uniform version of Xavier Glorot initialization /// described in [Understanding the difficulty of training deep feedforward neural networks /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) - XavierUniform { gain: f64 }, + XavierUniform { + /// The gain to use in initialization formula + gain: f64, + }, /// Fills tensor with values according to the normal version of Xavier Glorot initialization /// described in [Understanding the difficulty of training deep feedforward neural networks /// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf) - XavierNormal { gain: f64 }, + XavierNormal { + /// The gain to use in initialization formula + gain: f64, + }, } impl Initializer { diff --git a/burn-core/src/nn/loss/mse.rs b/burn-core/src/nn/loss/mse.rs index 95342b952..fcf2c8848 100644 --- a/burn-core/src/nn/loss/mse.rs +++ b/burn-core/src/nn/loss/mse.rs @@ -42,6 +42,7 @@ impl MSELoss { } } + /// Compute the criterion on the input tensor without reducing. pub fn forward_no_reduction( &self, logits: Tensor, diff --git a/burn-core/src/nn/loss/reduction.rs b/burn-core/src/nn/loss/reduction.rs index f0c42d543..499b17153 100644 --- a/burn-core/src/nn/loss/reduction.rs +++ b/burn-core/src/nn/loss/reduction.rs @@ -1,5 +1,11 @@ +/// The reduction type for the loss. pub enum Reduction { + /// The mean of the losses will be returned. Mean, + + /// The sum of the losses will be returned. Sum, + + /// The mean of the losses will be returned. Auto, } diff --git a/burn-core/src/nn/mod.rs b/burn-core/src/nn/mod.rs index ed8a56d54..9c1767fab 100644 --- a/burn-core/src/nn/mod.rs +++ b/burn-core/src/nn/mod.rs @@ -1,8 +1,19 @@ +/// Attention module pub mod attention; + +/// Cache module pub mod cache; + +/// Convolution module pub mod conv; + +/// Loss module pub mod loss; + +/// Pooling module pub mod pool; + +/// Transformer module pub mod transformer; mod dropout; diff --git a/burn-core/src/nn/rnn/gru.rs b/burn-core/src/nn/rnn/gru.rs index cebebca14..568ad7e74 100644 --- a/burn-core/src/nn/rnn/gru.rs +++ b/burn-core/src/nn/rnn/gru.rs @@ -11,6 +11,7 @@ use burn_tensor::activation; use super::gate_controller::GateController; +/// The configuration for a [gru](Gru) module. #[derive(Config)] pub struct GruConfig { /// The size of the input features. diff --git a/burn-core/src/nn/rnn/lstm.rs b/burn-core/src/nn/rnn/lstm.rs index de0e92e77..28eeb8753 100644 --- a/burn-core/src/nn/rnn/lstm.rs +++ b/burn-core/src/nn/rnn/lstm.rs @@ -11,6 +11,7 @@ use burn_tensor::activation; use super::gate_controller::GateController; +/// The configuration for a [lstm](Lstm) module. #[derive(Config)] pub struct LstmConfig { /// The size of the input features. diff --git a/burn-core/src/nn/rnn/mod.rs b/burn-core/src/nn/rnn/mod.rs index 81063b51a..8a8bd8e53 100644 --- a/burn-core/src/nn/rnn/mod.rs +++ b/burn-core/src/nn/rnn/mod.rs @@ -1,5 +1,9 @@ mod gate_controller; + +/// Gated Recurrent Unit module. pub mod gru; + +/// Long Short-Term Memory module. pub mod lstm; pub use gate_controller::*; diff --git a/burn-core/src/nn/transformer/decoder.rs b/burn-core/src/nn/transformer/decoder.rs index 1ccbf3757..f431b357b 100644 --- a/burn-core/src/nn/transformer/decoder.rs +++ b/burn-core/src/nn/transformer/decoder.rs @@ -47,6 +47,7 @@ pub struct TransformerDecoder { } impl TransformerDecoderConfig { + /// Initialize a new [Transformer Decoder](TransformerDecoder) module. pub fn init(&self) -> TransformerDecoder { let layers = (0..self.n_layers) .map(|_| TransformerDecoderLayer::new(self)) @@ -54,6 +55,12 @@ impl TransformerDecoderConfig { TransformerDecoder { layers } } + + /// Initialize a new [Transformer Decoder](TransformerDecoder) module with a record. + /// + /// # Params + /// + /// - record: the record to initialize the module with. pub fn init_with( &self, record: TransformerDecoderRecord, @@ -117,6 +124,7 @@ impl TransformerDecoderInput { } } +/// [Transformer Decoder](TransformerDecoder) layer module. #[derive(Module, Debug)] pub struct TransformerDecoderLayer { cross_attn: MultiHeadAttention, diff --git a/burn-core/src/nn/transformer/encoder.rs b/burn-core/src/nn/transformer/encoder.rs index fd8ba7758..7dabc80bb 100644 --- a/burn-core/src/nn/transformer/encoder.rs +++ b/burn-core/src/nn/transformer/encoder.rs @@ -151,6 +151,7 @@ impl TransformerEncoder { } } +/// Transformer encoder layer module. #[derive(Module, Debug)] pub struct TransformerEncoderLayer { mha: MultiHeadAttention, diff --git a/burn-core/src/optim/adam.rs b/burn-core/src/optim/adam.rs index e8a1f9b80..6ca16d6e7 100644 --- a/burn-core/src/optim/adam.rs +++ b/burn-core/src/optim/adam.rs @@ -12,6 +12,7 @@ use crate::optim::adaptor::OptimizerAdaptor; use crate::tensor::{backend::ADBackend, Tensor}; use burn_tensor::{backend::Backend, ElementConversion}; +/// Adam configuration. #[derive(Config)] pub struct AdamConfig { /// Parameter for Adam. @@ -35,6 +36,7 @@ pub struct Adam { weight_decay: Option>, } +/// Adam state. #[derive(Record, Clone, new)] pub struct AdamState { weight_decay: Option>, @@ -84,6 +86,11 @@ impl SimpleOptimizer for Adam { } impl AdamConfig { + /// Initialize Adam optimizer. + /// + /// # Returns + /// + /// Returns an optimizer that can be used to optimize a module. pub fn init>(&self) -> impl Optimizer { let optim = Adam { momentum: AdaptiveMomentum { @@ -102,6 +109,7 @@ impl AdamConfig { } } +/// Adaptive momentum state. #[derive(Record, new, Clone)] pub struct AdaptiveMomentumState { time: usize, @@ -164,6 +172,15 @@ impl AdaptiveMomentum { } impl AdaptiveMomentumState { + /// Move state to device. + /// + /// # Arguments + /// + /// * `device` - Device to move state to. + /// + /// # Returns + /// + /// Returns state moved to device. pub fn to_device(mut self, device: &B::Device) -> Self { self.moment_1 = self.moment_1.to_device(device); self.moment_2 = self.moment_2.to_device(device); diff --git a/burn-core/src/optim/decay.rs b/burn-core/src/optim/decay.rs index f3024c064..d545dcec4 100644 --- a/burn-core/src/optim/decay.rs +++ b/burn-core/src/optim/decay.rs @@ -13,6 +13,7 @@ pub struct WeightDecayConfig { pub penalty: f64, } +/// State of [WeightDecay](WeightDecay). #[derive(Record, Clone, new)] pub struct WeightDecayState { grad_last_step: Tensor, @@ -24,12 +25,24 @@ pub struct WeightDecay { } impl WeightDecay { + /// Creates a new [WeightDecay](WeightDecay) from a [WeightDecayConfig](WeightDecayConfig). pub fn new(config: &WeightDecayConfig) -> Self { Self { penalty: config.penalty.elem(), } } + /// Transforms a gradient. + /// + /// # Arguments + /// + /// * `grad` - Gradient to transform. + /// * `state` - State of the optimizer. + /// + /// # Returns + /// + /// * `grad` - Transformed gradient. + /// * `state` - State of the optimizer. pub fn transform( &self, grad: Tensor, @@ -47,6 +60,15 @@ impl WeightDecay { } impl WeightDecayState { + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. pub fn to_device(mut self, device: &B::Device) -> Self { self.grad_last_step = self.grad_last_step.to_device(device); self diff --git a/burn-core/src/optim/grads.rs b/burn-core/src/optim/grads.rs index 74877259e..0793f2f05 100644 --- a/burn-core/src/optim/grads.rs +++ b/burn-core/src/optim/grads.rs @@ -15,6 +15,7 @@ pub struct GradientsParams { } impl GradientsParams { + /// Creates a new [GradientsParams](GradientsParams). pub fn new() -> Self { Self::default() } diff --git a/burn-core/src/optim/mod.rs b/burn-core/src/optim/mod.rs index 1b4dd0dc7..8bee30774 100644 --- a/burn-core/src/optim/mod.rs +++ b/burn-core/src/optim/mod.rs @@ -1,4 +1,7 @@ +/// Weight decay module for optimizers. pub mod decay; + +/// Momentum module for optimizers. pub mod momentum; mod adam; diff --git a/burn-core/src/optim/momentum.rs b/burn-core/src/optim/momentum.rs index 7d8cf89f8..5e00c8716 100644 --- a/burn-core/src/optim/momentum.rs +++ b/burn-core/src/optim/momentum.rs @@ -14,11 +14,13 @@ pub struct MomentumConfig { /// Dampening factor. #[config(default = 0.1)] pub dampening: f64, - /// Enables Nesterov momentum, see [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf). + /// Enables Nesterov momentum, see [On the importance of initialization and + /// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf). #[config(default = false)] pub nesterov: bool, } +/// State of [Momentum](Momentum). #[derive(Record, Clone, new)] pub struct MomemtumState { velocity: Tensor, @@ -32,6 +34,7 @@ pub struct Momentum { } impl Momentum { + /// Creates a new [Momentum](Momentum) from a [MomentumConfig](MomentumConfig). pub fn new(config: &MomentumConfig) -> Self { Self { momentum: config.momentum.elem(), @@ -40,6 +43,17 @@ impl Momentum { } } + /// Transforms a gradient. + /// + /// # Arguments + /// + /// * `grad` - Gradient to transform. + /// * `state` - State of the optimizer. + /// + /// # Returns + /// + /// * `grad` - Transformed gradient. + /// * `state` - State of the optimizer. pub fn transform( &self, grad: Tensor, @@ -63,6 +77,15 @@ impl Momentum { } impl MomemtumState { + /// Moves the state to a device. + /// + /// # Arguments + /// + /// * `device` - Device to move the state to. + /// + /// # Returns + /// + /// * `self` - Moved state. pub fn to_device(mut self, device: &B::Device) -> Self { self.velocity = self.velocity.to_device(device); self diff --git a/burn-core/src/optim/sgd.rs b/burn-core/src/optim/sgd.rs index 02119eb6b..a50611cc7 100644 --- a/burn-core/src/optim/sgd.rs +++ b/burn-core/src/optim/sgd.rs @@ -30,6 +30,7 @@ pub struct Sgd { weight_decay: Option>, } +/// State of [Sgd](Sgd). #[derive(Record, Clone, new)] pub struct SgdState { weight_decay: Option>, @@ -37,6 +38,7 @@ pub struct SgdState { } impl SgdConfig { + /// Creates a new [SgdConfig](SgdConfig) with default values. pub fn init>( &self, ) -> OptimizerAdaptor, M, B> { diff --git a/burn-core/src/optim/simple/adaptor.rs b/burn-core/src/optim/simple/adaptor.rs index 062c69443..491bd742f 100644 --- a/burn-core/src/optim/simple/adaptor.rs +++ b/burn-core/src/optim/simple/adaptor.rs @@ -45,13 +45,22 @@ where M: ADModule, B: ADBackend, { + /// Sets the gradient clipping. + /// + /// # Arguments + /// + /// * `gradient_clipping` - The gradient clipping. + /// + /// # Returns + /// + /// The optimizer. pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self { self.grad_clipping = Some(gradient_clipping); self } #[cfg(test)] - pub fn has_gradient_clipping(&self) -> bool { + pub(crate) fn has_gradient_clipping(&self) -> bool { self.grad_clipping.is_some() } } diff --git a/burn-core/src/optim/simple/mod.rs b/burn-core/src/optim/simple/mod.rs index 97171ba5a..bfc1ff62d 100644 --- a/burn-core/src/optim/simple/mod.rs +++ b/burn-core/src/optim/simple/mod.rs @@ -1,5 +1,8 @@ mod base; pub use base::*; +/// Adaptor module for optimizers. pub mod adaptor; + +/// Record module for optimizers. pub mod record; diff --git a/burn-core/src/optim/simple/record/base.rs b/burn-core/src/optim/simple/record/base.rs index d0b7cc740..e0bc9199d 100644 --- a/burn-core/src/optim/simple/record/base.rs +++ b/burn-core/src/optim/simple/record/base.rs @@ -10,12 +10,15 @@ use serde::{Deserialize, Serialize}; /// /// Records are versioned for backward compatibility, so old records can be loaded. pub enum AdaptorRecord, B: Backend> { + /// Version 1. V1(AdaptorRecordV1), } +/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. #[derive(Serialize, Deserialize)] #[serde(bound = "")] pub enum AdaptorRecordItem, B: Backend, S: PrecisionSettings> { + /// Version 1. V1(AdaptorRecordItemV1), } @@ -56,12 +59,26 @@ where O: SimpleOptimizer, B: Backend, { + /// Converts the record into the optimizer state. + /// + /// # Returns + /// + /// The optimizer state. pub fn into_state(self) -> O::State { match self { AdaptorRecord::V1(record) => record.into_state(), } } + /// Converts the optimizer state into the record. + /// + /// # Arguments + /// + /// * `state`: The optimizer state. + /// + /// # Returns + /// + /// The record. pub fn from_state(state: O::State) -> Self { Self::V1(AdaptorRecordV1::from_state(state)) } diff --git a/burn-core/src/optim/simple/record/v1.rs b/burn-core/src/optim/simple/record/v1.rs index 140bd04a3..9c4740347 100644 --- a/burn-core/src/optim/simple/record/v1.rs +++ b/burn-core/src/optim/simple/record/v1.rs @@ -6,14 +6,30 @@ use burn_tensor::backend::Backend; use core::any::Any; use serde::{Deserialize, Serialize}; +/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. pub enum AdaptorRecordV1, B: Backend> { + /// Rank 1. Rank1(O::State<1>), + + /// Rank 2. Rank2(O::State<2>), + + /// Rank 3. Rank3(O::State<3>), + + /// Rank 4. Rank4(O::State<4>), + + /// Rank 5. Rank5(O::State<5>), + + /// Rank 6. Rank6(O::State<6>), + + /// Rank 7. Rank7(O::State<7>), + + /// Rank 8. Rank8(O::State<8>), } @@ -32,16 +48,32 @@ impl, B: Backend> Clone for AdaptorRecordV1 { } } +/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item. #[derive(Serialize, Deserialize)] #[serde(bound = "")] pub enum AdaptorRecordItemV1, B: Backend, S: PrecisionSettings> { + /// Rank 1. Rank1( as Record>::Item), + + /// Rank 2. Rank2( as Record>::Item), + + /// Rank 3. Rank3( as Record>::Item), + + /// Rank 4. Rank4( as Record>::Item), + + /// Rank 5. Rank5( as Record>::Item), + + /// Rank 6. Rank6( as Record>::Item), + + /// Rank 7. Rank7( as Record>::Item), + + /// Rank 8. Rank8( as Record>::Item), } @@ -50,6 +82,15 @@ where O: SimpleOptimizer, B: Backend, { + /// Convert the record into the state. + /// + /// # Returns + /// + /// The state. + /// + /// # Panics + /// + /// Panics if the state dimension is not supported. pub fn into_state(self) -> O::State { let boxed_state: Box = match self { AdaptorRecordV1::Rank1(s) => Box::new(s), @@ -66,6 +107,16 @@ where .expect("Unsupported state dimension, dimension up to 8 are supported."); *state } + + /// Convert the state into the record. + /// + /// # Arguments + /// + /// * `state`: The state. + /// + /// # Returns + /// + /// The record. pub fn from_state(state: O::State) -> Self { let state: Box = Box::new(state); diff --git a/burn-core/src/record/base.rs b/burn-core/src/record/base.rs index 2d42fbe66..0c1633e8e 100644 --- a/burn-core/src/record/base.rs +++ b/burn-core/src/record/base.rs @@ -5,10 +5,12 @@ use serde::{de::DeserializeOwned, Serialize}; /// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings). pub trait Record: Send + Sync { + /// Type of the item that can be serialized and deserialized. type Item: Serialize + DeserializeOwned; /// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings). fn into_item(self) -> Self::Item; + /// Convert the given item into a record. fn from_item(item: Self::Item) -> Self; } diff --git a/burn-core/src/record/file.rs b/burn-core/src/record/file.rs index 27aa1fd6e..dccdf1b62 100644 --- a/burn-core/src/record/file.rs +++ b/burn-core/src/record/file.rs @@ -8,6 +8,7 @@ use std::{fs::File, path::PathBuf}; pub trait FileRecorder: Recorder { + /// File extension of the format used by the recorder. fn file_extension() -> &'static str; } diff --git a/burn-core/src/record/primitive.rs b/burn-core/src/record/primitive.rs index 7789f9016..d15b53984 100644 --- a/burn-core/src/record/primitive.rs +++ b/burn-core/src/record/primitive.rs @@ -134,6 +134,7 @@ primitive!(f64); primitive!(f32); // TODO: Remove the feature flag when half supports serde with no_std +// https://github.com/burn-rs/burn/issues/268 tracking issue #[cfg(feature = "std")] primitive!(half::bf16); #[cfg(feature = "std")] diff --git a/burn-core/src/record/recorder.rs b/burn-core/src/record/recorder.rs index 27899835f..8b3f34ebb 100644 --- a/burn-core/src/record/recorder.rs +++ b/burn-core/src/record/recorder.rs @@ -13,15 +13,28 @@ use super::{ /// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned). pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Clone { + /// Type of the settings used by the recorder. type Settings: PrecisionSettings; + /// Arguments used to record objects. type RecordArgs: Clone; + /// Record output type. type RecordOutput; + /// Arguments used to load recorded objects. type LoadArgs: Clone; - /// Record an item with the given arguments. + /// Records an item. + /// + /// # Arguments + /// + /// * `record` - The item to record. + /// * `args` - Arguments used to record the item. + /// + /// # Returns + /// + /// The output of the recording. fn record( &self, record: R, @@ -80,11 +93,35 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl Ok(R::from_item(item.item)) } + /// Saves an item. + /// + /// This method is used by [record](Recorder::record) to save the item. + /// + /// # Arguments + /// + /// * `item` - Item to save. + /// * `args` - Arguments to use to save the item. + /// + /// # Returns + /// + /// The output of the save operation. fn save_item( &self, item: I, args: Self::RecordArgs, ) -> Result; + + /// Loads an item. + /// + /// This method is used by [load](Recorder::load) to load the item. + /// + /// # Arguments + /// + /// * `args` - Arguments to use to load the item. + /// + /// # Returns + /// + /// The loaded item. fn load_item(&self, args: Self::LoadArgs) -> Result; } @@ -98,9 +135,13 @@ fn recorder_metadata() -> BurnMetadata { ) } +/// Error that can occur when using a [Recorder](Recorder). #[derive(Debug)] pub enum RecorderError { + /// File not found. FileNotFound(String), + + /// Other error. Unknown(String), } @@ -118,22 +159,45 @@ pub(crate) fn bin_config() -> bincode::config::Configuration { bincode::config::standard() } +/// Metadata of a record. #[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct BurnMetadata { + /// Float type used to record the item. pub float: String, + + /// Int type used to record the item. pub int: String, + + /// Format used to record the item. pub format: String, + + /// Burn record version used to record the item. pub version: String, + + /// Settings used to record the item. pub settings: String, } +/// Record that can be saved by a [Recorder](Recorder). #[derive(Serialize, Deserialize)] pub struct BurnRecord { + /// Metadata of the record. pub metadata: BurnMetadata, + + /// Item to record. pub item: I, } impl BurnRecord { + /// Creates a new record. + /// + /// # Arguments + /// + /// * `item` - Item to record. + /// + /// # Returns + /// + /// The new record. pub fn new(item: I) -> Self { let metadata = recorder_metadata::(); @@ -141,8 +205,10 @@ impl BurnRecord { } } +/// Record that can be saved by a [Recorder](Recorder) without the item. #[derive(new, Debug, Serialize, Deserialize)] pub struct BurnRecordNoItem { + /// Metadata of the record. pub metadata: BurnMetadata, } diff --git a/burn-core/src/record/settings.rs b/burn-core/src/record/settings.rs index 0bc3ea76b..1ce7a5cef 100644 --- a/burn-core/src/record/settings.rs +++ b/burn-core/src/record/settings.rs @@ -5,7 +5,10 @@ use serde::{de::DeserializeOwned, Serialize}; pub trait PrecisionSettings: Send + Sync + core::fmt::Debug + core::default::Default + Clone { + /// Float element type. type FloatElem: Element + Serialize + DeserializeOwned; + + /// Integer element type. type IntElem: Element + Serialize + DeserializeOwned; } diff --git a/burn-dataset/src/audio/speech_commands.rs b/burn-dataset/src/audio/speech_commands.rs index 632a38375..28c2d34f2 100644 --- a/burn-dataset/src/audio/speech_commands.rs +++ b/burn-dataset/src/audio/speech_commands.rs @@ -12,7 +12,9 @@ type MappedDataset = MapperDataset, ConvertSamples, /// Enum representing speech command classes in the Speech Commands dataset. /// Class names are based on the Speech Commands dataset from Huggingface. -/// See: https://huggingface.co/datasets/speech_commands +/// See [speech_commands](https://huggingface.co/datasets/speech_commands) +/// for more information. +#[allow(missing_docs)] #[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)] pub enum SpeechCommandClass { // Target command words @@ -66,8 +68,13 @@ pub enum SpeechCommandClass { /// Struct containing raw speech data returned from a database. #[derive(Clone, Debug, Serialize, Deserialize)] pub struct SpeechItemRaw { + /// Audio file bytes. pub audio_bytes: Vec, + + /// Label index. pub label: usize, + + /// Indicates if the label is unknown. pub is_unknown: bool, } diff --git a/burn-dataset/src/dataset/base.rs b/burn-dataset/src/dataset/base.rs index 7149fb277..eb53980c9 100644 --- a/burn-dataset/src/dataset/base.rs +++ b/burn-dataset/src/dataset/base.rs @@ -4,11 +4,18 @@ use crate::DatasetIterator; /// The dataset trait defines a basic collection of items with a predefined size. pub trait Dataset: Send + Sync { + /// Gets the item at the given index. fn get(&self, index: usize) -> Option; + + /// Gets the number of items in the dataset. fn len(&self) -> usize; + + /// Checks if the dataset is empty. fn is_empty(&self) -> bool { self.len() == 0 } + + /// Returns an iterator over the dataset. fn iter(&self) -> DatasetIterator<'_, I> where Self: Sized, diff --git a/burn-dataset/src/dataset/fake.rs b/burn-dataset/src/dataset/fake.rs index 0d761e10d..c27f8cf0d 100644 --- a/burn-dataset/src/dataset/fake.rs +++ b/burn-dataset/src/dataset/fake.rs @@ -7,6 +7,7 @@ pub struct FakeDataset { } impl> FakeDataset { + /// Create a new fake dataset with the given size. pub fn new(size: usize) -> Self { let mut items = Vec::with_capacity(size); for _ in 0..size { diff --git a/burn-dataset/src/dataset/in_memory.rs b/burn-dataset/src/dataset/in_memory.rs index 0440ea631..a3b167f0c 100644 --- a/burn-dataset/src/dataset/in_memory.rs +++ b/burn-dataset/src/dataset/in_memory.rs @@ -14,6 +14,7 @@ pub struct InMemDataset { } impl InMemDataset { + /// Creates a new in memory dataset from the given items. pub fn new(items: Vec) -> Self { InMemDataset { items } } diff --git a/burn-dataset/src/dataset/iterator.rs b/burn-dataset/src/dataset/iterator.rs index 4b8140521..e513e1a08 100644 --- a/burn-dataset/src/dataset/iterator.rs +++ b/burn-dataset/src/dataset/iterator.rs @@ -8,6 +8,7 @@ pub struct DatasetIterator<'a, I> { } impl<'a, I> DatasetIterator<'a, I> { + /// Creates a new dataset iterator. pub fn new(dataset: &'a D) -> Self where D: Dataset, diff --git a/burn-dataset/src/dataset/sqlite.rs b/burn-dataset/src/dataset/sqlite.rs index f1369b359..270cad27c 100644 --- a/burn-dataset/src/dataset/sqlite.rs +++ b/burn-dataset/src/dataset/sqlite.rs @@ -21,28 +21,37 @@ use sanitize_filename::sanitize; use serde::{de::DeserializeOwned, Serialize}; use serde_rusqlite::{columns_from_statement, from_row_with_columns}; +/// Result type for the sqlite dataset. pub type Result = core::result::Result; +/// Sqlite dataset error. #[derive(thiserror::Error, Debug)] pub enum SqliteDatasetError { + /// IO related error. #[error("IO error: {0}")] Io(#[from] io::Error), + /// Sql related error. #[error("Sql error: {0}")] Sql(#[from] serde_rusqlite::rusqlite::Error), + /// Serde related error. #[error("Serde error: {0}")] Serde(#[from] rmp_serde::encode::Error), + /// The database file already exists error. #[error("Overwrite flag is set to false and the database file already exists: {0}")] FileExists(PathBuf), + /// Error when creating the connection pool. #[error("Failed to create connection pool: {0}")] ConnectionPool(#[from] r2d2::Error), + /// Error when persisting the temporary database file. #[error("Could not persist the temporary database file: {0}")] PersistDbFile(#[from] persist::Error), + /// Any other error. #[error("{0}")] Other(&'static str), } diff --git a/burn-dataset/src/lib.rs b/burn-dataset/src/lib.rs index 0b6b1d5fe..e27e5460b 100644 --- a/burn-dataset/src/lib.rs +++ b/burn-dataset/src/lib.rs @@ -1,11 +1,21 @@ +#![warn(missing_docs)] + +//! # Burn Dataset +//! +//! Burn Dataset is a library for creating and loading datasets. + #[macro_use] extern crate derive_new; extern crate dirs; +/// Sources for datasets. pub mod source; + +/// Transformations to be used with datasets. pub mod transform; +/// Audio datasets. #[cfg(feature = "audio")] pub mod audio; diff --git a/burn-dataset/src/source/huggingface/downloader.rs b/burn-dataset/src/source/huggingface/downloader.rs index 4549356f3..3e54ba126 100644 --- a/burn-dataset/src/source/huggingface/downloader.rs +++ b/burn-dataset/src/source/huggingface/downloader.rs @@ -11,13 +11,18 @@ use thiserror::Error; const PYTHON: &str = "python3"; const PYTHON_SOURCE: &str = include_str!("importer.py"); +/// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader). #[derive(Error, Debug)] pub enum ImporterError { + /// Unknown error. #[error("unknown: `{0}`")] Unknown(String), + + /// Fail to download python dependencies. #[error("fail to download python dependencies: `{0}`")] FailToDownloadPythonDependencies(String), + /// Fail to create sqlite dataset. #[error("sqlite dataset: `{0}`")] SqliteDataset(#[from] SqliteDatasetError), } diff --git a/burn-dataset/src/source/huggingface/mnist.rs b/burn-dataset/src/source/huggingface/mnist.rs index c4b17e9a4..88b37180a 100644 --- a/burn-dataset/src/source/huggingface/mnist.rs +++ b/burn-dataset/src/source/huggingface/mnist.rs @@ -8,9 +8,13 @@ use serde::{Deserialize, Serialize}; const WIDTH: usize = 28; const HEIGHT: usize = 28; +/// MNIST item. #[derive(Deserialize, Serialize, Debug, Clone)] pub struct MNISTItem { + /// Image as a 2D array of floats. pub image: [[f32; WIDTH]; HEIGHT], + + /// Label of the image. pub label: usize, } @@ -66,10 +70,12 @@ impl Dataset for MNISTDataset { } impl MNISTDataset { + /// Creates a new train dataset. pub fn train() -> Self { Self::new("train") } + /// Creates a new test dataset. pub fn test() -> Self { Self::new("test") } diff --git a/burn-dataset/src/source/huggingface/mod.rs b/burn-dataset/src/source/huggingface/mod.rs index 075f569ef..7df5c163c 100644 --- a/burn-dataset/src/source/huggingface/mod.rs +++ b/burn-dataset/src/source/huggingface/mod.rs @@ -1,4 +1,4 @@ -pub mod downloader; +pub(crate) mod downloader; mod mnist; pub use downloader::*; diff --git a/burn-dataset/src/source/mod.rs b/burn-dataset/src/source/mod.rs index bb5c9337c..551fb0f9e 100644 --- a/burn-dataset/src/source/mod.rs +++ b/burn-dataset/src/source/mod.rs @@ -1 +1,2 @@ +/// Huggingface source pub mod huggingface; diff --git a/burn-dataset/src/transform/mapper.rs b/burn-dataset/src/transform/mapper.rs index ec1822520..b089a375e 100644 --- a/burn-dataset/src/transform/mapper.rs +++ b/burn-dataset/src/transform/mapper.rs @@ -3,6 +3,7 @@ use std::marker::PhantomData; /// Basic mapper trait to be used with the [mapper dataset](MapperDataset). pub trait Mapper: Send + Sync { + /// Maps an item of type I to an item of type O. fn map(&self, item: &I) -> O; } diff --git a/burn-dataset/src/transform/partial.rs b/burn-dataset/src/transform/partial.rs index 9b365444f..c8bd53f08 100644 --- a/burn-dataset/src/transform/partial.rs +++ b/burn-dataset/src/transform/partial.rs @@ -14,6 +14,7 @@ impl PartialDataset where D: Dataset, { + /// Splits a dataset into multiple partial datasets. pub fn split(dataset: D, num: usize) -> Vec, I>> { let dataset = Arc::new(dataset); // cheap cloning. diff --git a/burn-dataset/src/transform/random.rs b/burn-dataset/src/transform/random.rs index 2faea90fc..db43943ed 100644 --- a/burn-dataset/src/transform/random.rs +++ b/burn-dataset/src/transform/random.rs @@ -14,6 +14,7 @@ impl ShuffledDataset where D: Dataset, { + /// Creates a new shuffled dataset. pub fn new(dataset: D, rng: &mut StdRng) -> Self { let mut indexes = Vec::with_capacity(dataset.len()); for i in 0..dataset.len() { @@ -28,6 +29,7 @@ where } } + /// Creates a new shuffled dataset with a fixed seed. pub fn with_seed(dataset: D, seed: u64) -> Self { let mut rng = StdRng::seed_from_u64(seed); Self::new(dataset, &mut rng) diff --git a/burn-dataset/src/transform/sampler.rs b/burn-dataset/src/transform/sampler.rs index 709e69019..b9723a380 100644 --- a/burn-dataset/src/transform/sampler.rs +++ b/burn-dataset/src/transform/sampler.rs @@ -17,6 +17,7 @@ where D: Dataset, I: Send + Sync, { + /// Creates a new sampler dataset. pub fn new(dataset: D, size: usize) -> Self { let rng = Mutex::new(StdRng::from_entropy()); @@ -28,6 +29,7 @@ where } } + /// Generates random index using uniform distribution (0, dataset.len()). fn index(&self) -> usize { let mut rng = self.rng.lock().unwrap(); rng.sample(Uniform::new(0, self.dataset.len())) diff --git a/burn-derive/src/config/analyzer_struct.rs b/burn-derive/src/config/analyzer_struct.rs index 9a7673eaa..6736db232 100644 --- a/burn-derive/src/config/analyzer_struct.rs +++ b/burn-derive/src/config/analyzer_struct.rs @@ -190,6 +190,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer { } let body = quote! { + /// Create a new instance of the config. pub fn new( #(#names),* ) -> Self { @@ -208,6 +209,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer { let fn_name = Ident::new(&format!("with_{name}"), name.span()); body.extend(quote! { + /// Set the default value for the field. pub fn #fn_name(mut self, #name: #ty) -> Self { self.#name = #name; self @@ -221,6 +223,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer { let fn_name = Ident::new(&format!("with_{name}"), name.span()); body.extend(quote! { + /// Set the default value for the field. pub fn #fn_name(mut self, #name: #ty) -> Self { self.#name = #name; self diff --git a/burn-derive/src/lib.rs b/burn-derive/src/lib.rs index 611c1b279..0fea48d2d 100644 --- a/burn-derive/src/lib.rs +++ b/burn-derive/src/lib.rs @@ -1,3 +1,7 @@ +#![warn(missing_docs)] + +//! The derive crate of Burn. + use proc_macro::TokenStream; pub(crate) mod config; @@ -9,18 +13,21 @@ use config::config_attr_impl; use module::module_derive_impl; use record::record_derive_impl; +/// Derive macro for the module. #[proc_macro_derive(Module)] pub fn module_derive(input: TokenStream) -> TokenStream { let input = syn::parse(input).unwrap(); module_derive_impl(&input) } +/// Derive macro for the record. #[proc_macro_derive(Record)] pub fn record_derive(input: TokenStream) -> TokenStream { let input = syn::parse(input).unwrap(); record_derive_impl(&input) } +/// Derive macro for the config. #[proc_macro_derive(Config, attributes(config))] pub fn config_derive(input: TokenStream) -> TokenStream { let item = syn::parse(input).unwrap(); diff --git a/burn-derive/src/module/record.rs b/burn-derive/src/module/record.rs index 39d3570ed..0234529f4 100644 --- a/burn-derive/src/module/record.rs +++ b/burn-derive/src/module/record.rs @@ -26,6 +26,7 @@ impl ModuleRecordGenerator { let name = &field.field.ident; fields.extend(quote! { + /// The #name field. pub #name: <#ty as burn::module::Module>::Record, }); } @@ -33,6 +34,8 @@ impl ModuleRecordGenerator { let generics = &self.generics; quote! { + + /// The record type for the module. #[derive(burn::record::Record, Debug, Clone)] pub struct #name #generics { #fields diff --git a/burn-derive/src/record/generator.rs b/burn-derive/src/record/generator.rs index ea2099b5a..1e867c797 100644 --- a/burn-derive/src/record/generator.rs +++ b/burn-derive/src/record/generator.rs @@ -31,6 +31,7 @@ impl RecordGenerator { let name = &field.field.ident; fields.extend(quote! { + /// The #name field. pub #name: <#ty as burn::record::Record>::Item, }); bounds.extend(quote!{ @@ -42,6 +43,8 @@ impl RecordGenerator { let bound = bounds.to_string(); quote! { + + /// The record item type for the module. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] #[serde(bound = #bound)] pub struct #name #generics { diff --git a/burn/src/lib.rs b/burn/src/lib.rs index 08032323a..7e12b4c09 100644 --- a/burn/src/lib.rs +++ b/burn/src/lib.rs @@ -1,7 +1,15 @@ #![cfg_attr(not(feature = "std"), no_std)] +#![warn(missing_docs)] + +//! # Burn +//! This library strives to serve as a comprehensive **deep learning framework**, +//! offering exceptional flexibility and written in Rust. The main objective is to cater +//! to both researchers and practitioners by simplifying the process of experimenting, +//! training, and deploying models. pub use burn_core::*; +/// Train module #[cfg(feature = "train")] pub mod train { pub use burn_train::*; diff --git a/examples/text-classification/src/data/dataset.rs b/examples/text-classification/src/data/dataset.rs index 4f4a59773..28ae0f224 100644 --- a/examples/text-classification/src/data/dataset.rs +++ b/examples/text-classification/src/data/dataset.rs @@ -5,9 +5,7 @@ // the TextClassificationDataset trait. These implementations are designed to be used // with a machine learning framework for tasks such as training a text classification model. -use burn::data::dataset::{ - source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset, -}; +use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset, SqliteDataset}; // Define a struct for text classification items #[derive(new, Clone, Debug)] diff --git a/examples/text-generation/src/data/dataset.rs b/examples/text-generation/src/data/dataset.rs index be85141e9..f19814358 100644 --- a/examples/text-generation/src/data/dataset.rs +++ b/examples/text-generation/src/data/dataset.rs @@ -1,6 +1,4 @@ -use burn::data::dataset::{ - source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset, -}; +use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset, SqliteDataset}; #[derive(new, Clone, Debug)] pub struct TextGenerationItem {