mirror of https://github.com/tracel-ai/burn.git
Add missing docs and enable missing_docs warn lint (#420)
This commit is contained in:
parent
c4e4c25fef
commit
eda241f8cf
|
@ -1,6 +1,7 @@
|
|||
use crate::{grads::Gradients, graph::backward::backward, tensor::ADTensor};
|
||||
use burn_tensor::backend::{ADBackend, Backend};
|
||||
|
||||
/// A decorator for a backend that enables automatic differentiation.
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct ADBackendDecorator<B> {
|
||||
_b: B,
|
||||
|
|
|
@ -5,6 +5,7 @@ use crate::{
|
|||
tensor::ADTensor,
|
||||
};
|
||||
|
||||
/// Gradient identifier.
|
||||
pub type GradID = String;
|
||||
|
||||
/// Gradients container used during the backward pass.
|
||||
|
@ -15,6 +16,7 @@ pub struct Gradients {
|
|||
type TensorPrimitive<B, const D: usize> = <B as Backend>::TensorPrimitive<D>;
|
||||
|
||||
impl Gradients {
|
||||
/// Creates a new gradients container.
|
||||
pub fn new<B: Backend, const D: usize>(
|
||||
root_node: NodeRef,
|
||||
root_tensor: TensorPrimitive<B, D>,
|
||||
|
@ -28,7 +30,8 @@ impl Gradients {
|
|||
);
|
||||
gradients
|
||||
}
|
||||
/// Consume the gradients for a given tensor.
|
||||
|
||||
/// Consumes the gradients for a given tensor.
|
||||
///
|
||||
/// Each tensor should be consumed exactly 1 time if its gradients are only required during the
|
||||
/// backward pass, otherwise, it may be consume multiple times.
|
||||
|
@ -48,7 +51,7 @@ impl Gradients {
|
|||
}
|
||||
}
|
||||
|
||||
/// Remove a grad tensor from the container.
|
||||
/// Removes a grad tensor from the container.
|
||||
pub fn remove<B: Backend, const D: usize>(
|
||||
&mut self,
|
||||
tensor: &ADTensor<B, D>,
|
||||
|
@ -58,6 +61,7 @@ impl Gradients {
|
|||
.map(|tensor| tensor.into_primitive())
|
||||
}
|
||||
|
||||
/// Gets a grad tensor from the container.
|
||||
pub fn get<B: Backend, const D: usize>(
|
||||
&self,
|
||||
tensor: &ADTensor<B, D>,
|
||||
|
@ -67,6 +71,7 @@ impl Gradients {
|
|||
.map(|tensor| tensor.into_primitive())
|
||||
}
|
||||
|
||||
/// Registers a grad tensor in the container.
|
||||
pub fn register<B: Backend, const D: usize>(
|
||||
&mut self,
|
||||
node: NodeRef,
|
||||
|
|
|
@ -1,3 +1,12 @@
|
|||
#![warn(missing_docs)]
|
||||
|
||||
//! # Burn Autodiff
|
||||
//!
|
||||
//! This autodiff library is a part of the Burn project. It is a standalone crate
|
||||
//! that can be used to perform automatic differentiation on tensors. It is
|
||||
//! designed to be used with the Burn Tensor crate, but it can be used with any
|
||||
//! tensor library that implements the `Backend` trait.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
#![allow(missing_docs)]
|
||||
|
||||
mod add;
|
||||
mod aggregation;
|
||||
mod avgpool1d;
|
||||
|
|
|
@ -1,3 +1,6 @@
|
|||
The `burn-common` package hosts code that _must_ be shared between burn packages (with `std` or `no_std` enabled). No other code should be placed in this package unless unavoidable.
|
||||
# Burn Common
|
||||
|
||||
The `burn-common` package hosts code that _must_ be shared between burn packages (with `std` or
|
||||
`no_std` enabled). No other code should be placed in this package unless unavoidable.
|
||||
|
||||
The package must build with `cargo build --no-default-features` as well.
|
||||
|
|
|
@ -4,9 +4,11 @@ use crate::rand::{get_seeded_rng, Rng, SEED};
|
|||
|
||||
use uuid::{Builder, Bytes};
|
||||
|
||||
/// Simple ID generator.
|
||||
pub struct IdGenerator {}
|
||||
|
||||
impl IdGenerator {
|
||||
/// Generates a new ID in the form of a UUID.
|
||||
pub fn generate() -> String {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng = if let Some(rng_seeded) = seed.as_ref() {
|
||||
|
|
|
@ -1,7 +1,18 @@
|
|||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! # Burn Common Library
|
||||
//!
|
||||
//! This library contains common types used by other Burn crates that must be shared.
|
||||
|
||||
/// Id module contains types for unique identifiers.
|
||||
pub mod id;
|
||||
|
||||
/// Rand module contains types for random number generation for non-std environments and for
|
||||
/// std environments.
|
||||
pub mod rand;
|
||||
|
||||
/// Stub module contains types for stubs for non-std environments and for std environments.
|
||||
pub mod stub;
|
||||
|
||||
extern crate alloc;
|
||||
|
|
|
@ -9,12 +9,14 @@ use crate::stub::Mutex;
|
|||
#[cfg(not(feature = "std"))]
|
||||
use const_random::const_random;
|
||||
|
||||
/// Returns a seeded random number generator using entropy.
|
||||
#[cfg(feature = "std")]
|
||||
#[inline(always)]
|
||||
pub fn get_seeded_rng() -> StdRng {
|
||||
StdRng::from_entropy()
|
||||
}
|
||||
|
||||
/// Returns a seeded random number generator using a pre-generated seed.
|
||||
#[cfg(not(feature = "std"))]
|
||||
#[inline(always)]
|
||||
pub fn get_seeded_rng() -> StdRng {
|
||||
|
|
|
@ -2,13 +2,19 @@ use spin::{
|
|||
Mutex as MutexImported, MutexGuard, RwLock as RwLockImported, RwLockReadGuard, RwLockWriteGuard,
|
||||
};
|
||||
|
||||
// Mutex wrapper to make spin::Mutex API compatible with std::sync::Mutex to swap
|
||||
/// A mutual exclusion primitive useful for protecting shared data
|
||||
///
|
||||
/// This mutex will block threads waiting for the lock to become available. The
|
||||
/// mutex can also be statically initialized or created via a [Mutex::new]
|
||||
///
|
||||
/// [Mutex] wrapper to make `spin::Mutex` API compatible with `std::sync::Mutex` to swap
|
||||
#[derive(Debug)]
|
||||
pub struct Mutex<T> {
|
||||
inner: MutexImported<T>,
|
||||
}
|
||||
|
||||
impl<T> Mutex<T> {
|
||||
/// Creates a new mutex in an unlocked state ready for use.
|
||||
#[inline(always)]
|
||||
pub const fn new(value: T) -> Self {
|
||||
Self {
|
||||
|
@ -16,19 +22,24 @@ impl<T> Mutex<T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Locks the mutex blocking the current thread until it is able to do so.
|
||||
#[inline(always)]
|
||||
pub fn lock(&self) -> Result<MutexGuard<T>, alloc::string::String> {
|
||||
Ok(self.inner.lock())
|
||||
}
|
||||
}
|
||||
|
||||
// Mutex wrapper to make spin::Mutex API compatible with std::sync::Mutex to swap
|
||||
/// A reader-writer lock which is exclusively locked for writing or shared for reading.
|
||||
/// This reader-writer lock will block threads waiting for the lock to become available.
|
||||
/// The lock can also be statically initialized or created via a [RwLock::new]
|
||||
/// [RwLock] wrapper to make `spin::RwLock` API compatible with `std::sync::RwLock` to swap
|
||||
#[derive(Debug)]
|
||||
pub struct RwLock<T> {
|
||||
inner: RwLockImported<T>,
|
||||
}
|
||||
|
||||
impl<T> RwLock<T> {
|
||||
/// Creates a new reader-writer lock in an unlocked state ready for use.
|
||||
#[inline(always)]
|
||||
pub const fn new(value: T) -> Self {
|
||||
Self {
|
||||
|
@ -36,17 +47,23 @@ impl<T> RwLock<T> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Locks this rwlock with shared read access, blocking the current thread
|
||||
/// until it can be acquired.
|
||||
#[inline(always)]
|
||||
pub fn read(&self) -> Result<RwLockReadGuard<T>, alloc::string::String> {
|
||||
Ok(self.inner.read())
|
||||
}
|
||||
|
||||
/// Locks this rwlock with exclusive write access, blocking the current thread
|
||||
/// until it can be acquired.
|
||||
#[inline(always)]
|
||||
pub fn write(&self) -> Result<RwLockWriteGuard<T>, alloc::string::String> {
|
||||
Ok(self.inner.write())
|
||||
}
|
||||
}
|
||||
|
||||
// ThreadId stub when no std is available
|
||||
/// A unique identifier for a running thread.
|
||||
///
|
||||
/// This module is a stub when no std is available to swap with std::thread::ThreadId.
|
||||
#[derive(Eq, PartialEq, Clone, Copy, Hash, Debug)]
|
||||
pub struct ThreadId(core::num::NonZeroU64);
|
||||
|
|
|
@ -1,9 +1,13 @@
|
|||
use alloc::{format, string::String, string::ToString};
|
||||
pub use burn_derive::Config;
|
||||
|
||||
/// Configuration IO error.
|
||||
#[derive(Debug)]
|
||||
pub enum ConfigError {
|
||||
/// Invalid format.
|
||||
InvalidFormat(String),
|
||||
|
||||
/// File not found.
|
||||
FileNotFound(String),
|
||||
}
|
||||
|
||||
|
@ -28,12 +32,31 @@ impl core::fmt::Display for ConfigError {
|
|||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for ConfigError {}
|
||||
|
||||
/// Configuration trait.
|
||||
pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
|
||||
/// Saves the configuration to a file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file` - File to save the configuration to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output of the save operation.
|
||||
#[cfg(feature = "std")]
|
||||
fn save(&self, file: &str) -> std::io::Result<()> {
|
||||
std::fs::write(file, config_to_json(self))
|
||||
}
|
||||
|
||||
/// Loads the configuration from a file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file` - File to load the configuration from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The loaded configuration.
|
||||
#[cfg(feature = "std")]
|
||||
fn load(file: &str) -> Result<Self, ConfigError> {
|
||||
let content = std::fs::read_to_string(file)
|
||||
|
@ -41,6 +64,15 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
|
|||
config_from_str(&content)
|
||||
}
|
||||
|
||||
/// Loads the configuration from a binary buffer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - Binary buffer to load the configuration from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The loaded configuration.
|
||||
fn load_binary(data: &[u8]) -> Result<Self, ConfigError> {
|
||||
let content = core::str::from_utf8(data).map_err(|_| {
|
||||
ConfigError::InvalidFormat("Could not parse data as utf-8.".to_string())
|
||||
|
@ -49,6 +81,15 @@ pub trait Config: serde::Serialize + serde::de::DeserializeOwned {
|
|||
}
|
||||
}
|
||||
|
||||
/// Converts a configuration to a JSON string.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - Configuration to convert.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The JSON string.
|
||||
pub fn config_to_json<C: Config>(config: &C) -> String {
|
||||
serde_json::to_string_pretty(config).unwrap()
|
||||
}
|
||||
|
|
|
@ -1,16 +1,24 @@
|
|||
pub use crate::data::dataset::{Dataset, DatasetIterator};
|
||||
use core::iter::Iterator;
|
||||
|
||||
/// A progress struct that can be used to track the progress of a data loader.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Progress {
|
||||
/// The number of items that have been processed.
|
||||
pub items_processed: usize,
|
||||
|
||||
/// The total number of items that need to be processed.
|
||||
pub items_total: usize,
|
||||
}
|
||||
|
||||
/// A data loader iterator that can be used to iterate over a data loader.
|
||||
pub trait DataLoaderIterator<O>: Iterator<Item = O> {
|
||||
/// Returns the progress of the data loader.
|
||||
fn progress(&self) -> Progress;
|
||||
}
|
||||
|
||||
/// A data loader that can be used to iterate over a dataset.
|
||||
pub trait DataLoader<O> {
|
||||
/// Returns a boxed [iterator](DataLoaderIterator) to iterate over the data loader.
|
||||
fn iter<'a>(&'a self) -> Box<dyn DataLoaderIterator<O> + 'a>;
|
||||
}
|
||||
|
|
|
@ -5,12 +5,14 @@ use super::{
|
|||
use burn_dataset::{transform::PartialDataset, Dataset};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A data loader that can be used to iterate over a dataset in batches.
|
||||
pub struct BatchDataLoader<I, O> {
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
batcher: Arc<dyn Batcher<I, O>>,
|
||||
}
|
||||
|
||||
/// A data loader iterator that can be used to iterate over a data loader.
|
||||
struct BatchDataloaderIterator<I, O> {
|
||||
current_index: usize,
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
|
@ -19,6 +21,17 @@ struct BatchDataloaderIterator<I, O> {
|
|||
}
|
||||
|
||||
impl<I, O> BatchDataLoader<I, O> {
|
||||
/// Creates a new batch data loader.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `strategy` - The batch strategy.
|
||||
/// * `dataset` - The dataset.
|
||||
/// * `batcher` - The batcher.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The batch data loader.
|
||||
pub fn new(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
|
@ -31,11 +44,24 @@ impl<I, O> BatchDataLoader<I, O> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<I, O> BatchDataLoader<I, O>
|
||||
where
|
||||
I: Send + Sync + Clone + 'static,
|
||||
O: Send + Sync + Clone + 'static,
|
||||
{
|
||||
/// Creates a new multi-threaded batch data loader.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `strategy` - The batch strategy.
|
||||
/// * `dataset` - The dataset.
|
||||
/// * `batcher` - The batcher.
|
||||
/// * `num_threads` - The number of threads.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The multi-threaded batch data loader.
|
||||
pub fn multi_thread(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
|
@ -65,6 +91,17 @@ impl<I, O> DataLoader<O> for BatchDataLoader<I, O> {
|
|||
}
|
||||
|
||||
impl<I, O> BatchDataloaderIterator<I, O> {
|
||||
/// Creates a new batch data loader iterator.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `strategy` - The batch strategy.
|
||||
/// * `dataset` - The dataset.
|
||||
/// * `batcher` - The batcher.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The batch data loader iterator.
|
||||
pub fn new(
|
||||
strategy: Box<dyn BatchStrategy<I>>,
|
||||
dataset: Arc<dyn Dataset<I>>,
|
||||
|
|
|
@ -1,4 +1,14 @@
|
|||
/// A trait for batching items of type `I` into items of type `O`.
|
||||
pub trait Batcher<I, O>: Send + Sync {
|
||||
/// Batches the given items.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `items` - The items to batch.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The batched items.
|
||||
fn batch(&self, items: Vec<I>) -> O;
|
||||
}
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::{batcher::Batcher, BatchDataLoader, BatchStrategy, DataLoader, FixBat
|
|||
use burn_dataset::{transform::ShuffledDataset, Dataset};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A builder for data loaders.
|
||||
pub struct DataLoaderBuilder<I, O> {
|
||||
strategy: Option<Box<dyn BatchStrategy<I>>>,
|
||||
batcher: Arc<dyn Batcher<I, O>>,
|
||||
|
@ -14,6 +15,15 @@ where
|
|||
I: Send + Sync + Clone + std::fmt::Debug + 'static,
|
||||
O: Send + Sync + Clone + std::fmt::Debug + 'static,
|
||||
{
|
||||
/// Creates a new data loader builder.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batcher` - The batcher.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader builder.
|
||||
pub fn new<B>(batcher: B) -> Self
|
||||
where
|
||||
B: Batcher<I, O> + 'static,
|
||||
|
@ -26,21 +36,58 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Sets the batch size to a fix number.The [fix batch strategy](FixBatchStrategy)
|
||||
/// will be used.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batch_size` - The batch size.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader builder.
|
||||
pub fn batch_size(mut self, batch_size: usize) -> Self {
|
||||
self.strategy = Some(Box::new(FixBatchStrategy::new(batch_size)));
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the seed for shuffling.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `seed` - The seed.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader builder.
|
||||
pub fn shuffle(mut self, seed: u64) -> Self {
|
||||
self.shuffle = Some(seed);
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the number of workers.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `num_workers` - The number of workers.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader builder.
|
||||
pub fn num_workers(mut self, num_workers: usize) -> Self {
|
||||
self.num_threads = Some(num_workers);
|
||||
self
|
||||
}
|
||||
|
||||
/// Builds the data loader.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataset` - The dataset.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data loader.
|
||||
pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<O>>
|
||||
where
|
||||
D: Dataset<I> + 'static,
|
||||
|
|
|
@ -4,6 +4,7 @@ mod builder;
|
|||
mod multithread;
|
||||
mod strategy;
|
||||
|
||||
/// Module for batching items.
|
||||
pub mod batcher;
|
||||
|
||||
pub use base::*;
|
||||
|
|
|
@ -3,15 +3,20 @@ use std::collections::HashMap;
|
|||
use std::sync::{mpsc, Arc};
|
||||
use std::thread;
|
||||
|
||||
static MAX_QUEUED_ITEMS: usize = 100;
|
||||
const MAX_QUEUED_ITEMS: usize = 100;
|
||||
|
||||
/// A multi-threaded data loader that can be used to iterate over a dataset.
|
||||
pub struct MultiThreadDataLoader<O> {
|
||||
dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>,
|
||||
}
|
||||
|
||||
/// A message that can be sent between threads.
|
||||
#[derive(Debug)]
|
||||
pub enum Message<O> {
|
||||
/// A batch of items.
|
||||
Batch(usize, O, Progress),
|
||||
|
||||
/// The thread is done.
|
||||
Done,
|
||||
}
|
||||
|
||||
|
@ -23,6 +28,15 @@ struct MultiThreadsDataloaderIterator<O> {
|
|||
}
|
||||
|
||||
impl<O> MultiThreadDataLoader<O> {
|
||||
/// Creates a new multi-threaded data loader.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataloaders` - The data loaders.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The multi-threaded data loader.
|
||||
pub fn new(dataloaders: Vec<Arc<dyn DataLoader<O> + Send + Sync>>) -> Self {
|
||||
Self { dataloaders }
|
||||
}
|
||||
|
|
|
@ -1,15 +1,47 @@
|
|||
/// A strategy to batch items.
|
||||
pub trait BatchStrategy<I>: Send + Sync {
|
||||
/// Adds an item to the strategy.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The item to add.
|
||||
fn add(&mut self, item: I);
|
||||
|
||||
/// Batches the items.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `force` - Whether to force batching.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The batched items.
|
||||
fn batch(&mut self, force: bool) -> Option<Vec<I>>;
|
||||
|
||||
/// Creates a new strategy of the same type.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The new strategy.
|
||||
fn new_like(&self) -> Box<dyn BatchStrategy<I>>;
|
||||
}
|
||||
|
||||
/// A strategy to batch items with a fixed batch size.
|
||||
pub struct FixBatchStrategy<I> {
|
||||
items: Vec<I>,
|
||||
batch_size: usize,
|
||||
}
|
||||
|
||||
impl<I> FixBatchStrategy<I> {
|
||||
/// Creates a new strategy to batch items with a fixed batch size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `batch_size` - The batch size.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The strategy.
|
||||
pub fn new(batch_size: usize) -> Self {
|
||||
FixBatchStrategy {
|
||||
items: Vec::with_capacity(batch_size),
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
/// Dataloader module.
|
||||
pub mod dataloader;
|
||||
|
||||
/// Dataset module.
|
||||
pub mod dataset {
|
||||
pub use burn_dataset::*;
|
||||
}
|
||||
|
|
|
@ -3,13 +3,22 @@ use crate as burn;
|
|||
use crate::{config::Config, tensor::Tensor};
|
||||
use burn_tensor::{backend::Backend, ElementConversion};
|
||||
|
||||
/// Gradient Clipping provides a way to mitigate exploding gradients
|
||||
#[derive(Config)]
|
||||
pub enum GradientClippingConfig {
|
||||
/// Clip the gradient by value.
|
||||
Value(f32),
|
||||
|
||||
/// Clip the gradient by norm.
|
||||
Norm(f32),
|
||||
}
|
||||
|
||||
impl GradientClippingConfig {
|
||||
/// Initialize the gradient clipping.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The gradient clipping.
|
||||
pub fn init(&self) -> GradientClipping {
|
||||
match self {
|
||||
GradientClippingConfig::Value(val) => GradientClipping::Value(*val),
|
||||
|
@ -22,11 +31,23 @@ impl GradientClippingConfig {
|
|||
/// by clipping every component of the gradient by value or by norm during
|
||||
/// backpropagation.
|
||||
pub enum GradientClipping {
|
||||
/// Clip the gradient by value.
|
||||
Value(f32),
|
||||
|
||||
/// Clip the gradient by norm.
|
||||
Norm(f32),
|
||||
}
|
||||
|
||||
impl GradientClipping {
|
||||
/// Clip the gradient.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `grad` - The gradient to clip.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The clipped gradient.
|
||||
pub fn clip_gradient<B: Backend, const D: usize>(&self, grad: Tensor<B, D>) -> Tensor<B, D> {
|
||||
match self {
|
||||
GradientClipping::Value(threshold) => self.clip_by_value(grad, *threshold),
|
||||
|
|
|
@ -1,23 +1,39 @@
|
|||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! The core crate of Burn.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// The configuration module.
|
||||
pub mod config;
|
||||
|
||||
/// Data module.
|
||||
#[cfg(feature = "std")]
|
||||
pub mod data;
|
||||
|
||||
/// Optimizer module.
|
||||
#[cfg(feature = "std")]
|
||||
pub mod optim;
|
||||
|
||||
/// Learning rate scheduler module.
|
||||
#[cfg(feature = "std")]
|
||||
pub mod lr_scheduler;
|
||||
|
||||
/// Gradient clipping module.
|
||||
pub mod grad_clipping;
|
||||
|
||||
/// Module for the neural network module.
|
||||
pub mod module;
|
||||
|
||||
/// Neural network module.
|
||||
pub mod nn;
|
||||
|
||||
/// Module for the recorder.
|
||||
pub mod record;
|
||||
|
||||
/// Module for the tensor.
|
||||
pub mod tensor;
|
||||
|
||||
extern crate alloc;
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
/// Constant learning rate scheduler
|
||||
pub mod constant;
|
||||
|
||||
/// Noam Learning rate schedule
|
||||
pub mod noam;
|
||||
|
||||
mod base;
|
||||
|
|
|
@ -178,16 +178,21 @@ pub trait Module<B: Backend>: Clone + Send + Sync + core::fmt::Debug {
|
|||
fn into_record(self) -> Self::Record;
|
||||
}
|
||||
|
||||
/// Module visitor trait.
|
||||
pub trait ModuleVisitor<B: Backend> {
|
||||
/// Visit a tensor in the module.
|
||||
fn visit<const D: usize>(&mut self, id: &ParamId, tensor: &Tensor<B, D>);
|
||||
}
|
||||
|
||||
/// Module mapper trait.
|
||||
pub trait ModuleMapper<B: Backend> {
|
||||
/// Map a tensor in the module.
|
||||
fn map<const D: usize>(&mut self, id: &ParamId, tensor: Tensor<B, D>) -> Tensor<B, D>;
|
||||
}
|
||||
|
||||
/// Module with auto-differentiation backend.
|
||||
pub trait ADModule<B: ADBackend>: Module<B> + Send + Sync + core::fmt::Debug {
|
||||
/// Inner module without auto-differentiation.
|
||||
type InnerModule: Module<B::InnerBackend>;
|
||||
|
||||
/// Get the same module, but on the inner backend without auto-differentiation.
|
||||
|
|
|
@ -15,6 +15,11 @@ impl<T> core::fmt::Display for Param<T> {
|
|||
}
|
||||
|
||||
impl<T: Clone> Param<T> {
|
||||
/// Gets the parameter value.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The parameter value.
|
||||
pub fn val(&self) -> T {
|
||||
self.value.clone()
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ impl Record for ConstantRecord {
|
|||
item
|
||||
}
|
||||
}
|
||||
|
||||
/// Constant macro.
|
||||
#[macro_export]
|
||||
macro_rules! constant {
|
||||
(module) => {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use alloc::string::{String, ToString};
|
||||
use burn_common::id::IdGenerator;
|
||||
|
||||
/// Parameter ID.
|
||||
#[derive(Debug, Hash, PartialEq, Eq, Clone)]
|
||||
pub struct ParamId {
|
||||
value: String,
|
||||
|
@ -27,11 +28,14 @@ impl Default for ParamId {
|
|||
}
|
||||
|
||||
impl ParamId {
|
||||
/// Create a new parameter ID.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
value: IdGenerator::generate(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert the parameter ID into a string.
|
||||
pub fn into_string(self) -> String {
|
||||
self.value
|
||||
}
|
||||
|
|
|
@ -22,8 +22,12 @@ pub fn generate_autoregressive_mask<B: Backend>(
|
|||
mask.equal_elem(1_i64.elem::<i64>())
|
||||
}
|
||||
|
||||
/// Generate a padding attention mask.
|
||||
pub struct GeneratePaddingMask<B: Backend> {
|
||||
/// The generated tensor.
|
||||
pub tensor: Tensor<B, 2, Int>,
|
||||
|
||||
/// The generated mask.
|
||||
pub mask: Tensor<B, 2, Bool>,
|
||||
}
|
||||
|
||||
|
|
|
@ -6,11 +6,17 @@ pub(crate) enum CacheState<T> {
|
|||
Empty,
|
||||
}
|
||||
|
||||
/// A cache for a tensor.
|
||||
pub struct TensorCache<B: Backend, const D: usize> {
|
||||
pub(crate) state: CacheState<Tensor<B, D>>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> TensorCache<B, D> {
|
||||
/// Creates a new empty cache.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The empty cache.
|
||||
pub fn empty() -> Self {
|
||||
Self {
|
||||
state: CacheState::Empty,
|
||||
|
|
|
@ -11,27 +11,60 @@ use crate as burn;
|
|||
#[derive(Config, Debug, PartialEq)]
|
||||
pub enum Initializer {
|
||||
/// Fills tensor with specified value everywhere
|
||||
Constant { value: f64 },
|
||||
Constant {
|
||||
/// The value to fill the tensor with
|
||||
value: f64,
|
||||
},
|
||||
/// Fills tensor with 1s everywhere
|
||||
Ones,
|
||||
/// Fills tensor with 0s everywhere
|
||||
Zeros,
|
||||
/// Fills tensor with values drawn uniformly between specified values
|
||||
Uniform { min: f64, max: f64 },
|
||||
Uniform {
|
||||
/// The minimum value to draw from
|
||||
min: f64,
|
||||
|
||||
/// The maximum value to draw from
|
||||
max: f64,
|
||||
},
|
||||
/// Fills tensor with values drawn from normal distribution with specified mean and std
|
||||
Normal { mean: f64, std: f64 },
|
||||
Normal {
|
||||
/// The mean of the normal distribution
|
||||
mean: f64,
|
||||
|
||||
/// The standard deviation of the normal distribution
|
||||
std: f64,
|
||||
},
|
||||
/// Fills tensor with values according to the uniform version of Kaiming initialization
|
||||
KaimingUniform { gain: f64, fan_out_only: bool },
|
||||
KaimingUniform {
|
||||
/// The gain to use in initialization formula
|
||||
gain: f64,
|
||||
|
||||
/// Whether to use fan out only in initialization formula
|
||||
fan_out_only: bool,
|
||||
},
|
||||
/// Fills tensor with values according to the uniform version of Kaiming initialization
|
||||
KaimingNormal { gain: f64, fan_out_only: bool },
|
||||
KaimingNormal {
|
||||
/// The gain to use in initialization formula
|
||||
gain: f64,
|
||||
|
||||
/// Whether to use fan out only in initialization formula
|
||||
fan_out_only: bool,
|
||||
},
|
||||
/// Fills tensor with values according to the uniform version of Xavier Glorot initialization
|
||||
/// described in [Understanding the difficulty of training deep feedforward neural networks
|
||||
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
|
||||
XavierUniform { gain: f64 },
|
||||
XavierUniform {
|
||||
/// The gain to use in initialization formula
|
||||
gain: f64,
|
||||
},
|
||||
/// Fills tensor with values according to the normal version of Xavier Glorot initialization
|
||||
/// described in [Understanding the difficulty of training deep feedforward neural networks
|
||||
/// ](https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf)
|
||||
XavierNormal { gain: f64 },
|
||||
XavierNormal {
|
||||
/// The gain to use in initialization formula
|
||||
gain: f64,
|
||||
},
|
||||
}
|
||||
|
||||
impl Initializer {
|
||||
|
|
|
@ -42,6 +42,7 @@ impl<B: Backend> MSELoss<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Compute the criterion on the input tensor without reducing.
|
||||
pub fn forward_no_reduction<const D: usize>(
|
||||
&self,
|
||||
logits: Tensor<B, D>,
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
/// The reduction type for the loss.
|
||||
pub enum Reduction {
|
||||
/// The mean of the losses will be returned.
|
||||
Mean,
|
||||
|
||||
/// The sum of the losses will be returned.
|
||||
Sum,
|
||||
|
||||
/// The mean of the losses will be returned.
|
||||
Auto,
|
||||
}
|
||||
|
|
|
@ -1,8 +1,19 @@
|
|||
/// Attention module
|
||||
pub mod attention;
|
||||
|
||||
/// Cache module
|
||||
pub mod cache;
|
||||
|
||||
/// Convolution module
|
||||
pub mod conv;
|
||||
|
||||
/// Loss module
|
||||
pub mod loss;
|
||||
|
||||
/// Pooling module
|
||||
pub mod pool;
|
||||
|
||||
/// Transformer module
|
||||
pub mod transformer;
|
||||
|
||||
mod dropout;
|
||||
|
|
|
@ -11,6 +11,7 @@ use burn_tensor::activation;
|
|||
|
||||
use super::gate_controller::GateController;
|
||||
|
||||
/// The configuration for a [gru](Gru) module.
|
||||
#[derive(Config)]
|
||||
pub struct GruConfig {
|
||||
/// The size of the input features.
|
||||
|
|
|
@ -11,6 +11,7 @@ use burn_tensor::activation;
|
|||
|
||||
use super::gate_controller::GateController;
|
||||
|
||||
/// The configuration for a [lstm](Lstm) module.
|
||||
#[derive(Config)]
|
||||
pub struct LstmConfig {
|
||||
/// The size of the input features.
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
mod gate_controller;
|
||||
|
||||
/// Gated Recurrent Unit module.
|
||||
pub mod gru;
|
||||
|
||||
/// Long Short-Term Memory module.
|
||||
pub mod lstm;
|
||||
|
||||
pub use gate_controller::*;
|
||||
|
|
|
@ -47,6 +47,7 @@ pub struct TransformerDecoder<B: Backend> {
|
|||
}
|
||||
|
||||
impl TransformerDecoderConfig {
|
||||
/// Initialize a new [Transformer Decoder](TransformerDecoder) module.
|
||||
pub fn init<B: Backend>(&self) -> TransformerDecoder<B> {
|
||||
let layers = (0..self.n_layers)
|
||||
.map(|_| TransformerDecoderLayer::new(self))
|
||||
|
@ -54,6 +55,12 @@ impl TransformerDecoderConfig {
|
|||
|
||||
TransformerDecoder { layers }
|
||||
}
|
||||
|
||||
/// Initialize a new [Transformer Decoder](TransformerDecoder) module with a record.
|
||||
///
|
||||
/// # Params
|
||||
///
|
||||
/// - record: the record to initialize the module with.
|
||||
pub fn init_with<B: Backend>(
|
||||
&self,
|
||||
record: TransformerDecoderRecord<B>,
|
||||
|
@ -117,6 +124,7 @@ impl<B: Backend> TransformerDecoderInput<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// [Transformer Decoder](TransformerDecoder) layer module.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerDecoderLayer<B: Backend> {
|
||||
cross_attn: MultiHeadAttention<B>,
|
||||
|
|
|
@ -151,6 +151,7 @@ impl<B: Backend> TransformerEncoder<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Transformer encoder layer module.
|
||||
#[derive(Module, Debug)]
|
||||
pub struct TransformerEncoderLayer<B: Backend> {
|
||||
mha: MultiHeadAttention<B>,
|
||||
|
|
|
@ -12,6 +12,7 @@ use crate::optim::adaptor::OptimizerAdaptor;
|
|||
use crate::tensor::{backend::ADBackend, Tensor};
|
||||
use burn_tensor::{backend::Backend, ElementConversion};
|
||||
|
||||
/// Adam configuration.
|
||||
#[derive(Config)]
|
||||
pub struct AdamConfig {
|
||||
/// Parameter for Adam.
|
||||
|
@ -35,6 +36,7 @@ pub struct Adam<B: Backend> {
|
|||
weight_decay: Option<WeightDecay<B>>,
|
||||
}
|
||||
|
||||
/// Adam state.
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct AdamState<B: Backend, const D: usize> {
|
||||
weight_decay: Option<WeightDecayState<B, D>>,
|
||||
|
@ -84,6 +86,11 @@ impl<B: Backend> SimpleOptimizer<B> for Adam<B> {
|
|||
}
|
||||
|
||||
impl AdamConfig {
|
||||
/// Initialize Adam optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns an optimizer that can be used to optimize a module.
|
||||
pub fn init<B: ADBackend, M: ADModule<B>>(&self) -> impl Optimizer<M, B> {
|
||||
let optim = Adam {
|
||||
momentum: AdaptiveMomentum {
|
||||
|
@ -102,6 +109,7 @@ impl AdamConfig {
|
|||
}
|
||||
}
|
||||
|
||||
/// Adaptive momentum state.
|
||||
#[derive(Record, new, Clone)]
|
||||
pub struct AdaptiveMomentumState<B: Backend, const D: usize> {
|
||||
time: usize,
|
||||
|
@ -164,6 +172,15 @@ impl AdaptiveMomentum {
|
|||
}
|
||||
|
||||
impl<B: Backend, const D: usize> AdaptiveMomentumState<B, D> {
|
||||
/// Move state to device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Returns state moved to device.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.moment_1 = self.moment_1.to_device(device);
|
||||
self.moment_2 = self.moment_2.to_device(device);
|
||||
|
|
|
@ -13,6 +13,7 @@ pub struct WeightDecayConfig {
|
|||
pub penalty: f64,
|
||||
}
|
||||
|
||||
/// State of [WeightDecay](WeightDecay).
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct WeightDecayState<B: Backend, const D: usize> {
|
||||
grad_last_step: Tensor<B, D>,
|
||||
|
@ -24,12 +25,24 @@ pub struct WeightDecay<B: Backend> {
|
|||
}
|
||||
|
||||
impl<B: Backend> WeightDecay<B> {
|
||||
/// Creates a new [WeightDecay](WeightDecay) from a [WeightDecayConfig](WeightDecayConfig).
|
||||
pub fn new(config: &WeightDecayConfig) -> Self {
|
||||
Self {
|
||||
penalty: config.penalty.elem(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Transforms a gradient.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `grad` - Gradient to transform.
|
||||
/// * `state` - State of the optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `grad` - Transformed gradient.
|
||||
/// * `state` - State of the optimizer.
|
||||
pub fn transform<const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
|
@ -47,6 +60,15 @@ impl<B: Backend> WeightDecay<B> {
|
|||
}
|
||||
|
||||
impl<B: Backend, const D: usize> WeightDecayState<B, D> {
|
||||
/// Moves the state to a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move the state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `self` - Moved state.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.grad_last_step = self.grad_last_step.to_device(device);
|
||||
self
|
||||
|
|
|
@ -15,6 +15,7 @@ pub struct GradientsParams {
|
|||
}
|
||||
|
||||
impl GradientsParams {
|
||||
/// Creates a new [GradientsParams](GradientsParams).
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
/// Weight decay module for optimizers.
|
||||
pub mod decay;
|
||||
|
||||
/// Momentum module for optimizers.
|
||||
pub mod momentum;
|
||||
|
||||
mod adam;
|
||||
|
|
|
@ -14,11 +14,13 @@ pub struct MomentumConfig {
|
|||
/// Dampening factor.
|
||||
#[config(default = 0.1)]
|
||||
pub dampening: f64,
|
||||
/// Enables Nesterov momentum, see [On the importance of initialization and momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).
|
||||
/// Enables Nesterov momentum, see [On the importance of initialization and
|
||||
/// momentum in deep learning](http://www.cs.toronto.edu/~hinton/absps/momentum.pdf).
|
||||
#[config(default = false)]
|
||||
pub nesterov: bool,
|
||||
}
|
||||
|
||||
/// State of [Momentum](Momentum).
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct MomemtumState<B: Backend, const D: usize> {
|
||||
velocity: Tensor<B, D>,
|
||||
|
@ -32,6 +34,7 @@ pub struct Momentum<B: Backend> {
|
|||
}
|
||||
|
||||
impl<B: Backend> Momentum<B> {
|
||||
/// Creates a new [Momentum](Momentum) from a [MomentumConfig](MomentumConfig).
|
||||
pub fn new(config: &MomentumConfig) -> Self {
|
||||
Self {
|
||||
momentum: config.momentum.elem(),
|
||||
|
@ -40,6 +43,17 @@ impl<B: Backend> Momentum<B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Transforms a gradient.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `grad` - Gradient to transform.
|
||||
/// * `state` - State of the optimizer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `grad` - Transformed gradient.
|
||||
/// * `state` - State of the optimizer.
|
||||
pub fn transform<const D: usize>(
|
||||
&self,
|
||||
grad: Tensor<B, D>,
|
||||
|
@ -63,6 +77,15 @@ impl<B: Backend> Momentum<B> {
|
|||
}
|
||||
|
||||
impl<B: Backend, const D: usize> MomemtumState<B, D> {
|
||||
/// Moves the state to a device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `device` - Device to move the state to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// * `self` - Moved state.
|
||||
pub fn to_device(mut self, device: &B::Device) -> Self {
|
||||
self.velocity = self.velocity.to_device(device);
|
||||
self
|
||||
|
|
|
@ -30,6 +30,7 @@ pub struct Sgd<B: Backend> {
|
|||
weight_decay: Option<WeightDecay<B>>,
|
||||
}
|
||||
|
||||
/// State of [Sgd](Sgd).
|
||||
#[derive(Record, Clone, new)]
|
||||
pub struct SgdState<B: Backend, const D: usize> {
|
||||
weight_decay: Option<WeightDecayState<B, D>>,
|
||||
|
@ -37,6 +38,7 @@ pub struct SgdState<B: Backend, const D: usize> {
|
|||
}
|
||||
|
||||
impl SgdConfig {
|
||||
/// Creates a new [SgdConfig](SgdConfig) with default values.
|
||||
pub fn init<B: ADBackend, M: ADModule<B>>(
|
||||
&self,
|
||||
) -> OptimizerAdaptor<Sgd<B::InnerBackend>, M, B> {
|
||||
|
|
|
@ -45,13 +45,22 @@ where
|
|||
M: ADModule<B>,
|
||||
B: ADBackend,
|
||||
{
|
||||
/// Sets the gradient clipping.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `gradient_clipping` - The gradient clipping.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The optimizer.
|
||||
pub fn with_grad_clipping(mut self, gradient_clipping: GradientClipping) -> Self {
|
||||
self.grad_clipping = Some(gradient_clipping);
|
||||
self
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn has_gradient_clipping(&self) -> bool {
|
||||
pub(crate) fn has_gradient_clipping(&self) -> bool {
|
||||
self.grad_clipping.is_some()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
mod base;
|
||||
pub use base::*;
|
||||
|
||||
/// Adaptor module for optimizers.
|
||||
pub mod adaptor;
|
||||
|
||||
/// Record module for optimizers.
|
||||
pub mod record;
|
||||
|
|
|
@ -10,12 +10,15 @@ use serde::{Deserialize, Serialize};
|
|||
///
|
||||
/// Records are versioned for backward compatibility, so old records can be loaded.
|
||||
pub enum AdaptorRecord<O: SimpleOptimizer<B>, B: Backend> {
|
||||
/// Version 1.
|
||||
V1(AdaptorRecordV1<O, B>),
|
||||
}
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(bound = "")]
|
||||
pub enum AdaptorRecordItem<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
|
||||
/// Version 1.
|
||||
V1(AdaptorRecordItemV1<O, B, S>),
|
||||
}
|
||||
|
||||
|
@ -56,12 +59,26 @@ where
|
|||
O: SimpleOptimizer<B>,
|
||||
B: Backend,
|
||||
{
|
||||
/// Converts the record into the optimizer state.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The optimizer state.
|
||||
pub fn into_state<const D: usize>(self) -> O::State<D> {
|
||||
match self {
|
||||
AdaptorRecord::V1(record) => record.into_state(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the optimizer state into the record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state`: The optimizer state.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The record.
|
||||
pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
|
||||
Self::V1(AdaptorRecordV1::from_state(state))
|
||||
}
|
||||
|
|
|
@ -6,14 +6,30 @@ use burn_tensor::backend::Backend;
|
|||
use core::any::Any;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
pub enum AdaptorRecordV1<O: SimpleOptimizer<B>, B: Backend> {
|
||||
/// Rank 1.
|
||||
Rank1(O::State<1>),
|
||||
|
||||
/// Rank 2.
|
||||
Rank2(O::State<2>),
|
||||
|
||||
/// Rank 3.
|
||||
Rank3(O::State<3>),
|
||||
|
||||
/// Rank 4.
|
||||
Rank4(O::State<4>),
|
||||
|
||||
/// Rank 5.
|
||||
Rank5(O::State<5>),
|
||||
|
||||
/// Rank 6.
|
||||
Rank6(O::State<6>),
|
||||
|
||||
/// Rank 7.
|
||||
Rank7(O::State<7>),
|
||||
|
||||
/// Rank 8.
|
||||
Rank8(O::State<8>),
|
||||
}
|
||||
|
||||
|
@ -32,16 +48,32 @@ impl<O: SimpleOptimizer<B>, B: Backend> Clone for AdaptorRecordV1<O, B> {
|
|||
}
|
||||
}
|
||||
|
||||
/// [Optimizer adaptor](crate::optim::simple::adaptor::OptimizerAdaptor) record item.
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[serde(bound = "")]
|
||||
pub enum AdaptorRecordItemV1<O: SimpleOptimizer<B>, B: Backend, S: PrecisionSettings> {
|
||||
/// Rank 1.
|
||||
Rank1(<O::State<1> as Record>::Item<S>),
|
||||
|
||||
/// Rank 2.
|
||||
Rank2(<O::State<2> as Record>::Item<S>),
|
||||
|
||||
/// Rank 3.
|
||||
Rank3(<O::State<3> as Record>::Item<S>),
|
||||
|
||||
/// Rank 4.
|
||||
Rank4(<O::State<4> as Record>::Item<S>),
|
||||
|
||||
/// Rank 5.
|
||||
Rank5(<O::State<5> as Record>::Item<S>),
|
||||
|
||||
/// Rank 6.
|
||||
Rank6(<O::State<6> as Record>::Item<S>),
|
||||
|
||||
/// Rank 7.
|
||||
Rank7(<O::State<7> as Record>::Item<S>),
|
||||
|
||||
/// Rank 8.
|
||||
Rank8(<O::State<8> as Record>::Item<S>),
|
||||
}
|
||||
|
||||
|
@ -50,6 +82,15 @@ where
|
|||
O: SimpleOptimizer<B>,
|
||||
B: Backend,
|
||||
{
|
||||
/// Convert the record into the state.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The state.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the state dimension is not supported.
|
||||
pub fn into_state<const D: usize>(self) -> O::State<D> {
|
||||
let boxed_state: Box<dyn Any> = match self {
|
||||
AdaptorRecordV1::Rank1(s) => Box::new(s),
|
||||
|
@ -66,6 +107,16 @@ where
|
|||
.expect("Unsupported state dimension, dimension up to 8 are supported.");
|
||||
*state
|
||||
}
|
||||
|
||||
/// Convert the state into the record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state`: The state.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The record.
|
||||
pub fn from_state<const D: usize>(state: O::State<D>) -> Self {
|
||||
let state: Box<dyn Any> = Box::new(state);
|
||||
|
||||
|
|
|
@ -5,10 +5,12 @@ use serde::{de::DeserializeOwned, Serialize};
|
|||
|
||||
/// Trait to define a family of types which can be recorded using any [settings](PrecisionSettings).
|
||||
pub trait Record: Send + Sync {
|
||||
/// Type of the item that can be serialized and deserialized.
|
||||
type Item<S: PrecisionSettings>: Serialize + DeserializeOwned;
|
||||
|
||||
/// Convert the current record into the corresponding item that follows the given [settings](PrecisionSettings).
|
||||
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S>;
|
||||
|
||||
/// Convert the given item into a record.
|
||||
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self;
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ use std::{fs::File, path::PathBuf};
|
|||
pub trait FileRecorder:
|
||||
Recorder<RecordArgs = PathBuf, RecordOutput = (), LoadArgs = PathBuf>
|
||||
{
|
||||
/// File extension of the format used by the recorder.
|
||||
fn file_extension() -> &'static str;
|
||||
}
|
||||
|
||||
|
|
|
@ -134,6 +134,7 @@ primitive!(f64);
|
|||
primitive!(f32);
|
||||
|
||||
// TODO: Remove the feature flag when half supports serde with no_std
|
||||
// https://github.com/burn-rs/burn/issues/268 tracking issue
|
||||
#[cfg(feature = "std")]
|
||||
primitive!(half::bf16);
|
||||
#[cfg(feature = "std")]
|
||||
|
|
|
@ -13,15 +13,28 @@ use super::{
|
|||
|
||||
/// Record any item implementing [Serialize](Serialize) and [DeserializeOwned](DeserializeOwned).
|
||||
pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Clone {
|
||||
/// Type of the settings used by the recorder.
|
||||
type Settings: PrecisionSettings;
|
||||
|
||||
/// Arguments used to record objects.
|
||||
type RecordArgs: Clone;
|
||||
|
||||
/// Record output type.
|
||||
type RecordOutput;
|
||||
|
||||
/// Arguments used to load recorded objects.
|
||||
type LoadArgs: Clone;
|
||||
|
||||
/// Record an item with the given arguments.
|
||||
/// Records an item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `record` - The item to record.
|
||||
/// * `args` - Arguments used to record the item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output of the recording.
|
||||
fn record<R: Record>(
|
||||
&self,
|
||||
record: R,
|
||||
|
@ -80,11 +93,35 @@ pub trait Recorder: Send + Sync + core::default::Default + core::fmt::Debug + Cl
|
|||
Ok(R::from_item(item.item))
|
||||
}
|
||||
|
||||
/// Saves an item.
|
||||
///
|
||||
/// This method is used by [record](Recorder::record) to save the item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - Item to save.
|
||||
/// * `args` - Arguments to use to save the item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output of the save operation.
|
||||
fn save_item<I: Serialize>(
|
||||
&self,
|
||||
item: I,
|
||||
args: Self::RecordArgs,
|
||||
) -> Result<Self::RecordOutput, RecorderError>;
|
||||
|
||||
/// Loads an item.
|
||||
///
|
||||
/// This method is used by [load](Recorder::load) to load the item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `args` - Arguments to use to load the item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The loaded item.
|
||||
fn load_item<I: DeserializeOwned>(&self, args: Self::LoadArgs) -> Result<I, RecorderError>;
|
||||
}
|
||||
|
||||
|
@ -98,9 +135,13 @@ fn recorder_metadata<R: Recorder>() -> BurnMetadata {
|
|||
)
|
||||
}
|
||||
|
||||
/// Error that can occur when using a [Recorder](Recorder).
|
||||
#[derive(Debug)]
|
||||
pub enum RecorderError {
|
||||
/// File not found.
|
||||
FileNotFound(String),
|
||||
|
||||
/// Other error.
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
|
@ -118,22 +159,45 @@ pub(crate) fn bin_config() -> bincode::config::Configuration {
|
|||
bincode::config::standard()
|
||||
}
|
||||
|
||||
/// Metadata of a record.
|
||||
#[derive(new, Debug, Serialize, Deserialize, PartialEq, Eq)]
|
||||
pub struct BurnMetadata {
|
||||
/// Float type used to record the item.
|
||||
pub float: String,
|
||||
|
||||
/// Int type used to record the item.
|
||||
pub int: String,
|
||||
|
||||
/// Format used to record the item.
|
||||
pub format: String,
|
||||
|
||||
/// Burn record version used to record the item.
|
||||
pub version: String,
|
||||
|
||||
/// Settings used to record the item.
|
||||
pub settings: String,
|
||||
}
|
||||
|
||||
/// Record that can be saved by a [Recorder](Recorder).
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct BurnRecord<I> {
|
||||
/// Metadata of the record.
|
||||
pub metadata: BurnMetadata,
|
||||
|
||||
/// Item to record.
|
||||
pub item: I,
|
||||
}
|
||||
|
||||
impl<I> BurnRecord<I> {
|
||||
/// Creates a new record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - Item to record.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The new record.
|
||||
pub fn new<R: Recorder>(item: I) -> Self {
|
||||
let metadata = recorder_metadata::<R>();
|
||||
|
||||
|
@ -141,8 +205,10 @@ impl<I> BurnRecord<I> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Record that can be saved by a [Recorder](Recorder) without the item.
|
||||
#[derive(new, Debug, Serialize, Deserialize)]
|
||||
pub struct BurnRecordNoItem {
|
||||
/// Metadata of the record.
|
||||
pub metadata: BurnMetadata,
|
||||
}
|
||||
|
||||
|
|
|
@ -5,7 +5,10 @@ use serde::{de::DeserializeOwned, Serialize};
|
|||
pub trait PrecisionSettings:
|
||||
Send + Sync + core::fmt::Debug + core::default::Default + Clone
|
||||
{
|
||||
/// Float element type.
|
||||
type FloatElem: Element + Serialize + DeserializeOwned;
|
||||
|
||||
/// Integer element type.
|
||||
type IntElem: Element + Serialize + DeserializeOwned;
|
||||
}
|
||||
|
||||
|
|
|
@ -12,7 +12,9 @@ type MappedDataset = MapperDataset<SqliteDataset<SpeechItemRaw>, ConvertSamples,
|
|||
|
||||
/// Enum representing speech command classes in the Speech Commands dataset.
|
||||
/// Class names are based on the Speech Commands dataset from Huggingface.
|
||||
/// See: https://huggingface.co/datasets/speech_commands
|
||||
/// See [speech_commands](https://huggingface.co/datasets/speech_commands)
|
||||
/// for more information.
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Display, Clone, Copy, FromRepr, Serialize, Deserialize, EnumCount)]
|
||||
pub enum SpeechCommandClass {
|
||||
// Target command words
|
||||
|
@ -66,8 +68,13 @@ pub enum SpeechCommandClass {
|
|||
/// Struct containing raw speech data returned from a database.
|
||||
#[derive(Clone, Debug, Serialize, Deserialize)]
|
||||
pub struct SpeechItemRaw {
|
||||
/// Audio file bytes.
|
||||
pub audio_bytes: Vec<u8>,
|
||||
|
||||
/// Label index.
|
||||
pub label: usize,
|
||||
|
||||
/// Indicates if the label is unknown.
|
||||
pub is_unknown: bool,
|
||||
}
|
||||
|
||||
|
|
|
@ -4,11 +4,18 @@ use crate::DatasetIterator;
|
|||
|
||||
/// The dataset trait defines a basic collection of items with a predefined size.
|
||||
pub trait Dataset<I>: Send + Sync {
|
||||
/// Gets the item at the given index.
|
||||
fn get(&self, index: usize) -> Option<I>;
|
||||
|
||||
/// Gets the number of items in the dataset.
|
||||
fn len(&self) -> usize;
|
||||
|
||||
/// Checks if the dataset is empty.
|
||||
fn is_empty(&self) -> bool {
|
||||
self.len() == 0
|
||||
}
|
||||
|
||||
/// Returns an iterator over the dataset.
|
||||
fn iter(&self) -> DatasetIterator<'_, I>
|
||||
where
|
||||
Self: Sized,
|
||||
|
|
|
@ -7,6 +7,7 @@ pub struct FakeDataset<I> {
|
|||
}
|
||||
|
||||
impl<I: Dummy<Faker>> FakeDataset<I> {
|
||||
/// Create a new fake dataset with the given size.
|
||||
pub fn new(size: usize) -> Self {
|
||||
let mut items = Vec::with_capacity(size);
|
||||
for _ in 0..size {
|
||||
|
|
|
@ -14,6 +14,7 @@ pub struct InMemDataset<I> {
|
|||
}
|
||||
|
||||
impl<I> InMemDataset<I> {
|
||||
/// Creates a new in memory dataset from the given items.
|
||||
pub fn new(items: Vec<I>) -> Self {
|
||||
InMemDataset { items }
|
||||
}
|
||||
|
|
|
@ -8,6 +8,7 @@ pub struct DatasetIterator<'a, I> {
|
|||
}
|
||||
|
||||
impl<'a, I> DatasetIterator<'a, I> {
|
||||
/// Creates a new dataset iterator.
|
||||
pub fn new<D>(dataset: &'a D) -> Self
|
||||
where
|
||||
D: Dataset<I>,
|
||||
|
|
|
@ -21,28 +21,37 @@ use sanitize_filename::sanitize;
|
|||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use serde_rusqlite::{columns_from_statement, from_row_with_columns};
|
||||
|
||||
/// Result type for the sqlite dataset.
|
||||
pub type Result<T> = core::result::Result<T, SqliteDatasetError>;
|
||||
|
||||
/// Sqlite dataset error.
|
||||
#[derive(thiserror::Error, Debug)]
|
||||
pub enum SqliteDatasetError {
|
||||
/// IO related error.
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] io::Error),
|
||||
|
||||
/// Sql related error.
|
||||
#[error("Sql error: {0}")]
|
||||
Sql(#[from] serde_rusqlite::rusqlite::Error),
|
||||
|
||||
/// Serde related error.
|
||||
#[error("Serde error: {0}")]
|
||||
Serde(#[from] rmp_serde::encode::Error),
|
||||
|
||||
/// The database file already exists error.
|
||||
#[error("Overwrite flag is set to false and the database file already exists: {0}")]
|
||||
FileExists(PathBuf),
|
||||
|
||||
/// Error when creating the connection pool.
|
||||
#[error("Failed to create connection pool: {0}")]
|
||||
ConnectionPool(#[from] r2d2::Error),
|
||||
|
||||
/// Error when persisting the temporary database file.
|
||||
#[error("Could not persist the temporary database file: {0}")]
|
||||
PersistDbFile(#[from] persist::Error<Writable>),
|
||||
|
||||
/// Any other error.
|
||||
#[error("{0}")]
|
||||
Other(&'static str),
|
||||
}
|
||||
|
|
|
@ -1,11 +1,21 @@
|
|||
#![warn(missing_docs)]
|
||||
|
||||
//! # Burn Dataset
|
||||
//!
|
||||
//! Burn Dataset is a library for creating and loading datasets.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
extern crate dirs;
|
||||
|
||||
/// Sources for datasets.
|
||||
pub mod source;
|
||||
|
||||
/// Transformations to be used with datasets.
|
||||
pub mod transform;
|
||||
|
||||
/// Audio datasets.
|
||||
#[cfg(feature = "audio")]
|
||||
pub mod audio;
|
||||
|
||||
|
|
|
@ -11,13 +11,18 @@ use thiserror::Error;
|
|||
const PYTHON: &str = "python3";
|
||||
const PYTHON_SOURCE: &str = include_str!("importer.py");
|
||||
|
||||
/// Error type for [HuggingfaceDatasetLoader](HuggingfaceDatasetLoader).
|
||||
#[derive(Error, Debug)]
|
||||
pub enum ImporterError {
|
||||
/// Unknown error.
|
||||
#[error("unknown: `{0}`")]
|
||||
Unknown(String),
|
||||
|
||||
/// Fail to download python dependencies.
|
||||
#[error("fail to download python dependencies: `{0}`")]
|
||||
FailToDownloadPythonDependencies(String),
|
||||
|
||||
/// Fail to create sqlite dataset.
|
||||
#[error("sqlite dataset: `{0}`")]
|
||||
SqliteDataset(#[from] SqliteDatasetError),
|
||||
}
|
||||
|
|
|
@ -8,9 +8,13 @@ use serde::{Deserialize, Serialize};
|
|||
const WIDTH: usize = 28;
|
||||
const HEIGHT: usize = 28;
|
||||
|
||||
/// MNIST item.
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct MNISTItem {
|
||||
/// Image as a 2D array of floats.
|
||||
pub image: [[f32; WIDTH]; HEIGHT],
|
||||
|
||||
/// Label of the image.
|
||||
pub label: usize,
|
||||
}
|
||||
|
||||
|
@ -66,10 +70,12 @@ impl Dataset<MNISTItem> for MNISTDataset {
|
|||
}
|
||||
|
||||
impl MNISTDataset {
|
||||
/// Creates a new train dataset.
|
||||
pub fn train() -> Self {
|
||||
Self::new("train")
|
||||
}
|
||||
|
||||
/// Creates a new test dataset.
|
||||
pub fn test() -> Self {
|
||||
Self::new("test")
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
pub mod downloader;
|
||||
pub(crate) mod downloader;
|
||||
mod mnist;
|
||||
|
||||
pub use downloader::*;
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
/// Huggingface source
|
||||
pub mod huggingface;
|
||||
|
|
|
@ -3,6 +3,7 @@ use std::marker::PhantomData;
|
|||
|
||||
/// Basic mapper trait to be used with the [mapper dataset](MapperDataset).
|
||||
pub trait Mapper<I, O>: Send + Sync {
|
||||
/// Maps an item of type I to an item of type O.
|
||||
fn map(&self, item: &I) -> O;
|
||||
}
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ impl<D, I> PartialDataset<D, I>
|
|||
where
|
||||
D: Dataset<I>,
|
||||
{
|
||||
/// Splits a dataset into multiple partial datasets.
|
||||
pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
|
||||
let dataset = Arc::new(dataset); // cheap cloning.
|
||||
|
||||
|
|
|
@ -14,6 +14,7 @@ impl<D, I> ShuffledDataset<D, I>
|
|||
where
|
||||
D: Dataset<I>,
|
||||
{
|
||||
/// Creates a new shuffled dataset.
|
||||
pub fn new(dataset: D, rng: &mut StdRng) -> Self {
|
||||
let mut indexes = Vec::with_capacity(dataset.len());
|
||||
for i in 0..dataset.len() {
|
||||
|
@ -28,6 +29,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Creates a new shuffled dataset with a fixed seed.
|
||||
pub fn with_seed(dataset: D, seed: u64) -> Self {
|
||||
let mut rng = StdRng::seed_from_u64(seed);
|
||||
Self::new(dataset, &mut rng)
|
||||
|
|
|
@ -17,6 +17,7 @@ where
|
|||
D: Dataset<I>,
|
||||
I: Send + Sync,
|
||||
{
|
||||
/// Creates a new sampler dataset.
|
||||
pub fn new(dataset: D, size: usize) -> Self {
|
||||
let rng = Mutex::new(StdRng::from_entropy());
|
||||
|
||||
|
@ -28,6 +29,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Generates random index using uniform distribution (0, dataset.len()).
|
||||
fn index(&self) -> usize {
|
||||
let mut rng = self.rng.lock().unwrap();
|
||||
rng.sample(Uniform::new(0, self.dataset.len()))
|
||||
|
|
|
@ -190,6 +190,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
|
|||
}
|
||||
|
||||
let body = quote! {
|
||||
/// Create a new instance of the config.
|
||||
pub fn new(
|
||||
#(#names),*
|
||||
) -> Self {
|
||||
|
@ -208,6 +209,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
|
|||
let fn_name = Ident::new(&format!("with_{name}"), name.span());
|
||||
|
||||
body.extend(quote! {
|
||||
/// Set the default value for the field.
|
||||
pub fn #fn_name(mut self, #name: #ty) -> Self {
|
||||
self.#name = #name;
|
||||
self
|
||||
|
@ -221,6 +223,7 @@ impl ConfigAnalyzer for ConfigStructAnalyzer {
|
|||
let fn_name = Ident::new(&format!("with_{name}"), name.span());
|
||||
|
||||
body.extend(quote! {
|
||||
/// Set the default value for the field.
|
||||
pub fn #fn_name(mut self, #name: #ty) -> Self {
|
||||
self.#name = #name;
|
||||
self
|
||||
|
|
|
@ -1,3 +1,7 @@
|
|||
#![warn(missing_docs)]
|
||||
|
||||
//! The derive crate of Burn.
|
||||
|
||||
use proc_macro::TokenStream;
|
||||
|
||||
pub(crate) mod config;
|
||||
|
@ -9,18 +13,21 @@ use config::config_attr_impl;
|
|||
use module::module_derive_impl;
|
||||
use record::record_derive_impl;
|
||||
|
||||
/// Derive macro for the module.
|
||||
#[proc_macro_derive(Module)]
|
||||
pub fn module_derive(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
module_derive_impl(&input)
|
||||
}
|
||||
|
||||
/// Derive macro for the record.
|
||||
#[proc_macro_derive(Record)]
|
||||
pub fn record_derive(input: TokenStream) -> TokenStream {
|
||||
let input = syn::parse(input).unwrap();
|
||||
record_derive_impl(&input)
|
||||
}
|
||||
|
||||
/// Derive macro for the config.
|
||||
#[proc_macro_derive(Config, attributes(config))]
|
||||
pub fn config_derive(input: TokenStream) -> TokenStream {
|
||||
let item = syn::parse(input).unwrap();
|
||||
|
|
|
@ -26,6 +26,7 @@ impl ModuleRecordGenerator {
|
|||
let name = &field.field.ident;
|
||||
|
||||
fields.extend(quote! {
|
||||
/// The #name field.
|
||||
pub #name: <#ty as burn::module::Module<B>>::Record,
|
||||
});
|
||||
}
|
||||
|
@ -33,6 +34,8 @@ impl ModuleRecordGenerator {
|
|||
let generics = &self.generics;
|
||||
|
||||
quote! {
|
||||
|
||||
/// The record type for the module.
|
||||
#[derive(burn::record::Record, Debug, Clone)]
|
||||
pub struct #name #generics {
|
||||
#fields
|
||||
|
|
|
@ -31,6 +31,7 @@ impl RecordGenerator {
|
|||
let name = &field.field.ident;
|
||||
|
||||
fields.extend(quote! {
|
||||
/// The #name field.
|
||||
pub #name: <#ty as burn::record::Record>::Item<S>,
|
||||
});
|
||||
bounds.extend(quote!{
|
||||
|
@ -42,6 +43,8 @@ impl RecordGenerator {
|
|||
let bound = bounds.to_string();
|
||||
|
||||
quote! {
|
||||
|
||||
/// The record item type for the module.
|
||||
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
||||
#[serde(bound = #bound)]
|
||||
pub struct #name #generics {
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! # Burn
|
||||
//! This library strives to serve as a comprehensive **deep learning framework**,
|
||||
//! offering exceptional flexibility and written in Rust. The main objective is to cater
|
||||
//! to both researchers and practitioners by simplifying the process of experimenting,
|
||||
//! training, and deploying models.
|
||||
|
||||
pub use burn_core::*;
|
||||
|
||||
/// Train module
|
||||
#[cfg(feature = "train")]
|
||||
pub mod train {
|
||||
pub use burn_train::*;
|
||||
|
|
|
@ -5,9 +5,7 @@
|
|||
// the TextClassificationDataset trait. These implementations are designed to be used
|
||||
// with a machine learning framework for tasks such as training a text classification model.
|
||||
|
||||
use burn::data::dataset::{
|
||||
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset,
|
||||
};
|
||||
use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset, SqliteDataset};
|
||||
|
||||
// Define a struct for text classification items
|
||||
#[derive(new, Clone, Debug)]
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
use burn::data::dataset::{
|
||||
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset,
|
||||
};
|
||||
use burn::data::dataset::{source::huggingface::HuggingfaceDatasetLoader, Dataset, SqliteDataset};
|
||||
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct TextGenerationItem {
|
||||
|
|
Loading…
Reference in New Issue