Add missing documents (#424)

This commit is contained in:
Dilshod Tadjibaev 2023-06-23 08:28:34 -05:00 committed by GitHub
parent eda241f8cf
commit 825aaa9977
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
55 changed files with 3156 additions and 11 deletions

View File

@ -1,3 +1,4 @@
/// The graph module.
pub mod graph;
pub(crate) mod node;

View File

@ -1,6 +1,7 @@
use proc_macro2::TokenStream;
use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt};
/// Formats a token stream into a string.
pub fn format_tokens(tokens: TokenStream) -> String {
let fmt = code_formatter();

View File

@ -1,13 +1,21 @@
#![warn(missing_docs)]
#![allow(clippy::ptr_arg)]
#![allow(clippy::single_match)]
#![allow(clippy::upper_case_acronyms)]
//! `burn-import` is a crate designed to simplify the process of importing models trained in other
//! machine learning frameworks into the Burn framework. This tool generates a Rust source file that
//! aligns the imported model with Burn's model and converts tensor data into a format compatible with
//! Burn.
#[macro_use]
extern crate derive_new;
/// The onnx module.
#[cfg(feature = "onnx")]
pub mod onnx;
/// The module for generating the burn code.
pub mod burn;
mod formater;

View File

@ -72,13 +72,25 @@ pub enum TensorData {
Bool(Vec<bool>),
}
/// ONNX graph representation
#[derive(Debug, Clone)]
pub struct ONNXGraph {
/// The nodes of the graph.
pub nodes: Vec<Node>,
/// The inputs of the graph.
pub inputs: Vec<Argument>,
/// The outputs of the graph.
pub outputs: Vec<Argument>,
/// The states of the graph.
pub states: Vec<State>,
/// The original node names.
pub old_node_names: HashMap<String, String>,
/// The original input names.
pub old_input_names: HashMap<String, String>,
}

View File

@ -50,6 +50,7 @@ pub struct ModelGen {
}
impl ModelGen {
/// Create a new `ModelGen`.
pub fn new() -> Self {
init_log().ok(); // Error when init multiple times are ignored.
Self::default()
@ -143,6 +144,7 @@ impl ModelGen {
}
impl ONNXGraph {
/// Converts ONNX graph to Burn graph.
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
let mut graph = BurnGraph::<PS>::default();

View File

@ -16,8 +16,10 @@ use burn_common::stub::Mutex;
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
/// The device type for the ndarray backend.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum NdArrayDevice {
/// The CPU device.
Cpu,
}
@ -27,6 +29,7 @@ impl Default for NdArrayDevice {
}
}
/// The ndarray backend.
#[derive(Clone, Copy, Default, Debug)]
pub struct NdArrayBackend<E> {
phantom: PhantomData<E>,

View File

@ -1,4 +1,7 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
//! Burn ndarray backend.
#[macro_use]
extern crate derive_new;

View File

@ -1,3 +1,4 @@
/// Macro for running a function in parallel.
#[macro_export(local_inner_macros)]
macro_rules! run_par {
(
@ -16,6 +17,7 @@ macro_rules! run_par {
}};
}
/// Macro for iterating over a range in parallel.
#[macro_export(local_inner_macros)]
macro_rules! iter_par {
(

View File

@ -32,6 +32,7 @@ mod utils {
}
}
/// Converts a slice of usize to a typed dimension.
#[macro_export(local_inner_macros)]
macro_rules! to_typed_dims {
(
@ -48,6 +49,7 @@ macro_rules! to_typed_dims {
}};
}
/// Reshapes an array into a tensor.
#[macro_export(local_inner_macros)]
macro_rules! reshape {
(

View File

@ -1,6 +1,7 @@
use proc_macro::TokenStream;
use quote::{format_ident, quote};
#[allow(missing_docs)]
#[proc_macro_attribute]
pub fn testgen(attr: TokenStream, item: TokenStream) -> TokenStream {
let item: proc_macro2::TokenStream = proc_macro2::TokenStream::from(item);

View File

@ -1,4 +1,8 @@
#![cfg_attr(not(feature = "std"), no_std)]
#![warn(missing_docs)]
//! This library provides multiple tensor implementations hidden behind an easy to use API
//! that supports reverse mode automatic differentiation.
#[macro_use]
extern crate derive_new;
@ -8,6 +12,7 @@ extern crate alloc;
mod tensor;
#[cfg(feature = "export_tests")]
#[allow(missing_docs)]
mod tests;
pub use half::{bf16, f16};

View File

@ -8,6 +8,7 @@ use crate::{
backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind,
};
/// A tensor with a given backend, shape and data type.
#[derive(new, Clone, Debug)]
pub struct Tensor<B, const D: usize, K = Float>
where
@ -401,43 +402,284 @@ where
///
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
pub trait BasicOps<B: Backend>: TensorKind<B> {
/// The type of the tensor elements.
type Elem: 'static;
/// Creates an empty tensor with the given shape.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device on which the tensor will be allocated.
///
/// # Returns
///
/// The empty tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
/// which is more high-level and designed for public use.
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D>;
/// Returns the shape of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The shape of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the shape of a tensor, users should prefer the [Tensor::shape](Tensor::shape) function,
/// which is more high-level and designed for public use.
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D>;
/// Reshapes the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `shape` - The new shape of the tensor.
///
/// # Returns
///
/// The reshaped tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
/// which is more high-level and designed for public use.
fn reshape<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
shape: Shape<D2>,
) -> Self::Primitive<D2>;
/// Select tensor elements corresponding for the given indexes.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indexes` - The indexes of the elements to select.
///
/// # Returns
///
/// The selected elements.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For selecting elements of a tensor, users should prefer the [Tensor::index](Tensor::index) function,
/// which is more high-level and designed for public use.
fn index<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
) -> Self::Primitive<D1>;
/// Assigns the given value to the tensor elements corresponding for the given indexes.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indexes` - The indexes of the elements to select.
/// * `value` - The value to assign.
///
/// # Returns
///
/// The tensor with the assigned values.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For assigning values to elements of a tensor, users should prefer the [Tensor::index_assign](Tensor::index_assign) function,
/// which is more high-level and designed for public use.
fn index_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
value: Self::Primitive<D1>,
) -> Self::Primitive<D1>;
/// Returns the device on which the tensor is allocated.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The device on which the tensor is allocated.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
/// which is more high-level and designed for public use.
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> B::Device;
/// Moves the tensor to the given device.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `device` - The device on which the tensor will be moved.
///
/// # Returns
///
/// The tensor on the given device.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
/// which is more high-level and designed for public use.
fn to_device<const D: usize>(
tensor: Self::Primitive<D>,
device: &B::Device,
) -> Self::Primitive<D>;
/// Extracts the data from the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
/// which is more high-level and designed for public use.
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D>;
/// Creates a tensor from the given data.
///
/// # Arguments
///
/// * `data` - The data of the tensor.
/// * `device` - The device on which the tensor will be allocated.
///
/// # Returns
///
/// The tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
/// which is more high-level and designed for public use.
fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
device: &B::Device,
) -> Self::Primitive<D>;
/// Repeat the tensor along the given dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension along which the tensor will be repeated.
/// * `times` - The number of times the tensor will be repeated.
///
/// # Returns
///
/// The repeated tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function,
/// which is more high-level and designed for public use.
fn repeat<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
times: usize,
) -> Self::Primitive<D>;
/// Concatenates the given tensors along the given dimension.
///
/// # Arguments
///
/// * `vectors` - The tensors to concatenate.
/// * `dim` - The dimension along which the tensors will be concatenated.
///
/// # Returns
///
/// The concatenated tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
/// which is more high-level and designed for public use.
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D>;
/// Equates the given tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The tensor of booleans indicating whether the corresponding elements are equal.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
/// which is more high-level and designed for public use.
fn equal<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
/// Returns the name of the element type.
fn elem_type_name() -> &'static str {
core::any::type_name::<Self::Elem>()
}

View File

@ -13,10 +13,12 @@ impl<const D: usize, B> Tensor<B, D>
where
B: Backend,
{
/// Converts the tensor into a primitive tensor.
pub fn into_primitive(self) -> B::TensorPrimitive<D> {
self.primitive
}
/// Converts from a primitive tensor into a tensor.
pub fn from_primitive(tensor: B::TensorPrimitive<D>) -> Self {
Self::new(tensor)
}
@ -261,6 +263,7 @@ where
}
impl<const D: usize, B: ADBackend> Tensor<B, D> {
/// Backward pass of the tensor.
pub fn backward(&self) -> B::Gradients {
B::backward::<D>(self.primitive.clone())
}
@ -279,10 +282,20 @@ impl<const D: usize, B: ADBackend> Tensor<B, D> {
B::grad_remove(&self.primitive, grads).map(Tensor::new)
}
/// Returns the inner tensor without the autodiff information.
pub fn inner(self) -> Tensor<B::InnerBackend, D> {
Tensor::new(B::inner(self.primitive))
}
/// Convert a tensor to the autodiff backend.
///
/// # Arguments
///
/// * `inner` - The tensor to convert.
///
/// # Returns
///
/// The tensor converted to the autodiff backend.
pub fn from_inner(inner: Tensor<B::InnerBackend, D>) -> Self {
Self::new(B::from_inner(inner.primitive))
}

View File

@ -1,14 +1,23 @@
use crate::backend::Backend;
/// A type-level representation of the kind of a float tensor
#[derive(Clone, Debug)]
pub struct Float;
/// A type-level representation of the kind of a int tensor.
#[derive(Clone, Debug)]
pub struct Int;
/// A type-level representation of the kind of a bool tensor.
#[derive(Clone, Debug)]
pub struct Bool;
/// A type-level representation of the kind of a tensor.
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
/// The primitive type of the tensor.
type Primitive<const D: usize>: Clone + core::fmt::Debug;
/// The name of the tensor kind.
fn name() -> &'static str;
}

View File

@ -397,103 +397,901 @@ pub trait Numeric<B: Backend>: BasicOps<B>
where
Self::Elem: Element,
{
/// Adds two tensors together.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The sum of the two tensors.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function,
/// which is more high-level and designed for public use.
fn add<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Primitive<D>) -> Self::Primitive<D>;
/// Adds a scalar to a tensor element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The sum of the tensor and the scalar.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function,
/// which is more high-level and designed for public use.
fn add_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D>;
/// Subtracts two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The difference of the two tensors.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function,
/// which is more high-level and designed for public use.
fn sub<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Primitive<D>) -> Self::Primitive<D>;
/// Subtracts a scalar from a tensor element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The difference of the tensor and the scalar.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function,
/// which is more high-level and designed for public use.
fn sub_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D>;
/// Divides two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The quotient of the two tensors.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function,
/// which is more high-level and designed for public use.
fn div<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Primitive<D>) -> Self::Primitive<D>;
/// Divides a tensor by a scalar element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The quotient of the tensor and the scalar.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function,
/// which is more high-level and designed for public use.
fn div_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D>;
/// Multiplies two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The product of the two tensors.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function,
/// which is more high-level and designed for public use.
fn mul<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Primitive<D>) -> Self::Primitive<D>;
/// Multiplies a tensor by a scalar element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The product of the tensor and the scalar.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function,
/// which is more high-level and designed for public use.
fn mul_scalar<const D: usize, E: ElementConversion>(
lhs: Self::Primitive<D>,
rhs: E,
) -> Self::Primitive<D>;
/// Negates a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to negate.
///
/// # Returns
///
/// The negated tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function,
/// which is more high-level and designed for public use.
fn neg<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D>;
/// Creates a tensor filled with zeros.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device on which the tensor will be allocated.
///
/// # Returns
///
/// The tensor filled with zeros.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function,
/// which is more high-level and designed for public use.
fn zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D>;
/// Creates a tensor filled with ones.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device on which the tensor will be allocated.
///
/// # Returns
///
/// The tensor filled with ones.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function,
/// which is more high-level and designed for public use.
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D>;
/// Sums all the elements of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
///
/// # Returns
///
/// The sum of all the elements of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function,
/// which is more high-level and designed for public use.
fn sum<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
/// Sums all the elements of the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
/// * `dim` - The dimension along which to sum.
///
/// # Returns
///
/// The sum of all the elements of the tensor along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function,
/// which is more high-level and designed for public use.
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
/// Computes the mean of all the elements of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the mean of.
///
/// # Returns
///
/// The mean of all the elements of the tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function,
/// which is more high-level and designed for public use.
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
/// Computes the mean of all the elements of the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the mean of.
/// * `dim` - The dimension along which to compute the mean.
///
/// # Returns
///
/// The mean of all the elements of the tensor along the specified dimension.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For computing the mean of all the elements of a tensor along a dimension, users should prefer
/// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use.
fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
/// Element-wise equality between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding elements of the input tensors are equal, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem) function,
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
/// Element-wise greater than comparison between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding element of the left hand side tensor is greater than the corresponding element
/// of the right hand side tensor, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function,
/// which is more high-level and designed for public use.
fn greater<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
/// Element-wise greater than comparison between a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
/// corresponding element of the left hand side tensor is greater than the right hand side
/// scalar, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise greater than comparison between a tensor and a scalar, users should prefer
/// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use.
fn greater_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem)
-> Tensor<B, D, Bool>;
/// Element-wise greater than or equal comparison between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding element of the left hand side tensor is greater than or equal to the
/// corresponding element of the right hand side tensor, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise greater than or equal comparison between two tensors, users should prefer
/// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use.
fn greater_equal<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
/// Element-wise greater than or equal comparison between a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
/// corresponding element of the left hand side tensor is greater than or equal to the right
/// hand side scalar, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer
/// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use.
fn greater_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
) -> Tensor<B, D, Bool>;
/// Element-wise less than comparison between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding element of the left hand side tensor is less than the corresponding element of
/// the right hand side tensor, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function,
/// which is more high-level and designed for public use.
fn lower<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
/// Element-wise less than comparison between a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
/// corresponding element of the left hand side tensor is less than the right hand side scalar,
/// and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise less than comparison between a tensor and a scalar, users should prefer
/// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use.
fn lower_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
/// Element-wise less than or equal comparison between two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
/// corresponding element of the left hand side tensor is less than or equal to the corresponding
/// element of the right hand side tensor, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise less than or equal comparison between two tensors, users should prefer
/// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use.
fn lower_equal<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool>;
/// Element-wise less than or equal comparison between a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
/// corresponding element of the left hand side tensor is less than or equal to the right hand
/// side scalar, and false otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer
/// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use.
fn lower_equal_elem<const D: usize>(
lhs: Self::Primitive<D>,
rhs: Self::Elem,
) -> Tensor<B, D, Bool>;
/// Selects elements from a tensor based on a boolean mask.
///
/// # Arguments
///
/// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
/// * `mask` - The boolean mask to use for selecting elements.
/// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
///
/// # Returns
///
/// A tensor with the same shape as the input tensors, where each element is taken from the
/// corresponding element of the left hand side tensor if the corresponding element of the mask
/// is true, and from the corresponding element of the right hand side tensor otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For selecting elements from a tensor based on a boolean mask, users should prefer the
/// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use.
fn mask_where<const D: usize>(
tensor: Self::Primitive<D>,
mask: Tensor<B, D, Bool>,
source: Self::Primitive<D>,
) -> Self::Primitive<D>;
/// Fills elements of a tensor based on a boolean mask.
///
/// # Arguments
///
/// * `tensor` - The tensor where will be overwritten with the value
/// when the corresponding element of the mask is true.
/// * `mask` - The boolean mask to use for filling elements.
/// * `value` - The value to fill elements with when the corresponding element of the mask is true.
///
/// # Returns
///
/// A tensor with the same shape as the input tensors, where each element is taken from the
/// corresponding element unmodified if the corresponding element of the mask is false, and
/// filled with the value otherwise.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For filling elements of a tensor based on a boolean mask, users should prefer the
/// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use.
fn mask_fill<const D: usize>(
tensor: Self::Primitive<D>,
mask: Tensor<B, D, Bool>,
value: Self::Elem,
) -> Self::Primitive<D>;
/// Gathers elements from a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to gather elements.
/// * `tensor` - The tensor to gather elements from.
/// * `indexes` - The indexes of the elements to gather.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is taken from the
/// corresponding element of the input tensor at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For gathering elements from a tensor along an axis, users should prefer the
/// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use.
fn gather<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
) -> Self::Primitive<D>;
/// Scatters elements into a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to scatter elements.
/// * `tensor` - The tensor to scatter elements into.
/// * `indices` - The indexes of the elements to scatter.
/// * `values` - The values to scatter into the tensor.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is taken from the
/// corresponding element of the input tensor at the corresponding index along the specified axis,
/// except for the elements at the specified indexes, which are taken from the corresponding
/// element of the values tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function,
/// which is more high-level and designed for public use.
fn scatter<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D>;
/// Select tensor elements along the given dimension corresponding for the given indexes.
///
/// # Arguments
///
/// * `tensor` - The tensor to select elements from.
/// * `dim` - The axis along which to select elements.
/// * `indexes` - The indexes of the elements to select.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is taken from the
/// corresponding element of the input tensor at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For selecting elements from a tensor along an axis, users should prefer the
/// [Tensor::index_select](Tensor::index_select) function, which is more high-level and designed for public use.
fn index_select<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
) -> Self::Primitive<D>;
/// Assign the selected elements along the given dimension corresponding to the given indexes
/// from the value tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to assign elements to.
/// * `dim` - The axis along which to assign elements.
/// * `indexes` - The indexes of the elements to assign.
/// * `values` - The values to assign to the tensor.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is taken from the
/// corresponding element of the input tensor at the corresponding index along the specified axis,
/// except for the elements at the specified indexes, which are taken from the corresponding
/// element of the values tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For assigning elements to a tensor along an axis, users should prefer the
/// [Tensor::index_select_assign](Tensor::index_select_assign) function, which is more high-level and designed for public use.
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Self::Primitive<D2>,
) -> Self::Primitive<D1>;
/// Gets the indexes of the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to get the indexes of the maximum elements.
/// * `tensor` - The tensor to get the indexes of the maximum elements from.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the index of the
/// maximum element of the input tensor at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the indexes of the maximum elements of a tensor along an axis, users should prefer the
/// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
fn argmax<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> B::IntTensorPrimitive<D>;
/// Gets the indexes of the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to get the indexes of the minimum elements.
/// * `tensor` - The tensor to get the indexes of the minimum elements from.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the index of the
/// minimum element of the input tensor at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the indexes of the minimum elements of a tensor along an axis, users should prefer the
/// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
fn argmin<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> B::IntTensorPrimitive<D>;
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `dim` - The axis along which to get the maximum elements.
///
/// # Returns
///
/// A single-element tensor containing the maximum element of the input tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the maximum elements of a tensor along an axis, users should prefer the
/// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use.
fn max<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements from.
/// * `dim` - The axis along which to get the maximum elements.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the maximum element
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the maximum elements of a tensor along an axis, users should prefer the
/// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use.
fn max_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements from.
/// * `dim` - The axis along which to get the maximum elements.
///
/// # Returns
///
/// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
/// as the input tensor, where each element is the index of the maximum element of the input tensor
/// at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the maximum elements of a tensor along an axis, users should prefer the
/// [Tensor::max_dim_with_indexes](Tensor::max_dim_with_indexes) function, which is more high-level and designed for public use.
fn max_dim_with_indexes<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
) -> (Self::Primitive<D>, B::IntTensorPrimitive<D>);
/// Gets the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements from.
///
/// # Returns
///
/// A single-element tensor containing the minimum element of the input tensor.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the minimum elements of a tensor along an axis, users should prefer the
/// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use.
fn min<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
/// Gets the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements from.
/// * `dim` - The axis along which to get the minimum elements.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor, where each element is the minimum element
/// of the input tensor at the corresponding index along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the minimum elements of a tensor along an axis, users should prefer the
/// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use.
fn min_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
/// Gets the minimum elements and indices of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements from.
///
/// # Returns
///
/// A tensor with the same shape as the input tensor and corresponding indices, where
/// each element is the minimum element of the input tensor at the corresponding index
/// along the specified axis.
///
/// # Remarks
///
/// This is a low-level function used internally by the library to call different backend functions
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For getting the minimum elements of a tensor along an axis, users should prefer the
/// [Tensor::min_dim_with_indexes](Tensor::min_dim_with_indexes) function, which is more high-level and designed for public use.
fn min_dim_with_indexes<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,

View File

@ -101,21 +101,75 @@ pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =
/// Trait that allows a backend to support autodiff.
pub trait ADBackend: Backend {
/// The inner backend type.
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem>;
/// Gradients type.
type Gradients: Send + Sync;
/// Backward pass.
///
/// # Arguments
///
/// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
///
/// # Returns
///
/// The gradients.
fn backward<const D: usize>(tensor: Self::TensorPrimitive<D>) -> Self::Gradients;
/// Returns the gradients of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to extract the gradients from.
///
/// # Returns
///
/// An optional tensor containing the gradient.
fn grad<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &Self::Gradients,
) -> Option<ADBackendTensorPrimitive<D, Self>>;
/// Pops the gradients of a tensor and returns them.
///
/// # Arguments
///
/// * `tensor` - The tensor to pop the gradients from.
/// * `grads` - The gradients.
///
/// # Returns
///
/// An optional tensor containing the given gradients.
fn grad_remove<const D: usize>(
tensor: &Self::TensorPrimitive<D>,
grads: &mut Self::Gradients,
) -> Option<ADBackendTensorPrimitive<D, Self>>;
/// Returns the tensor with inner backend type.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the inner backend tensor for.
///
/// # Returns
///
/// The inner backend tensor.
fn inner<const D: usize>(
tensor: Self::TensorPrimitive<D>,
) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>;
/// Converts the inner backend tensor to the autodiff backend tensor.
///
/// # Arguments
///
/// * `tensor` - The inner backend tensor to convert.
///
///
/// # Returns
///
/// The autodiff backend tensor.
fn from_inner<const D: usize>(
tensor: <Self::InnerBackend as Backend>::TensorPrimitive<D>,
) -> Self::TensorPrimitive<D>;

View File

@ -6,26 +6,42 @@ use crate::{tensor::Shape, Element, ElementConversion};
use rand::{distributions::Standard, Rng, RngCore};
/// Data structure for serializing and deserializing tensor data.
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone, new)]
pub struct DataSerialize<E> {
/// The values of the tensor.
pub value: Vec<E>,
/// The shape of the tensor.
pub shape: Vec<usize>,
}
/// Data structure for tensors.
#[derive(new, Debug, Clone, PartialEq, Eq)]
pub struct Data<E, const D: usize> {
/// The values of the tensor.
pub value: Vec<E>,
/// The shape of the tensor.
pub shape: Shape<D>,
}
/// Distribution for random value of a tensor.
#[derive(Clone, Copy)]
pub enum Distribution<E> {
/// Standard distribution.
Standard,
/// Bernoulli distribution with the given probability.
Bernoulli(f64),
/// Uniform distribution. The range is inclusive.
Uniform(E, E),
/// Normal distribution with the given mean and standard deviation.
Normal(f64, f64),
}
/// Distribution sampler for random value of a tensor.
#[derive(new)]
pub struct DistributionSampler<'a, E, R>
where
@ -37,14 +53,22 @@ where
rng: &'a mut R,
}
/// Distribution sampler kind for random value of a tensor.
pub enum DistributionSamplerKind<E>
where
Standard: rand::distributions::Distribution<E>,
E: rand::distributions::uniform::SampleUniform,
{
/// Standard distribution.
Standard(rand::distributions::Standard),
/// Uniform distribution.
Uniform(rand::distributions::Uniform<E>),
/// Bernoulli distribution.
Bernoulli(rand::distributions::Bernoulli),
/// Normal distribution.
Normal(rand_distr::Normal<f64>),
}
@ -55,6 +79,7 @@ where
E: Element,
R: RngCore,
{
/// Sames a random value from the distribution.
pub fn sample(&mut self) -> E {
match &self.kind {
DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
@ -76,6 +101,15 @@ where
Standard: rand::distributions::Distribution<E>,
E: rand::distributions::uniform::SampleUniform,
{
/// Creates a new distribution sampler.
///
/// # Arguments
///
/// * `rng` - The random number generator.
///
/// # Returns
///
/// The distribution sampler.
pub fn sampler<R: RngCore>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> {
let kind = match self {
Distribution::Standard => {
@ -100,6 +134,11 @@ impl<E> Distribution<E>
where
E: Element,
{
/// Converts the distribution to a different element type.
///
/// # Returns
///
/// The converted distribution.
pub fn convert<EOther: Element>(self) -> Distribution<EOther> {
match self {
Distribution::Standard => Distribution::Standard,
@ -113,6 +152,7 @@ where
}
impl<const D: usize, E: Element> Data<E, D> {
/// Converts the data to a different element type.
pub fn convert<EOther: Element>(self) -> Data<EOther, D> {
let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
@ -138,6 +178,7 @@ impl<const D: usize, E: Element> Data<E, D> {
}
impl<E: Element> DataSerialize<E> {
/// Converts the data to a different element type.
pub fn convert<EOther: Element>(self) -> DataSerialize<EOther> {
let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
@ -149,6 +190,7 @@ impl<E: Element> DataSerialize<E> {
}
impl<const D: usize> Data<bool, D> {
/// Converts the data to a different element type.
pub fn convert<E: Element>(self) -> Data<E, D> {
let value: Vec<E> = self.value.into_iter().map(|a| (a as i64).elem()).collect();
@ -160,6 +202,7 @@ impl<const D: usize> Data<bool, D> {
}
impl<E: Element, const D: usize> Data<E, D> {
/// Populates the data with random values.
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution<E>, rng: &mut R) -> Self {
let num_elements = shape.num_elements();
let mut data = Vec::with_capacity(num_elements);
@ -176,6 +219,7 @@ impl<E: core::fmt::Debug, const D: usize> Data<E, D>
where
E: Element,
{
/// Populates the data with zeros.
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Data<E, D> {
let shape = shape.into();
let num_elements = shape.num_elements();
@ -187,15 +231,13 @@ where
Data::new(data, shape)
}
pub fn zeros_(shape: Shape<D>, _kind: E) -> Data<E, D> {
Self::zeros(shape)
}
}
impl<E: core::fmt::Debug, const D: usize> Data<E, D>
where
E: Element,
{
/// Populates the data with ones.
pub fn ones(shape: Shape<D>) -> Data<E, D> {
let num_elements = shape.num_elements();
let mut data = Vec::with_capacity(num_elements);
@ -206,12 +248,14 @@ where
Data::new(data, shape)
}
pub fn ones_(shape: Shape<D>, _kind: E) -> Data<E, D> {
Self::ones(shape)
}
}
impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
/// Serializes the data.
///
/// # Returns
///
/// The serialized data.
pub fn serialize(&self) -> DataSerialize<E> {
DataSerialize {
value: self.value.clone(),
@ -221,6 +265,16 @@ impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
}
impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E, D> {
/// Asserts the data is approximately equal to another data.
///
/// # Arguments
///
/// * `other` - The other data.
/// * `precision` - The precision of the comparison.
///
/// # Panics
///
/// Panics if the data is not approximately equal.
pub fn assert_approx_eq(&self, other: &Self, precision: usize) {
let mut message = String::new();
if self.shape != other.shape {
@ -270,6 +324,7 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
}
impl<const D: usize> Data<usize, D> {
/// Converts the usize data to a different element type.
pub fn from_usize<O: num_traits::FromPrimitive>(self) -> Data<O, D> {
let value: Vec<O> = self
.value

View File

@ -3,6 +3,7 @@ use half::{bf16, f16};
use num_traits::ToPrimitive;
use rand::RngCore;
/// Element trait for tensor.
pub trait Element:
ToPrimitive
+ ElementRandom
@ -18,29 +19,63 @@ pub trait Element:
{
}
/// Element conversion trait for tensor.
pub trait ElementConversion {
/// Converts an element to another element.
///
/// # Arguments
///
/// * `elem` - The element to convert.
///
/// # Returns
///
/// The converted element.
fn from_elem<E: ToPrimitive>(elem: E) -> Self;
/// Converts and returns the converted element.
fn elem<E: Element>(self) -> E;
}
/// Element trait for random value of a tensor.
pub trait ElementRandom {
/// Returns a random value for the given distribution.
///
/// # Arguments
///
/// * `distribution` - The distribution to sample from.
/// * `rng` - The random number generator.
///
/// # Returns
///
/// The random value.
fn random<R: RngCore>(distribution: Distribution<Self>, rng: &mut R) -> Self
where
Self: Sized;
}
/// Element precision trait for tensor.
#[derive(Clone, PartialEq, Eq, Copy, Debug)]
pub enum Precision {
/// Double precision, e.g. f64.
Double,
/// Full precision, e.g. f32.
Full,
/// Half precision, e.g. f16.
Half,
/// Other precision.
Other,
}
/// Element precision trait for tensor.
pub trait ElementPrecision {
/// Returns the precision of the element.
fn precision() -> Precision;
}
/// Macro to implement the element trait for a type.
#[macro_export]
macro_rules! make_element {
(

View File

@ -1,6 +1,16 @@
use crate::backend::Backend;
use crate::{activation, Tensor};
/// Computes the log softmax cross entropy between logits and target probabilities.
///
/// # Arguments
///
/// * `logits` - The logits.
/// * `target_probs` - The target probabilities.
///
/// # Returns
///
/// The log softmax cross entropy.
pub fn cross_entropy_with_logits<B: Backend, const D: usize>(
logits: Tensor<B, D>,
target_probs: Tensor<B, D>,

View File

@ -10,11 +10,22 @@ pub use data::*;
pub use element::*;
pub use shape::*;
/// The activation module.
pub mod activation;
/// The backend module.
pub mod backend;
/// The container module.
pub mod container;
/// The loss module.
pub mod loss;
/// The burn module.
pub mod module;
/// Operations on tensors module.
pub mod ops;
#[cfg(feature = "experimental-named-tensor")]

View File

@ -4,15 +4,22 @@ use alloc::string::String;
use crate::backend::Backend;
use crate::Tensor;
/// Dimension trait.
pub trait Dim: core::fmt::Debug {
/// Converts the dimension to a string.
fn to_string() -> String;
}
/// Named dimensions trait.
pub trait NamedDims<B: Backend>: core::fmt::Debug {
/// Tensor type.
type Tensor;
/// Converts the named dimensions to a string.
fn to_string() -> String;
}
/// Named dimension macro.
#[macro_export]
macro_rules! NamedDim {
($name:ident) => {

View File

@ -5,11 +5,30 @@ use core::f64::consts::SQRT_2;
///
/// This trait let backend implementations override activation functions for better performance.
pub trait ActivationOps<B: Backend> {
/// Applies the ReLU activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn relu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
let mask = B::lower_equal_elem(tensor.clone(), 0.elem());
B::mask_fill(tensor, mask, 0.elem())
}
/// Applies the ReLU activation function backward.
///
/// # Arguments
///
/// * `output` - The output tensor.
///
/// # Returns
///
/// The gradient.
fn relu_backward<const D: usize>(
output: B::TensorPrimitive<D>,
grad: B::TensorPrimitive<D>,
@ -18,6 +37,16 @@ pub trait ActivationOps<B: Backend> {
B::mask_fill(grad, mask, 0.elem())
}
/// Applies the Gelu activation function.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The output tensor.
fn gelu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
let x = B::div_scalar(tensor.clone(), SQRT_2.elem());
let x = B::erf(x);
@ -27,6 +56,16 @@ pub trait ActivationOps<B: Backend> {
B::div_scalar(x, 2i32.elem())
}
/// Applies the Gelu activation function backward.
///
/// # Arguments
///
/// * `x` - The tensor.
/// * `grad` - The gradient.
///
/// # Returns
///
/// The output tensor.
fn gelu_backward<const D: usize>(
x: B::TensorPrimitive<D>,
grad: B::TensorPrimitive<D>,

View File

@ -6,37 +6,157 @@ use crate::{backend::Backend, tensor::Shape, Data};
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
/// for documentation on each function.
pub trait BoolTensorOps<B: Backend> {
/// Creates a new bool tensor.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The boolean tensor with the given shape.
fn bool_empty<const D: usize>(shape: Shape<D>, device: &B::Device)
-> B::BoolTensorPrimitive<D>;
/// Returns the shape of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The shape of the tensor.
fn bool_shape<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Shape<D>;
/// Converts the tensor to a data structure.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data structure with the tensor's data.
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
/// Gets the data from the tensor.
///
///
/// # Arguments
///
/// * `data` - The data structure.
///
/// # Returns
///
/// The data cloned from the data structure.
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D> {
Self::bool_into_data(tensor.clone())
}
/// Creates a tensor from the data structure.
///
/// # Arguments
///
/// * `data` - The data structure.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the data.
fn bool_from_data<const D: usize>(
data: Data<bool, D>,
device: &B::Device,
) -> B::BoolTensorPrimitive<D>;
/// Converts bool tensor to int tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The int tensor with the same data as the bool tensor.
fn bool_into_int<const D: usize>(tensor: B::BoolTensorPrimitive<D>)
-> B::IntTensorPrimitive<D>;
/// Gets the device of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The device of the tensor.
fn bool_device<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> B::Device;
/// Moves the tensor to the device.
fn bool_to_device<const D: usize>(
tensor: B::BoolTensorPrimitive<D>,
device: &B::Device,
) -> B::BoolTensorPrimitive<D>;
/// Reshapes the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `shape` - The new shape.
///
/// # Returns
///
/// The tensor with the new shape.
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
shape: Shape<D2>,
) -> B::BoolTensorPrimitive<D2>;
/// Gets the values from the tensor for the given indexes.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indexes` - The indexes to get the values from.
///
/// # Returns
///
/// The tensor with the values for the given indexes.
fn bool_index<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
) -> B::BoolTensorPrimitive<D1>;
/// Sets the values in the tensor for the given indexes.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indexes` - The indexes to set the values for.
/// * `value` - The values to set.
///
/// # Returns
///
/// The tensor with the values set for the given indexes.
fn bool_index_assign<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
value: B::BoolTensorPrimitive<D1>,
) -> B::BoolTensorPrimitive<D1>;
/// Repeats one dimension of the tensor a given number of times along that dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension to repeat.
/// * `times` - The number of times to repeat the dimension.
///
/// # Returns
///
/// The tensor with the dimension repeated.
fn bool_repeat<const D: usize>(
tensor: B::BoolTensorPrimitive<D>,
dim: usize,
@ -65,14 +185,47 @@ pub trait BoolTensorOps<B: Backend> {
tensor_output
}
/// Concatenates the tensors along the given dimension.
///
/// # Arguments
///
/// * `tensors` - The tensors to concatenate.
/// * `dim` - The dimension to concatenate along.
///
/// # Returns
///
/// The tensor with the tensors concatenated along the given dimension.
fn bool_cat<const D: usize>(
tensors: Vec<B::BoolTensorPrimitive<D>>,
dim: usize,
) -> B::BoolTensorPrimitive<D>;
/// Equates the two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The tensor with the result of the equate.
fn bool_equal<const D: usize>(
lhs: B::BoolTensorPrimitive<D>,
rhs: B::BoolTensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Equates the tensor with the element.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side element.
///
/// # Returns
///
/// The tensor with the result of the equate.
fn bool_equal_elem<const D: usize>(
lhs: B::BoolTensorPrimitive<D>,
rhs: bool,

View File

@ -6,66 +6,246 @@ use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
/// for documentation on each function.
pub trait IntTensorOps<B: Backend> {
/// Creates a new int tensor.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The integer tensor with the given shape.
fn int_empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
/// Returns the shape of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The shape of the tensor.
fn int_shape<const D: usize>(tensor: &B::IntTensorPrimitive<D>) -> Shape<D>;
/// Converts the tensor to a data structure.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data structure with the tensor's data.
fn int_into_data<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> Data<B::IntElem, D>;
/// Gets the data from the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data cloned from the data structure.
fn int_to_data<const D: usize>(tensor: &B::IntTensorPrimitive<D>) -> Data<B::IntElem, D> {
Self::int_into_data(tensor.clone())
}
/// Creates a tensor from the data structure.
///
/// # Arguments
///
/// * `data` - The data structure.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the data.
fn int_from_data<const D: usize>(
data: Data<B::IntElem, D>,
device: &B::Device,
) -> B::IntTensorPrimitive<D>;
/// Gets the device of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The device of the tensor.
fn int_device<const D: usize>(tensor: &B::IntTensorPrimitive<D>) -> B::Device;
/// Moves the tensor to the given device.
fn int_to_device<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
device: &B::Device,
) -> B::IntTensorPrimitive<D>;
/// Reshapes the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `shape` - The new shape.
///
/// # Returns
///
/// The tensor with the new shape.
fn int_reshape<const D1: usize, const D2: usize>(
tensor: B::IntTensorPrimitive<D1>,
shape: Shape<D2>,
) -> B::IntTensorPrimitive<D2>;
/// Gets the element at the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indices` - The indices.
///
/// # Returns
///
/// The elements at the given indices.
fn int_index<const D1: usize, const D2: usize>(
tensor: B::IntTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
) -> B::IntTensorPrimitive<D1>;
/// Sets the element at the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `indices` - The indices.
///
/// # Returns
///
/// The tensor with the element at the given indices set.
fn int_index_assign<const D1: usize, const D2: usize>(
tensor: B::IntTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
indices: [Range<usize>; D2],
value: B::IntTensorPrimitive<D1>,
) -> B::IntTensorPrimitive<D1>;
/// Fills the tensor with values from the source tensor if the mask is true at the given
/// indices.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `mask` - The mask.
/// * `source` - The source tensor.
///
/// # Returns
///
/// The tensor with the values filled.
fn int_mask_where<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
mask: B::BoolTensorPrimitive<D>,
source: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
/// Fills the tensor with the given value if the mask is true at the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `mask` - The mask.
/// * `value` - The value.
///
/// # Returns
///
/// The tensor with the values filled.
fn int_mask_fill<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
mask: B::BoolTensorPrimitive<D>,
value: B::IntElem,
) -> B::IntTensorPrimitive<D>;
/// Gather elements from the tensor at the given indices.
///
/// # Arguments
///
/// * `dim` - The dimension to gather from.
/// * `tensor` - The tensor.
/// * `indices` - The indices.
fn int_gather<const D: usize>(
dim: usize,
tensor: B::IntTensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
indices: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
/// Scatter a given value to the tensor at the given indices.
///
/// # Arguments
///
/// * `dim` - The dimension to scatter to.
/// * `tensor` - The tensor.
/// * `indices` - The indices.
/// * `value` - The value.
///
/// # Returns
///
/// The tensor with the values scattered.
fn int_scatter<const D: usize>(
dim: usize,
tensor: B::IntTensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
indices: B::IntTensorPrimitive<D>,
value: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
/// Select tensor elements along the given dimension corresponding to the given indices.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension to select from.
/// * `indexes` - The indexes.
///
/// # Returns
///
/// The tensor with the selected elements.
fn int_index_select_dim<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
) -> B::IntTensorPrimitive<D>;
/// Assign the selected elements along the given dimension corresponding to the given indices
/// to the given value.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension to select from.
/// * `indexes` - The indexes.
/// * `value` - The value.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
tensor: B::IntTensorPrimitive<D1>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
value: B::IntTensorPrimitive<D2>,
) -> B::IntTensorPrimitive<D1>;
/// Repeats the tensor along the given dimension the given number of times.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension to repeat.
/// * `times` - The number of times to repeat.
///
/// # Returns
///
/// The tensor with the given dimension repeated the given number of times.
fn int_repeat<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
@ -94,114 +274,438 @@ pub trait IntTensorOps<B: Backend> {
tensor_output
}
/// Concatenates the given tensors along the given dimension.
///
/// # Arguments
///
/// * `tensors` - The tensors.
/// * `dim` - The dimension to concatenate along.
///
/// # Returns
///
/// The concatenated tensor.
fn int_cat<const D: usize>(
tensors: Vec<B::IntTensorPrimitive<D>>,
dim: usize,
) -> B::IntTensorPrimitive<D>;
/// Elementwise equality comparison.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_equal<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntTensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Elementwise equality comparison with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_equal_elem<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::BoolTensorPrimitive<D>;
/// Elementwise greater than comparison.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_greater<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntTensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Elementwise greater than comparison with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_greater_elem<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::BoolTensorPrimitive<D>;
/// Elementwise greater than or equal comparison.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_greater_equal<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntTensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Elementwise greater than or equal comparison with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_greater_equal_elem<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::BoolTensorPrimitive<D>;
/// Elementwise less than comparison.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_lower<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntTensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Elementwise less than comparison with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_lower_elem<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::BoolTensorPrimitive<D>;
/// Elementwise less than or equal comparison.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_lower_equal<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntTensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Elementwise less than or equal comparison with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The boolean tensor with the result of the comparison.
fn int_lower_equal_elem<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::BoolTensorPrimitive<D>;
// ==== NUMERIC ==== //
/// Elementwise addition.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of the addition.
fn int_add<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
/// Elementwise addition with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of the addition.
fn int_add_scalar<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::IntTensorPrimitive<D>;
/// Elementwise subtraction.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of the subtraction.
fn int_sub<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
/// Elementwise subtraction with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of the subtraction.
fn int_sub_scalar<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::IntTensorPrimitive<D>;
/// Elementwise multiplication.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of the multiplication.
fn int_mul<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
/// Elementwise multiplication with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of the multiplication.
fn int_mul_scalar<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::IntTensorPrimitive<D>;
/// Elementwise division.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of the division.
fn int_div<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
/// Elementwise division with a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of the division.
fn int_div_scalar<const D: usize>(
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::IntTensorPrimitive<D>;
/// Elementwise negation.
///
/// # Arguments
///
/// * `tensor` - The tensor to negate.
///
/// # Returns
///
/// The negated tensor.
fn int_neg<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<D> {
Self::int_mul_scalar(tensor, (-1.0).elem::<B::IntElem>())
}
/// Creates a tensor of zeros.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor of zeros.
fn int_zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
/// Creates a tensor of ones.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor of ones.
fn int_ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
/// Sums all elements in the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
///
/// # Returns
///
/// The sum of all elements in the tensor.
fn int_sum<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1>;
/// Sums all elements in the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
/// * `dim` - The dimension to sum along.
///
/// # Returns
///
/// The sum of all elements in the tensor along the dimension.
fn int_sum_dim<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> B::IntTensorPrimitive<D>;
/// Computes the mean of all elements in the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the mean of.
///
/// # Returns
///
/// The mean of all elements in the tensor.
fn int_mean<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1>;
/// Computes the mean of all elements in the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to compute the mean of.
///
/// # Returns
///
/// The mean of all elements in the tensor along the dimension.
fn int_mean_dim<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> B::IntTensorPrimitive<D>;
/// Gets the indices of the maximum elements along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum indices of.
/// * `dim` - The dimension to get the maximum indices along.
///
/// # Returns
///
/// The indices of the maximum elements along the dimension.
fn int_argmax<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> B::IntTensorPrimitive<D>;
/// Gets the indices of the minimum elements along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum indices of.
/// * `dim` - The dimension to get the minimum indices along.
///
/// # Returns
///
/// The indices of the minimum elements along the dimension.
fn int_argmin<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> B::IntTensorPrimitive<D>;
/// Gets the maximum element in the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum element of.
///
/// # Returns
///
/// The maximum element in the tensor.
fn int_max<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1> {
let shape = B::int_shape(&tensor);
let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
B::int_max_dim(tensor, 0)
}
/// Gets the maximum element in the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum element of.
/// * `dim` - The dimension to get the maximum element along.
///
/// # Returns
///
/// The maximum element in the tensor along the dimension.
fn int_max_dim<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
@ -210,6 +714,17 @@ pub trait IntTensorOps<B: Backend> {
B::int_gather(D - 1, tensor, index)
}
/// Gets the maximum elements and corresponding indices along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements and indices of.
/// * `dim` - The dimension to get the maximum elements and indices along.
///
/// # Returns
///
/// The maximum elements and corresponding indices along the dimension.
fn int_max_dim_with_indexes<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
@ -219,12 +734,33 @@ pub trait IntTensorOps<B: Backend> {
(values, index)
}
/// Gets the minimum element in the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum element of.
///
/// # Returns
///
/// The minimum element in the tensor.
fn int_min<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1> {
let shape = B::int_shape(&tensor);
let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
B::int_min_dim(tensor, 0)
}
/// Gets the minimum elements in the tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum element of.
/// * `dim` - The dimension to get the minimum element along.
///
/// # Returns
///
/// The minimum element in the tensor along the dimension.
fn int_min_dim<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
@ -233,6 +769,17 @@ pub trait IntTensorOps<B: Backend> {
B::int_gather(D - 1, tensor, index)
}
/// Gets the minimum elements and corresponding indices along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements and indices of.
/// * `dim` - The dimension to get the minimum elements and indices along.
///
/// # Returns
///
/// The minimum elements and corresponding indices along the dimension.
fn int_min_dim_with_indexes<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,

View File

@ -4,52 +4,93 @@ use crate::{backend::Backend, Shape};
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
#[derive(new)]
pub struct Conv2dBackward<B: Backend> {
/// Gradient.
pub x_grad: B::TensorPrimitive<4>,
/// Weights gradient.
pub weights_grad: B::TensorPrimitive<4>,
/// Bias gradient.
pub bias_grad: Option<B::TensorPrimitive<1>>,
}
/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
#[derive(new)]
pub struct MaxPool2dBackward<B: Backend> {
/// Gradient.
pub x_grad: B::TensorPrimitive<4>,
}
/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indexes).
#[derive(new)]
pub struct MaxPool2dWithIndexes<B: Backend> {
/// The output tensor.
pub output: B::TensorPrimitive<4>,
/// The indexes tensor.
pub indexes: B::IntTensorPrimitive<4>,
}
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
#[derive(new)]
pub struct Conv1dBackward<B: Backend> {
/// Gradient.
pub x_grad: B::TensorPrimitive<3>,
/// Weights gradient.
pub weights_grad: B::TensorPrimitive<3>,
/// Bias gradient.
pub bias_grad: Option<B::TensorPrimitive<1>>,
}
/// Convolution options.
#[derive(new, Debug, Clone)]
pub struct ConvOptions<const N: usize> {
/// Stride.
pub stride: [usize; N],
/// Padding.
pub padding: [usize; N],
/// Dilation.
pub dilation: [usize; N],
/// Groups.
pub groups: usize,
}
/// Transposed convolution options.
#[derive(new, Debug, Clone)]
pub struct ConvTransposeOptions<const N: usize> {
/// Stride.
pub stride: [usize; N],
/// Padding.
pub padding: [usize; N],
/// Padding out.
pub padding_out: [usize; N],
/// Dilation.
pub dilation: [usize; N],
/// Groups.
pub groups: usize,
}
/// Module operations trait.
pub trait ModuleOps<B: Backend> {
/// Embedding operation.
///
/// # Arguments
///
/// * `weights` - The embedding weights.
/// * `indexes` - The indexes tensor.
///
/// # Returns
///
/// The output tensor.
fn embedding(
weights: B::TensorPrimitive<2>,
indexes: B::IntTensorPrimitive<2>,
@ -62,6 +103,18 @@ pub trait ModuleOps<B: Backend> {
B::reshape(output, Shape::new([batch_size, seq_length, d_model]))
}
/// Embedding backward operation.
///
/// # Arguments
///
/// * `weights` - The embedding weights.
/// * `output_grad` - The output gradient.
/// * `indexes` - The indexes tensor.
///
/// # Returns
///
/// The gradient.
fn embedding_backward(
weights: B::TensorPrimitive<2>,
output_grad: B::TensorPrimitive<3>,
@ -77,6 +130,7 @@ pub trait ModuleOps<B: Backend> {
B::index_select_assign(grad, 0, indexes, output_grad)
}
/// Two dimensional convolution.
///
/// # Shapes

View File

@ -1,4 +1,7 @@
/// Module with convolution operations.
pub mod conv;
/// Module with pooling operations.
pub mod pool;
mod base;

View File

@ -5,31 +5,137 @@ use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversi
/// Operations on float tensors.
pub trait TensorOps<B: Backend> {
/// Creates a new tensor from the data structure.
///
/// # Arguments
///
/// * `data` - The data structure.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given data.
fn from_data<const D: usize>(
data: Data<B::FloatElem, D>,
device: &B::Device,
) -> B::TensorPrimitive<D>;
/// Creates a new tensor with random values.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `distribution` - The distribution to sample from.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given shape and random values.
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<B::FloatElem>,
device: &B::Device,
) -> B::TensorPrimitive<D>;
/// Creates a new tensor with zeros.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given shape and zeros.
fn zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D> {
Self::from_data(Data::zeros(shape), device)
}
/// Creates a new tensor with ones.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given shape and ones.
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D> {
Self::from_data(Data::ones(shape), device)
}
/// Gets the shape of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The shape of the tensor.
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
/// Converts the tensor to a data structure.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data structure with the tensor's data.
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
/// Converts the tensor to a data structure.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The data structure with the tensor's data.
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D> {
Self::to_data(&tensor)
}
/// Gets the device of the tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor.
///
/// # Returns
///
/// The device of the tensor.
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
/// Moves the tensor to the given device.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `device` - The device to move the tensor to.
///
/// # Returns
///
/// The tensor on the given device.
fn to_device<const D: usize>(
tensor: B::TensorPrimitive<D>,
device: &B::Device,
) -> B::TensorPrimitive<D>;
/// Creates a new tensor with values from the given range.
///
/// # Arguments
///
/// * `range` - The range of values.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given values.
fn arange(range: Range<usize>, device: &B::Device) -> B::IntTensorPrimitive<1> {
let shape = Shape::new([range.end - range.start]);
let value = range
@ -39,7 +145,30 @@ pub trait TensorOps<B: Backend> {
let data = Data::new(value, shape);
B::int_from_data(data, device)
}
/// Creates an empty tensor with the given shape.
///
/// # Arguments
///
/// * `shape` - The shape of the tensor.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The empty tensor with the given shape.
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D>;
/// Repeat the tensor along the given dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor.
/// * `dim` - The dimension to repeat.
/// * `times` - The number of times to repeat the dimension.
///
/// # Returns
///
/// The tensor with the given dimension repeated.
fn repeat<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
@ -68,142 +197,479 @@ pub trait TensorOps<B: Backend> {
tensor_output
}
/// Adds two tensors together.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of adding the two tensors together.
fn add<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Adds a scalar to a tensor.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of adding the scalar to the tensor.
fn add_scalar<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::FloatElem,
) -> B::TensorPrimitive<D>;
/// Subtracts two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of subtracting the two tensors.
fn sub<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Subtracts a scalar from a tensor.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of subtracting the scalar from the tensor.
fn sub_scalar<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::FloatElem,
) -> B::TensorPrimitive<D>;
/// Multiplies two tensors together element-wise.
fn mul<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Multiplies a tensor by a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of multiplying the tensor by the scalar.
fn mul_scalar<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::FloatElem,
) -> B::TensorPrimitive<D>;
/// Divides two tensors element-wise.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of dividing the two tensors.
fn div<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Divides a tensor by a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// The result of dividing the tensor by the scalar.
fn div_scalar<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::FloatElem,
) -> B::TensorPrimitive<D>;
/// Multiplies two tensors together using matrix multiplication.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// The result of multiplying the two tensors together using matrix multiplication.
fn matmul<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Negates a tensor element-wise.
fn neg<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
Self::mul_scalar(tensor, (-1.0_f32).elem::<B::FloatElem>())
}
/// Transposes a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to transpose.
///
/// # Returns
///
/// The transposed tensor.
fn transpose<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
Self::swap_dims(tensor, D - 2, D - 1)
}
/// Swaps two dimensions of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to swap the dimensions of.
/// * `dim1` - The first dimension to swap.
/// * `dim2` - The second dimension to swap.
///
/// # Returns
///
/// The tensor with the dimensions swapped.
fn swap_dims<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim1: usize,
dim2: usize,
) -> B::TensorPrimitive<D>;
/// Reshapes a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to reshape.
/// * `shape` - The new shape of the tensor.
///
/// # Returns
///
/// The tensor with the new shape.
fn reshape<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
shape: Shape<D2>,
) -> B::TensorPrimitive<D2>;
/// Gather elements from a tensor.
///
/// # Arguments
///
/// * `dim` - The dimension to gather from.
/// * `tensor` - The tensor to gather from.
/// * `indexes` - The indexes to gather.
///
/// # Returns
///
/// The gathered elements.
fn gather<const D: usize>(
dim: usize,
tensor: B::TensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Scatter elements into a tensor.
///
/// # Arguments
///
/// * `dim` - The dimension to scatter into.
/// * `tensor` - The tensor to scatter into.
/// * `indexes` - The indexes to scatter into.
/// * `value` - The value to scatter.
///
/// # Returns
///
/// The tensor with the scattered elements.
fn scatter<const D: usize>(
dim: usize,
tensor: B::TensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
value: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Select tensor elements along the given dimension corresponding for the given indexes.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `dim` - The dimension to select from.
/// * `indexes` - The indexes to select.
///
/// # Returns
///
/// The selected elements.
fn index_select<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
) -> B::TensorPrimitive<D>;
/// Assign the selected elements along the given dimension corresponding for the given indexes
/// to the given value.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `dim` - The dimension to select from.
/// * `indexes` - The indexes to select.
/// * `value` - The value to assign.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
value: B::TensorPrimitive<D2>,
) -> B::TensorPrimitive<D1>;
/// Select tensor elements corresponding for the given indexes.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `indexes` - The indexes to select.
///
/// # Returns
///
/// The selected elements in a new tensor.
fn index<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
indexes: [Range<usize>; D2],
) -> B::TensorPrimitive<D1>;
/// Assign the selected elements corresponding for the given indexes to the given value.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `indexes` - The indexes to select.
/// * `value` - The value to assign.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn index_assign<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
indexes: [Range<usize>; D2],
value: B::TensorPrimitive<D1>,
) -> B::TensorPrimitive<D1>;
/// Update the given tensor with the value tensor where the mask is true.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `mask` - The boolean mask to select with.
/// * `value` - The value to assign to the selected elements from the value tensor.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn mask_where<const D: usize>(
tensor: B::TensorPrimitive<D>,
mask: B::BoolTensorPrimitive<D>,
value: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Update the given tensor with the value where the mask is true.
///
/// # Arguments
///
/// * `tensor` - The tensor to select from.
/// * `mask` - The boolean mask to select with.
/// * `value` - The value to assign to the selected elements.
///
/// # Returns
///
/// The tensor with the selected elements assigned to the given value.
fn mask_fill<const D: usize>(
tensor: B::TensorPrimitive<D>,
mask: B::BoolTensorPrimitive<D>,
value: B::FloatElem,
) -> B::TensorPrimitive<D>;
/// Equal comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn equal<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Equal comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn equal_elem<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::FloatElem,
) -> B::BoolTensorPrimitive<D>;
/// Greater than comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Greater than comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater_elem<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::FloatElem,
) -> B::BoolTensorPrimitive<D>;
/// Greater than or equal comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater_equal<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Greater than or equal comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn greater_equal_elem<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::FloatElem,
) -> B::BoolTensorPrimitive<D>;
/// Less than comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Less than comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower_elem<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::FloatElem,
) -> B::BoolTensorPrimitive<D>;
/// Less than or equal comparison of two tensors.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side tensor.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower_equal<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::TensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
/// Less than or equal comparison of a tensor and a scalar.
///
/// # Arguments
///
/// * `lhs` - The left hand side tensor.
/// * `rhs` - The right hand side scalar.
///
/// # Returns
///
/// A boolean tensor with the result of the comparison.
fn lower_equal_elem<const D: usize>(
lhs: B::TensorPrimitive<D>,
rhs: B::FloatElem,
) -> B::BoolTensorPrimitive<D>;
/// Detaches a tensor from the computation graph.
fn detach<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
// Should only be overriden by autodiff backends.
tensor
}
/// Sets the `require_grad` flag of a tensor.
fn set_require_grad<const D: usize>(
tensor: B::TensorPrimitive<D>,
_require_grad: bool,
@ -211,54 +677,273 @@ pub trait TensorOps<B: Backend> {
// Should only be overriden by autodiff backends.
tensor
}
/// Returns the `require_grad` flag of a tensor.
fn is_require_grad<const D: usize>(_tensor: &B::TensorPrimitive<D>) -> bool {
// Should only be overriden by autodiff backends.
false
}
/// Sum of all elements in a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
///
/// # Returns
///
/// A scalar tensor with the sum of all elements in `tensor`.
fn sum<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
/// Sum of all elements in a tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to sum.
/// * `dim` - The dimension along which to sum.
///
/// # Returns
///
/// A tensor with the sum of all elements in `tensor` along `dim`.
fn sum_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D>;
/// Mean of all elements in a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to mean.
///
/// # Returns
///
/// A scalar tensor with the mean of all elements in `tensor`.
fn mean<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
/// Mean of all elements in a tensor along a dimension.
///
/// # Arguments
///
/// * `tensor` - The tensor to mean.
/// * `dim` - The dimension along which to mean.
///
/// # Returns
///
/// A tensor with the mean of all elements in `tensor` along `dim`.
fn mean_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize)
-> B::TensorPrimitive<D>;
/// Converts a tensor to full precision.
///
/// # Arguments
///
/// * `tensor` - The tensor to convert.
///
/// # Returns
///
/// A tensor with the same values as `tensor` but with full precision.
fn to_full_precision<const D: usize>(
tensor: &B::TensorPrimitive<D>,
) -> <B::FullPrecisionBackend as Backend>::TensorPrimitive<D>;
/// Converts a tensor from full precision.
///
/// # Arguments
///
/// * `tensor` - The tensor to convert.
///
/// # Returns
///
/// A tensor with the same values as `tensor` but with the precision of the backend.
fn from_full_precision<const D: usize>(
tensor: <B::FullPrecisionBackend as Backend>::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
/// Returns a new tensor with exponential values.
///
/// # Arguments
///
/// * `tensor` - The tensor to exponentiate.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with exponential values.
fn exp<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Returns a new tensor with natural logarithm values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the logarithm of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with natural logarithm values.
fn log<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Returns a new tensor with logarithm values of (1 + Xi).
///
/// # Arguments
///
/// * `tensor` - The tensor to take the logarithm of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
fn log1p<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Returns a new tensor with values raised to the power of `value`.
///
/// # Arguments
///
/// * `tensor` - The tensor to exponentiate.
/// * `value` - The exponent.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with values raised to the power of `value`.
fn powf<const D: usize>(tensor: B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
/// Returns a new tensor with square root values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the square root of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with square root values.
fn sqrt<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Returns a new tensor with cosine values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the cosine of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with cosine values.
fn cos<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Returns a new tensor with sine values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the sine of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with sine values.
fn sin<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Returns a new tensor with tangent values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the tangent of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with tangent values.
fn tanh<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Returns a new tensor with the error function values.
///
/// # Arguments
///
/// * `tensor` - The tensor to take the error function of.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with error function values.
fn erf<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
/// Catcatenates tensors along a dimension.
///
/// # Arguments
///
/// * `tensors` - The tensors to catcatenate.
/// * `dim` - The dimension along which to catcatenate.
///
/// # Returns
///
/// A tensor with the catcatenated tensors along `dim`.
fn cat<const D: usize>(
tensors: Vec<B::TensorPrimitive<D>>,
dim: usize,
) -> B::TensorPrimitive<D>;
/// Gets the indexes of the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements of.
/// * `dim` - The dimension along which to get the maximum elements.
///
/// # Returns
///
/// A tensor with the indexes of the maximum elements of `tensor` along `dim`.
fn argmax<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
) -> B::IntTensorPrimitive<D>;
/// Gets the indexes of the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements of.
/// * `dim` - The dimension along which to get the minimum elements.
///
/// # Returns
///
/// A tensor with the indexes of the minimum elements of `tensor` along `dim`.
fn argmin<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
) -> B::IntTensorPrimitive<D>;
/// Gets the maximum element of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements of.
///
/// # Returns
///
/// A tensor with the maximum element of `tensor`.
fn max<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
let shape = B::shape(&tensor);
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
B::max_dim(tensor, 0)
}
/// Gets the maximum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements of.
/// * `dim` - The dimension along which to get the maximum elements.
///
/// # Returns
///
/// A tensor with the maximum elements of `tensor` along `dim`.
fn max_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
let index = B::argmax(tensor.clone(), dim);
B::gather(D - 1, tensor, index)
}
/// Gets the maximum elements of a tensor along an axis and their indexes.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the maximum elements of.
/// * `dim` - The dimension along which to get the maximum elements.
///
/// # Returns
///
/// A tuple with the maximum elements of `tensor` along `dim` and their indexes.
fn max_dim_with_indexes<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
@ -268,17 +953,49 @@ pub trait TensorOps<B: Backend> {
(values, index)
}
/// Gets the minimum element of a tensor.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements of.
///
/// # Returns
///
/// A tensor with the minimum element of `tensor`.
fn min<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
let shape = B::shape(&tensor);
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
B::min_dim(tensor, 0)
}
/// Gets the minimum elements of a tensor along an axis.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements of.
/// * `dim` - The dimension along which to get the minimum elements.
///
/// # Returns
///
/// A tensor with the minimum elements of `tensor` along `dim`.
fn min_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
let index = B::argmin(tensor.clone(), dim);
B::gather(D - 1, tensor, index)
}
/// Gets the minimum elements of a tensor along an axis and their indexes.
///
/// # Arguments
///
/// * `tensor` - The tensor to get the minimum elements of.
/// * `dim` - The dimension along which to get the minimum elements.
///
/// # Returns
///
/// A tuple with the minimum elements of `tensor` along `dim` and their indexes.
fn min_dim_with_indexes<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,

View File

@ -1,7 +1,9 @@
use alloc::vec::Vec;
/// Shape of a tensor.
#[derive(new, Debug, Clone, PartialEq, Eq)]
pub struct Shape<const D: usize> {
/// The dimensions of the tensor.
pub dims: [usize; D],
}

View File

@ -3,6 +3,7 @@ mod module;
mod ops;
mod stats;
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_all {
() => {

View File

@ -12,6 +12,7 @@ enum Message<T, V> {
End,
}
/// Async trainer callback tracker.
pub struct AsyncTrainerCallback<T, V> {
sender: mpsc::Sender<Message<T, V>>,
handler: Option<JoinHandle<()>>,
@ -52,6 +53,7 @@ impl<T, V> CallbackThread<T, V> {
}
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncTrainerCallback<T, V> {
/// Create a new async trainer callback.
pub fn new(callback: Box<dyn LearnerCallback<T, V>>) -> Self {
let (sender, receiver) = mpsc::channel();
let thread = CallbackThread::new(Mutex::new(callback), receiver);

View File

@ -1,18 +1,38 @@
use burn_core::{data::dataloader::Progress, LearningRate};
/// The base trait for trainer callbacks.
pub trait LearnerCallback<T, V>: Send {
/// Called when a training item is logged.
fn on_train_item(&mut self, _item: LearnerItem<T>) {}
/// Called when a validation item is logged.
fn on_valid_item(&mut self, _item: LearnerItem<V>) {}
/// Called when a training epoch is finished.
fn on_train_end_epoch(&mut self, _epoch: usize) {}
/// Called when a validation epoch is finished.
fn on_valid_end_epoch(&mut self, _epoch: usize) {}
}
/// A learner item.
#[derive(new)]
pub struct LearnerItem<T> {
/// The item.
pub item: T,
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
/// The learning rate.
pub lr: Option<LearningRate>,
}

View File

@ -26,6 +26,7 @@ impl<R: Record> CheckpointerThread<R> {
}
}
/// Async checkpointer.
pub struct AsyncCheckpointer<E> {
checkpointer: Arc<dyn Checkpointer<E> + Send + Sync>,
sender: mpsc::SyncSender<Message<E>>,
@ -33,6 +34,15 @@ pub struct AsyncCheckpointer<E> {
}
impl<R: Record + 'static> AsyncCheckpointer<R> {
/// Create a new async checkpointer.
///
/// # Arguments
///
/// * `checkpointer` - The checkpointer.
///
/// # Returns
///
/// The async checkpointer.
pub fn new(checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>) -> Self {
// Only on checkpoint can be done in advance.
let (sender, receiver) = mpsc::sync_channel(0);

View File

@ -1,13 +1,36 @@
use burn_core::record::{Record, RecorderError};
/// The error type for checkpointer.
#[derive(Debug)]
pub enum CheckpointerError {
/// IO error.
IOError(std::io::Error),
/// Recorder error.
RecorderError(RecorderError),
/// Other errors.
Unknown(String),
}
/// The trait for checkpointer.
pub trait Checkpointer<R: Record> {
/// Save the record.
///
/// # Arguments
///
/// * `epoch` - The epoch.
/// * `record` - The record.
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
/// Restore the record.
///
/// # Arguments
///
/// * `epoch` - The epoch.
///
/// # Returns
///
/// The record.
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError>;
}

View File

@ -1,6 +1,7 @@
use super::{Checkpointer, CheckpointerError};
use burn_core::record::{FileRecorder, Record};
/// The file checkpointer.
pub struct FileCheckpointer<FR> {
directory: String,
name: String,
@ -9,6 +10,14 @@ pub struct FileCheckpointer<FR> {
}
impl<FR> FileCheckpointer<FR> {
/// Creates a new file checkpointer.
///
/// # Arguments
///
/// * `recorder` - The file recorder.
/// * `directory` - The directory to save the checkpoints.
/// * `name` - The name of the checkpoint.
/// * `num_keep` - The number of checkpoints to keep.
pub fn new(recorder: FR, directory: &str, name: &str, num_keep: usize) -> Self {
std::fs::create_dir_all(directory).ok();

View File

@ -44,6 +44,11 @@ where
Optim: Optimizer<Model, B>,
LR: LRScheduler,
{
/// Creates a new learner builder.
///
/// # Arguments
///
/// * `directory` - The directory to save the checkpoints.
pub fn new(directory: &str) -> Self {
let renderer = Box::new(CLIDashboardRenderer::new());
let logger_train = Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()));

View File

@ -5,8 +5,13 @@ use burn_core::tensor::{Int, Tensor};
/// Simple classification output adapted for multiple metrics.
#[derive(new)]
pub struct ClassificationOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,
/// The output.
pub output: Tensor<B, 2>,
/// The targets.
pub targets: Tensor<B, 1, Int>,
}

View File

@ -9,6 +9,7 @@ use std::sync::Arc;
use crate::{LearnerCallback, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep};
/// A validation epoch.
#[derive(new)]
pub struct ValidEpoch<VI> {
dataloader: Arc<dyn DataLoader<VI>>,
@ -16,6 +17,7 @@ pub struct ValidEpoch<VI> {
epoch_total: usize,
}
/// A training epoch.
#[derive(new)]
pub struct TrainEpoch<TI> {
dataloader: Arc<dyn DataLoader<TI>>,
@ -25,6 +27,12 @@ pub struct TrainEpoch<TI> {
}
impl<I> ValidEpoch<I> {
/// Runs the validation epoch.
///
/// # Arguments
///
/// * `model` - The model to validate.
/// * `callback` - The callback to use.
pub fn run<B, M, TO, VO>(&self, model: &M, callback: &mut Box<dyn LearnerCallback<TO, VO>>)
where
B: ADBackend,
@ -58,6 +66,18 @@ impl<I> ValidEpoch<I> {
}
impl<TI> TrainEpoch<TI> {
/// Runs the training epoch.
///
/// # Arguments
///
/// * `model` - The model to train.
/// * `optim` - The optimizer to use.
/// * `scheduler` - The learning rate scheduler to use.
/// * `callback` - The callback to use.
///
/// # Returns
///
/// The trained model and the optimizer.
pub fn run<B, M, O, LR, TO, VO>(
&self,
mut model: M,
@ -119,6 +139,19 @@ impl<TI> TrainEpoch<TI> {
}
impl<TI> TrainEpoch<TI> {
/// Runs the training epoch on multiple devices.
///
/// # Arguments
///
/// * `model` - The model to train.
/// * `optim` - The optimizer to use.
/// * `lr_scheduler` - The learning rate scheduler to use.
/// * `callback` - The callback to use.
/// * `devices` - The devices to use.
///
/// # Returns
///
/// The trained model and the optimizer.
pub fn run_multi_device<B, M, O, S, TO, VO>(
&self,
mut model: M,

View File

@ -5,8 +5,13 @@ use burn_core::tensor::Tensor;
/// Simple regression output adapted for multiple metrics.
#[derive(new)]
pub struct RegressionOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,
/// The output.
pub output: Tensor<B, 2>,
/// The targets.
pub targets: Tensor<B, 2>,
}

View File

@ -1 +1,2 @@
/// The trainer module.
pub mod train;

View File

@ -5,6 +5,7 @@ use burn_core::{
use std::sync::mpsc::{Receiver, Sender};
use std::thread::spawn;
/// Multi devices train step.
pub struct MultiDevicesTrainStep<B: ADBackend, M, TI, TO> {
workers: Vec<Worker<B, M, TI>>,
receiver: Receiver<TrainOutput<TO>>,
@ -68,6 +69,15 @@ where
TI: Send + 'static,
TO: Send + 'static,
{
/// Create a new multi devices train step.
///
/// # Arguments
///
/// * `devices` - Devices.
///
/// # Returns
///
/// MultiDevicesTrainStep instance.
pub fn new(devices: &[B::Device]) -> Self
where
TI: Send + 'static,
@ -93,6 +103,16 @@ where
}
}
/// Collect outputs from workers for one step.
///
/// # Arguments
///
/// * `dataloader` - Dataloader.
/// * `model` - Model.
///
/// # Returns
///
/// Outputs.
pub fn step<'a>(
&self,
dataloader: &mut Box<dyn DataLoaderIterator<TI> + 'a>,

View File

@ -8,23 +8,58 @@ use burn_core::optim::{GradientsParams, Optimizer};
use burn_core::tensor::backend::ADBackend;
use std::sync::Arc;
/// A training output.
pub struct TrainOutput<TO> {
/// The gradients.
pub grads: GradientsParams,
/// The item.
pub item: TO,
}
impl<TO> TrainOutput<TO> {
/// Creates a new training output.
///
/// # Arguments
///
/// * `module` - The module.
/// * `grads` - The gradients.
/// * `item` - The item.
///
/// # Returns
///
/// A new training output.
pub fn new<B: ADBackend, M: ADModule<B>>(module: &M, grads: B::Gradients, item: TO) -> Self {
let grads = GradientsParams::from_grads(grads, module);
Self { grads, item }
}
}
/// Trait for a training step.
pub trait TrainStep<TI, TO> {
/// Runs a training step.
///
/// # Arguments
///
/// * `item` - The item to train on.
///
/// # Returns
///
/// The training output.
fn step(&self, item: TI) -> TrainOutput<TO>;
}
/// Trait for a validation step.
pub trait ValidStep<VI, VO> {
/// Runs a validation step.
///
/// # Arguments
///
/// * `item` - The item to validate on.
///
/// # Returns
///
/// The validation output.
fn step(&self, item: VI) -> VO;
}
@ -37,6 +72,16 @@ where
O: Optimizer<M, B>,
LR: LRScheduler,
{
/// Fits the model.
///
/// # Arguments
///
/// * `dataloader_train` - The training dataloader.
/// * `dataloader_valid` - The validation dataloader.
///
/// # Returns
///
/// The fitted model.
pub fn fit<TI, VI>(
mut self,
dataloader_train: Arc<dyn DataLoader<TI>>,

View File

@ -1,8 +1,17 @@
#![warn(missing_docs)]
//! A library for training neural networks using the burn crate.
#[macro_use]
extern crate derive_new;
/// The checkpoint module.
pub mod checkpoint;
/// The logger module.
pub mod logger;
/// The metric module.
pub mod metric;
mod callback;

View File

@ -5,7 +5,7 @@ enum Message<T> {
Log(T),
End,
}
/// Async logger.
pub struct AsyncLogger<T> {
sender: mpsc::Sender<Message<T>>,
handler: Option<std::thread::JoinHandle<()>>,
@ -34,6 +34,7 @@ impl<T> LoggerThread<T> {
}
impl<T: Send + Sync + 'static> AsyncLogger<T> {
/// Create a new async logger.
pub fn new(logger: Box<dyn Logger<T>>) -> Self {
let (sender, receiver) = mpsc::channel();
let thread = LoggerThread::new(Mutex::new(logger), receiver);

View File

@ -1,9 +1,26 @@
/// The logger trait.
pub trait Logger<T>: Send {
/// Logs an item.
///
/// # Arguments
///
/// * `item` - The item.
fn log(&mut self, item: T);
}
/// The logger backend trait.
pub trait LoggerBackend {
/// The logger type.
type Logger<T>: Logger<T>;
/// Create a new logger.
///
/// # Arguments
///
/// * `epoch` - The epoch.
///
/// # Returns
///
/// The logger.
fn create<T>(&self, epoch: usize) -> Self::Logger<T>;
}

View File

@ -1,11 +1,21 @@
use super::Logger;
use std::{fs::File, io::Write};
/// File logger.
pub struct FileLogger {
file: File,
}
impl FileLogger {
/// Create a new file logger.
///
/// # Arguments
///
/// * `path` - The path.
///
/// # Returns
///
/// The file logger.
pub fn new(path: &str) -> Self {
let mut options = std::fs::File::options();
let file = options

View File

@ -2,11 +2,24 @@ use super::{AsyncLogger, FileLogger, Logger};
use crate::metric::MetricEntry;
use std::collections::HashMap;
/// Metric logger.
pub trait MetricLogger: Send {
/// Logs an item.
///
/// # Arguments
///
/// * `item` - The item.
fn log(&mut self, item: &MetricEntry);
/// Logs an epoch.
///
/// # Arguments
///
/// * `epoch` - The epoch.
fn epoch(&mut self, epoch: usize);
}
/// The file metric logger.
pub struct FileMetricLogger {
loggers: HashMap<String, Box<dyn Logger<String>>>,
directory: String,
@ -14,6 +27,15 @@ pub struct FileMetricLogger {
}
impl FileMetricLogger {
/// Create a new file metric logger.
///
/// # Arguments
///
/// * `directory` - The directory.
///
/// # Returns
///
/// The file metric logger.
pub fn new(directory: &str) -> Self {
Self {
loggers: HashMap::new(),

View File

@ -20,11 +20,12 @@ pub struct AccuracyInput<B: Backend> {
}
impl<B: Backend> AccuracyMetric<B> {
/// Create the metric.
/// Creates the metric.
pub fn new() -> Self {
Self::default()
}
/// Sets the pad token.
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self

View File

@ -2,10 +2,19 @@ use burn_core::{data::dataloader::Progress, LearningRate};
/// Metric metadata that can be used when computing metrics.
pub struct MetricMetadata {
/// The current progress.
pub progress: Progress,
/// The current epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The current iteration.
pub iteration: usize,
/// The current learning rate.
pub lr: Option<LearningRate>,
}
@ -33,6 +42,7 @@ impl MetricMetadata {
/// This is important since some conflict may happen when the model output is adapted for each
/// metric's input type.
pub trait Metric: Send + Sync {
/// The input type of the metric.
type Input;
/// Update the metric state and returns the current metric entry.
@ -54,6 +64,7 @@ pub trait Adaptor<T> {
///
/// This is usefull to plot the values of a metric during training.
pub trait Numeric {
/// Returns the numeric value of the metric.
fn value(&self) -> f64;
}

View File

@ -10,6 +10,7 @@ pub struct CUDAMetric {
}
impl CUDAMetric {
/// Creates a new metric for CUDA.
pub fn new() -> Self {
Self {
nvml: Nvml::init().map(Some).unwrap_or_else(|err| {

View File

@ -5,14 +5,23 @@ use crate::{
};
use burn_core::data::dataloader::Progress;
/// Training progress.
pub struct TrainingProgress {
/// The progress.
pub progress: Progress,
/// The epoch.
pub epoch: usize,
/// The total number of epochs.
pub epoch_total: usize,
/// The iteration.
pub iteration: usize,
}
impl TrainingProgress {
/// Creates a new empy training progress.
pub fn none() -> Self {
Self {
progress: Progress {
@ -26,18 +35,47 @@ impl TrainingProgress {
}
}
/// A dashboard metric.
pub enum DashboardMetricState {
/// A generic metric.
Generic(MetricEntry),
/// A numeric metric.
Numeric(MetricEntry, f64),
}
/// Trait for rendering dashboard metrics.
pub trait DashboardRenderer: Send + Sync {
/// Updates the training metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_train(&mut self, state: DashboardMetricState);
/// Updates the validation metric state.
///
/// # Arguments
///
/// * `state` - The metric state.
fn update_valid(&mut self, state: DashboardMetricState);
/// Renders the training progress.
///
/// # Arguments
///
/// * `item` - The training progress.
fn render_train(&mut self, item: TrainingProgress);
/// Renders the validation progress.
///
/// # Arguments
///
/// * `item` - The validation progress.
fn render_valid(&mut self, item: TrainingProgress);
}
/// A dashboard container for all metrics.
pub struct Dashboard<T, V>
where
T: Send + Sync + 'static,
@ -57,6 +95,17 @@ where
T: Send + Sync + 'static,
V: Send + Sync + 'static,
{
/// Creates a new dashboard.
///
/// # Arguments
///
/// * `renderer` - The dashboard renderer.
/// * `logger_train` - The training logger.
/// * `logger_valid` - The validation logger.
///
/// # Returns
///
/// A new dashboard.
pub fn new(
renderer: Box<dyn DashboardRenderer>,
logger_train: Box<dyn MetricLogger>,
@ -73,6 +122,11 @@ where
}
}
/// Registers a training metric.
///
/// # Arguments
///
/// * `metric` - The metric.
pub fn register_train<M: Metric + 'static>(&mut self, metric: M)
where
T: Adaptor<M::Input>,
@ -81,6 +135,11 @@ where
.push(Box::new(MetricWrapper::new(metric)));
}
/// Registers a training numeric metric.
///
/// # Arguments
///
/// * `metric` - The metric.
pub fn register_train_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
where
T: Adaptor<M::Input>,
@ -88,6 +147,12 @@ where
self.metrics_train_numeric
.push(Box::new(MetricWrapper::new(metric)));
}
/// Registers a validation metric.
///
/// # Arguments
///
/// * `metric` - The metric.
pub fn register_valid<M: Metric + 'static>(&mut self, metric: M)
where
V: Adaptor<M::Input>,
@ -96,6 +161,11 @@ where
.push(Box::new(MetricWrapper::new(metric)));
}
/// Registers a validation numeric metric.
///
/// # Arguments
///
/// * `metric` - The metric.
pub fn register_valid_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
where
V: Adaptor<M::Input>,

View File

@ -4,6 +4,7 @@ use std::{collections::HashMap, fmt::Write};
static MAX_REFRESH_RATE_MILLIS: u128 = 250;
/// The CLI dashboard renderer.
pub struct CLIDashboardRenderer {
pb_epoch: ProgressBar,
pb_iteration: ProgressBar,
@ -94,6 +95,7 @@ impl DashboardRenderer for CLIDashboardRenderer {
}
impl CLIDashboardRenderer {
/// Create a new CLI dashboard renderer.
pub fn new() -> Self {
let pb = MultiProgress::new();
let pb_epoch = ProgressBar::new(0);
@ -238,6 +240,7 @@ impl CLIDashboardRenderer {
self.last_update = std::time::Instant::now();
}
/// Registers a new metric to be displayed.
pub fn register_key_item(
&self,
key: &'static str,

View File

@ -1,3 +1,4 @@
/// Command line interface module for the dashboard.
pub mod cli;
mod base;

View File

@ -2,6 +2,7 @@ use rgb::RGB8;
use terminal_size::{terminal_size, Height, Width};
use textplots::{Chart, ColorPlot, Shape};
/// Text plot.
pub struct TextPlot {
train: Vec<(f32, f32)>,
valid: Vec<(f32, f32)>,
@ -16,6 +17,7 @@ impl Default for TextPlot {
}
impl TextPlot {
/// Creates a new text plot.
pub fn new() -> Self {
Self {
train: Vec::new(),
@ -25,6 +27,16 @@ impl TextPlot {
}
}
/// Merges two text plots.
///
/// # Arguments
///
/// * `self` - The first text plot.
/// * `other` - The second text plot.
///
/// # Returns
///
/// The merged text plot.
pub fn merge(self, other: Self) -> Self {
let mut other = other;
let mut train = self.train;
@ -41,6 +53,11 @@ impl TextPlot {
}
}
/// Updates the text plot with a new item for training.
///
/// # Arguments
///
/// * `item` - The new item.
pub fn update_train(&mut self, item: f32) {
self.iteration += 1;
self.train.push((self.iteration as f32, item));
@ -61,6 +78,11 @@ impl TextPlot {
}
}
/// Updates the text plot with a new item for validation.
///
/// # Arguments
///
/// * `item` - The new item.
pub fn update_valid(&mut self, item: f32) {
self.iteration += 1;
self.valid.push((self.iteration as f32, item));
@ -81,6 +103,11 @@ impl TextPlot {
}
}
/// Renders the text plot.
///
/// # Returns
///
/// The rendered text plot.
pub fn render(&self) -> String {
let train_color = RGB8::new(255, 140, 140);
let valid_color = RGB8::new(140, 140, 255);

View File

@ -10,6 +10,7 @@ pub struct LearningRateMetric {
}
impl LearningRateMetric {
/// Creates a new learning rate metric.
pub fn new() -> Self {
Self {
state: NumericMetricState::new(),

View File

@ -1,4 +1,7 @@
/// Dashboard module for training progress.
pub mod dashboard;
/// State module for dashboard metrics.
pub mod state;
mod acc;