mirror of https://github.com/tracel-ai/burn.git
Experimental/named tensor (#113)
This commit is contained in:
parent
c4c739d91b
commit
e0e787f87d
|
@ -65,6 +65,11 @@ where
|
|||
}
|
||||
|
||||
pub fn matmul(self, other: BatchMatrix<E, D>) -> Self {
|
||||
let require_broadcast = self.arrays.len() != other.arrays.len();
|
||||
if require_broadcast {
|
||||
return self.matmul_broadcast(other);
|
||||
}
|
||||
|
||||
let self_iter = self.arrays.iter();
|
||||
let other_iter = other.arrays.iter();
|
||||
|
||||
|
@ -79,6 +84,37 @@ where
|
|||
|
||||
Self::new(arrays, shape)
|
||||
}
|
||||
|
||||
fn matmul_broadcast(self, other: BatchMatrix<E, D>) -> Self {
|
||||
let valid_broadcast = self.arrays.len() == 1 || other.arrays.len() == 1;
|
||||
if !valid_broadcast {
|
||||
panic!("Invalid broadcast => {:?} , {:?}", self.shape, other.shape);
|
||||
}
|
||||
let batch_size = usize::max(self.arrays.len(), other.arrays.len());
|
||||
let mut arrays = Vec::with_capacity(batch_size);
|
||||
|
||||
for batch in 0..batch_size {
|
||||
let self_tensor = if self.arrays.len() == 1 {
|
||||
&self.arrays[0]
|
||||
} else {
|
||||
&self.arrays[batch]
|
||||
};
|
||||
|
||||
let other_tensor = if other.arrays.len() == 1 {
|
||||
&other.arrays[0]
|
||||
} else {
|
||||
&other.arrays[batch]
|
||||
};
|
||||
|
||||
let tensor = self_tensor.dot(other_tensor);
|
||||
arrays.push(tensor.into_shared());
|
||||
}
|
||||
|
||||
let mut shape = self.shape;
|
||||
shape.dims[D - 1] = other.shape.dims[D - 1];
|
||||
|
||||
Self::new(arrays, shape)
|
||||
}
|
||||
}
|
||||
|
||||
fn batch_size<const D: usize>(shape: &Shape<D>) -> usize {
|
||||
|
|
|
@ -17,6 +17,7 @@ edition = "2021"
|
|||
[features]
|
||||
default = []
|
||||
export_tests = ["burn-tensor-testgen"]
|
||||
named_tensor = []
|
||||
|
||||
[dependencies]
|
||||
burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", optional = true }
|
||||
|
|
|
@ -17,3 +17,8 @@ pub mod activation;
|
|||
pub mod backend;
|
||||
pub mod loss;
|
||||
pub mod module;
|
||||
|
||||
#[cfg(feature = "named_tensor")]
|
||||
mod named;
|
||||
#[cfg(feature = "named_tensor")]
|
||||
pub use named::*;
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
use crate::backend::Backend;
|
||||
use crate::{Distribution, NamedDims, Shape, Tensor};
|
||||
|
||||
/// A tensor with named dimensions.
|
||||
pub struct NamedTensor<B: Backend, D: NamedDims<B>> {
|
||||
pub(crate) tensor: D::Tensor,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, ND: NamedDims<B>> std::fmt::Display for NamedTensor<B, ND>
|
||||
where
|
||||
ND: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
{
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(&format!(
|
||||
"NamedTensor[shape={:?}, dims={}]",
|
||||
self.shape().dims,
|
||||
ND::to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
|
||||
where
|
||||
ND: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
{
|
||||
/// Create a named tensor from a tensor.
|
||||
pub fn from_tensor(tensor: Tensor<B, D>) -> Self {
|
||||
Self { tensor }
|
||||
}
|
||||
|
||||
/// Create a random named tensor of the given shape where each element is sampled from
|
||||
/// the given distribution.
|
||||
pub fn random<S: Into<Shape<D>>>(shape: S, distribution: Distribution<B::Elem>) -> Self {
|
||||
Self::from_tensor(Tensor::random(shape, distribution))
|
||||
}
|
||||
|
||||
/// Returns the shape of the current tensor.
|
||||
pub fn shape(&self) -> &Shape<D> {
|
||||
self.tensor.shape()
|
||||
}
|
||||
|
||||
/// Applies element wise multiplication operation.
|
||||
///
|
||||
/// `y = x2 * x1`
|
||||
pub fn mul(&self, rhs: &Self) -> Self {
|
||||
Self::from_tensor(self.tensor.mul(&rhs.tensor))
|
||||
}
|
||||
|
||||
/// Reshape the tensor to have the given shape.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the tensor can not be reshape to the given shape.
|
||||
pub fn reshape<const D2: usize, S, ND2>(&self, shape: S, _: ND2) -> NamedTensor<B, ND2>
|
||||
where
|
||||
S: Into<Shape<D2>>,
|
||||
ND2: NamedDims<B, Tensor = Tensor<B, D2>>,
|
||||
{
|
||||
NamedTensor::from_tensor(self.tensor.reshape(shape))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,84 @@
|
|||
use crate::backend::Backend;
|
||||
use crate::Tensor;
|
||||
|
||||
pub trait Dim {
|
||||
fn to_string() -> String;
|
||||
}
|
||||
|
||||
pub trait NamedDims<B: Backend> {
|
||||
type Tensor;
|
||||
fn to_string() -> String;
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! NamedDim {
|
||||
($name:ident) => {
|
||||
pub struct $name;
|
||||
impl Dim for $name {
|
||||
fn to_string() -> String {
|
||||
stringify!($name).to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl<B: Backend, D1> NamedDims<B> for (D1,)
|
||||
where
|
||||
B: Backend,
|
||||
D1: Dim,
|
||||
{
|
||||
type Tensor = Tensor<B, 1>;
|
||||
fn to_string() -> String {
|
||||
format!("[{}]", D1::to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, D1, D2> NamedDims<B> for (D1, D2)
|
||||
where
|
||||
B: Backend,
|
||||
D1: Dim,
|
||||
D2: Dim,
|
||||
{
|
||||
type Tensor = Tensor<B, 2>;
|
||||
fn to_string() -> String {
|
||||
format!("[{}, {}]", D1::to_string(), D2::to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, D1, D2, D3> NamedDims<B> for (D1, D2, D3)
|
||||
where
|
||||
B: Backend,
|
||||
D1: Dim,
|
||||
D2: Dim,
|
||||
D3: Dim,
|
||||
{
|
||||
type Tensor = Tensor<B, 3>;
|
||||
fn to_string() -> String {
|
||||
format!(
|
||||
"[{}, {}, {}]",
|
||||
D1::to_string(),
|
||||
D2::to_string(),
|
||||
D3::to_string()
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, D1, D2, D3, D4> NamedDims<B> for (D1, D2, D3, D4)
|
||||
where
|
||||
B: Backend,
|
||||
D1: Dim,
|
||||
D2: Dim,
|
||||
D3: Dim,
|
||||
D4: Dim,
|
||||
{
|
||||
type Tensor = Tensor<B, 4>;
|
||||
fn to_string() -> String {
|
||||
format!(
|
||||
"[{}, {}, {}, {}]",
|
||||
D1::to_string(),
|
||||
D2::to_string(),
|
||||
D3::to_string(),
|
||||
D4::to_string()
|
||||
)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,59 @@
|
|||
use crate::backend::Backend;
|
||||
use crate::{Dim, NamedDims, NamedTensor, Tensor};
|
||||
|
||||
pub trait Matmul<Rhs, Out> {
|
||||
fn matmul(&self, rhs: &Rhs) -> Out;
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
|
||||
where
|
||||
ND: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
{
|
||||
/// Applies the matrix multiplication operation.
|
||||
///
|
||||
/// `C = AB`
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the two tensors dont' have a compatible shape.
|
||||
pub fn matmul<NamedDimsRhs, NamedDimsOut>(
|
||||
&self,
|
||||
rhs: &NamedTensor<B, NamedDimsRhs>,
|
||||
) -> NamedTensor<B, NamedDimsOut>
|
||||
where
|
||||
NamedDimsRhs: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
NamedDimsOut: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
Self: Matmul<NamedTensor<B, NamedDimsRhs>, NamedTensor<B, NamedDimsOut>>,
|
||||
{
|
||||
Matmul::matmul(self, rhs)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, X: Dim, Y: Dim, Z: Dim> Matmul<NamedTensor<B, (Y, Z)>, NamedTensor<B, (X, Z)>>
|
||||
for NamedTensor<B, (X, Y)>
|
||||
{
|
||||
fn matmul(&self, rhs: &NamedTensor<B, (Y, Z)>) -> NamedTensor<B, (X, Z)> {
|
||||
NamedTensor::from_tensor(self.tensor.matmul(&rhs.tensor))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, Batch: Dim, X: Dim, Y: Dim, Z: Dim>
|
||||
Matmul<NamedTensor<B, (Batch, Y, Z)>, NamedTensor<B, (Batch, X, Z)>>
|
||||
for NamedTensor<B, (Batch, X, Y)>
|
||||
{
|
||||
fn matmul(&self, rhs: &NamedTensor<B, (Batch, Y, Z)>) -> NamedTensor<B, (Batch, X, Z)> {
|
||||
NamedTensor::from_tensor(self.tensor.matmul(&rhs.tensor))
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, Batch1: Dim, Batch2: Dim, X: Dim, Y: Dim, Z: Dim>
|
||||
Matmul<NamedTensor<B, (Batch1, Batch2, Y, Z)>, NamedTensor<B, (Batch1, Batch2, X, Z)>>
|
||||
for NamedTensor<B, (Batch1, Batch2, X, Y)>
|
||||
{
|
||||
fn matmul(
|
||||
&self,
|
||||
rhs: &NamedTensor<B, (Batch1, Batch2, Y, Z)>,
|
||||
) -> NamedTensor<B, (Batch1, Batch2, X, Z)> {
|
||||
NamedTensor::from_tensor(self.tensor.matmul(&rhs.tensor))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
mod base;
|
||||
mod dims;
|
||||
mod matmul;
|
||||
mod permut;
|
||||
|
||||
pub use base::*;
|
||||
pub use dims::*;
|
|
@ -0,0 +1,62 @@
|
|||
use crate::backend::Backend;
|
||||
use crate::{Dim, NamedDims, NamedTensor, Tensor};
|
||||
|
||||
pub trait Permut<N, const D1: usize, const D2: usize> {
|
||||
fn permut(&self) -> N;
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
|
||||
where
|
||||
ND: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
{
|
||||
/// Permut two dimensions.
|
||||
pub fn permut<ND2, const D1: usize, const D2: usize>(&self) -> NamedTensor<B, ND2>
|
||||
where
|
||||
ND2: NamedDims<B, Tensor = Tensor<B, D>>,
|
||||
Self: Permut<NamedTensor<B, ND2>, D1, D2>,
|
||||
{
|
||||
Permut::permut(self)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! generate_permut {
|
||||
(2 => $output:ty, ($dim1:expr, $dim2:expr)) => {
|
||||
impl<B: Backend, D1: Dim, D2: Dim> Permut<NamedTensor<B, $output>, $dim1, $dim2>
|
||||
for NamedTensor<B, (D1, D2)>
|
||||
{
|
||||
fn permut(&self) -> NamedTensor<B, $output> {
|
||||
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(3 => $output:ty, ($dim1:expr, $dim2:expr)) => {
|
||||
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim> Permut<NamedTensor<B, $output>, $dim1, $dim2>
|
||||
for NamedTensor<B, (D1, D2, D3)>
|
||||
{
|
||||
fn permut(&self) -> NamedTensor<B, $output> {
|
||||
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(4 => $output:ty, ($dim1:expr, $dim2:expr)) => {
|
||||
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim, D4: Dim>
|
||||
Permut<NamedTensor<B, $output>, $dim1, $dim2> for NamedTensor<B, (D1, D2, D3, D4)>
|
||||
{
|
||||
fn permut(&self) -> NamedTensor<B, $output> {
|
||||
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
generate_permut!(2 => (D2, D1), (0, 1));
|
||||
generate_permut!(3 => (D2, D1, D3), (0, 1));
|
||||
generate_permut!(3 => (D3, D2, D1), (0, 2));
|
||||
generate_permut!(3 => (D1, D3, D2), (1, 2));
|
||||
generate_permut!(4 => (D2, D1, D3, D4), (0, 1));
|
||||
generate_permut!(4 => (D3, D2, D1, D4), (0, 2));
|
||||
generate_permut!(4 => (D4, D2, D3, D1), (0, 3));
|
||||
generate_permut!(4 => (D1, D3, D2, D4), (1, 2));
|
||||
generate_permut!(4 => (D1, D4, D3, D2), (1, 3));
|
|
@ -11,6 +11,7 @@ license = "MIT/Apache-2.0"
|
|||
edition = "2021"
|
||||
|
||||
[features]
|
||||
named_tensor = ["burn-tensor/named_tensor"]
|
||||
|
||||
[dependencies]
|
||||
burn-tensor = { version = "0.3.0", path = "../burn-tensor" }
|
||||
|
|
|
@ -0,0 +1,15 @@
|
|||
[package]
|
||||
name = "named-tensor"
|
||||
version = "0.1.0"
|
||||
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
|
||||
license = "MIT/Apache-2.0"
|
||||
edition = "2021"
|
||||
publish = false
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../burn", features = ["named_tensor"] }
|
||||
burn-autodiff = { path = "../../burn-autodiff" }
|
||||
burn-ndarray = { path = "../../burn-ndarray" }
|
||||
|
||||
# Serialization
|
||||
serde = { version = "1.0", features = ["derive"] }
|
|
@ -0,0 +1,5 @@
|
|||
use named_tensor;
|
||||
|
||||
fn main() {
|
||||
named_tensor::run::<burn_ndarray::NdArrayBackend<f32>>();
|
||||
}
|
|
@ -0,0 +1,47 @@
|
|||
use burn::tensor::backend::Backend;
|
||||
use burn::tensor::{Dim, Distribution, NamedDim, NamedTensor};
|
||||
|
||||
NamedDim!(Batch);
|
||||
NamedDim!(SeqLenght);
|
||||
NamedDim!(DModel);
|
||||
|
||||
pub fn run<B: Backend>() {
|
||||
let batch_size = 32;
|
||||
let seq_length = 48;
|
||||
let d_model = 24;
|
||||
|
||||
let weights = NamedTensor::<B, (Batch, DModel, DModel)>::random(
|
||||
[1, d_model, d_model],
|
||||
Distribution::Standard,
|
||||
);
|
||||
|
||||
let input = NamedTensor::<B, (Batch, SeqLenght, DModel)>::random(
|
||||
[batch_size, seq_length, d_model],
|
||||
Distribution::Standard,
|
||||
);
|
||||
|
||||
// Doesn't compile
|
||||
//
|
||||
// mismatched types
|
||||
// expected reference `&NamedTensor<B, (Batch, DModel, _)>`
|
||||
// found reference `&NamedTensor<B, (Batch, SeqLenght, DModel)>`
|
||||
// let output = weights.matmul(&input);
|
||||
|
||||
let output = input.matmul(&weights);
|
||||
|
||||
// Doesn't compile
|
||||
//
|
||||
// mismatched types
|
||||
// expected reference `&NamedTensor<B, (Batch, SeqLenght, DModel)>`
|
||||
// found reference `&NamedTensor<B, (Batch, DModel, DModel)>`
|
||||
// let output = output.mul(&weights);
|
||||
|
||||
let output = output.mul(&input);
|
||||
|
||||
let permut = output.permut::<_, 1, 2>();
|
||||
|
||||
println!("Weights => {}", weights);
|
||||
println!("Input => {}", input);
|
||||
println!("Output => {}", output);
|
||||
println!("Permut => {}", permut);
|
||||
}
|
Loading…
Reference in New Issue