Add missing docs and enable missing_docs warn lint (#420)

This commit is contained in:
Dilshod Tadjibaev 2023-06-21 13:12:13 -05:00 committed by GitHub
parent c4e4c25fef
commit eda241f8cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
73 changed files with 696 additions and 26 deletions

View File

@ -1,6 +1,7 @@
use crate::{grads::Gradients, graph::backward::backward, tensor::ADTensor};
use burn_tensor::backend::{ADBackend, Backend};
/// A decorator for a backend that enables automatic differentiation.
#[derive(Clone, Copy, Debug, Default)]
pub struct ADBackendDecorator<B> {
_b: B,

View File

@ -5,6 +5,7 @@ use crate::{
tensor::ADTensor,
};
/// Gradient identifier.
pub type GradID = String;
/// Gradients container used during the backward pass.
@ -15,6 +16,7 @@ pub struct Gradients {
type TensorPrimitive<B, const D: usize> = <B as Backend>::TensorPrimitive<D>;
impl Gradients {
/// Creates a new gradients container.
pub fn new<B: Backend, const D: usize>(
root_node: NodeRef,
root_tensor: TensorPrimitive<B, D>,
@ -28,7 +30,8 @@ impl Gradients {
);
gradients
}
/// Consume the gradients for a given tensor.
/// Consumes the gradients for a given tensor.
///
/// Each tensor should be consumed exactly 1 time if its gradients are only required during the
/// backward pass, otherwise, it may be consume multiple times.
@ -48,7 +51,7 @@ impl Gradients {
}
}
/// Remove a grad tensor from the container.
/// Removes a grad tensor from the container.
pub fn remove<B: Backend, const D: usize>(
&mut self,
tensor: &ADTensor<B, D>,
@ -58,6 +61,7 @@ impl Gradients {
.map(|tensor| tensor.into_primitive())
}
/// Gets a grad tensor from the container.
pub fn get<B: Backend, const D: usize>(
&self,
tensor: &ADTensor<B, D>,
@ -67,6 +71,7 @@ impl Gradients {
.map(|tensor| tensor.into_primitive())
}
/// Registers a grad tensor in the container.
pub fn register<B: Backend, const D: usize>(
&mut self,
node: NodeRef,

View File

@ -1,3 +1,12 @@
#![warn(missing_docs)]
//! # Burn Autodiff
//!
//! This autodiff library is a part of the Burn project. It is a standalone crate
//! that can be used to perform automatic differentiation on tensors. It is
//! designed to be used with the Burn Tensor crate, but it can be used with any
//! tensor library that implements the `Backend` trait.
#[macro_use]
extern crate derive_new;

View File

@ -1,3 +1,5 @@
#![allow(missing_docs)]
mod add;
mod aggregation;
mod avgpool1d;

View File

@ -1,3 +1,6 @@
The `burn-common` package hosts code that _must_ be shared between burn packages (with `std` or `no_std` enabled). No other code should be placed in this package unless unavoidable.
# Burn Common
The `burn-common` package hosts code that _must_ be shared between burn packages (with `std` or
`no_std` enabled). No other code should be placed in this package unless unavoidable.
The package must build with `cargo build --no-default-features` as well.

View File

@ -4,9 +4,11 @@ use crate::rand::{get_seeded_rng, Rng, SEED};
use uuid::{Builder, Bytes};
/// Simple ID generator.
pub struct IdGenerator {}
impl IdGenerator {
/// Generates a new ID in the form of a UUID.
pub fn generate() -> String {
let mut seed = SEED.lock().unwrap();
let mut rng = if let Some(rng_seeded) = seed.as_ref() {

View File

@ -1,7 +1,18 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
//! # Burn Common Library
//!
//! This library contains common types used by other Burn crates that must be shared.
/// Id module contains types for unique identifiers.
pub mod id;
/// Rand module contains types for random number generation for non-std environments and for
/// std environments.
pub mod rand;
/// Stub module contains types for stubs for non-std environments and for std environments.
pub mod stub;
extern crate alloc;

View File

@ -9,12 +9,14 @@ use crate::stub::Mutex;
#[cfg(not(feature = "std"))]
use const_random::const_random;
/// Returns a seeded random number generator using entropy.
#[cfg(feature = "std")]
#[inline(always)]
pub fn get_seeded_rng() -> StdRng {
StdRng::from_entropy()
}
/// Returns a seeded random number generator using a pre-generated seed.
#[cfg(not(feature = "std"))]
#[inline(always)]
pub fn get_seeded_rng() -> StdRng {

View File

@ -2,13 +2,19 @@ use spin::{
Mutex as MutexImported, MutexGuard, RwLock as RwLockImported, RwLockReadGuard, RwLockWriteGuard,
};
// Mutex wrapper to make spin::Mutex API compatible with std::sync::Mutex to swap
/// A mutual exclusion primitive useful for protecting shared data
///
/// This mutex will block threads waiting for the lock to become available. The
/// mutex can also be statically initialized or created via a [Mutex::new]
///
/// [Mutex] wrapper to make `spin::Mutex` API compatible with `std::sync::Mutex` to swap
#[derive(Debug)]
pub struct Mutex<T> {
inner: MutexImported<T>,
}
impl<T> Mutex<T> {
/// Creates a new mutex in an unlocked state ready for use.
#[inline(always)]
pub const fn new(value: T) -> Self {
Self {
@ -16,19 +22,24 @@ impl<T> Mutex<T> {
}
}
/// Locks the mutex blocking the current thread until it is able to do so.
#[inline(always)]
pub fn lock(&self) -> Result<MutexGuard<T>, alloc::string::String> {
Ok(self.inner.lock())
}
}
// Mutex wrapper to make spin::Mutex API compatible with std::sync::Mutex to swap
/// A reader-writer lock which is exclusively locked for writing or shared for reading.
/// This reader-writer lock will block threads waiting for the lock to become available.
/// The lock can also be statically initialized or created via a [RwLock::new]
/// [RwLock] wrapper to make `spin::RwLock` API compatible with `std::sync::RwLock` to swap
#[derive(Debug)]
pub struct RwLock<T> {
inner: RwLockImported<T>,
}
impl<T> RwLock<T> {
/// Creates a new reader-writer lock in an unlocked state ready for use.
#[inline(always)]
pub const fn new(value: T) -> Self {
Self {
@ -36,17 +47,23 @@ impl<T> RwLock<T> {
}
}
/// Locks this rwlock with shared read access, blocking the current thread
/// until it can be acquired.
#[inline(always)]
pub fn read(&self) -> Result<RwLockReadGuard<T>, alloc::string::String> {
Ok(self.inner.read())
}
/// Locks this rwlock with exclusive write access, blocking the current thread
/// until it can be acquired.
#[inline(always)]
pub fn write(&self) -> Result<RwLockWriteGuard<T>, alloc::string::String> {
Ok(self.inner.write())
}
}
// ThreadId stub when no std is available
/// A unique identifier for a running thread.
///
/// This module is a stub when no std is available to swap with std::thread::ThreadId.
#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)]
pub struct ThreadId(core::num::NonZeroU64);

View File

@ -1,9 +1,13 @@
use alloc::{format, string::String, string::ToString};
pub use burn_derive::Config;
/// Configuration IO error.
#[derive(Debug)]
pub enum ConfigError {
/// Invalid format.
InvalidFormat(String),
/// File not found.
FileNotFound(String),
}
@ -28,12 +32,31 @@ impl core::fmt::Display for ConfigError {
#[cfg(feature = "std")]
impl std::error::Error for ConfigError {}
/// Configuration trait.
pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
/// Saves the configuration to a file.
///
/// # Arguments
///
/// * `file` - File to save the configuration to.
///
/// # Returns
///
/// The output of the save operation.
#[cfg(feature = "std")]
fn save(&self, file: &str) -> std::io::Result<()> {
std::fs::write(file, config_to_json(self))
}
/// Loads the configuration from a file.
///
/// # Arguments
///
/// * `file` - File to load the configuration from.
///
/// # Returns
///
/// The loaded configuration.
#[cfg(feature = "std")]
fn load(file: &str) -> Result<Self, ConfigError> {
let content = std::fs::read_to_string(file)
@ -41,6 +64,15 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
config_from_str(&content)
}
/// Loads the configuration from a binary buffer.
///
/// # Arguments
///
/// * `data` - Binary buffer to load the configuration from.
///
/// # Returns
///
/// The loaded configuration.
fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
let content = core::str::from_utf8(data).map_err(|_| {
ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
@ -49,6 +81,15 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
}
}
/// Converts a configuration to a JSON string.
///
/// # Arguments
///
/// * `config` - Configuration to convert.
///
/// # Returns
///
/// The JSON string.
pub fn config_to_json<C: Config>(config: &C) -> String {
serde_json::to_string_pretty(config).unwrap()
}

View File

@ -1,16 +1,24 @@
pub use crate::data::dataset::{Dataset, DatasetIterator};
use core::iter::Iterator;
/// A progress struct that can be used to track the progress of a data loader.
#[derive(Clone, Debug)]
pub struct Progress {
/// The number of items that have been processed.
pub items_processed: usize,
/// The total number of items that need to be processed.
pub items_total: usize,
}
/// A data loader iterator that can be used to iterate over a data loader.
pub trait DataLoaderIterator<O>: Iterator<Item = O> {
/// Returns the progress of the data loader.
fn progress(&self) -> Progress;
}
/// A data loader that can be used to iterate over a dataset.
pub trait DataLoader<O> {
/// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader.
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a>;
}

View File

@ -5,12 +5,14 @@ use super::{
use burn_dataset::{transform::PartialDataset, Dataset};
use std::sync::Arc;
/// A data loader that can be used to iterate over a dataset in batches.
pub struct BatchDataLoader<I, O> {
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
batcher: Arc<dyn Batcher<I, O>>,
}
/// A data loader iterator that can be used to iterate over a data loader.
struct BatchDataloaderIterator<I, O> {
current_index: usize,
strategy: Box<dyn BatchStrategy<I>>,
@ -19,6 +21,17 @@ struct BatchDataloaderIterator<I, O> {
}
impl<I, O> BatchDataLoader<I, O> {
/// Creates a new batch data loader.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
///
/// # Returns
///
/// The batch data loader.
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
@ -31,11 +44,24 @@ impl<I, O> BatchDataLoader<I, O> {
}
}
}
impl<I, O> BatchDataLoader<I, O>
where
I: Send + Sync + Clone + 'static,
O: Send + Sync + Clone + 'static,
{
/// Creates a new multi-threaded batch data loader.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
/// * `num_threads` - The number of threads.
///
/// # Returns
///
/// The multi-threaded batch data loader.
pub fn multi_thread(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,
@ -65,6 +91,17 @@ impl<I, O> DataLoader<O> for BatchDataLoader<I, O> {
}
impl<I, O> BatchDataloaderIterator<I, O> {
/// Creates a new batch data loader iterator.
///
/// # Arguments
///
/// * `strategy` - The batch strategy.
/// * `dataset` - The dataset.
/// * `batcher` - The batcher.
///
/// # Returns
///
/// The batch data loader iterator.
pub fn new(
strategy: Box<dyn BatchStrategy<I>>,
dataset: Arc<dyn Dataset<I>>,

View File

@ -1,4 +1,14 @@
/// A trait for batching items of type `I` into items of type `O`.
pub trait Batcher<I, O>: Send + Sync {
/// Batches the given items.
///
/// # Arguments
///
/// * `items` - The items to batch.
///
/// # Returns
///
/// The batched items.
fn batch(&self, items: Vec<I>) -> O;
}

View File

@ -2,6 +2,7 @@ use super::{batcher::Batcher, BatchDataLoader, BatchStrategy, DataLoader, FixBat
use burn_dataset::{transform::ShuffledDataset, Dataset};
use std::sync::Arc;
/// A builder for data loaders.
pub struct DataLoaderBuilder<I, O> {
strategy: Option<Box<dyn BatchStrategy<I>>>,
batcher: Arc<dyn Batcher<I, O>>,
@ -14,6 +15,15 @@ where
I: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Sync + Clone + std::fmt::Debug + 'static,
{
/// Creates a new data loader builder.
///
/// # Arguments
///
/// * `batcher` - The batcher.
///
/// # Returns
///
/// The data loader builder.
pub fn new<B>(batcher: B) -> Self
where
B: Batcher<I, O> + 'static,
@ -26,21 +36,58 @@ where
}
}
/// Sets the batch size to a fix number.The [fix batch strategy](FixBatchStrategy)
/// will be used.
///
/// # Arguments
///
/// * `batch_size` - The batch size.
///
/// # Returns
///
/// The data loader builder.
pub fn batch_size(mut self, batch_size: usize) -> Self {
self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
self
}
/// Sets the seed for shuffling.
///
/// # Arguments
///
/// * `seed` - The seed.
///
/// # Returns
///
/// The data loader builder.
pub fn shuffle(mut self, seed: u64) -> Self {
self.shuffle = Some(seed);
self
}
/// Sets the number of workers.
///
/// # Arguments
///
/// * `num_workers` - The number of workers.
///
/// # Returns
///
/// The data loader builder.
pub fn num_workers(mut self, num_workers: usize) -> Self {
self.num_threads = Some(num_workers);
self
}
/// Builds the data loader.
///
/// # Arguments
///
/// * `dataset` - The dataset.
///
/// # Returns
///
/// The data loader.
pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<O>>
where
D: Dataset<I> + 'static,

View File

@ -4,6 +4,7 @@ mod builder;
mod multithread;
mod strategy;
/// Module for batching items.
pub mod batcher;
pub use base::*;

View File

@ -3,15 +3,20 @@ use std::collections::HashMap;
use std::sync::{mpsc, Arc};
use std::thread;
static MAX_QUEUED_ITEMS: usize = 100;
const MAX_QUEUED_ITEMS: usize = 100;
/// A multi-threaded data loader that can be used to iterate over a dataset.
pub struct MultiThreadDataLoader<O> {
dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>,
}
/// A message that can be sent between threads.
#[derive(Debug)]
pub enum Message<O> {
/// A batch of items.
Batch(usize, O, Progress),
/// The thread is done.
Done,
}
@ -23,6 +28,15 @@ struct MultiThreadsDataloaderIterator<O> {
}
impl<O> MultiThreadDataLoader<O> {
/// Creates a new multi-threaded data loader.
///
/// # Arguments
///
/// * `dataloaders` - The data loaders.
///
/// # Returns
///
/// The multi-threaded data loader.
pub fn new(dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>) -> Self {
Self { dataloaders }
}

View File

@ -1,15 +1,47 @@
/// A strategy to batch items.
pub trait BatchStrategy<I>: Send + Sync {
/// Adds an item to the strategy.
///
/// # Arguments
///
/// * `item` - The item to add.
fn add(&mut self, item: I);
/// Batches the items.
///
/// # Arguments
///
/// * `force` - Whether to force batching.
///
/// # Returns
///
/// The batched items.
fn batch(&mut self, force: bool) -> Option<Vec<I>>;
/// Creates a new strategy of the same type.
///
/// # Returns
///
/// The new strategy.
fn new_like(&self) -> Box<dyn BatchStrategy<I>>;
}
/// A strategy to batch items with a fixed batch size.
pub struct FixBatchStrategy<I> {
items: Vec<I>,
batch_size: usize,
}
impl<I> FixBatchStrategy<I> {
/// Creates a new strategy to batch items with a fixed batch size.
///
/// # Arguments
///
/// * `batch_size` - The batch size.
///
/// # Returns
///
/// The strategy.
pub fn new(batch_size: usize) -> Self {
FixBatchStrategy {
items: Vec::with_capacity(batch_size),

View File

@ -1,4 +1,7 @@
/// Dataloader module.
pub mod dataloader;
/// Dataset module.
pub mod dataset {
pub use burn_dataset::*;
}

View File

@ -3,13 +3,22 @@ use crate as burn;
use crate::{config::Config, tensor::Tensor};
use burn_tensor::{backend::Backend, ElementConversion};
/// Gradient Clipping provides a way to mitigate exploding gradients
#[derive(Config)]
pub enum GradientClippingConfig {
/// Clip the gradient by value.
Value(f32),
/// Clip the gradient by norm.
Norm(f32),
}
impl GradientClippingConfig {
/// Initialize the gradient clipping.
///
/// # Returns
///
/// The gradient clipping.
pub fn init(&self) -> GradientClipping {
match self {
GradientClippingConfig::Value(val) => GradientClipping::Value(*val),
@ -22,11 +31,23 @@ impl GradientClippingConfig {
/// by clipping every component of the gradient by value or by norm during
/// backpropagation.
pub enum GradientClipping {
/// Clip the gradient by value.
Value(f32),
/// Clip the gradient by norm.
Norm(f32),
}
impl GradientClipping {
/// Clip the gradient.
///
/// # Arguments
///
/// * `grad` - The gradient to clip.
///
/// # Returns
///
/// The clipped gradient.
pub fn clip_gradient<B: Backend, const D: usize>(&self, grad: Tensor<B, D>) -> Tensor<B, D> {
match self {
GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold),

View File

@ -1,23 +1,39 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
//! The core crate of Burn.
#[macro_use]
extern crate derive_new;
/// The configuration module.
pub mod config;
/// Data module.
#[cfg(feature = "std")]
pub mod data;
/// Optimizer module.
#[cfg(feature = "std")]
pub mod optim;
/// Learning rate scheduler module.
#[cfg(feature = "std")]
pub mod lr_scheduler;
/// Gradient clipping module.
pub mod grad_clipping;
/// Module for the neural network module.
pub mod module;
/// Neural network module.
pub mod nn;
/// Module for the recorder.
pub mod record;
/// Module for the tensor.
pub mod tensor;
extern crate alloc;

View File

@ -1,4 +1,7 @@
/// Constant learning rate scheduler
pub mod constant;
/// Noam Learning rate schedule
pub mod noam;
mod base;

View File

@ -178,16 +178,21 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
fn into_record(self) -> Self::Record;
}
/// Module visitor trait.
pub trait ModuleVisitor<B: Backend> {
/// Visit a tensor in the module.
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
}
/// Module mapper trait.
pub trait ModuleMapper<B: Backend> {
/// Map a tensor in the module.
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
}
/// Module with auto-differentiation backend.
pub trait ADModule<B: ADBackend>: Module<B> + Send + Sync + core::fmt::Debug {
/// Inner module without auto-differentiation.
type InnerModule: Module<B::InnerBackend>;
/// Get the same module, but on the inner backend without auto-differentiation.

View File

@ -15,6 +15,11 @@ impl<T> core::fmt::Display for Param<T> {
}
impl<T: Clone> Param<T> {
/// Gets the parameter value.
///
/// # Returns
///
/// The parameter value.
pub fn val(&self) -> T {
self.value.clone()
}

View File

@ -36,7 +36,7 @@ impl Record for ConstantRecord {
item
}
}
/// Constant macro.
#[macro_export]
macro_rules! constant {
(module) => {

View File

@ -1,6 +1,7 @@
use alloc::string::{String, ToString};
use burn_common::id::IdGenerator;
/// Parameter ID.
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
pub struct ParamId {
value: String,
@ -27,11 +28,14 @@ impl Default for ParamId {
}
impl ParamId {
/// Create a new parameter ID.
pub fn new() -> Self {
Self {
value: IdGenerator::generate(),
}
}
/// Convert the parameter ID into a string.
pub fn into_string(self) -> String {
self.value
}

View File

@ -22,8 +22,12 @@ pub fn generate_autoregressive_mask<B: Backend>(
mask.equal_elem(1_i64.elem::<i64>())
}
/// Generate a padding attention mask.
pub struct GeneratePaddingMask<B: Backend> {
/// The generated tensor.
pub tensor: Tensor<B, 2, Int>,
/// The generated mask.
pub mask: Tensor<B, 2, Bool>,
}

View File

@ -6,11 +6,17 @@ pub(crate) enum CacheState<T> {
Empty,
}
/// A cache for a tensor.
pub struct TensorCache<B: Backend, const D: usize> {
pub(crate) state: CacheState<Tensor<B, D>>,
}
impl<B: Backend, const D: usize> TensorCache<B, D> {
/// Creates a new empty cache.
///
/// # Returns
///
/// The empty cache.
pub fn empty() -> Self {
Self {
state: CacheState::Empty,

View File

@ -11,27 +11,60 @@ use crate as burn;
#[derive(Config, Debug, PartialEq)]
pub enum Initializer {
/// Fills tensor with specified value everywhere
Constant { value: f64 },
Constant {
/// The value to fill the tensor with
value: f64,
},
/// Fills tensor with 1s everywhere
Ones,
/// Fills tensor with 0s everywhere
Zeros,
/// Fills tensor with values drawn uniformly between specified values
Uniform { min: f64, max: f64 },
Uniform {
/// The minimum value to draw from
min: f64,
/// The maximum value to draw from
max: f64,
},
/// Fills tensor with values drawn from normal distribution with specified mean and std
Normal { mean: f64, std: f64 },
Normal {
/// The mean of the normal distribution
mean: f64,
/// The standard deviation of the normal distribution
std: f64,
},
/// Fills tensor with values according to the uniform version of Kaiming initialization
KaimingUniform { gain: f64, fan_out_only: bool },
KaimingUniform {
/// The gain to use in initialization formula
gain: f64,
/// Whether to use fan out only in initialization formula
fan_out_only: bool,
},
/// Fills tensor with values according to the uniform version of Kaiming initialization
KaimingNormal { gain: f64, fan_out_only: bool },
KaimingNormal {
/// The gain to use in initialization formula
gain: f64,
/// Whether to use fan out only in initialization formula
fan_out_only: bool,
},
/// Fills tensor with values according to the uniform version of Xavier Glorot initialization
/// described in [Understanding the difficulty of training deep feedforward neural networks
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
XavierUniform { gain: f64 },
XavierUniform {
/// The gain to use in initialization formula
gain: f64,
},
/// Fills tensor with values according to the normal version of Xavier Glorot initialization
/// described in [Understanding the difficulty of training deep feedforward neural networks
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
XavierNormal { gain: f64 },
XavierNormal {
/// The gain to use in initialization formula
gain: f64,
},
}
impl Initializer {

View File

@ -42,6 +42,7 @@ impl<B: Backend> MSELoss<B> {
}
}
/// Compute the criterion on the input tensor without reducing.
pub fn forward_no_reduction<const D: usize>(
&self,
logits: Tensor<B, D>,

View File

@ -1,5 +1,11 @@
/// The reduction type for the loss.
pub enum Reduction {
/// The mean of the losses will be returned.
Mean,
/// The sum of the losses will be returned.
Sum,
/// The mean of the losses will be returned.
Auto,
}

View File

@ -1,8 +1,19 @@
/// Attention module
pub mod attention;
/// Cache module
pub mod cache;
/// Convolution module
pub mod conv;
/// Loss module
pub mod loss;
/// Pooling module
pub mod pool;
/// Transformer module
pub mod transformer;
mod dropout;

View File

@ -11,6 +11,7 @@ use burn_tensor::activation;
use super::gate_controller::GateController;
/// The configuration for a [gru](Gru) module.
#[derive(Config)]
pub struct GruConfig {
/// The size of the input features.

View File

@ -11,6 +11,7 @@ use burn_tensor::activation;
use super::gate_controller::GateController;
/// The configuration for a [lstm](Lstm) module.
#[derive(Config)]
pub struct LstmConfig {
/// The size of the input features.

View File

@ -1,5 +1,9 @@
mod gate_controller;
/// Gated Recurrent Unit module.
pub mod gru;
/// Long Short-Term Memory module.
pub mod lstm;
pub use gate_controller::*;

View File

@ -47,6 +47,7 @@ pub struct TransformerDecoder<B: Backend> {
}
impl TransformerDecoderConfig {
/// Initialize a new [Transformer Decoder](TransformerDecoder) module.
pub fn init<B: Backend>(&self) -> TransformerDecoder<B> {
let layers = (0..self.n_layers)
.map(|_| TransformerDecoderLayer::new(self))
@ -54,6 +55,12 @@ impl TransformerDecoderConfig {
TransformerDecoder { layers }
}
/// Initialize a new [Transformer Decoder](TransformerDecoder) module with a record.
///
/// # Params
///
/// - record: the record to initialize the module with.
pub fn init_with<B: Backend>(
&self,
record: TransformerDecoderRecord<B>,
@ -117,6 +124,7 @@ impl<B: Backend> TransformerDecoderInput<B> {
}
}
/// [Transformer Decoder](TransformerDecoder) layer module.
#[derive(Module, Debug)]
pub struct TransformerDecoderLayer<B: Backend> {
cross_attn: MultiHeadAttention<B>,

View File

@ -151,6 +151,7 @@ impl<B: Backend> TransformerEncoder<B> {
}
}
/// Transformer encoder layer module.
#[derive(Module, Debug)]
pub struct TransformerEncoderLayer<B: Backend> {
mha: MultiHeadAttention<B>,

View File

@ -12,6 +12,7 @@ use crate::optim::adaptor::OptimizerAdaptor;
use crate::tensor::{backend::ADBackend, Tensor};
use burn_tensor::{backend::Backend, ElementConversion};
/// Adam configuration.
#[derive(Config)]
pub struct AdamConfig {
/// Parameter for Adam.
@ -35,6 +36,7 @@ pub struct Adam<B: Backend> {
weight_decay: Option<WeightDecay<B>>,
}
/// Adam state.
#[derive(Record, Clone, new)]
pub struct AdamState<B: Backend, const D: usize> {
weight_decay: Option<WeightDecayState<B, D>>,
@ -84,6 +86,11 @@ impl<B: Backend> SimpleOptimizer<B> for Adam<B> {
}
impl AdamConfig {
/// Initialize Adam optimizer.
///
/// # Returns
///
/// Returns an optimizer that can be used to optimize a module.
pub fn init<B: ADBackend, M: ADModule<B>>(&self) -> impl Optimizer<M, B> {
let optim = Adam {
momentum: AdaptiveMomentum {
@ -102,6 +109,7 @@ impl AdamConfig {
}
}
/// Adaptive momentum state.
#[derive(Record, new, Clone)]
pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
time: usize,
@ -164,6 +172,15 @@ impl AdaptiveMomentum {
}
impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
/// Move state to device.
///
/// # Arguments
///
/// * `device` - Device to move state to.
///
/// # Returns
///
/// Returns state moved to device.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.moment_1 = self.moment_1.to_device(device);
self.moment_2 = self.moment_2.to_device(device);

View File

@ -13,6 +13,7 @@ pub struct WeightDecayConfig {
pub penalty: f64,
}
/// State of [WeightDecay](WeightDecay).
#[derive(Record, Clone, new)]
pub struct WeightDecayState<B: Backend, const D: usize> {
grad_last_step: Tensor<B, D>,
@ -24,12 +25,24 @@ pub struct WeightDecay<B: Backend> {
}
impl<B: Backend> WeightDecay<B> {
/// Creates a new [WeightDecay](WeightDecay) from a [WeightDecayConfig](WeightDecayConfig).
pub fn new(config: &WeightDecayConfig) -> Self {
Self {
penalty: config.penalty.elem(),
}
}
/// Transforms a gradient.
///
/// # Arguments
///
/// * `grad` - Gradient to transform.
/// * `state` - State of the optimizer.
///
/// # Returns
///
/// * `grad` - Transformed gradient.
/// * `state` - State of the optimizer.
pub fn transform<const D: usize>(
&self,
grad: Tensor<B, D>,
@ -47,6 +60,15 @@ impl<B: Backend> WeightDecay<B> {
}
impl<B: Backend, const D: usize> WeightDecayState<B, D> {
/// Moves the state to a device.
///
/// # Arguments
///
/// * `device` - Device to move the state to.
///
/// # Returns
///
/// * `self` - Moved state.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.grad_last_step = self.grad_last_step.to_device(device);
self

View File

@ -15,6 +15,7 @@ pub struct GradientsParams {
}
impl GradientsParams {
/// Creates a new [GradientsParams](GradientsParams).
pub fn new() -> Self {
Self::default()
}

View File

@ -1,4 +1,7 @@
/// Weight decay module for optimizers.
pub mod decay;
/// Momentum module for optimizers.
pub mod momentum;
mod adam;

View File

@ -14,11 +14,13 @@ pub struct MomentumConfig {
/// Dampening factor.
#[config(default = 0.1)]
pub dampening: f64,
/// Enables Nesterov momentum, see [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).
/// Enables Nesterov momentum, see [On the importance of initialization and
/// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).
#[config(default = false)]
pub nesterov: bool,
}
/// State of [Momentum](Momentum).
#[derive(Record, Clone, new)]
pub struct MomemtumState<B: Backend, const D: usize> {
velocity: Tensor<B, D>,
@ -32,6 +34,7 @@ pub struct Momentum<B: Backend> {
}
impl<B: Backend> Momentum<B> {
/// Creates a new [Momentum](Momentum) from a [MomentumConfig](MomentumConfig).
pub fn new(config: &MomentumConfig) -> Self {
Self {
momentum: config.momentum.elem(),
@ -40,6 +43,17 @@ impl<B: Backend> Momentum<B> {
}
}
/// Transforms a gradient.
///
/// # Arguments
///
/// * `grad` - Gradient to transform.
/// * `state` - State of the optimizer.
///
/// # Returns
///
/// * `grad` - Transformed gradient.
/// * `state` - State of the optimizer.
pub fn transform<const D: usize>(
&self,
grad: Tensor<B, D>,
@ -63,6 +77,15 @@ impl<B: Backend> Momentum<B> {
}
impl<B: Backend, const D: usize> MomemtumState<B, D> {
/// Moves the state to a device.
///
/// # Arguments
///
/// * `device` - Device to move the state to.
///
/// # Returns
///
/// * `self` - Moved state.
pub fn to_device(mut self, device: &B::Device) -> Self {
self.velocity = self.velocity.to_device(device);
self

View File

@ -30,6 +30,7 @@ pub struct Sgd<B: Backend> {
weight_decay: Option<WeightDecay<B>>,
}
/// State of [Sgd](Sgd).
#[derive(Record, Clone, new)]
pub struct SgdState<B: Backend, const D: usize> {
weight_decay: Option<WeightDecayState<B, D>>,
@ -37,6 +38,7 @@ pub struct SgdState<B: Backend, const D: usize> {
}
impl SgdConfig {
/// Creates a new [SgdConfig](SgdConfig) with default values.
pub fn init<B: ADBackend, M: ADModule<B>>(
&self,
) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {

View File

@ -45,13 +45,22 @@ where
M: ADModule<B>,
B: ADBackend,
{
/// Sets the gradient clipping.
///
/// # Arguments
///
/// * `gradient_clipping` - The gradient clipping.
///
/// # Returns
///
/// The optimizer.
pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
self.grad_clipping = Some(gradient_clipping);
self
}
#[cfg(test)]
pub fn has_gradient_clipping(&self) -> bool {
pub(crate) fn has_gradient_clipping(&self) -> bool {
self.grad_clipping.is_some()
}
}

View File

@ -1,5 +1,8 @@
mod base;
pub use base::*;
/// Adaptor module for optimizers.
pub mod adaptor;
/// Record module for optimizers.
pub mod record;

View File

@ -10,12 +10,15 @@ use serde::{Deserialize, Serialize};
///
/// Records are versioned for backward compatibility, so old records can be loaded.
pub enum AdaptorRecord<O: SimpleOptimizer<B>, B: Backend> {
/// Version 1.
V1(AdaptorRecordV1<O, B>),
}
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub enum AdaptorRecordItem<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
/// Version 1.
V1(AdaptorRecordItemV1<O, B, S>),
}
@ -56,12 +59,26 @@ where
O: SimpleOptimizer<B>,
B: Backend,
{
/// Converts the record into the optimizer state.
///
/// # Returns
///
/// The optimizer state.
pub fn into_state<const D: usize>(self) -> O::State<D> {
match self {
AdaptorRecord::V1(record) => record.into_state(),
}
}
/// Converts the optimizer state into the record.
///
/// # Arguments
///
/// * `state`: The optimizer state.
///
/// # Returns
///
/// The record.
pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
Self::V1(AdaptorRecordV1::from_state(state))
}

View File

@ -6,14 +6,30 @@ use burn_tensor::backend::Backend;
use core::any::Any;
use serde::{Deserialize, Serialize};
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
/// Rank 1.
Rank1(O::State<1>),
/// Rank 2.
Rank2(O::State<2>),
/// Rank 3.
Rank3(O::State<3>),
/// Rank 4.
Rank4(O::State<4>),
/// Rank 5.
Rank5(O::State<5>),
/// Rank 6.
Rank6(O::State<6>),
/// Rank 7.
Rank7(O::State<7>),
/// Rank 8.
Rank8(O::State<8>),
}
@ -32,16 +48,32 @@ impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
}
}
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
#[derive(Serialize, Deserialize)]
#[serde(bound = "")]
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
/// Rank 1.
Rank1(<O::State<1> as Record>::Item<S>),
/// Rank 2.
Rank2(<O::State<2> as Record>::Item<S>),
/// Rank 3.
Rank3(<O::State<3> as Record>::Item<S>),
/// Rank 4.
Rank4(<O::State<4> as Record>::Item<S>),
/// Rank 5.
Rank5(<O::State<5> as Record>::Item<S>),
/// Rank 6.
Rank6(<O::State<6> as Record>::Item<S>),
/// Rank 7.
Rank7(<O::State<7> as Record>::Item<S>),
/// Rank 8.
Rank8(<O::State<8> as Record>::Item<S>),
}
@ -50,6 +82,15 @@ where
O: SimpleOptimizer<B>,
B: Backend,
{
/// Convert the record into the state.
///
/// # Returns
///
/// The state.
///
/// # Panics
///
/// Panics if the state dimension is not supported.
pub fn into_state<const D: usize>(self) -> O::State<D> {
let boxed_state: Box<dyn Any> = match self {
AdaptorRecordV1::Rank1(s) => Box::new(s),
@ -66,6 +107,16 @@ where
.expect("Unsupported state dimension, dimension up to 8 are supported.");
*state
}
/// Convert the state into the record.
///
/// # Arguments
///
/// * `state`: The state.
///
/// # Returns
///
/// The record.
pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
let state: Box<dyn Any> = Box::new(state);

View File

@ -5,10 +5,12 @@ use serde::{de::DeserializeOwned, Serialize};
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
pub trait Record: Send + Sync {
/// Type of the item that can be serialized and deserialized.
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
/// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
/// Convert the given item into a record.
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self;
}

View File

@ -8,6 +8,7 @@ use std::{fs::File, path::PathBuf};
pub trait FileRecorder:
Recorder<RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
{
/// File extension of the format used by the recorder.
fn file_extension() -> &'static str;
}

View File

@ -134,6 +134,7 @@ primitive!(f64);
primitive!(f32);
// TODO: Remove the feature flag when half supports serde with no_std
// https://github.com/burn-rs/burn/issues/268 tracking issue
#[cfg(feature = "std")]
primitive!(half::bf16);
#[cfg(feature = "std")]

View File

@ -13,15 +13,28 @@ use super::{
/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Clone {
/// Type of the settings used by the recorder.
type Settings: PrecisionSettings;
/// Arguments used to record objects.
type RecordArgs: Clone;
/// Record output type.
type RecordOutput;
/// Arguments used to load recorded objects.
type LoadArgs: Clone;
/// Record an item with the given arguments.
/// Records an item.
///
/// # Arguments
///
/// * `record` - The item to record.
/// * `args` - Arguments used to record the item.
///
/// # Returns
///
/// The output of the recording.
fn record<R: Record>(
&self,
record: R,
@ -80,11 +93,35 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
Ok(R::from_item(item.item))
}
/// Saves an item.
///
/// This method is used by [record](Recorder::record) to save the item.
///
/// # Arguments
///
/// * `item` - Item to save.
/// * `args` - Arguments to use to save the item.
///
/// # Returns
///
/// The output of the save operation.
fn save_item<I: Serialize>(
&self,
item: I,
args: Self::RecordArgs,
) -> Result<Self::RecordOutput, RecorderError>;
/// Loads an item.
///
/// This method is used by [load](Recorder::load) to load the item.
///
/// # Arguments
///
/// * `args` - Arguments to use to load the item.
///
/// # Returns
///
/// The loaded item.
fn load_item<I: DeserializeOwned>(&self, args: Self::LoadArgs) -> Result<I, RecorderError>;
}
@ -98,9 +135,13 @@ fn recorder_metadata<R: Recorder>() -> BurnMetadata {
)
}
/// Error that can occur when using a [Recorder](Recorder).
#[derive(Debug)]
pub enum RecorderError {
/// File not found.
FileNotFound(String),
/// Other error.
Unknown(String),
}
@ -118,22 +159,45 @@ pub(crate) fn bin_config() -> bincode::config::Configuration {
bincode::config::standard()
}
/// Metadata of a record.
#[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct BurnMetadata {
/// Float type used to record the item.
pub float: String,
/// Int type used to record the item.
pub int: String,
/// Format used to record the item.
pub format: String,
/// Burn record version used to record the item.
pub version: String,
/// Settings used to record the item.
pub settings: String,
}
/// Record that can be saved by a [Recorder](Recorder).
#[derive(Serialize, Deserialize)]
pub struct BurnRecord<I> {
/// Metadata of the record.
pub metadata: BurnMetadata,
/// Item to record.
pub item: I,
}
impl<I> BurnRecord<I> {
/// Creates a new record.
///
/// # Arguments
///
/// * `item` - Item to record.
///
/// # Returns
///
/// The new record.
pub fn new<R: Recorder>(item: I) -> Self {
let metadata = recorder_metadata::<R>();
@ -141,8 +205,10 @@ impl<I> BurnRecord<I> {
}
}
/// Record that can be saved by a [Recorder](Recorder) without the item.
#[derive(new, Debug, Serialize, Deserialize)]
pub struct BurnRecordNoItem {
/// Metadata of the record.
pub metadata: BurnMetadata,
}

View File

@ -5,7 +5,10 @@ use serde::{de::DeserializeOwned, Serialize};
pub trait PrecisionSettings:
Send + Sync + core::fmt::Debug + core::default::Default + Clone
{
/// Float element type.
type FloatElem: Element + Serialize + DeserializeOwned;
/// Integer element type.
type IntElem: Element + Serialize + DeserializeOwned;
}

View File

@ -12,7 +12,9 @@ type MappedDataset = MapperDataset<SqliteDataset<SpeechItemRaw>, ConvertSamples,
/// Enum representing speech command classes in the Speech Commands dataset.
/// Class names are based on the Speech Commands dataset from Huggingface.
/// See: https://huggingface.co/datasets/speech_commands
/// See [speech_commands](https://huggingface.co/datasets/speech_commands)
/// for more information.
#[allow(missing_docs)]
#[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)]
pub enum SpeechCommandClass {
// Target command words
@ -66,8 +68,13 @@ pub enum SpeechCommandClass {
/// Struct containing raw speech data returned from a database.
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SpeechItemRaw {
/// Audio file bytes.
pub audio_bytes: Vec<u8>,
/// Label index.
pub label: usize,
/// Indicates if the label is unknown.
pub is_unknown: bool,
}

View File

@ -4,11 +4,18 @@ use crate::DatasetIterator;
/// The dataset trait defines a basic collection of items with a predefined size.
pub trait Dataset<I>: Send + Sync {
/// Gets the item at the given index.
fn get(&self, index: usize) -> Option<I>;
/// Gets the number of items in the dataset.
fn len(&self) -> usize;
/// Checks if the dataset is empty.
fn is_empty(&self) -> bool {
self.len() == 0
}
/// Returns an iterator over the dataset.
fn iter(&self) -> DatasetIterator<'_, I>
where
Self: Sized,

View File

@ -7,6 +7,7 @@ pub struct FakeDataset<I> {
}
impl<I: Dummy<Faker>> FakeDataset<I> {
/// Create a new fake dataset with the given size.
pub fn new(size: usize) -> Self {
let mut items = Vec::with_capacity(size);
for _ in 0..size {

View File

@ -14,6 +14,7 @@ pub struct InMemDataset<I> {
}
impl<I> InMemDataset<I> {
/// Creates a new in memory dataset from the given items.
pub fn new(items: Vec<I>) -> Self {
InMemDataset { items }
}

View File

@ -8,6 +8,7 @@ pub struct DatasetIterator<'a, I> {
}
impl<'a, I> DatasetIterator<'a, I> {
/// Creates a new dataset iterator.
pub fn new<D>(dataset: &'a D) -> Self
where
D: Dataset<I>,

View File

@ -21,28 +21,37 @@ use sanitize_filename::sanitize;
use serde::{de::DeserializeOwned, Serialize};
use serde_rusqlite::{columns_from_statement, from_row_with_columns};
/// Result type for the sqlite dataset.
pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
/// Sqlite dataset error.
#[derive(thiserror::Error, Debug)]
pub enum SqliteDatasetError {
/// IO related error.
#[error("IO error: {0}")]
Io(#[from] io::Error),
/// Sql related error.
#[error("Sql error: {0}")]
Sql(#[from] serde_rusqlite::rusqlite::Error),
/// Serde related error.
#[error("Serde error: {0}")]
Serde(#[from] rmp_serde::encode::Error),
/// The database file already exists error.
#[error("Overwrite flag is set to false and the database file already exists: {0}")]
FileExists(PathBuf),
/// Error when creating the connection pool.
#[error("Failed to create connection pool: {0}")]
ConnectionPool(#[from] r2d2::Error),
/// Error when persisting the temporary database file.
#[error("Could not persist the temporary database file: {0}")]
PersistDbFile(#[from] persist::Error<Writable>),
/// Any other error.
#[error("{0}")]
Other(&'static str),
}

View File

@ -1,11 +1,21 @@
#![warn(missing_docs)]
//! # Burn Dataset
//!
//! Burn Dataset is a library for creating and loading datasets.
#[macro_use]
extern crate derive_new;
extern crate dirs;
/// Sources for datasets.
pub mod source;
/// Transformations to be used with datasets.
pub mod transform;
/// Audio datasets.
#[cfg(feature = "audio")]
pub mod audio;

View File

@ -11,13 +11,18 @@ use thiserror::Error;
const PYTHON: &str = "python3";
const PYTHON_SOURCE: &str = include_str!("importer.py");
/// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader).
#[derive(Error, Debug)]
pub enum ImporterError {
/// Unknown error.
#[error("unknown: `{0}`")]
Unknown(String),
/// Fail to download python dependencies.
#[error("fail to download python dependencies: `{0}`")]
FailToDownloadPythonDependencies(String),
/// Fail to create sqlite dataset.
#[error("sqlite dataset: `{0}`")]
SqliteDataset(#[from] SqliteDatasetError),
}

View File

@ -8,9 +8,13 @@ use serde::{Deserialize, Serialize};
const WIDTH: usize = 28;
const HEIGHT: usize = 28;
/// MNIST item.
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct MNISTItem {
/// Image as a 2D array of floats.
pub image: [[f32; WIDTH]; HEIGHT],
/// Label of the image.
pub label: usize,
}
@ -66,10 +70,12 @@ impl Dataset<MNISTItem> for MNISTDataset {
}
impl MNISTDataset {
/// Creates a new train dataset.
pub fn train() -> Self {
Self::new("train")
}
/// Creates a new test dataset.
pub fn test() -> Self {
Self::new("test")
}

View File

@ -1,4 +1,4 @@
pub mod downloader;
pub(crate) mod downloader;
mod mnist;
pub use downloader::*;

View File

@ -1 +1,2 @@
/// Huggingface source
pub mod huggingface;

View File

@ -3,6 +3,7 @@ use std::marker::PhantomData;
/// Basic mapper trait to be used with the [mapper dataset](MapperDataset).
pub trait Mapper<I, O>: Send + Sync {
/// Maps an item of type I to an item of type O.
fn map(&self, item: &I) -> O;
}

View File

@ -14,6 +14,7 @@ impl<D, I> PartialDataset<D, I>
where
D: Dataset<I>,
{
/// Splits a dataset into multiple partial datasets.
pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
let dataset = Arc::new(dataset); // cheap cloning.

View File

@ -14,6 +14,7 @@ impl<D, I> ShuffledDataset<D, I>
where
D: Dataset<I>,
{
/// Creates a new shuffled dataset.
pub fn new(dataset: D, rng: &mut StdRng) -> Self {
let mut indexes = Vec::with_capacity(dataset.len());
for i in 0..dataset.len() {
@ -28,6 +29,7 @@ where
}
}
/// Creates a new shuffled dataset with a fixed seed.
pub fn with_seed(dataset: D, seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
Self::new(dataset, &mut rng)

View File

@ -17,6 +17,7 @@ where
D: Dataset<I>,
I: Send + Sync,
{
/// Creates a new sampler dataset.
pub fn new(dataset: D, size: usize) -> Self {
let rng = Mutex::new(StdRng::from_entropy());
@ -28,6 +29,7 @@ where
}
}
/// Generates random index using uniform distribution (0, dataset.len()).
fn index(&self) -> usize {
let mut rng = self.rng.lock().unwrap();
rng.sample(Uniform::new(0, self.dataset.len()))

View File

@ -190,6 +190,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
}
let body = quote! {
/// Create a new instance of the config.
pub fn new(
#(#names),*
) -> Self {
@ -208,6 +209,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
let fn_name = Ident::new(&format!("with_{name}"), name.span());
body.extend(quote! {
/// Set the default value for the field.
pub fn #fn_name(mut self, #name: #ty) -> Self {
self.#name = #name;
self
@ -221,6 +223,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
let fn_name = Ident::new(&format!("with_{name}"), name.span());
body.extend(quote! {
/// Set the default value for the field.
pub fn #fn_name(mut self, #name: #ty) -> Self {
self.#name = #name;
self

View File

@ -1,3 +1,7 @@
#![warn(missing_docs)]
//! The derive crate of Burn.
use proc_macro::TokenStream;
pub(crate) mod config;
@ -9,18 +13,21 @@ use config::config_attr_impl;
use module::module_derive_impl;
use record::record_derive_impl;
/// Derive macro for the module.
#[proc_macro_derive(Module)]
pub fn module_derive(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
module_derive_impl(&input)
}
/// Derive macro for the record.
#[proc_macro_derive(Record)]
pub fn record_derive(input: TokenStream) -> TokenStream {
let input = syn::parse(input).unwrap();
record_derive_impl(&input)
}
/// Derive macro for the config.
#[proc_macro_derive(Config, attributes(config))]
pub fn config_derive(input: TokenStream) -> TokenStream {
let item = syn::parse(input).unwrap();

View File

@ -26,6 +26,7 @@ impl ModuleRecordGenerator {
let name = &field.field.ident;
fields.extend(quote! {
/// The #name field.
pub #name: <#ty as burn::module::Module<B>>::Record,
});
}
@ -33,6 +34,8 @@ impl ModuleRecordGenerator {
let generics = &self.generics;
quote! {
/// The record type for the module.
#[derive(burn::record::Record, Debug, Clone)]
pub struct #name #generics {
#fields

View File

@ -31,6 +31,7 @@ impl RecordGenerator {
let name = &field.field.ident;
fields.extend(quote! {
/// The #name field.
pub #name: <#ty as burn::record::Record>::Item<S>,
});
bounds.extend(quote!{
@ -42,6 +43,8 @@ impl RecordGenerator {
let bound = bounds.to_string();
quote! {
/// The record item type for the module.
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
#[serde(bound = #bound)]
pub struct #name #generics {

View File

@ -1,7 +1,15 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
//! # Burn
//! This library strives to serve as a comprehensive **deep learning framework**,
//! offering exceptional flexibility and written in Rust. The main objective is to cater
//! to both researchers and practitioners by simplifying the process of experimenting,
//! training, and deploying models.
pub use burn_core::*;
/// Train module
#[cfg(feature = "train")]
pub mod train {
pub use burn_train::*;

View File

@ -5,9 +5,7 @@
// the TextClassificationDataset trait. These implementations are designed to be used
// with a machine learning framework for tasks such as training a text classification model.
use burn::data::dataset::{
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset,
};
use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset, SqliteDataset};
// Define a struct for text classification items
#[derive(new, Clone, Debug)]

View File

@ -1,6 +1,4 @@
use burn::data::dataset::{
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset,
};
use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset, SqliteDataset};
#[derive(new, Clone, Debug)]
pub struct TextGenerationItem {