mirror of https://github.com/tracel-ai/burn.git
Add missing documents (#424)
This commit is contained in:
parent
eda241f8cf
commit
825aaa9977
|
@ -1,3 +1,4 @@
|
|||
/// The graph module.
|
||||
pub mod graph;
|
||||
|
||||
pub(crate) mod node;
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use proc_macro2::TokenStream;
|
||||
use rust_format::{Config, Edition, Formatter, PostProcess, RustFmt};
|
||||
|
||||
/// Formats a token stream into a string.
|
||||
pub fn format_tokens(tokens: TokenStream) -> String {
|
||||
let fmt = code_formatter();
|
||||
|
||||
|
|
|
@ -1,13 +1,21 @@
|
|||
#![warn(missing_docs)]
|
||||
#![allow(clippy::ptr_arg)]
|
||||
#![allow(clippy::single_match)]
|
||||
#![allow(clippy::upper_case_acronyms)]
|
||||
|
||||
//! `burn-import` is a crate designed to simplify the process of importing models trained in other
|
||||
//! machine learning frameworks into the Burn framework. This tool generates a Rust source file that
|
||||
//! aligns the imported model with Burn's model and converts tensor data into a format compatible with
|
||||
//! Burn.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// The onnx module.
|
||||
#[cfg(feature = "onnx")]
|
||||
pub mod onnx;
|
||||
|
||||
/// The module for generating the burn code.
|
||||
pub mod burn;
|
||||
|
||||
mod formater;
|
||||
|
|
|
@ -72,13 +72,25 @@ pub enum TensorData {
|
|||
Bool(Vec<bool>),
|
||||
}
|
||||
|
||||
/// ONNX graph representation
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ONNXGraph {
|
||||
/// The nodes of the graph.
|
||||
pub nodes: Vec<Node>,
|
||||
|
||||
/// The inputs of the graph.
|
||||
pub inputs: Vec<Argument>,
|
||||
|
||||
/// The outputs of the graph.
|
||||
pub outputs: Vec<Argument>,
|
||||
|
||||
/// The states of the graph.
|
||||
pub states: Vec<State>,
|
||||
|
||||
/// The original node names.
|
||||
pub old_node_names: HashMap<String, String>,
|
||||
|
||||
/// The original input names.
|
||||
pub old_input_names: HashMap<String, String>,
|
||||
}
|
||||
|
||||
|
|
|
@ -50,6 +50,7 @@ pub struct ModelGen {
|
|||
}
|
||||
|
||||
impl ModelGen {
|
||||
/// Create a new `ModelGen`.
|
||||
pub fn new() -> Self {
|
||||
init_log().ok(); // Error when init multiple times are ignored.
|
||||
Self::default()
|
||||
|
@ -143,6 +144,7 @@ impl ModelGen {
|
|||
}
|
||||
|
||||
impl ONNXGraph {
|
||||
/// Converts ONNX graph to Burn graph.
|
||||
pub fn into_burn<PS: PrecisionSettings + 'static>(self) -> BurnGraph<PS> {
|
||||
let mut graph = BurnGraph::<PS>::default();
|
||||
|
||||
|
|
|
@ -16,8 +16,10 @@ use burn_common::stub::Mutex;
|
|||
|
||||
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||
|
||||
/// The device type for the ndarray backend.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum NdArrayDevice {
|
||||
/// The CPU device.
|
||||
Cpu,
|
||||
}
|
||||
|
||||
|
@ -27,6 +29,7 @@ impl Default for NdArrayDevice {
|
|||
}
|
||||
}
|
||||
|
||||
/// The ndarray backend.
|
||||
#[derive(Clone, Copy, Default, Debug)]
|
||||
pub struct NdArrayBackend<E> {
|
||||
phantom: PhantomData<E>,
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! Burn ndarray backend.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
/// Macro for running a function in parallel.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! run_par {
|
||||
(
|
||||
|
@ -16,6 +17,7 @@ macro_rules! run_par {
|
|||
}};
|
||||
}
|
||||
|
||||
/// Macro for iterating over a range in parallel.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! iter_par {
|
||||
(
|
||||
|
|
|
@ -32,6 +32,7 @@ mod utils {
|
|||
}
|
||||
}
|
||||
|
||||
/// Converts a slice of usize to a typed dimension.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! to_typed_dims {
|
||||
(
|
||||
|
@ -48,6 +49,7 @@ macro_rules! to_typed_dims {
|
|||
}};
|
||||
}
|
||||
|
||||
/// Reshapes an array into a tensor.
|
||||
#[macro_export(local_inner_macros)]
|
||||
macro_rules! reshape {
|
||||
(
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use proc_macro::TokenStream;
|
||||
use quote::{format_ident, quote};
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[proc_macro_attribute]
|
||||
pub fn testgen(attr: TokenStream, item: TokenStream) -> TokenStream {
|
||||
let item: proc_macro2::TokenStream = proc_macro2::TokenStream::from(item);
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
#![cfg_attr(not(feature = "std"), no_std)]
|
||||
#![warn(missing_docs)]
|
||||
|
||||
//! This library provides multiple tensor implementations hidden behind an easy to use API
|
||||
//! that supports reverse mode automatic differentiation.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
@ -8,6 +12,7 @@ extern crate alloc;
|
|||
mod tensor;
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
#[allow(missing_docs)]
|
||||
mod tests;
|
||||
|
||||
pub use half::{bf16, f16};
|
||||
|
|
|
@ -8,6 +8,7 @@ use crate::{
|
|||
backend::Backend, check, check::TensorCheck, Bool, Data, Float, Int, Shape, TensorKind,
|
||||
};
|
||||
|
||||
/// A tensor with a given backend, shape and data type.
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct Tensor<B, const D: usize, K = Float>
|
||||
where
|
||||
|
@ -401,43 +402,284 @@ where
|
|||
///
|
||||
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
|
||||
pub trait BasicOps<B: Backend>: TensorKind<B> {
|
||||
/// The type of the tensor elements.
|
||||
type Elem: 'static;
|
||||
|
||||
/// Creates an empty tensor with the given shape.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The empty tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For creating empty tensors, users should prefer the [Tensor::empty](Tensor::empty) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D>;
|
||||
|
||||
/// Returns the shape of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The shape of the tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the shape of a tensor, users should prefer the [Tensor::shape](Tensor::shape) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D>;
|
||||
|
||||
/// Reshapes the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `shape` - The new shape of the tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The reshaped tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For reshaping a tensor, users should prefer the [Tensor::reshape](Tensor::reshape) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn reshape<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> Self::Primitive<D2>;
|
||||
|
||||
/// Select tensor elements corresponding for the given indexes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indexes` - The indexes of the elements to select.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The selected elements.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For selecting elements of a tensor, users should prefer the [Tensor::index](Tensor::index) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> Self::Primitive<D1>;
|
||||
|
||||
/// Assigns the given value to the tensor elements corresponding for the given indexes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indexes` - The indexes of the elements to select.
|
||||
/// * `value` - The value to assign.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the assigned values.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For assigning values to elements of a tensor, users should prefer the [Tensor::index_assign](Tensor::index_assign) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
value: Self::Primitive<D1>,
|
||||
) -> Self::Primitive<D1>;
|
||||
|
||||
/// Returns the device on which the tensor is allocated.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The device on which the tensor is allocated.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the device of a tensor, users should prefer the [Tensor::device](Tensor::device) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> B::Device;
|
||||
|
||||
/// Moves the tensor to the given device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `device` - The device on which the tensor will be moved.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor on the given device.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For moving a tensor to a device, users should prefer the [Tensor::to_device](Tensor::to_device) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn to_device<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
device: &B::Device,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Extracts the data from the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data of the tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For extracting the data of a tensor, users should prefer the [Tensor::into_data](Tensor::into_data) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn into_data<const D: usize>(tensor: Self::Primitive<D>) -> Data<Self::Elem, D>;
|
||||
|
||||
/// Creates a tensor from the given data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The data of the tensor.
|
||||
/// * `device` - The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For creating a tensor from data, users should prefer the [Tensor::from_data](Tensor::from_data) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<Self::Elem, D>,
|
||||
device: &B::Device,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Repeat the tensor along the given dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `dim` - The dimension along which the tensor will be repeated.
|
||||
/// * `times` - The number of times the tensor will be repeated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The repeated tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn repeat<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Concatenates the given tensors along the given dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `vectors` - The tensors to concatenate.
|
||||
/// * `dim` - The dimension along which the tensors will be concatenated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The concatenated tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For concatenating tensors, users should prefer the [Tensor::cat](Tensor::cat) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D>;
|
||||
|
||||
/// Equates the given tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor of booleans indicating whether the corresponding elements are equal.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For equating tensors, users should prefer the [Tensor::equal](Tensor::equal) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn equal<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Returns the name of the element type.
|
||||
fn elem_type_name() -> &'static str {
|
||||
core::any::type_name::<Self::Elem>()
|
||||
}
|
||||
|
|
|
@ -13,10 +13,12 @@ impl<const D: usize, B> Tensor<B, D>
|
|||
where
|
||||
B: Backend,
|
||||
{
|
||||
/// Converts the tensor into a primitive tensor.
|
||||
pub fn into_primitive(self) -> B::TensorPrimitive<D> {
|
||||
self.primitive
|
||||
}
|
||||
|
||||
/// Converts from a primitive tensor into a tensor.
|
||||
pub fn from_primitive(tensor: B::TensorPrimitive<D>) -> Self {
|
||||
Self::new(tensor)
|
||||
}
|
||||
|
@ -261,6 +263,7 @@ where
|
|||
}
|
||||
|
||||
impl<const D: usize, B: ADBackend> Tensor<B, D> {
|
||||
/// Backward pass of the tensor.
|
||||
pub fn backward(&self) -> B::Gradients {
|
||||
B::backward::<D>(self.primitive.clone())
|
||||
}
|
||||
|
@ -279,10 +282,20 @@ impl<const D: usize, B: ADBackend> Tensor<B, D> {
|
|||
B::grad_remove(&self.primitive, grads).map(Tensor::new)
|
||||
}
|
||||
|
||||
/// Returns the inner tensor without the autodiff information.
|
||||
pub fn inner(self) -> Tensor<B::InnerBackend, D> {
|
||||
Tensor::new(B::inner(self.primitive))
|
||||
}
|
||||
|
||||
/// Convert a tensor to the autodiff backend.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `inner` - The tensor to convert.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor converted to the autodiff backend.
|
||||
pub fn from_inner(inner: Tensor<B::InnerBackend, D>) -> Self {
|
||||
Self::new(B::from_inner(inner.primitive))
|
||||
}
|
||||
|
|
|
@ -1,14 +1,23 @@
|
|||
use crate::backend::Backend;
|
||||
|
||||
/// A type-level representation of the kind of a float tensor
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Float;
|
||||
|
||||
/// A type-level representation of the kind of a int tensor.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Int;
|
||||
|
||||
/// A type-level representation of the kind of a bool tensor.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct Bool;
|
||||
|
||||
/// A type-level representation of the kind of a tensor.
|
||||
pub trait TensorKind<B: Backend>: Clone + core::fmt::Debug {
|
||||
/// The primitive type of the tensor.
|
||||
type Primitive<const D: usize>: Clone + core::fmt::Debug;
|
||||
|
||||
/// The name of the tensor kind.
|
||||
fn name() -> &'static str;
|
||||
}
|
||||
|
||||
|
|
|
@ -397,103 +397,901 @@ pub trait Numeric<B: Backend>: BasicOps<B>
|
|||
where
|
||||
Self::Elem: Element,
|
||||
{
|
||||
/// Adds two tensors together.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The sum of the two tensors.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For adding tensors, users should prefer the [Tensor::add](Tensor::add) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn add<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Primitive<D>) -> Self::Primitive<D>;
|
||||
|
||||
/// Adds a scalar to a tensor element-wise.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The sum of the tensor and the scalar.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For adding a scalar to a tensor, users should prefer the [Tensor::add_scalar](Tensor::add_scalar) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn add_scalar<const D: usize, E: ElementConversion>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: E,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Subtracts two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The difference of the two tensors.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For subtracting tensors, users should prefer the [Tensor::sub](Tensor::sub) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn sub<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Primitive<D>) -> Self::Primitive<D>;
|
||||
|
||||
/// Subtracts a scalar from a tensor element-wise.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The difference of the tensor and the scalar.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For subtracting a scalar from a tensor, users should prefer the [Tensor::sub_scalar](Tensor::sub_scalar) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn sub_scalar<const D: usize, E: ElementConversion>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: E,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Divides two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The quotient of the two tensors.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For dividing tensors, users should prefer the [Tensor::div](Tensor::div) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn div<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Primitive<D>) -> Self::Primitive<D>;
|
||||
|
||||
/// Divides a tensor by a scalar element-wise.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The quotient of the tensor and the scalar.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For dividing a tensor by a scalar, users should prefer the [Tensor::div_scalar](Tensor::div_scalar) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn div_scalar<const D: usize, E: ElementConversion>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: E,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Multiplies two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The product of the two tensors.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For multiplying tensors, users should prefer the [Tensor::mul](Tensor::mul) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn mul<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Primitive<D>) -> Self::Primitive<D>;
|
||||
|
||||
/// Multiplies a tensor by a scalar element-wise.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The product of the tensor and the scalar.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For multiplying a tensor by a scalar, users should prefer the [Tensor::mul_scalar](Tensor::mul_scalar) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn mul_scalar<const D: usize, E: ElementConversion>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: E,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Negates a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to negate.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The negated tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For negating a tensor, users should prefer the [Tensor::neg](Tensor::neg) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn neg<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D>;
|
||||
|
||||
/// Creates a tensor filled with zeros.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor filled with zeros.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For creating a tensor filled with zeros, users should prefer the [Tensor::zeros](Tensor::zeros) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D>;
|
||||
|
||||
/// Creates a tensor filled with ones.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device on which the tensor will be allocated.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor filled with ones.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For creating a tensor filled with ones, users should prefer the [Tensor::ones](Tensor::ones) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> Self::Primitive<D>;
|
||||
|
||||
/// Sums all the elements of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to sum.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The sum of all the elements of the tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For summing all the elements of a tensor, users should prefer the [Tensor::sum](Tensor::sum) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn sum<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
|
||||
|
||||
/// Sums all the elements of the tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to sum.
|
||||
/// * `dim` - The dimension along which to sum.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The sum of all the elements of the tensor along the specified dimension.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For summing all the elements of a tensor along a dimension, users should prefer the [Tensor::sum_dim](Tensor::sum_dim) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn sum_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||
|
||||
/// Computes the mean of all the elements of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to compute the mean of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The mean of all the elements of the tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For computing the mean of all the elements of a tensor, users should prefer the [Tensor::mean](Tensor::mean) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn mean<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
|
||||
|
||||
/// Computes the mean of all the elements of the tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to compute the mean of.
|
||||
/// * `dim` - The dimension along which to compute the mean.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The mean of all the elements of the tensor along the specified dimension.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For computing the mean of all the elements of a tensor along a dimension, users should prefer
|
||||
/// the [Tensor::mean_dim](Tensor::mean_dim) function, which is more high-level and designed for public use.
|
||||
fn mean_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||
|
||||
/// Element-wise equality between two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
|
||||
/// corresponding elements of the input tensors are equal, and false otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For element-wise equality between two tensors, users should prefer the [Tensor::equal_elem](Tensor::equal_elem) function,
|
||||
fn equal_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Element-wise greater than comparison between two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
|
||||
/// corresponding element of the left hand side tensor is greater than the corresponding element
|
||||
/// of the right hand side tensor, and false otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For element-wise greater than comparison between two tensors, users should prefer the [Tensor::greater](Tensor::greater) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn greater<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Element-wise greater than comparison between a tensor and a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
|
||||
/// corresponding element of the left hand side tensor is greater than the right hand side
|
||||
/// scalar, and false otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For element-wise greater than comparison between a tensor and a scalar, users should prefer
|
||||
/// the [Tensor::greater_elem](Tensor::greater_elem) function, which is more high-level and designed for public use.
|
||||
fn greater_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem)
|
||||
-> Tensor<B, D, Bool>;
|
||||
|
||||
/// Element-wise greater than or equal comparison between two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
|
||||
/// corresponding element of the left hand side tensor is greater than or equal to the
|
||||
/// corresponding element of the right hand side tensor, and false otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For element-wise greater than or equal comparison between two tensors, users should prefer
|
||||
/// the [Tensor::greater_equal](Tensor::greater_equal) function, which is more high-level and designed for public use.
|
||||
fn greater_equal<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Element-wise greater than or equal comparison between a tensor and a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
|
||||
/// corresponding element of the left hand side tensor is greater than or equal to the right
|
||||
/// hand side scalar, and false otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For element-wise greater than or equal comparison between a tensor and a scalar, users should prefer
|
||||
/// the [Tensor::greater_equal_elem](Tensor::greater_equal_elem) function, which is more high-level and designed for public use.
|
||||
fn greater_equal_elem<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Elem,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Element-wise less than comparison between two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
|
||||
/// corresponding element of the left hand side tensor is less than the corresponding element of
|
||||
/// the right hand side tensor, and false otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For element-wise less than comparison between two tensors, users should prefer the [Tensor::lower](Tensor::lower) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn lower<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Element-wise less than comparison between a tensor and a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
|
||||
/// corresponding element of the left hand side tensor is less than the right hand side scalar,
|
||||
/// and false otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For element-wise less than comparison between a tensor and a scalar, users should prefer
|
||||
/// the [Tensor::lower_elem](Tensor::lower_elem) function, which is more high-level and designed for public use.
|
||||
fn lower_elem<const D: usize>(lhs: Self::Primitive<D>, rhs: Self::Elem) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Element-wise less than or equal comparison between two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensors, where each element is true if the
|
||||
/// corresponding element of the left hand side tensor is less than or equal to the corresponding
|
||||
/// element of the right hand side tensor, and false otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For element-wise less than or equal comparison between two tensors, users should prefer
|
||||
/// the [Tensor::lower_equal](Tensor::lower_equal) function, which is more high-level and designed for public use.
|
||||
fn lower_equal<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Element-wise less than or equal comparison between a tensor and a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the same shape as the input tensor, where each element is true if the
|
||||
/// corresponding element of the left hand side tensor is less than or equal to the right hand
|
||||
/// side scalar, and false otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For element-wise less than or equal comparison between a tensor and a scalar, users should prefer
|
||||
/// the [Tensor::lower_equal_elem](Tensor::lower_equal_elem) function, which is more high-level and designed for public use.
|
||||
fn lower_equal_elem<const D: usize>(
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Elem,
|
||||
) -> Tensor<B, D, Bool>;
|
||||
|
||||
/// Selects elements from a tensor based on a boolean mask.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select elements from if the corresponding element of the mask is true.
|
||||
/// * `mask` - The boolean mask to use for selecting elements.
|
||||
/// * `source` - The tensor to select elements from when the corresponding element of the mask is false.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensors, where each element is taken from the
|
||||
/// corresponding element of the left hand side tensor if the corresponding element of the mask
|
||||
/// is true, and from the corresponding element of the right hand side tensor otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For selecting elements from a tensor based on a boolean mask, users should prefer the
|
||||
/// [Tensor::mask_where](Tensor::mask_where) function, which is more high-level and designed for public use.
|
||||
fn mask_where<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
mask: Tensor<B, D, Bool>,
|
||||
source: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Fills elements of a tensor based on a boolean mask.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor where will be overwritten with the value
|
||||
/// when the corresponding element of the mask is true.
|
||||
/// * `mask` - The boolean mask to use for filling elements.
|
||||
/// * `value` - The value to fill elements with when the corresponding element of the mask is true.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensors, where each element is taken from the
|
||||
/// corresponding element unmodified if the corresponding element of the mask is false, and
|
||||
/// filled with the value otherwise.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For filling elements of a tensor based on a boolean mask, users should prefer the
|
||||
/// [Tensor::mask_fill](Tensor::mask_fill) function, which is more high-level and designed for public use.
|
||||
fn mask_fill<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
mask: Tensor<B, D, Bool>,
|
||||
value: Self::Elem,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Gathers elements from a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The axis along which to gather elements.
|
||||
/// * `tensor` - The tensor to gather elements from.
|
||||
/// * `indexes` - The indexes of the elements to gather.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is taken from the
|
||||
/// corresponding element of the input tensor at the corresponding index along the specified axis.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For gathering elements from a tensor along an axis, users should prefer the
|
||||
/// [Tensor::gather](Tensor::gather) function, which is more high-level and designed for public use.
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Scatters elements into a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The axis along which to scatter elements.
|
||||
/// * `tensor` - The tensor to scatter elements into.
|
||||
/// * `indices` - The indexes of the elements to scatter.
|
||||
/// * `values` - The values to scatter into the tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is taken from the
|
||||
/// corresponding element of the input tensor at the corresponding index along the specified axis,
|
||||
/// except for the elements at the specified indexes, which are taken from the corresponding
|
||||
/// element of the values tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For scattering elements into a tensor along an axis, users should prefer the [Tensor::scatter](Tensor::scatter) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Select tensor elements along the given dimension corresponding for the given indexes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select elements from.
|
||||
/// * `dim` - The axis along which to select elements.
|
||||
/// * `indexes` - The indexes of the elements to select.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is taken from the
|
||||
/// corresponding element of the input tensor at the corresponding index along the specified axis.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For selecting elements from a tensor along an axis, users should prefer the
|
||||
/// [Tensor::index_select](Tensor::index_select) function, which is more high-level and designed for public use.
|
||||
fn index_select<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
) -> Self::Primitive<D>;
|
||||
|
||||
/// Assign the selected elements along the given dimension corresponding to the given indexes
|
||||
/// from the value tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to assign elements to.
|
||||
/// * `dim` - The axis along which to assign elements.
|
||||
/// * `indexes` - The indexes of the elements to assign.
|
||||
/// * `values` - The values to assign to the tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is taken from the
|
||||
/// corresponding element of the input tensor at the corresponding index along the specified axis,
|
||||
/// except for the elements at the specified indexes, which are taken from the corresponding
|
||||
/// element of the values tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For assigning elements to a tensor along an axis, users should prefer the
|
||||
/// [Tensor::index_select_assign](Tensor::index_select_assign) function, which is more high-level and designed for public use.
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D2>,
|
||||
) -> Self::Primitive<D1>;
|
||||
|
||||
/// Gets the indexes of the maximum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The axis along which to get the indexes of the maximum elements.
|
||||
/// * `tensor` - The tensor to get the indexes of the maximum elements from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is the index of the
|
||||
/// maximum element of the input tensor at the corresponding index along the specified axis.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the indexes of the maximum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::argmax](Tensor::argmax) function, which is more high-level and designed for public use.
|
||||
fn argmax<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the indexes of the minimum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The axis along which to get the indexes of the minimum elements.
|
||||
/// * `tensor` - The tensor to get the indexes of the minimum elements from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is the index of the
|
||||
/// minimum element of the input tensor at the corresponding index along the specified axis.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the indexes of the minimum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::argmin](Tensor::argmin) function, which is more high-level and designed for public use.
|
||||
fn argmin<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the maximum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The axis along which to get the maximum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A single-element tensor containing the maximum element of the input tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the maximum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::max](Tensor::max) function, which is more high-level and designed for public use.
|
||||
fn max<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
|
||||
|
||||
/// Gets the maximum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum elements from.
|
||||
/// * `dim` - The axis along which to get the maximum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is the maximum element
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the maximum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::max_dim](Tensor::max_dim) function, which is more high-level and designed for public use.
|
||||
fn max_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||
|
||||
/// Gets the maximum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum elements from.
|
||||
/// * `dim` - The axis along which to get the maximum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tuple containing the maximum element of the input tensor, and a tensor with the same shape
|
||||
/// as the input tensor, where each element is the index of the maximum element of the input tensor
|
||||
/// at the corresponding index along the specified axis.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the maximum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::max_dim_with_indexes](Tensor::max_dim_with_indexes) function, which is more high-level and designed for public use.
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
) -> (Self::Primitive<D>, B::IntTensorPrimitive<D>);
|
||||
|
||||
/// Gets the minimum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum elements from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A single-element tensor containing the minimum element of the input tensor.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the minimum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::min](Tensor::min) function, which is more high-level and designed for public use.
|
||||
fn min<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<1>;
|
||||
|
||||
/// Gets the minimum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum elements from.
|
||||
/// * `dim` - The axis along which to get the minimum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor, where each element is the minimum element
|
||||
/// of the input tensor at the corresponding index along the specified axis.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the minimum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::min_dim](Tensor::min_dim) function, which is more high-level and designed for public use.
|
||||
fn min_dim<const D: usize>(tensor: Self::Primitive<D>, dim: usize) -> Self::Primitive<D>;
|
||||
|
||||
/// Gets the minimum elements and indices of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum elements from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as the input tensor and corresponding indices, where
|
||||
/// each element is the minimum element of the input tensor at the corresponding index
|
||||
/// along the specified axis.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// This is a low-level function used internally by the library to call different backend functions
|
||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For getting the minimum elements of a tensor along an axis, users should prefer the
|
||||
/// [Tensor::min_dim_with_indexes](Tensor::min_dim_with_indexes) function, which is more high-level and designed for public use.
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -101,21 +101,75 @@ pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =
|
|||
|
||||
/// Trait that allows a backend to support autodiff.
|
||||
pub trait ADBackend: Backend {
|
||||
/// The inner backend type.
|
||||
type InnerBackend: Backend<Device = Self::Device, FloatElem = Self::FloatElem>;
|
||||
|
||||
/// Gradients type.
|
||||
type Gradients: Send + Sync;
|
||||
|
||||
/// Backward pass.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor is the last node of computational graph where the gradients are computed.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The gradients.
|
||||
fn backward<const D: usize>(tensor: Self::TensorPrimitive<D>) -> Self::Gradients;
|
||||
|
||||
/// Returns the gradients of a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to extract the gradients from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional tensor containing the gradient.
|
||||
fn grad<const D: usize>(
|
||||
tensor: &Self::TensorPrimitive<D>,
|
||||
grads: &Self::Gradients,
|
||||
) -> Option<ADBackendTensorPrimitive<D, Self>>;
|
||||
|
||||
/// Pops the gradients of a tensor and returns them.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to pop the gradients from.
|
||||
/// * `grads` - The gradients.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An optional tensor containing the given gradients.
|
||||
fn grad_remove<const D: usize>(
|
||||
tensor: &Self::TensorPrimitive<D>,
|
||||
grads: &mut Self::Gradients,
|
||||
) -> Option<ADBackendTensorPrimitive<D, Self>>;
|
||||
|
||||
/// Returns the tensor with inner backend type.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the inner backend tensor for.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The inner backend tensor.
|
||||
fn inner<const D: usize>(
|
||||
tensor: Self::TensorPrimitive<D>,
|
||||
) -> <Self::InnerBackend as Backend>::TensorPrimitive<D>;
|
||||
|
||||
/// Converts the inner backend tensor to the autodiff backend tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The inner backend tensor to convert.
|
||||
///
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The autodiff backend tensor.
|
||||
fn from_inner<const D: usize>(
|
||||
tensor: <Self::InnerBackend as Backend>::TensorPrimitive<D>,
|
||||
) -> Self::TensorPrimitive<D>;
|
||||
|
|
|
@ -6,26 +6,42 @@ use crate::{tensor::Shape, Element, ElementConversion};
|
|||
|
||||
use rand::{distributions::Standard, Rng, RngCore};
|
||||
|
||||
/// Data structure for serializing and deserializing tensor data.
|
||||
#[derive(serde::Serialize, serde::Deserialize, Debug, PartialEq, Eq, Clone, new)]
|
||||
pub struct DataSerialize<E> {
|
||||
/// The values of the tensor.
|
||||
pub value: Vec<E>,
|
||||
/// The shape of the tensor.
|
||||
pub shape: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Data structure for tensors.
|
||||
#[derive(new, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Data<E, const D: usize> {
|
||||
/// The values of the tensor.
|
||||
pub value: Vec<E>,
|
||||
|
||||
/// The shape of the tensor.
|
||||
pub shape: Shape<D>,
|
||||
}
|
||||
|
||||
/// Distribution for random value of a tensor.
|
||||
#[derive(Clone, Copy)]
|
||||
pub enum Distribution<E> {
|
||||
/// Standard distribution.
|
||||
Standard,
|
||||
|
||||
/// Bernoulli distribution with the given probability.
|
||||
Bernoulli(f64),
|
||||
|
||||
/// Uniform distribution. The range is inclusive.
|
||||
Uniform(E, E),
|
||||
|
||||
/// Normal distribution with the given mean and standard deviation.
|
||||
Normal(f64, f64),
|
||||
}
|
||||
|
||||
/// Distribution sampler for random value of a tensor.
|
||||
#[derive(new)]
|
||||
pub struct DistributionSampler<'a, E, R>
|
||||
where
|
||||
|
@ -37,14 +53,22 @@ where
|
|||
rng: &'a mut R,
|
||||
}
|
||||
|
||||
/// Distribution sampler kind for random value of a tensor.
|
||||
pub enum DistributionSamplerKind<E>
|
||||
where
|
||||
Standard: rand::distributions::Distribution<E>,
|
||||
E: rand::distributions::uniform::SampleUniform,
|
||||
{
|
||||
/// Standard distribution.
|
||||
Standard(rand::distributions::Standard),
|
||||
|
||||
/// Uniform distribution.
|
||||
Uniform(rand::distributions::Uniform<E>),
|
||||
|
||||
/// Bernoulli distribution.
|
||||
Bernoulli(rand::distributions::Bernoulli),
|
||||
|
||||
/// Normal distribution.
|
||||
Normal(rand_distr::Normal<f64>),
|
||||
}
|
||||
|
||||
|
@ -55,6 +79,7 @@ where
|
|||
E: Element,
|
||||
R: RngCore,
|
||||
{
|
||||
/// Sames a random value from the distribution.
|
||||
pub fn sample(&mut self) -> E {
|
||||
match &self.kind {
|
||||
DistributionSamplerKind::Standard(distribution) => self.rng.sample(distribution),
|
||||
|
@ -76,6 +101,15 @@ where
|
|||
Standard: rand::distributions::Distribution<E>,
|
||||
E: rand::distributions::uniform::SampleUniform,
|
||||
{
|
||||
/// Creates a new distribution sampler.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `rng` - The random number generator.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The distribution sampler.
|
||||
pub fn sampler<R: RngCore>(self, rng: &'_ mut R) -> DistributionSampler<'_, E, R> {
|
||||
let kind = match self {
|
||||
Distribution::Standard => {
|
||||
|
@ -100,6 +134,11 @@ impl<E> Distribution<E>
|
|||
where
|
||||
E: Element,
|
||||
{
|
||||
/// Converts the distribution to a different element type.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The converted distribution.
|
||||
pub fn convert<EOther: Element>(self) -> Distribution<EOther> {
|
||||
match self {
|
||||
Distribution::Standard => Distribution::Standard,
|
||||
|
@ -113,6 +152,7 @@ where
|
|||
}
|
||||
|
||||
impl<const D: usize, E: Element> Data<E, D> {
|
||||
/// Converts the data to a different element type.
|
||||
pub fn convert<EOther: Element>(self) -> Data<EOther, D> {
|
||||
let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
|
||||
|
||||
|
@ -138,6 +178,7 @@ impl<const D: usize, E: Element> Data<E, D> {
|
|||
}
|
||||
|
||||
impl<E: Element> DataSerialize<E> {
|
||||
/// Converts the data to a different element type.
|
||||
pub fn convert<EOther: Element>(self) -> DataSerialize<EOther> {
|
||||
let value: Vec<EOther> = self.value.into_iter().map(|a| a.elem()).collect();
|
||||
|
||||
|
@ -149,6 +190,7 @@ impl<E: Element> DataSerialize<E> {
|
|||
}
|
||||
|
||||
impl<const D: usize> Data<bool, D> {
|
||||
/// Converts the data to a different element type.
|
||||
pub fn convert<E: Element>(self) -> Data<E, D> {
|
||||
let value: Vec<E> = self.value.into_iter().map(|a| (a as i64).elem()).collect();
|
||||
|
||||
|
@ -160,6 +202,7 @@ impl<const D: usize> Data<bool, D> {
|
|||
}
|
||||
|
||||
impl<E: Element, const D: usize> Data<E, D> {
|
||||
/// Populates the data with random values.
|
||||
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution<E>, rng: &mut R) -> Self {
|
||||
let num_elements = shape.num_elements();
|
||||
let mut data = Vec::with_capacity(num_elements);
|
||||
|
@ -176,6 +219,7 @@ impl<E: core::fmt::Debug, const D: usize> Data<E, D>
|
|||
where
|
||||
E: Element,
|
||||
{
|
||||
/// Populates the data with zeros.
|
||||
pub fn zeros<S: Into<Shape<D>>>(shape: S) -> Data<E, D> {
|
||||
let shape = shape.into();
|
||||
let num_elements = shape.num_elements();
|
||||
|
@ -187,15 +231,13 @@ where
|
|||
|
||||
Data::new(data, shape)
|
||||
}
|
||||
pub fn zeros_(shape: Shape<D>, _kind: E) -> Data<E, D> {
|
||||
Self::zeros(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: core::fmt::Debug, const D: usize> Data<E, D>
|
||||
where
|
||||
E: Element,
|
||||
{
|
||||
/// Populates the data with ones.
|
||||
pub fn ones(shape: Shape<D>) -> Data<E, D> {
|
||||
let num_elements = shape.num_elements();
|
||||
let mut data = Vec::with_capacity(num_elements);
|
||||
|
@ -206,12 +248,14 @@ where
|
|||
|
||||
Data::new(data, shape)
|
||||
}
|
||||
pub fn ones_(shape: Shape<D>, _kind: E) -> Data<E, D> {
|
||||
Self::ones(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
|
||||
/// Serializes the data.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The serialized data.
|
||||
pub fn serialize(&self) -> DataSerialize<E> {
|
||||
DataSerialize {
|
||||
value: self.value.clone(),
|
||||
|
@ -221,6 +265,16 @@ impl<E: core::fmt::Debug + Copy, const D: usize> Data<E, D> {
|
|||
}
|
||||
|
||||
impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E, D> {
|
||||
/// Asserts the data is approximately equal to another data.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `other` - The other data.
|
||||
/// * `precision` - The precision of the comparison.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if the data is not approximately equal.
|
||||
pub fn assert_approx_eq(&self, other: &Self, precision: usize) {
|
||||
let mut message = String::new();
|
||||
if self.shape != other.shape {
|
||||
|
@ -270,6 +324,7 @@ impl<E: Into<f64> + Clone + core::fmt::Debug + PartialEq, const D: usize> Data<E
|
|||
}
|
||||
|
||||
impl<const D: usize> Data<usize, D> {
|
||||
/// Converts the usize data to a different element type.
|
||||
pub fn from_usize<O: num_traits::FromPrimitive>(self) -> Data<O, D> {
|
||||
let value: Vec<O> = self
|
||||
.value
|
||||
|
|
|
@ -3,6 +3,7 @@ use half::{bf16, f16};
|
|||
use num_traits::ToPrimitive;
|
||||
use rand::RngCore;
|
||||
|
||||
/// Element trait for tensor.
|
||||
pub trait Element:
|
||||
ToPrimitive
|
||||
+ ElementRandom
|
||||
|
@ -18,29 +19,63 @@ pub trait Element:
|
|||
{
|
||||
}
|
||||
|
||||
/// Element conversion trait for tensor.
|
||||
pub trait ElementConversion {
|
||||
/// Converts an element to another element.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `elem` - The element to convert.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The converted element.
|
||||
fn from_elem<E: ToPrimitive>(elem: E) -> Self;
|
||||
|
||||
/// Converts and returns the converted element.
|
||||
fn elem<E: Element>(self) -> E;
|
||||
}
|
||||
|
||||
/// Element trait for random value of a tensor.
|
||||
pub trait ElementRandom {
|
||||
/// Returns a random value for the given distribution.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `distribution` - The distribution to sample from.
|
||||
/// * `rng` - The random number generator.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The random value.
|
||||
fn random<R: RngCore>(distribution: Distribution<Self>, rng: &mut R) -> Self
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
/// Element precision trait for tensor.
|
||||
#[derive(Clone, PartialEq, Eq, Copy, Debug)]
|
||||
pub enum Precision {
|
||||
/// Double precision, e.g. f64.
|
||||
Double,
|
||||
|
||||
/// Full precision, e.g. f32.
|
||||
Full,
|
||||
|
||||
/// Half precision, e.g. f16.
|
||||
Half,
|
||||
|
||||
/// Other precision.
|
||||
Other,
|
||||
}
|
||||
|
||||
/// Element precision trait for tensor.
|
||||
pub trait ElementPrecision {
|
||||
/// Returns the precision of the element.
|
||||
fn precision() -> Precision;
|
||||
}
|
||||
|
||||
/// Macro to implement the element trait for a type.
|
||||
#[macro_export]
|
||||
macro_rules! make_element {
|
||||
(
|
||||
|
|
|
@ -1,6 +1,16 @@
|
|||
use crate::backend::Backend;
|
||||
use crate::{activation, Tensor};
|
||||
|
||||
/// Computes the log softmax cross entropy between logits and target probabilities.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `logits` - The logits.
|
||||
/// * `target_probs` - The target probabilities.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The log softmax cross entropy.
|
||||
pub fn cross_entropy_with_logits<B: Backend, const D: usize>(
|
||||
logits: Tensor<B, D>,
|
||||
target_probs: Tensor<B, D>,
|
||||
|
|
|
@ -10,11 +10,22 @@ pub use data::*;
|
|||
pub use element::*;
|
||||
pub use shape::*;
|
||||
|
||||
/// The activation module.
|
||||
pub mod activation;
|
||||
|
||||
/// The backend module.
|
||||
pub mod backend;
|
||||
|
||||
/// The container module.
|
||||
pub mod container;
|
||||
|
||||
/// The loss module.
|
||||
pub mod loss;
|
||||
|
||||
/// The burn module.
|
||||
pub mod module;
|
||||
|
||||
/// Operations on tensors module.
|
||||
pub mod ops;
|
||||
|
||||
#[cfg(feature = "experimental-named-tensor")]
|
||||
|
|
|
@ -4,15 +4,22 @@ use alloc::string::String;
|
|||
use crate::backend::Backend;
|
||||
use crate::Tensor;
|
||||
|
||||
/// Dimension trait.
|
||||
pub trait Dim: core::fmt::Debug {
|
||||
/// Converts the dimension to a string.
|
||||
fn to_string() -> String;
|
||||
}
|
||||
|
||||
/// Named dimensions trait.
|
||||
pub trait NamedDims<B: Backend>: core::fmt::Debug {
|
||||
/// Tensor type.
|
||||
type Tensor;
|
||||
|
||||
/// Converts the named dimensions to a string.
|
||||
fn to_string() -> String;
|
||||
}
|
||||
|
||||
/// Named dimension macro.
|
||||
#[macro_export]
|
||||
macro_rules! NamedDim {
|
||||
($name:ident) => {
|
||||
|
|
|
@ -5,11 +5,30 @@ use core::f64::consts::SQRT_2;
|
|||
///
|
||||
/// This trait let backend implementations override activation functions for better performance.
|
||||
pub trait ActivationOps<B: Backend> {
|
||||
/// Applies the ReLU activation function.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
fn relu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
let mask = B::lower_equal_elem(tensor.clone(), 0.elem());
|
||||
|
||||
B::mask_fill(tensor, mask, 0.elem())
|
||||
}
|
||||
|
||||
/// Applies the ReLU activation function backward.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `output` - The output tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The gradient.
|
||||
fn relu_backward<const D: usize>(
|
||||
output: B::TensorPrimitive<D>,
|
||||
grad: B::TensorPrimitive<D>,
|
||||
|
@ -18,6 +37,16 @@ pub trait ActivationOps<B: Backend> {
|
|||
|
||||
B::mask_fill(grad, mask, 0.elem())
|
||||
}
|
||||
|
||||
/// Applies the Gelu activation function.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
fn gelu<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
let x = B::div_scalar(tensor.clone(), SQRT_2.elem());
|
||||
let x = B::erf(x);
|
||||
|
@ -27,6 +56,16 @@ pub trait ActivationOps<B: Backend> {
|
|||
B::div_scalar(x, 2i32.elem())
|
||||
}
|
||||
|
||||
/// Applies the Gelu activation function backward.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `x` - The tensor.
|
||||
/// * `grad` - The gradient.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
fn gelu_backward<const D: usize>(
|
||||
x: B::TensorPrimitive<D>,
|
||||
grad: B::TensorPrimitive<D>,
|
||||
|
|
|
@ -6,37 +6,157 @@ use crate::{backend::Backend, tensor::Shape, Data};
|
|||
/// Bool Tensor API for basic operations, see [tensor](crate::Tensor)
|
||||
/// for documentation on each function.
|
||||
pub trait BoolTensorOps<B: Backend> {
|
||||
/// Creates a new bool tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the given shape.
|
||||
fn bool_empty<const D: usize>(shape: Shape<D>, device: &B::Device)
|
||||
-> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Returns the shape of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The shape of the tensor.
|
||||
fn bool_shape<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Shape<D>;
|
||||
|
||||
/// Converts the tensor to a data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
|
||||
|
||||
/// Gets the data from the tensor.
|
||||
///
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The data structure.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data cloned from the data structure.
|
||||
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D> {
|
||||
Self::bool_into_data(tensor.clone())
|
||||
}
|
||||
|
||||
/// Creates a tensor from the data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The data structure.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the data.
|
||||
fn bool_from_data<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: &B::Device,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Converts bool tensor to int tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The int tensor with the same data as the bool tensor.
|
||||
fn bool_into_int<const D: usize>(tensor: B::BoolTensorPrimitive<D>)
|
||||
-> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the device of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The device of the tensor.
|
||||
fn bool_device<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> B::Device;
|
||||
|
||||
/// Moves the tensor to the device.
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D>,
|
||||
device: &B::Device,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Reshapes the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `shape` - The new shape.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the new shape.
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> B::BoolTensorPrimitive<D2>;
|
||||
|
||||
/// Gets the values from the tensor for the given indexes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indexes` - The indexes to get the values from.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the values for the given indexes.
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> B::BoolTensorPrimitive<D1>;
|
||||
|
||||
/// Sets the values in the tensor for the given indexes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indexes` - The indexes to set the values for.
|
||||
/// * `value` - The values to set.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the values set for the given indexes.
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
value: B::BoolTensorPrimitive<D1>,
|
||||
) -> B::BoolTensorPrimitive<D1>;
|
||||
|
||||
/// Repeats one dimension of the tensor a given number of times along that dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `dim` - The dimension to repeat.
|
||||
/// * `times` - The number of times to repeat the dimension.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the dimension repeated.
|
||||
fn bool_repeat<const D: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
@ -65,14 +185,47 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
|
||||
tensor_output
|
||||
}
|
||||
|
||||
/// Concatenates the tensors along the given dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors` - The tensors to concatenate.
|
||||
/// * `dim` - The dimension to concatenate along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the tensors concatenated along the given dimension.
|
||||
fn bool_cat<const D: usize>(
|
||||
tensors: Vec<B::BoolTensorPrimitive<D>>,
|
||||
dim: usize,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Equates the two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the result of the equate.
|
||||
fn bool_equal<const D: usize>(
|
||||
lhs: B::BoolTensorPrimitive<D>,
|
||||
rhs: B::BoolTensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Equates the tensor with the element.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side element.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the result of the equate.
|
||||
fn bool_equal_elem<const D: usize>(
|
||||
lhs: B::BoolTensorPrimitive<D>,
|
||||
rhs: bool,
|
||||
|
|
|
@ -6,66 +6,246 @@ use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
|
|||
/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
|
||||
/// for documentation on each function.
|
||||
pub trait IntTensorOps<B: Backend> {
|
||||
/// Creates a new int tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The integer tensor with the given shape.
|
||||
fn int_empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Returns the shape of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The shape of the tensor.
|
||||
fn int_shape<const D: usize>(tensor: &B::IntTensorPrimitive<D>) -> Shape<D>;
|
||||
|
||||
/// Converts the tensor to a data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn int_into_data<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> Data<B::IntElem, D>;
|
||||
|
||||
/// Gets the data from the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data cloned from the data structure.
|
||||
fn int_to_data<const D: usize>(tensor: &B::IntTensorPrimitive<D>) -> Data<B::IntElem, D> {
|
||||
Self::int_into_data(tensor.clone())
|
||||
}
|
||||
|
||||
/// Creates a tensor from the data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The data structure.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the data.
|
||||
fn int_from_data<const D: usize>(
|
||||
data: Data<B::IntElem, D>,
|
||||
device: &B::Device,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the device of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The device of the tensor.
|
||||
fn int_device<const D: usize>(tensor: &B::IntTensorPrimitive<D>) -> B::Device;
|
||||
|
||||
/// Moves the tensor to the given device.
|
||||
fn int_to_device<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
device: &B::Device,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Reshapes the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `shape` - The new shape.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the new shape.
|
||||
fn int_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: B::IntTensorPrimitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> B::IntTensorPrimitive<D2>;
|
||||
|
||||
/// Gets the element at the given indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indices` - The indices.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The elements at the given indices.
|
||||
fn int_index<const D1: usize, const D2: usize>(
|
||||
tensor: B::IntTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> B::IntTensorPrimitive<D1>;
|
||||
|
||||
/// Sets the element at the given indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indices` - The indices.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the element at the given indices set.
|
||||
fn int_index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::IntTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
indices: [Range<usize>; D2],
|
||||
value: B::IntTensorPrimitive<D1>,
|
||||
) -> B::IntTensorPrimitive<D1>;
|
||||
|
||||
/// Fills the tensor with values from the source tensor if the mask is true at the given
|
||||
/// indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `mask` - The mask.
|
||||
/// * `source` - The source tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the values filled.
|
||||
fn int_mask_where<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
mask: B::BoolTensorPrimitive<D>,
|
||||
source: B::IntTensorPrimitive<D>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Fills the tensor with the given value if the mask is true at the given indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `mask` - The mask.
|
||||
/// * `value` - The value.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the values filled.
|
||||
fn int_mask_fill<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
mask: B::BoolTensorPrimitive<D>,
|
||||
value: B::IntElem,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gather elements from the tensor at the given indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The dimension to gather from.
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indices` - The indices.
|
||||
fn int_gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
indices: B::IntTensorPrimitive<D>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Scatter a given value to the tensor at the given indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The dimension to scatter to.
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `indices` - The indices.
|
||||
/// * `value` - The value.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the values scattered.
|
||||
fn int_scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
indices: B::IntTensorPrimitive<D>,
|
||||
value: B::IntTensorPrimitive<D>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Select tensor elements along the given dimension corresponding to the given indices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `dim` - The dimension to select from.
|
||||
/// * `indexes` - The indexes.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements.
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Assign the selected elements along the given dimension corresponding to the given indices
|
||||
/// to the given value.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `dim` - The dimension to select from.
|
||||
/// * `indexes` - The indexes.
|
||||
/// * `value` - The value.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements assigned to the given value.
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::IntTensorPrimitive<D1>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
value: B::IntTensorPrimitive<D2>,
|
||||
) -> B::IntTensorPrimitive<D1>;
|
||||
|
||||
/// Repeats the tensor along the given dimension the given number of times.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `dim` - The dimension to repeat.
|
||||
/// * `times` - The number of times to repeat.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given dimension repeated the given number of times.
|
||||
fn int_repeat<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
@ -94,114 +274,438 @@ pub trait IntTensorOps<B: Backend> {
|
|||
|
||||
tensor_output
|
||||
}
|
||||
|
||||
/// Concatenates the given tensors along the given dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors` - The tensors.
|
||||
/// * `dim` - The dimension to concatenate along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The concatenated tensor.
|
||||
fn int_cat<const D: usize>(
|
||||
tensors: Vec<B::IntTensorPrimitive<D>>,
|
||||
dim: usize,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise equality comparison.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_equal<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntTensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise equality comparison with a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_equal_elem<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise greater than comparison.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_greater<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntTensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise greater than comparison with a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_greater_elem<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise greater than or equal comparison.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_greater_equal<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntTensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise greater than or equal comparison with a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_greater_equal_elem<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise less than comparison.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_lower<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntTensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise less than comparison with a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_lower_elem<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise less than or equal comparison.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_lower_equal<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntTensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise less than or equal comparison with a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The boolean tensor with the result of the comparison.
|
||||
fn int_lower_equal_elem<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
// ==== NUMERIC ==== //
|
||||
|
||||
/// Elementwise addition.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of the addition.
|
||||
fn int_add<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntTensorPrimitive<D>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise addition with a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of the addition.
|
||||
fn int_add_scalar<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise subtraction.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of the subtraction.
|
||||
fn int_sub<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntTensorPrimitive<D>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise subtraction with a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of the subtraction.
|
||||
fn int_sub_scalar<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise multiplication.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of the multiplication.
|
||||
fn int_mul<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntTensorPrimitive<D>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise multiplication with a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of the multiplication.
|
||||
fn int_mul_scalar<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise division.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of the division.
|
||||
fn int_div<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntTensorPrimitive<D>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise division with a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of the division.
|
||||
fn int_div_scalar<const D: usize>(
|
||||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Elementwise negation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to negate.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The negated tensor.
|
||||
fn int_neg<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<D> {
|
||||
Self::int_mul_scalar(tensor, (-1.0).elem::<B::IntElem>())
|
||||
}
|
||||
|
||||
/// Creates a tensor of zeros.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor of zeros.
|
||||
fn int_zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Creates a tensor of ones.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor of ones.
|
||||
fn int_ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Sums all elements in the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to sum.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The sum of all elements in the tensor.
|
||||
fn int_sum<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1>;
|
||||
|
||||
/// Sums all elements in the tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to sum.
|
||||
/// * `dim` - The dimension to sum along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The sum of all elements in the tensor along the dimension.
|
||||
fn int_sum_dim<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Computes the mean of all elements in the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to compute the mean of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The mean of all elements in the tensor.
|
||||
fn int_mean<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1>;
|
||||
|
||||
/// Computes the mean of all elements in the tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to compute the mean of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The mean of all elements in the tensor along the dimension.
|
||||
fn int_mean_dim<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the indices of the maximum elements along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum indices of.
|
||||
/// * `dim` - The dimension to get the maximum indices along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The indices of the maximum elements along the dimension.
|
||||
fn int_argmax<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the indices of the minimum elements along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum indices of.
|
||||
/// * `dim` - The dimension to get the minimum indices along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The indices of the minimum elements along the dimension.
|
||||
fn int_argmin<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the maximum element in the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum element of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The maximum element in the tensor.
|
||||
fn int_max<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1> {
|
||||
let shape = B::int_shape(&tensor);
|
||||
let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
|
||||
|
||||
B::int_max_dim(tensor, 0)
|
||||
}
|
||||
|
||||
/// Gets the maximum element in the tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum element of.
|
||||
/// * `dim` - The dimension to get the maximum element along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The maximum element in the tensor along the dimension.
|
||||
fn int_max_dim<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
@ -210,6 +714,17 @@ pub trait IntTensorOps<B: Backend> {
|
|||
|
||||
B::int_gather(D - 1, tensor, index)
|
||||
}
|
||||
|
||||
/// Gets the maximum elements and corresponding indices along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum elements and indices of.
|
||||
/// * `dim` - The dimension to get the maximum elements and indices along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The maximum elements and corresponding indices along the dimension.
|
||||
fn int_max_dim_with_indexes<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
@ -219,12 +734,33 @@ pub trait IntTensorOps<B: Backend> {
|
|||
|
||||
(values, index)
|
||||
}
|
||||
|
||||
/// Gets the minimum element in the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum element of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The minimum element in the tensor.
|
||||
fn int_min<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1> {
|
||||
let shape = B::int_shape(&tensor);
|
||||
let tensor = B::int_reshape(tensor, Shape::new([shape.num_elements()]));
|
||||
|
||||
B::int_min_dim(tensor, 0)
|
||||
}
|
||||
|
||||
/// Gets the minimum elements in the tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum element of.
|
||||
/// * `dim` - The dimension to get the minimum element along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The minimum element in the tensor along the dimension.
|
||||
fn int_min_dim<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
@ -233,6 +769,17 @@ pub trait IntTensorOps<B: Backend> {
|
|||
|
||||
B::int_gather(D - 1, tensor, index)
|
||||
}
|
||||
|
||||
/// Gets the minimum elements and corresponding indices along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum elements and indices of.
|
||||
/// * `dim` - The dimension to get the minimum elements and indices along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The minimum elements and corresponding indices along the dimension.
|
||||
fn int_min_dim_with_indexes<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -4,52 +4,93 @@ use crate::{backend::Backend, Shape};
|
|||
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
|
||||
#[derive(new)]
|
||||
pub struct Conv2dBackward<B: Backend> {
|
||||
/// Gradient.
|
||||
pub x_grad: B::TensorPrimitive<4>,
|
||||
|
||||
/// Weights gradient.
|
||||
pub weights_grad: B::TensorPrimitive<4>,
|
||||
|
||||
/// Bias gradient.
|
||||
pub bias_grad: Option<B::TensorPrimitive<1>>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [max_pool2d](ModuleOps::max_pool2d).
|
||||
#[derive(new)]
|
||||
pub struct MaxPool2dBackward<B: Backend> {
|
||||
/// Gradient.
|
||||
pub x_grad: B::TensorPrimitive<4>,
|
||||
}
|
||||
|
||||
/// Results from [max_pool2d](ModuleOps::max_pool2d_with_indexes).
|
||||
#[derive(new)]
|
||||
pub struct MaxPool2dWithIndexes<B: Backend> {
|
||||
/// The output tensor.
|
||||
pub output: B::TensorPrimitive<4>,
|
||||
|
||||
/// The indexes tensor.
|
||||
pub indexes: B::IntTensorPrimitive<4>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [conv1d](ModuleOps::conv1d).
|
||||
#[derive(new)]
|
||||
pub struct Conv1dBackward<B: Backend> {
|
||||
/// Gradient.
|
||||
pub x_grad: B::TensorPrimitive<3>,
|
||||
|
||||
/// Weights gradient.
|
||||
pub weights_grad: B::TensorPrimitive<3>,
|
||||
|
||||
/// Bias gradient.
|
||||
pub bias_grad: Option<B::TensorPrimitive<1>>,
|
||||
}
|
||||
|
||||
/// Convolution options.
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct ConvOptions<const N: usize> {
|
||||
/// Stride.
|
||||
pub stride: [usize; N],
|
||||
|
||||
/// Padding.
|
||||
pub padding: [usize; N],
|
||||
|
||||
/// Dilation.
|
||||
pub dilation: [usize; N],
|
||||
|
||||
/// Groups.
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
/// Transposed convolution options.
|
||||
#[derive(new, Debug, Clone)]
|
||||
pub struct ConvTransposeOptions<const N: usize> {
|
||||
/// Stride.
|
||||
pub stride: [usize; N],
|
||||
|
||||
/// Padding.
|
||||
pub padding: [usize; N],
|
||||
|
||||
/// Padding out.
|
||||
pub padding_out: [usize; N],
|
||||
|
||||
/// Dilation.
|
||||
pub dilation: [usize; N],
|
||||
|
||||
/// Groups.
|
||||
pub groups: usize,
|
||||
}
|
||||
|
||||
/// Module operations trait.
|
||||
pub trait ModuleOps<B: Backend> {
|
||||
/// Embedding operation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `weights` - The embedding weights.
|
||||
/// * `indexes` - The indexes tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The output tensor.
|
||||
fn embedding(
|
||||
weights: B::TensorPrimitive<2>,
|
||||
indexes: B::IntTensorPrimitive<2>,
|
||||
|
@ -62,6 +103,18 @@ pub trait ModuleOps<B: Backend> {
|
|||
|
||||
B::reshape(output, Shape::new([batch_size, seq_length, d_model]))
|
||||
}
|
||||
|
||||
/// Embedding backward operation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `weights` - The embedding weights.
|
||||
/// * `output_grad` - The output gradient.
|
||||
/// * `indexes` - The indexes tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The gradient.
|
||||
fn embedding_backward(
|
||||
weights: B::TensorPrimitive<2>,
|
||||
output_grad: B::TensorPrimitive<3>,
|
||||
|
@ -77,6 +130,7 @@ pub trait ModuleOps<B: Backend> {
|
|||
|
||||
B::index_select_assign(grad, 0, indexes, output_grad)
|
||||
}
|
||||
|
||||
/// Two dimensional convolution.
|
||||
///
|
||||
/// # Shapes
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
/// Module with convolution operations.
|
||||
pub mod conv;
|
||||
|
||||
/// Module with pooling operations.
|
||||
pub mod pool;
|
||||
|
||||
mod base;
|
||||
|
|
|
@ -5,31 +5,137 @@ use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversi
|
|||
|
||||
/// Operations on float tensors.
|
||||
pub trait TensorOps<B: Backend> {
|
||||
/// Creates a new tensor from the data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `data` - The data structure.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given data.
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<B::FloatElem, D>,
|
||||
device: &B::Device,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Creates a new tensor with random values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `distribution` - The distribution to sample from.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given shape and random values.
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<B::FloatElem>,
|
||||
device: &B::Device,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Creates a new tensor with zeros.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given shape and zeros.
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D> {
|
||||
Self::from_data(Data::zeros(shape), device)
|
||||
}
|
||||
|
||||
/// Creates a new tensor with ones.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given shape and ones.
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D> {
|
||||
Self::from_data(Data::ones(shape), device)
|
||||
}
|
||||
|
||||
/// Gets the shape of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The shape of the tensor.
|
||||
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
|
||||
|
||||
/// Converts the tensor to a data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
|
||||
|
||||
/// Converts the tensor to a data structure.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The data structure with the tensor's data.
|
||||
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D> {
|
||||
Self::to_data(&tensor)
|
||||
}
|
||||
|
||||
/// Gets the device of the tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The device of the tensor.
|
||||
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
|
||||
|
||||
/// Moves the tensor to the given device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `device` - The device to move the tensor to.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor on the given device.
|
||||
fn to_device<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
device: &B::Device,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Creates a new tensor with values from the given range.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given values.
|
||||
fn arange(range: Range<usize>, device: &B::Device) -> B::IntTensorPrimitive<1> {
|
||||
let shape = Shape::new([range.end - range.start]);
|
||||
let value = range
|
||||
|
@ -39,7 +145,30 @@ pub trait TensorOps<B: Backend> {
|
|||
let data = Data::new(value, shape);
|
||||
B::int_from_data(data, device)
|
||||
}
|
||||
|
||||
/// Creates an empty tensor with the given shape.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `shape` - The shape of the tensor.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The empty tensor with the given shape.
|
||||
fn empty<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Repeat the tensor along the given dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor.
|
||||
/// * `dim` - The dimension to repeat.
|
||||
/// * `times` - The number of times to repeat the dimension.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given dimension repeated.
|
||||
fn repeat<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
@ -68,142 +197,479 @@ pub trait TensorOps<B: Backend> {
|
|||
|
||||
tensor_output
|
||||
}
|
||||
|
||||
/// Adds two tensors together.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of adding the two tensors together.
|
||||
fn add<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Adds a scalar to a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of adding the scalar to the tensor.
|
||||
fn add_scalar<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::FloatElem,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Subtracts two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of subtracting the two tensors.
|
||||
fn sub<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Subtracts a scalar from a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of subtracting the scalar from the tensor.
|
||||
fn sub_scalar<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::FloatElem,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Multiplies two tensors together element-wise.
|
||||
fn mul<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Multiplies a tensor by a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of multiplying the tensor by the scalar.
|
||||
fn mul_scalar<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::FloatElem,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Divides two tensors element-wise.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of dividing the two tensors.
|
||||
fn div<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Divides a tensor by a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of dividing the tensor by the scalar.
|
||||
fn div_scalar<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::FloatElem,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Multiplies two tensors together using matrix multiplication.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The result of multiplying the two tensors together using matrix multiplication.
|
||||
fn matmul<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Negates a tensor element-wise.
|
||||
fn neg<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
Self::mul_scalar(tensor, (-1.0_f32).elem::<B::FloatElem>())
|
||||
}
|
||||
|
||||
/// Transposes a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to transpose.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The transposed tensor.
|
||||
fn transpose<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
Self::swap_dims(tensor, D - 2, D - 1)
|
||||
}
|
||||
|
||||
/// Swaps two dimensions of a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to swap the dimensions of.
|
||||
/// * `dim1` - The first dimension to swap.
|
||||
/// * `dim2` - The second dimension to swap.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the dimensions swapped.
|
||||
fn swap_dims<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim1: usize,
|
||||
dim2: usize,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Reshapes a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to reshape.
|
||||
/// * `shape` - The new shape of the tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the new shape.
|
||||
fn reshape<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> B::TensorPrimitive<D2>;
|
||||
|
||||
/// Gather elements from a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The dimension to gather from.
|
||||
/// * `tensor` - The tensor to gather from.
|
||||
/// * `indexes` - The indexes to gather.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The gathered elements.
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Scatter elements into a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dim` - The dimension to scatter into.
|
||||
/// * `tensor` - The tensor to scatter into.
|
||||
/// * `indexes` - The indexes to scatter into.
|
||||
/// * `value` - The value to scatter.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the scattered elements.
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
value: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Select tensor elements along the given dimension corresponding for the given indexes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `dim` - The dimension to select from.
|
||||
/// * `indexes` - The indexes to select.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The selected elements.
|
||||
fn index_select<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Assign the selected elements along the given dimension corresponding for the given indexes
|
||||
/// to the given value.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `dim` - The dimension to select from.
|
||||
/// * `indexes` - The indexes to select.
|
||||
/// * `value` - The value to assign.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements assigned to the given value.
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
value: B::TensorPrimitive<D2>,
|
||||
) -> B::TensorPrimitive<D1>;
|
||||
|
||||
/// Select tensor elements corresponding for the given indexes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `indexes` - The indexes to select.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The selected elements in a new tensor.
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> B::TensorPrimitive<D1>;
|
||||
|
||||
/// Assign the selected elements corresponding for the given indexes to the given value.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `indexes` - The indexes to select.
|
||||
/// * `value` - The value to assign.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements assigned to the given value.
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
value: B::TensorPrimitive<D1>,
|
||||
) -> B::TensorPrimitive<D1>;
|
||||
|
||||
/// Update the given tensor with the value tensor where the mask is true.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `mask` - The boolean mask to select with.
|
||||
/// * `value` - The value to assign to the selected elements from the value tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements assigned to the given value.
|
||||
fn mask_where<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
mask: B::BoolTensorPrimitive<D>,
|
||||
value: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Update the given tensor with the value where the mask is true.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to select from.
|
||||
/// * `mask` - The boolean mask to select with.
|
||||
/// * `value` - The value to assign to the selected elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the selected elements assigned to the given value.
|
||||
fn mask_fill<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
mask: B::BoolTensorPrimitive<D>,
|
||||
value: B::FloatElem,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Equal comparison of two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn equal<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Equal comparison of a tensor and a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn equal_elem<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::FloatElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Greater than comparison of two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn greater<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Greater than comparison of a tensor and a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn greater_elem<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::FloatElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Greater than or equal comparison of two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn greater_equal<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Greater than or equal comparison of a tensor and a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn greater_equal_elem<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::FloatElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Less than comparison of two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn lower<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Less than comparison of a tensor and a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn lower_elem<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::FloatElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Less than or equal comparison of two tensors.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side tensor.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn lower_equal<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::TensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Less than or equal comparison of a tensor and a scalar.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `lhs` - The left hand side tensor.
|
||||
/// * `rhs` - The right hand side scalar.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A boolean tensor with the result of the comparison.
|
||||
fn lower_equal_elem<const D: usize>(
|
||||
lhs: B::TensorPrimitive<D>,
|
||||
rhs: B::FloatElem,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
|
||||
/// Detaches a tensor from the computation graph.
|
||||
fn detach<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D> {
|
||||
// Should only be overriden by autodiff backends.
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Sets the `require_grad` flag of a tensor.
|
||||
fn set_require_grad<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
_require_grad: bool,
|
||||
|
@ -211,54 +677,273 @@ pub trait TensorOps<B: Backend> {
|
|||
// Should only be overriden by autodiff backends.
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Returns the `require_grad` flag of a tensor.
|
||||
fn is_require_grad<const D: usize>(_tensor: &B::TensorPrimitive<D>) -> bool {
|
||||
// Should only be overriden by autodiff backends.
|
||||
false
|
||||
}
|
||||
|
||||
/// Sum of all elements in a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to sum.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A scalar tensor with the sum of all elements in `tensor`.
|
||||
fn sum<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
|
||||
|
||||
/// Sum of all elements in a tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to sum.
|
||||
/// * `dim` - The dimension along which to sum.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the sum of all elements in `tensor` along `dim`.
|
||||
fn sum_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Mean of all elements in a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to mean.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A scalar tensor with the mean of all elements in `tensor`.
|
||||
fn mean<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
|
||||
|
||||
/// Mean of all elements in a tensor along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to mean.
|
||||
/// * `dim` - The dimension along which to mean.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the mean of all elements in `tensor` along `dim`.
|
||||
fn mean_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize)
|
||||
-> B::TensorPrimitive<D>;
|
||||
|
||||
/// Converts a tensor to full precision.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to convert.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same values as `tensor` but with full precision.
|
||||
fn to_full_precision<const D: usize>(
|
||||
tensor: &B::TensorPrimitive<D>,
|
||||
) -> <B::FullPrecisionBackend as Backend>::TensorPrimitive<D>;
|
||||
|
||||
/// Converts a tensor from full precision.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to convert.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same values as `tensor` but with the precision of the backend.
|
||||
fn from_full_precision<const D: usize>(
|
||||
tensor: <B::FullPrecisionBackend as Backend>::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Returns a new tensor with exponential values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to exponentiate.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` with exponential values.
|
||||
fn exp<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Returns a new tensor with natural logarithm values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to take the logarithm of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` with natural logarithm values.
|
||||
fn log<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Returns a new tensor with logarithm values of (1 + Xi).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to take the logarithm of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` with logarithm values of (1 + Xi).
|
||||
fn log1p<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Returns a new tensor with values raised to the power of `value`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to exponentiate.
|
||||
/// * `value` - The exponent.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` with values raised to the power of `value`.
|
||||
fn powf<const D: usize>(tensor: B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Returns a new tensor with square root values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to take the square root of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` with square root values.
|
||||
fn sqrt<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Returns a new tensor with cosine values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to take the cosine of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` with cosine values.
|
||||
fn cos<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Returns a new tensor with sine values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to take the sine of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` with sine values.
|
||||
fn sin<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Returns a new tensor with tangent values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to take the tangent of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` with tangent values.
|
||||
fn tanh<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Returns a new tensor with the error function values.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to take the error function of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the same shape as `tensor` with error function values.
|
||||
fn erf<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Catcatenates tensors along a dimension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensors` - The tensors to catcatenate.
|
||||
/// * `dim` - The dimension along which to catcatenate.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the catcatenated tensors along `dim`.
|
||||
fn cat<const D: usize>(
|
||||
tensors: Vec<B::TensorPrimitive<D>>,
|
||||
dim: usize,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
|
||||
/// Gets the indexes of the maximum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum elements of.
|
||||
/// * `dim` - The dimension along which to get the maximum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the indexes of the maximum elements of `tensor` along `dim`.
|
||||
fn argmax<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the indexes of the minimum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum elements of.
|
||||
/// * `dim` - The dimension along which to get the minimum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the indexes of the minimum elements of `tensor` along `dim`.
|
||||
fn argmin<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
|
||||
/// Gets the maximum element of a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum elements of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the maximum element of `tensor`.
|
||||
fn max<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
|
||||
let shape = B::shape(&tensor);
|
||||
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
|
||||
|
||||
B::max_dim(tensor, 0)
|
||||
}
|
||||
|
||||
/// Gets the maximum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum elements of.
|
||||
/// * `dim` - The dimension along which to get the maximum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the maximum elements of `tensor` along `dim`.
|
||||
fn max_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
|
||||
let index = B::argmax(tensor.clone(), dim);
|
||||
|
||||
B::gather(D - 1, tensor, index)
|
||||
}
|
||||
|
||||
/// Gets the maximum elements of a tensor along an axis and their indexes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the maximum elements of.
|
||||
/// * `dim` - The dimension along which to get the maximum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tuple with the maximum elements of `tensor` along `dim` and their indexes.
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
@ -268,17 +953,49 @@ pub trait TensorOps<B: Backend> {
|
|||
|
||||
(values, index)
|
||||
}
|
||||
|
||||
/// Gets the minimum element of a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum elements of.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the minimum element of `tensor`.
|
||||
fn min<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
|
||||
let shape = B::shape(&tensor);
|
||||
let tensor = B::reshape(tensor, Shape::new([shape.num_elements()]));
|
||||
|
||||
B::min_dim(tensor, 0)
|
||||
}
|
||||
|
||||
/// Gets the minimum elements of a tensor along an axis.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum elements of.
|
||||
/// * `dim` - The dimension along which to get the minimum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tensor with the minimum elements of `tensor` along `dim`.
|
||||
fn min_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
|
||||
let index = B::argmin(tensor.clone(), dim);
|
||||
|
||||
B::gather(D - 1, tensor, index)
|
||||
}
|
||||
|
||||
/// Gets the minimum elements of a tensor along an axis and their indexes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to get the minimum elements of.
|
||||
/// * `dim` - The dimension along which to get the minimum elements.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A tuple with the minimum elements of `tensor` along `dim` and their indexes.
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
use alloc::vec::Vec;
|
||||
|
||||
/// Shape of a tensor.
|
||||
#[derive(new, Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Shape<const D: usize> {
|
||||
/// The dimensions of the tensor.
|
||||
pub dims: [usize; D],
|
||||
}
|
||||
|
||||
|
|
|
@ -3,6 +3,7 @@ mod module;
|
|||
mod ops;
|
||||
mod stats;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! testgen_all {
|
||||
() => {
|
||||
|
|
|
@ -12,6 +12,7 @@ enum Message<T, V> {
|
|||
End,
|
||||
}
|
||||
|
||||
/// Async trainer callback tracker.
|
||||
pub struct AsyncTrainerCallback<T, V> {
|
||||
sender: mpsc::Sender<Message<T, V>>,
|
||||
handler: Option<JoinHandle<()>>,
|
||||
|
@ -52,6 +53,7 @@ impl<T, V> CallbackThread<T, V> {
|
|||
}
|
||||
|
||||
impl<T: Send + Sync + 'static, V: Send + Sync + 'static> AsyncTrainerCallback<T, V> {
|
||||
/// Create a new async trainer callback.
|
||||
pub fn new(callback: Box<dyn LearnerCallback<T, V>>) -> Self {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let thread = CallbackThread::new(Mutex::new(callback), receiver);
|
||||
|
|
|
@ -1,18 +1,38 @@
|
|||
use burn_core::{data::dataloader::Progress, LearningRate};
|
||||
|
||||
/// The base trait for trainer callbacks.
|
||||
pub trait LearnerCallback<T, V>: Send {
|
||||
/// Called when a training item is logged.
|
||||
fn on_train_item(&mut self, _item: LearnerItem<T>) {}
|
||||
|
||||
/// Called when a validation item is logged.
|
||||
fn on_valid_item(&mut self, _item: LearnerItem<V>) {}
|
||||
|
||||
/// Called when a training epoch is finished.
|
||||
fn on_train_end_epoch(&mut self, _epoch: usize) {}
|
||||
|
||||
/// Called when a validation epoch is finished.
|
||||
fn on_valid_end_epoch(&mut self, _epoch: usize) {}
|
||||
}
|
||||
|
||||
/// A learner item.
|
||||
#[derive(new)]
|
||||
pub struct LearnerItem<T> {
|
||||
/// The item.
|
||||
pub item: T,
|
||||
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
|
||||
/// The learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
}
|
||||
|
|
|
@ -26,6 +26,7 @@ impl<R: Record> CheckpointerThread<R> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Async checkpointer.
|
||||
pub struct AsyncCheckpointer<E> {
|
||||
checkpointer: Arc<dyn Checkpointer<E> + Send + Sync>,
|
||||
sender: mpsc::SyncSender<Message<E>>,
|
||||
|
@ -33,6 +34,15 @@ pub struct AsyncCheckpointer<E> {
|
|||
}
|
||||
|
||||
impl<R: Record + 'static> AsyncCheckpointer<R> {
|
||||
/// Create a new async checkpointer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `checkpointer` - The checkpointer.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The async checkpointer.
|
||||
pub fn new(checkpointer: Arc<dyn Checkpointer<R> + Send + Sync>) -> Self {
|
||||
// Only on checkpoint can be done in advance.
|
||||
let (sender, receiver) = mpsc::sync_channel(0);
|
||||
|
|
|
@ -1,13 +1,36 @@
|
|||
use burn_core::record::{Record, RecorderError};
|
||||
|
||||
/// The error type for checkpointer.
|
||||
#[derive(Debug)]
|
||||
pub enum CheckpointerError {
|
||||
/// IO error.
|
||||
IOError(std::io::Error),
|
||||
|
||||
/// Recorder error.
|
||||
RecorderError(RecorderError),
|
||||
|
||||
/// Other errors.
|
||||
Unknown(String),
|
||||
}
|
||||
|
||||
/// The trait for checkpointer.
|
||||
pub trait Checkpointer<R: Record> {
|
||||
/// Save the record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `epoch` - The epoch.
|
||||
/// * `record` - The record.
|
||||
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
|
||||
|
||||
/// Restore the record.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `epoch` - The epoch.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The record.
|
||||
fn restore(&self, epoch: usize) -> Result<R, CheckpointerError>;
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use super::{Checkpointer, CheckpointerError};
|
||||
use burn_core::record::{FileRecorder, Record};
|
||||
|
||||
/// The file checkpointer.
|
||||
pub struct FileCheckpointer<FR> {
|
||||
directory: String,
|
||||
name: String,
|
||||
|
@ -9,6 +10,14 @@ pub struct FileCheckpointer<FR> {
|
|||
}
|
||||
|
||||
impl<FR> FileCheckpointer<FR> {
|
||||
/// Creates a new file checkpointer.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `recorder` - The file recorder.
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
/// * `name` - The name of the checkpoint.
|
||||
/// * `num_keep` - The number of checkpoints to keep.
|
||||
pub fn new(recorder: FR, directory: &str, name: &str, num_keep: usize) -> Self {
|
||||
std::fs::create_dir_all(directory).ok();
|
||||
|
||||
|
|
|
@ -44,6 +44,11 @@ where
|
|||
Optim: Optimizer<Model, B>,
|
||||
LR: LRScheduler,
|
||||
{
|
||||
/// Creates a new learner builder.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory to save the checkpoints.
|
||||
pub fn new(directory: &str) -> Self {
|
||||
let renderer = Box::new(CLIDashboardRenderer::new());
|
||||
let logger_train = Box::new(FileMetricLogger::new(format!("{directory}/train").as_str()));
|
||||
|
|
|
@ -5,8 +5,13 @@ use burn_core::tensor::{Int, Tensor};
|
|||
/// Simple classification output adapted for multiple metrics.
|
||||
#[derive(new)]
|
||||
pub struct ClassificationOutput<B: Backend> {
|
||||
/// The loss.
|
||||
pub loss: Tensor<B, 1>,
|
||||
|
||||
/// The output.
|
||||
pub output: Tensor<B, 2>,
|
||||
|
||||
/// The targets.
|
||||
pub targets: Tensor<B, 1, Int>,
|
||||
}
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ use std::sync::Arc;
|
|||
|
||||
use crate::{LearnerCallback, LearnerItem, MultiDevicesTrainStep, TrainStep, ValidStep};
|
||||
|
||||
/// A validation epoch.
|
||||
#[derive(new)]
|
||||
pub struct ValidEpoch<VI> {
|
||||
dataloader: Arc<dyn DataLoader<VI>>,
|
||||
|
@ -16,6 +17,7 @@ pub struct ValidEpoch<VI> {
|
|||
epoch_total: usize,
|
||||
}
|
||||
|
||||
/// A training epoch.
|
||||
#[derive(new)]
|
||||
pub struct TrainEpoch<TI> {
|
||||
dataloader: Arc<dyn DataLoader<TI>>,
|
||||
|
@ -25,6 +27,12 @@ pub struct TrainEpoch<TI> {
|
|||
}
|
||||
|
||||
impl<I> ValidEpoch<I> {
|
||||
/// Runs the validation epoch.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model to validate.
|
||||
/// * `callback` - The callback to use.
|
||||
pub fn run<B, M, TO, VO>(&self, model: &M, callback: &mut Box<dyn LearnerCallback<TO, VO>>)
|
||||
where
|
||||
B: ADBackend,
|
||||
|
@ -58,6 +66,18 @@ impl<I> ValidEpoch<I> {
|
|||
}
|
||||
|
||||
impl<TI> TrainEpoch<TI> {
|
||||
/// Runs the training epoch.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model to train.
|
||||
/// * `optim` - The optimizer to use.
|
||||
/// * `scheduler` - The learning rate scheduler to use.
|
||||
/// * `callback` - The callback to use.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The trained model and the optimizer.
|
||||
pub fn run<B, M, O, LR, TO, VO>(
|
||||
&self,
|
||||
mut model: M,
|
||||
|
@ -119,6 +139,19 @@ impl<TI> TrainEpoch<TI> {
|
|||
}
|
||||
|
||||
impl<TI> TrainEpoch<TI> {
|
||||
/// Runs the training epoch on multiple devices.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model to train.
|
||||
/// * `optim` - The optimizer to use.
|
||||
/// * `lr_scheduler` - The learning rate scheduler to use.
|
||||
/// * `callback` - The callback to use.
|
||||
/// * `devices` - The devices to use.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The trained model and the optimizer.
|
||||
pub fn run_multi_device<B, M, O, S, TO, VO>(
|
||||
&self,
|
||||
mut model: M,
|
||||
|
|
|
@ -5,8 +5,13 @@ use burn_core::tensor::Tensor;
|
|||
/// Simple regression output adapted for multiple metrics.
|
||||
#[derive(new)]
|
||||
pub struct RegressionOutput<B: Backend> {
|
||||
/// The loss.
|
||||
pub loss: Tensor<B, 1>,
|
||||
|
||||
/// The output.
|
||||
pub output: Tensor<B, 2>,
|
||||
|
||||
/// The targets.
|
||||
pub targets: Tensor<B, 2>,
|
||||
}
|
||||
|
||||
|
|
|
@ -1 +1,2 @@
|
|||
/// The trainer module.
|
||||
pub mod train;
|
||||
|
|
|
@ -5,6 +5,7 @@ use burn_core::{
|
|||
use std::sync::mpsc::{Receiver, Sender};
|
||||
use std::thread::spawn;
|
||||
|
||||
/// Multi devices train step.
|
||||
pub struct MultiDevicesTrainStep<B: ADBackend, M, TI, TO> {
|
||||
workers: Vec<Worker<B, M, TI>>,
|
||||
receiver: Receiver<TrainOutput<TO>>,
|
||||
|
@ -68,6 +69,15 @@ where
|
|||
TI: Send + 'static,
|
||||
TO: Send + 'static,
|
||||
{
|
||||
/// Create a new multi devices train step.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `devices` - Devices.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// MultiDevicesTrainStep instance.
|
||||
pub fn new(devices: &[B::Device]) -> Self
|
||||
where
|
||||
TI: Send + 'static,
|
||||
|
@ -93,6 +103,16 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Collect outputs from workers for one step.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataloader` - Dataloader.
|
||||
/// * `model` - Model.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Outputs.
|
||||
pub fn step<'a>(
|
||||
&self,
|
||||
dataloader: &mut Box<dyn DataLoaderIterator<TI> + 'a>,
|
||||
|
|
|
@ -8,23 +8,58 @@ use burn_core::optim::{GradientsParams, Optimizer};
|
|||
use burn_core::tensor::backend::ADBackend;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// A training output.
|
||||
pub struct TrainOutput<TO> {
|
||||
/// The gradients.
|
||||
pub grads: GradientsParams,
|
||||
|
||||
/// The item.
|
||||
pub item: TO,
|
||||
}
|
||||
|
||||
impl<TO> TrainOutput<TO> {
|
||||
/// Creates a new training output.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `module` - The module.
|
||||
/// * `grads` - The gradients.
|
||||
/// * `item` - The item.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new training output.
|
||||
pub fn new<B: ADBackend, M: ADModule<B>>(module: &M, grads: B::Gradients, item: TO) -> Self {
|
||||
let grads = GradientsParams::from_grads(grads, module);
|
||||
Self { grads, item }
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait for a training step.
|
||||
pub trait TrainStep<TI, TO> {
|
||||
/// Runs a training step.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The item to train on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The training output.
|
||||
fn step(&self, item: TI) -> TrainOutput<TO>;
|
||||
}
|
||||
|
||||
/// Trait for a validation step.
|
||||
pub trait ValidStep<VI, VO> {
|
||||
/// Runs a validation step.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The item to validate on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The validation output.
|
||||
fn step(&self, item: VI) -> VO;
|
||||
}
|
||||
|
||||
|
@ -37,6 +72,16 @@ where
|
|||
O: Optimizer<M, B>,
|
||||
LR: LRScheduler,
|
||||
{
|
||||
/// Fits the model.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dataloader_train` - The training dataloader.
|
||||
/// * `dataloader_valid` - The validation dataloader.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The fitted model.
|
||||
pub fn fit<TI, VI>(
|
||||
mut self,
|
||||
dataloader_train: Arc<dyn DataLoader<TI>>,
|
||||
|
|
|
@ -1,8 +1,17 @@
|
|||
#![warn(missing_docs)]
|
||||
|
||||
//! A library for training neural networks using the burn crate.
|
||||
|
||||
#[macro_use]
|
||||
extern crate derive_new;
|
||||
|
||||
/// The checkpoint module.
|
||||
pub mod checkpoint;
|
||||
|
||||
/// The logger module.
|
||||
pub mod logger;
|
||||
|
||||
/// The metric module.
|
||||
pub mod metric;
|
||||
|
||||
mod callback;
|
||||
|
|
|
@ -5,7 +5,7 @@ enum Message<T> {
|
|||
Log(T),
|
||||
End,
|
||||
}
|
||||
|
||||
/// Async logger.
|
||||
pub struct AsyncLogger<T> {
|
||||
sender: mpsc::Sender<Message<T>>,
|
||||
handler: Option<std::thread::JoinHandle<()>>,
|
||||
|
@ -34,6 +34,7 @@ impl<T> LoggerThread<T> {
|
|||
}
|
||||
|
||||
impl<T: Send + Sync + 'static> AsyncLogger<T> {
|
||||
/// Create a new async logger.
|
||||
pub fn new(logger: Box<dyn Logger<T>>) -> Self {
|
||||
let (sender, receiver) = mpsc::channel();
|
||||
let thread = LoggerThread::new(Mutex::new(logger), receiver);
|
||||
|
|
|
@ -1,9 +1,26 @@
|
|||
/// The logger trait.
|
||||
pub trait Logger<T>: Send {
|
||||
/// Logs an item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The item.
|
||||
fn log(&mut self, item: T);
|
||||
}
|
||||
|
||||
/// The logger backend trait.
|
||||
pub trait LoggerBackend {
|
||||
/// The logger type.
|
||||
type Logger<T>: Logger<T>;
|
||||
|
||||
/// Create a new logger.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `epoch` - The epoch.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The logger.
|
||||
fn create<T>(&self, epoch: usize) -> Self::Logger<T>;
|
||||
}
|
||||
|
|
|
@ -1,11 +1,21 @@
|
|||
use super::Logger;
|
||||
use std::{fs::File, io::Write};
|
||||
|
||||
/// File logger.
|
||||
pub struct FileLogger {
|
||||
file: File,
|
||||
}
|
||||
|
||||
impl FileLogger {
|
||||
/// Create a new file logger.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - The path.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The file logger.
|
||||
pub fn new(path: &str) -> Self {
|
||||
let mut options = std::fs::File::options();
|
||||
let file = options
|
||||
|
|
|
@ -2,11 +2,24 @@ use super::{AsyncLogger, FileLogger, Logger};
|
|||
use crate::metric::MetricEntry;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Metric logger.
|
||||
pub trait MetricLogger: Send {
|
||||
/// Logs an item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The item.
|
||||
fn log(&mut self, item: &MetricEntry);
|
||||
|
||||
/// Logs an epoch.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `epoch` - The epoch.
|
||||
fn epoch(&mut self, epoch: usize);
|
||||
}
|
||||
|
||||
/// The file metric logger.
|
||||
pub struct FileMetricLogger {
|
||||
loggers: HashMap<String, Box<dyn Logger<String>>>,
|
||||
directory: String,
|
||||
|
@ -14,6 +27,15 @@ pub struct FileMetricLogger {
|
|||
}
|
||||
|
||||
impl FileMetricLogger {
|
||||
/// Create a new file metric logger.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `directory` - The directory.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The file metric logger.
|
||||
pub fn new(directory: &str) -> Self {
|
||||
Self {
|
||||
loggers: HashMap::new(),
|
||||
|
|
|
@ -20,11 +20,12 @@ pub struct AccuracyInput<B: Backend> {
|
|||
}
|
||||
|
||||
impl<B: Backend> AccuracyMetric<B> {
|
||||
/// Create the metric.
|
||||
/// Creates the metric.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Sets the pad token.
|
||||
pub fn with_pad_token(mut self, index: usize) -> Self {
|
||||
self.pad_token = Some(index);
|
||||
self
|
||||
|
|
|
@ -2,10 +2,19 @@ use burn_core::{data::dataloader::Progress, LearningRate};
|
|||
|
||||
/// Metric metadata that can be used when computing metrics.
|
||||
pub struct MetricMetadata {
|
||||
/// The current progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The current epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The current iteration.
|
||||
pub iteration: usize,
|
||||
|
||||
/// The current learning rate.
|
||||
pub lr: Option<LearningRate>,
|
||||
}
|
||||
|
||||
|
@ -33,6 +42,7 @@ impl MetricMetadata {
|
|||
/// This is important since some conflict may happen when the model output is adapted for each
|
||||
/// metric's input type.
|
||||
pub trait Metric: Send + Sync {
|
||||
/// The input type of the metric.
|
||||
type Input;
|
||||
|
||||
/// Update the metric state and returns the current metric entry.
|
||||
|
@ -54,6 +64,7 @@ pub trait Adaptor<T> {
|
|||
///
|
||||
/// This is usefull to plot the values of a metric during training.
|
||||
pub trait Numeric {
|
||||
/// Returns the numeric value of the metric.
|
||||
fn value(&self) -> f64;
|
||||
}
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@ pub struct CUDAMetric {
|
|||
}
|
||||
|
||||
impl CUDAMetric {
|
||||
/// Creates a new metric for CUDA.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
nvml: Nvml::init().map(Some).unwrap_or_else(|err| {
|
||||
|
|
|
@ -5,14 +5,23 @@ use crate::{
|
|||
};
|
||||
use burn_core::data::dataloader::Progress;
|
||||
|
||||
/// Training progress.
|
||||
pub struct TrainingProgress {
|
||||
/// The progress.
|
||||
pub progress: Progress,
|
||||
|
||||
/// The epoch.
|
||||
pub epoch: usize,
|
||||
|
||||
/// The total number of epochs.
|
||||
pub epoch_total: usize,
|
||||
|
||||
/// The iteration.
|
||||
pub iteration: usize,
|
||||
}
|
||||
|
||||
impl TrainingProgress {
|
||||
/// Creates a new empy training progress.
|
||||
pub fn none() -> Self {
|
||||
Self {
|
||||
progress: Progress {
|
||||
|
@ -26,18 +35,47 @@ impl TrainingProgress {
|
|||
}
|
||||
}
|
||||
|
||||
/// A dashboard metric.
|
||||
pub enum DashboardMetricState {
|
||||
/// A generic metric.
|
||||
Generic(MetricEntry),
|
||||
|
||||
/// A numeric metric.
|
||||
Numeric(MetricEntry, f64),
|
||||
}
|
||||
|
||||
/// Trait for rendering dashboard metrics.
|
||||
pub trait DashboardRenderer: Send + Sync {
|
||||
/// Updates the training metric state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The metric state.
|
||||
fn update_train(&mut self, state: DashboardMetricState);
|
||||
|
||||
/// Updates the validation metric state.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `state` - The metric state.
|
||||
fn update_valid(&mut self, state: DashboardMetricState);
|
||||
|
||||
/// Renders the training progress.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The training progress.
|
||||
fn render_train(&mut self, item: TrainingProgress);
|
||||
|
||||
/// Renders the validation progress.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The validation progress.
|
||||
fn render_valid(&mut self, item: TrainingProgress);
|
||||
}
|
||||
|
||||
/// A dashboard container for all metrics.
|
||||
pub struct Dashboard<T, V>
|
||||
where
|
||||
T: Send + Sync + 'static,
|
||||
|
@ -57,6 +95,17 @@ where
|
|||
T: Send + Sync + 'static,
|
||||
V: Send + Sync + 'static,
|
||||
{
|
||||
/// Creates a new dashboard.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `renderer` - The dashboard renderer.
|
||||
/// * `logger_train` - The training logger.
|
||||
/// * `logger_valid` - The validation logger.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new dashboard.
|
||||
pub fn new(
|
||||
renderer: Box<dyn DashboardRenderer>,
|
||||
logger_train: Box<dyn MetricLogger>,
|
||||
|
@ -73,6 +122,11 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
/// Registers a training metric.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric` - The metric.
|
||||
pub fn register_train<M: Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
T: Adaptor<M::Input>,
|
||||
|
@ -81,6 +135,11 @@ where
|
|||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
|
||||
/// Registers a training numeric metric.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric` - The metric.
|
||||
pub fn register_train_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
T: Adaptor<M::Input>,
|
||||
|
@ -88,6 +147,12 @@ where
|
|||
self.metrics_train_numeric
|
||||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
|
||||
/// Registers a validation metric.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric` - The metric.
|
||||
pub fn register_valid<M: Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
|
@ -96,6 +161,11 @@ where
|
|||
.push(Box::new(MetricWrapper::new(metric)));
|
||||
}
|
||||
|
||||
/// Registers a validation numeric metric.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `metric` - The metric.
|
||||
pub fn register_valid_plot<M: Numeric + Metric + 'static>(&mut self, metric: M)
|
||||
where
|
||||
V: Adaptor<M::Input>,
|
||||
|
|
|
@ -4,6 +4,7 @@ use std::{collections::HashMap, fmt::Write};
|
|||
|
||||
static MAX_REFRESH_RATE_MILLIS: u128 = 250;
|
||||
|
||||
/// The CLI dashboard renderer.
|
||||
pub struct CLIDashboardRenderer {
|
||||
pb_epoch: ProgressBar,
|
||||
pb_iteration: ProgressBar,
|
||||
|
@ -94,6 +95,7 @@ impl DashboardRenderer for CLIDashboardRenderer {
|
|||
}
|
||||
|
||||
impl CLIDashboardRenderer {
|
||||
/// Create a new CLI dashboard renderer.
|
||||
pub fn new() -> Self {
|
||||
let pb = MultiProgress::new();
|
||||
let pb_epoch = ProgressBar::new(0);
|
||||
|
@ -238,6 +240,7 @@ impl CLIDashboardRenderer {
|
|||
self.last_update = std::time::Instant::now();
|
||||
}
|
||||
|
||||
/// Registers a new metric to be displayed.
|
||||
pub fn register_key_item(
|
||||
&self,
|
||||
key: &'static str,
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
/// Command line interface module for the dashboard.
|
||||
pub mod cli;
|
||||
|
||||
mod base;
|
||||
|
|
|
@ -2,6 +2,7 @@ use rgb::RGB8;
|
|||
use terminal_size::{terminal_size, Height, Width};
|
||||
use textplots::{Chart, ColorPlot, Shape};
|
||||
|
||||
/// Text plot.
|
||||
pub struct TextPlot {
|
||||
train: Vec<(f32, f32)>,
|
||||
valid: Vec<(f32, f32)>,
|
||||
|
@ -16,6 +17,7 @@ impl Default for TextPlot {
|
|||
}
|
||||
|
||||
impl TextPlot {
|
||||
/// Creates a new text plot.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
train: Vec::new(),
|
||||
|
@ -25,6 +27,16 @@ impl TextPlot {
|
|||
}
|
||||
}
|
||||
|
||||
/// Merges two text plots.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `self` - The first text plot.
|
||||
/// * `other` - The second text plot.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The merged text plot.
|
||||
pub fn merge(self, other: Self) -> Self {
|
||||
let mut other = other;
|
||||
let mut train = self.train;
|
||||
|
@ -41,6 +53,11 @@ impl TextPlot {
|
|||
}
|
||||
}
|
||||
|
||||
/// Updates the text plot with a new item for training.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The new item.
|
||||
pub fn update_train(&mut self, item: f32) {
|
||||
self.iteration += 1;
|
||||
self.train.push((self.iteration as f32, item));
|
||||
|
@ -61,6 +78,11 @@ impl TextPlot {
|
|||
}
|
||||
}
|
||||
|
||||
/// Updates the text plot with a new item for validation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `item` - The new item.
|
||||
pub fn update_valid(&mut self, item: f32) {
|
||||
self.iteration += 1;
|
||||
self.valid.push((self.iteration as f32, item));
|
||||
|
@ -81,6 +103,11 @@ impl TextPlot {
|
|||
}
|
||||
}
|
||||
|
||||
/// Renders the text plot.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The rendered text plot.
|
||||
pub fn render(&self) -> String {
|
||||
let train_color = RGB8::new(255, 140, 140);
|
||||
let valid_color = RGB8::new(140, 140, 255);
|
||||
|
|
|
@ -10,6 +10,7 @@ pub struct LearningRateMetric {
|
|||
}
|
||||
|
||||
impl LearningRateMetric {
|
||||
/// Creates a new learning rate metric.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
state: NumericMetricState::new(),
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
/// Dashboard module for training progress.
|
||||
pub mod dashboard;
|
||||
|
||||
/// State module for dashboard metrics.
|
||||
pub mod state;
|
||||
|
||||
mod acc;
|
||||
|
|
Loading…
Reference in New Issue