Experimental/named tensor (#113)

This commit is contained in:
Nathaniel Simard 2022-11-23 19:05:46 -05:00 committed by GitHub
parent c4c739d91b
commit e0e787f87d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 383 additions and 0 deletions

View File

@ -65,6 +65,11 @@ where
}
pub fn matmul(self, other: BatchMatrix<E, D>) -> Self {
let require_broadcast = self.arrays.len() != other.arrays.len();
if require_broadcast {
return self.matmul_broadcast(other);
}
let self_iter = self.arrays.iter();
let other_iter = other.arrays.iter();
@ -79,6 +84,37 @@ where
Self::new(arrays, shape)
}
fn matmul_broadcast(self, other: BatchMatrix<E, D>) -> Self {
let valid_broadcast = self.arrays.len() == 1 || other.arrays.len() == 1;
if !valid_broadcast {
panic!("Invalid broadcast => {:?} , {:?}", self.shape, other.shape);
}
let batch_size = usize::max(self.arrays.len(), other.arrays.len());
let mut arrays = Vec::with_capacity(batch_size);
for batch in 0..batch_size {
let self_tensor = if self.arrays.len() == 1 {
&self.arrays[0]
} else {
&self.arrays[batch]
};
let other_tensor = if other.arrays.len() == 1 {
&other.arrays[0]
} else {
&other.arrays[batch]
};
let tensor = self_tensor.dot(other_tensor);
arrays.push(tensor.into_shared());
}
let mut shape = self.shape;
shape.dims[D - 1] = other.shape.dims[D - 1];
Self::new(arrays, shape)
}
}
fn batch_size<const D: usize>(shape: &Shape<D>) -> usize {

View File

@ -17,6 +17,7 @@ edition = "2021"
[features]
default = []
export_tests = ["burn-tensor-testgen"]
named_tensor = []
[dependencies]
burn-tensor-testgen = { version = "0.3.0", path = "../burn-tensor-testgen", optional = true }

View File

@ -17,3 +17,8 @@ pub mod activation;
pub mod backend;
pub mod loss;
pub mod module;
#[cfg(feature = "named_tensor")]
mod named;
#[cfg(feature = "named_tensor")]
pub use named::*;

View File

@ -0,0 +1,61 @@
use crate::backend::Backend;
use crate::{Distribution, NamedDims, Shape, Tensor};
/// A tensor with named dimensions.
pub struct NamedTensor<B: Backend, D: NamedDims<B>> {
pub(crate) tensor: D::Tensor,
}
impl<B: Backend, const D: usize, ND: NamedDims<B>> std::fmt::Display for NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(&format!(
"NamedTensor[shape={:?}, dims={}]",
self.shape().dims,
ND::to_string(),
))
}
}
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
/// Create a named tensor from a tensor.
pub fn from_tensor(tensor: Tensor<B, D>) -> Self {
Self { tensor }
}
/// Create a random named tensor of the given shape where each element is sampled from
/// the given distribution.
pub fn random<S: Into<Shape<D>>>(shape: S, distribution: Distribution<B::Elem>) -> Self {
Self::from_tensor(Tensor::random(shape, distribution))
}
/// Returns the shape of the current tensor.
pub fn shape(&self) -> &Shape<D> {
self.tensor.shape()
}
/// Applies element wise multiplication operation.
///
/// `y = x2 * x1`
pub fn mul(&self, rhs: &Self) -> Self {
Self::from_tensor(self.tensor.mul(&rhs.tensor))
}
/// Reshape the tensor to have the given shape.
///
/// # Panics
///
/// If the tensor can not be reshape to the given shape.
pub fn reshape<const D2: usize, S, ND2>(&self, shape: S, _: ND2) -> NamedTensor<B, ND2>
where
S: Into<Shape<D2>>,
ND2: NamedDims<B, Tensor = Tensor<B, D2>>,
{
NamedTensor::from_tensor(self.tensor.reshape(shape))
}
}

View File

@ -0,0 +1,84 @@
use crate::backend::Backend;
use crate::Tensor;
pub trait Dim {
fn to_string() -> String;
}
pub trait NamedDims<B: Backend> {
type Tensor;
fn to_string() -> String;
}
#[macro_export]
macro_rules! NamedDim {
($name:ident) => {
pub struct $name;
impl Dim for $name {
fn to_string() -> String {
stringify!($name).to_string()
}
}
};
}
impl<B: Backend, D1> NamedDims<B> for (D1,)
where
B: Backend,
D1: Dim,
{
type Tensor = Tensor<B, 1>;
fn to_string() -> String {
format!("[{}]", D1::to_string())
}
}
impl<B: Backend, D1, D2> NamedDims<B> for (D1, D2)
where
B: Backend,
D1: Dim,
D2: Dim,
{
type Tensor = Tensor<B, 2>;
fn to_string() -> String {
format!("[{}, {}]", D1::to_string(), D2::to_string())
}
}
impl<B: Backend, D1, D2, D3> NamedDims<B> for (D1, D2, D3)
where
B: Backend,
D1: Dim,
D2: Dim,
D3: Dim,
{
type Tensor = Tensor<B, 3>;
fn to_string() -> String {
format!(
"[{}, {}, {}]",
D1::to_string(),
D2::to_string(),
D3::to_string()
)
}
}
impl<B: Backend, D1, D2, D3, D4> NamedDims<B> for (D1, D2, D3, D4)
where
B: Backend,
D1: Dim,
D2: Dim,
D3: Dim,
D4: Dim,
{
type Tensor = Tensor<B, 4>;
fn to_string() -> String {
format!(
"[{}, {}, {}, {}]",
D1::to_string(),
D2::to_string(),
D3::to_string(),
D4::to_string()
)
}
}

View File

@ -0,0 +1,59 @@
use crate::backend::Backend;
use crate::{Dim, NamedDims, NamedTensor, Tensor};
pub trait Matmul<Rhs, Out> {
fn matmul(&self, rhs: &Rhs) -> Out;
}
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
/// Applies the matrix multiplication operation.
///
/// `C = AB`
///
/// # Panics
///
/// If the two tensors dont' have a compatible shape.
pub fn matmul<NamedDimsRhs, NamedDimsOut>(
&self,
rhs: &NamedTensor<B, NamedDimsRhs>,
) -> NamedTensor<B, NamedDimsOut>
where
NamedDimsRhs: NamedDims<B, Tensor = Tensor<B, D>>,
NamedDimsOut: NamedDims<B, Tensor = Tensor<B, D>>,
Self: Matmul<NamedTensor<B, NamedDimsRhs>, NamedTensor<B, NamedDimsOut>>,
{
Matmul::matmul(self, rhs)
}
}
impl<B: Backend, X: Dim, Y: Dim, Z: Dim> Matmul<NamedTensor<B, (Y, Z)>, NamedTensor<B, (X, Z)>>
for NamedTensor<B, (X, Y)>
{
fn matmul(&self, rhs: &NamedTensor<B, (Y, Z)>) -> NamedTensor<B, (X, Z)> {
NamedTensor::from_tensor(self.tensor.matmul(&rhs.tensor))
}
}
impl<B: Backend, Batch: Dim, X: Dim, Y: Dim, Z: Dim>
Matmul<NamedTensor<B, (Batch, Y, Z)>, NamedTensor<B, (Batch, X, Z)>>
for NamedTensor<B, (Batch, X, Y)>
{
fn matmul(&self, rhs: &NamedTensor<B, (Batch, Y, Z)>) -> NamedTensor<B, (Batch, X, Z)> {
NamedTensor::from_tensor(self.tensor.matmul(&rhs.tensor))
}
}
impl<B: Backend, Batch1: Dim, Batch2: Dim, X: Dim, Y: Dim, Z: Dim>
Matmul<NamedTensor<B, (Batch1, Batch2, Y, Z)>, NamedTensor<B, (Batch1, Batch2, X, Z)>>
for NamedTensor<B, (Batch1, Batch2, X, Y)>
{
fn matmul(
&self,
rhs: &NamedTensor<B, (Batch1, Batch2, Y, Z)>,
) -> NamedTensor<B, (Batch1, Batch2, X, Z)> {
NamedTensor::from_tensor(self.tensor.matmul(&rhs.tensor))
}
}

View File

@ -0,0 +1,7 @@
mod base;
mod dims;
mod matmul;
mod permut;
pub use base::*;
pub use dims::*;

View File

@ -0,0 +1,62 @@
use crate::backend::Backend;
use crate::{Dim, NamedDims, NamedTensor, Tensor};
pub trait Permut<N, const D1: usize, const D2: usize> {
fn permut(&self) -> N;
}
impl<B: Backend, const D: usize, ND> NamedTensor<B, ND>
where
ND: NamedDims<B, Tensor = Tensor<B, D>>,
{
/// Permut two dimensions.
pub fn permut<ND2, const D1: usize, const D2: usize>(&self) -> NamedTensor<B, ND2>
where
ND2: NamedDims<B, Tensor = Tensor<B, D>>,
Self: Permut<NamedTensor<B, ND2>, D1, D2>,
{
Permut::permut(self)
}
}
macro_rules! generate_permut {
(2 => $output:ty, ($dim1:expr, $dim2:expr)) => {
impl<B: Backend, D1: Dim, D2: Dim> Permut<NamedTensor<B, $output>, $dim1, $dim2>
for NamedTensor<B, (D1, D2)>
{
fn permut(&self) -> NamedTensor<B, $output> {
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
}
}
};
(3 => $output:ty, ($dim1:expr, $dim2:expr)) => {
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim> Permut<NamedTensor<B, $output>, $dim1, $dim2>
for NamedTensor<B, (D1, D2, D3)>
{
fn permut(&self) -> NamedTensor<B, $output> {
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
}
}
};
(4 => $output:ty, ($dim1:expr, $dim2:expr)) => {
impl<B: Backend, D1: Dim, D2: Dim, D3: Dim, D4: Dim>
Permut<NamedTensor<B, $output>, $dim1, $dim2> for NamedTensor<B, (D1, D2, D3, D4)>
{
fn permut(&self) -> NamedTensor<B, $output> {
NamedTensor::from_tensor(self.tensor.swap_dims($dim1, $dim2))
}
}
};
}
generate_permut!(2 => (D2, D1), (0, 1));
generate_permut!(3 => (D2, D1, D3), (0, 1));
generate_permut!(3 => (D3, D2, D1), (0, 2));
generate_permut!(3 => (D1, D3, D2), (1, 2));
generate_permut!(4 => (D2, D1, D3, D4), (0, 1));
generate_permut!(4 => (D3, D2, D1, D4), (0, 2));
generate_permut!(4 => (D4, D2, D3, D1), (0, 3));
generate_permut!(4 => (D1, D3, D2, D4), (1, 2));
generate_permut!(4 => (D1, D4, D3, D2), (1, 3));

View File

@ -11,6 +11,7 @@ license = "MIT/Apache-2.0"
edition = "2021"
[features]
named_tensor = ["burn-tensor/named_tensor"]
[dependencies]
burn-tensor = { version = "0.3.0", path = "../burn-tensor" }

View File

@ -0,0 +1,15 @@
[package]
name = "named-tensor"
version = "0.1.0"
authors = ["nathanielsimard <nathaniel.simard.42@gmail.com>"]
license = "MIT/Apache-2.0"
edition = "2021"
publish = false
[dependencies]
burn = { path = "../../burn", features = ["named_tensor"] }
burn-autodiff = { path = "../../burn-autodiff" }
burn-ndarray = { path = "../../burn-ndarray" }
# Serialization
serde = { version = "1.0", features = ["derive"] }

View File

@ -0,0 +1,5 @@
use named_tensor;
fn main() {
named_tensor::run::<burn_ndarray::NdArrayBackend<f32>>();
}

View File

@ -0,0 +1,47 @@
use burn::tensor::backend::Backend;
use burn::tensor::{Dim, Distribution, NamedDim, NamedTensor};
NamedDim!(Batch);
NamedDim!(SeqLenght);
NamedDim!(DModel);
pub fn run<B: Backend>() {
let batch_size = 32;
let seq_length = 48;
let d_model = 24;
let weights = NamedTensor::<B, (Batch, DModel, DModel)>::random(
[1, d_model, d_model],
Distribution::Standard,
);
let input = NamedTensor::<B, (Batch, SeqLenght, DModel)>::random(
[batch_size, seq_length, d_model],
Distribution::Standard,
);
// Doesn't compile
//
// mismatched types
// expected reference `&NamedTensor<B, (Batch, DModel, _)>`
// found reference `&NamedTensor<B, (Batch, SeqLenght, DModel)>`
// let output = weights.matmul(&input);
let output = input.matmul(&weights);
// Doesn't compile
//
// mismatched types
// expected reference `&NamedTensor<B, (Batch, SeqLenght, DModel)>`
// found reference `&NamedTensor<B, (Batch, DModel, DModel)>`
// let output = output.mul(&weights);
let output = output.mul(&input);
let permut = output.permut::<_, 1, 2>();
println!("Weights => {}", weights);
println!("Input => {}", input);
println!("Output => {}", output);
println!("Permut => {}", permut);
}