draft: Perf/ndarray matmul (#214)

This commit is contained in:
Nathaniel Simard 2023-03-09 14:00:35 -05:00 committed by GitHub
parent d8e5b3fed1
commit a2ec774c37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 288 additions and 216 deletions

View File

@ -13,7 +13,6 @@ pub fn backward<B: Backend, const D: usize>(root: ADTensor<B, D>) -> Gradients {
fn build_tape(root: NodeRef, graph: Graph) -> Vec<Vec<StepBoxed>> {
let mut tape = (0..root.order)
.into_iter()
.map(|_| Vec::with_capacity(1))
.collect::<Vec<_>>();

View File

@ -18,9 +18,11 @@ std = [
"burn-tensor/std",
"burn-common/std",
"burn-autodiff",
"rayon",
"ndarray/std",
"ndarray/rayon",
"ndarray/matrixmultiply-threading",
"matrixmultiply/std",
"matrixmultiply/threading",
]
blas-accelerate = ["ndarray/blas", "blas-src/accelerate"] # Accelerate framework (macOS only)
@ -40,7 +42,12 @@ burn-autodiff = {path = "../burn-autodiff", features = ["export_tests"], optiona
burn-common = {path = "../burn-common", default-features = false}
burn-tensor = {path = "../burn-tensor", default-features = false, features = ["export_tests"]}
matrixmultiply = {version = "0.3.2", default-features = false}
rayon = {version= "1.6", optional = true}
blas-src = {version = "0.8.0", default-features = false, optional = true}# no-std compatible
derive-new = {workspace = true}
libm = {workspace = true}
ndarray = {workspace = true}

View File

@ -1,7 +1,7 @@
use alloc::string::String;
use core::marker::PhantomData;
use crate::element::NdArrayElement;
use crate::element::FloatNdArrayElement;
use crate::NdArrayTensor;
use burn_tensor::backend::Backend;
@ -32,7 +32,7 @@ pub struct NdArrayBackend<E> {
phantom: PhantomData<E>,
}
impl<E: NdArrayElement> Backend for NdArrayBackend<E> {
impl<E: FloatNdArrayElement> Backend for NdArrayBackend<E> {
type Device = NdArrayDevice;
type FullPrecisionElem = f32;
type FullPrecisionBackend = NdArrayBackend<f32>;

View File

@ -1,6 +1,13 @@
use burn_tensor::Element;
use libm::{exp, log, log1p, pow, sqrt};
use libm::{expf, log1pf, logf, powf, sqrtf};
use ndarray::LinalgScalar;
pub(crate) trait FloatNdArrayElement: NdArrayElement + LinalgScalar
where
Self: Sized,
{
}
pub(crate) trait NdArrayElement:
Element
@ -21,6 +28,7 @@ pub(crate) trait ExpElement {
fn sqrt_elem(self) -> Self;
}
impl FloatNdArrayElement for f64 {}
impl NdArrayElement for f64 {}
impl ExpElement for f64 {
fn exp_elem(self) -> Self {
@ -44,24 +52,30 @@ impl ExpElement for f64 {
}
}
impl FloatNdArrayElement for f32 {}
impl NdArrayElement for f32 {}
impl ExpElement for f32 {
#[inline(always)]
fn exp_elem(self) -> Self {
expf(self)
}
#[inline(always)]
fn log_elem(self) -> Self {
logf(self)
}
#[inline(always)]
fn log1p_elem(self) -> Self {
log1pf(self)
}
#[inline(always)]
fn pow_elem(self, value: f32) -> Self {
powf(self, value)
}
#[inline(always)]
fn sqrt_elem(self) -> Self {
sqrtf(self)
}

View File

@ -13,9 +13,11 @@ extern crate blas_src;
mod backend;
mod element;
mod ops;
mod sharing;
mod tensor;
pub use backend::*;
pub(crate) use sharing::*;
pub(crate) use tensor::*;
extern crate alloc;

View File

@ -10,7 +10,7 @@ use ndarray::SliceInfoElem;
use crate::element::NdArrayElement;
use crate::ops::macros::{keepdim, mean_dim, sum_dim};
use crate::{tensor::NdArrayTensor, to_nd_array_tensor};
use crate::{reshape, tensor::NdArrayTensor};
pub struct NdArrayOps<E> {
e: PhantomData<E>,
@ -51,15 +51,12 @@ where
tensor: NdArrayTensor<E, D1>,
shape: Shape<D2>,
) -> NdArrayTensor<E, D2> {
match D2 {
1 => to_nd_array_tensor!(1, shape, tensor.array),
2 => to_nd_array_tensor!(2, shape, tensor.array),
3 => to_nd_array_tensor!(3, shape, tensor.array),
4 => to_nd_array_tensor!(4, shape, tensor.array),
5 => to_nd_array_tensor!(5, shape, tensor.array),
6 => to_nd_array_tensor!(6, shape, tensor.array),
_ => panic!("NdArrayTensor support only 6 dimensions."),
}
reshape!(
ty E,
shape shape,
array tensor.array,
d D2
)
}
pub fn cat<const D: usize>(

View File

@ -1,20 +1,20 @@
// Language
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::ops::BoolTensorOps;
use burn_tensor::ops::{BoolTensorOps, IntTensorOps};
use core::ops::Range;
// Current crate
use crate::element::NdArrayElement;
use crate::element::FloatNdArrayElement;
use crate::NdArrayDevice;
use crate::{tensor::NdArrayTensor, NdArrayBackend};
// Workspace crates
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Shape};
use burn_tensor::{backend::Backend, Data, Shape};
use super::NdArrayOps;
impl<E: NdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn bool_from_data<const D: usize>(
data: Data<bool, D>,
_device: &NdArrayDevice,
@ -68,7 +68,7 @@ impl<E: NdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> NdArrayTensor<i64, D> {
let data = Self::bool_into_data(tensor);
NdArrayBackend::<i64>::from_data(data.convert(), &NdArrayDevice::Cpu)
NdArrayBackend::<E>::int_from_data(data.convert(), &NdArrayDevice::Cpu)
}
fn bool_device<const D: usize>(

View File

@ -1,7 +1,7 @@
use alloc::vec::Vec;
use super::padding::apply_padding2d;
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use burn_tensor::{ops::TensorOps, Shape};
@ -9,7 +9,7 @@ use libm::ceilf;
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
/// A more optimized version should be used in its place.
pub(crate) fn conv2d_naive<E: NdArrayElement>(
pub(crate) fn conv2d_naive<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>,
@ -35,7 +35,7 @@ pub(crate) fn conv2d_naive<E: NdArrayElement>(
NdArrayBackend::cat(results, 0)
}
pub(crate) fn conv2d_naive_no_batch_size<E: NdArrayElement>(
pub(crate) fn conv2d_naive_no_batch_size<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 3>,
weight: NdArrayTensor<E, 4>,
bias: Option<NdArrayTensor<E, 1>>,
@ -81,7 +81,7 @@ pub(crate) fn conv2d_naive_no_batch_size<E: NdArrayElement>(
result
}
fn conv2d_with_kernel<E: NdArrayElement>(
fn conv2d_with_kernel<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 2>,
kernel: NdArrayTensor<E, 2>,
stride: [usize; 2],

View File

@ -5,7 +5,7 @@ use burn_tensor::ops::IntTensorOps;
use core::ops::Range;
// Current crate
use crate::element::NdArrayElement;
use crate::element::FloatNdArrayElement;
use crate::NdArrayDevice;
use crate::{tensor::NdArrayTensor, NdArrayBackend};
@ -14,7 +14,7 @@ use burn_tensor::{backend::Backend, Data, Shape};
use super::{NdArrayMathOps, NdArrayOps};
impl<E: NdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn int_from_data<const D: usize>(
data: Data<i64, D>,
_device: &NdArrayDevice,

View File

@ -0,0 +1,126 @@
use crate::UnsafeSharedRef;
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use burn_tensor::ElementConversion;
use burn_tensor::{ops::TensorOps, Shape};
use ndarray::s;
#[cfg(feature = "std")]
use rayon::prelude::*;
pub(crate) fn matmul<E, const D: usize>(
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D>
where
E: FloatNdArrayElement,
{
let shape_ori_lhs = lhs.shape();
let shape_ori_rhs = rhs.shape();
let lhs = reshape(lhs);
let rhs = reshape(rhs);
let [batch_size_lhs, m, _] = lhs.shape().dims;
let [batch_size_rhs, _, n] = rhs.shape().dims;
let mut shape_out = match batch_size_lhs > batch_size_rhs {
true => shape_ori_lhs,
false => shape_ori_rhs,
};
shape_out.dims[D - 2] = m;
shape_out.dims[D - 1] = n;
let out = general_matmul(lhs, rhs);
NdArrayBackend::<E>::reshape(out, shape_out)
}
fn general_matmul<E: FloatNdArrayElement>(
lhs: NdArrayTensor<E, 3>,
rhs: NdArrayTensor<E, 3>,
) -> NdArrayTensor<E, 3> {
let run = || {
let [batch_size_lhs, m, _] = lhs.shape().dims;
let [batch_size_rhs, k, n] = rhs.shape().dims;
let batch_size = usize::max(batch_size_rhs, batch_size_lhs);
if batch_size_lhs > batch_size && batch_size_lhs != 1 {
panic!("Broadcast on multiple dimensions is not yet supported");
}
if batch_size_rhs > batch_size && batch_size_rhs != 1 {
panic!("Broadcast on multiple dimensions is not yet supported");
}
let alpha: E = 1.0.elem();
let beta: E = 0.0.elem();
let mut out_array = ndarray::Array3::<E>::zeros((batch_size, m, n));
let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap();
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap();
#[cfg(feature = "std")]
let iter = (0..batch_size).into_par_iter();
#[cfg(not(feature = "std"))]
let iter = 0..batch_size;
iter.for_each(|b| {
let lhs_slice = match batch_size_lhs == 1 {
true => lhs_array.slice(s!(0, .., ..)),
false => lhs_array.slice(s!(b, .., ..)),
};
let rhs_slice = match batch_size_rhs == 1 {
true => rhs_array.slice(s!(0, .., ..)),
false => rhs_array.slice(s!(b, .., ..)),
};
unsafe {
let mut out_slice = unsafe_shared_out_array.get().slice_mut(s!(b, .., ..));
ndarray::linalg::general_mat_mul(
alpha,
&lhs_slice,
&rhs_slice,
beta,
&mut out_slice,
);
}
});
NdArrayTensor::new(out_array.into_shared().into_dyn())
};
#[cfg(feature = "std")]
let output = rayon::scope(|_| run());
#[cfg(not(feature = "std"))]
let output = run();
output
}
fn reshape<E: FloatNdArrayElement, const D: usize>(
tensor: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, 3> {
let shape = tensor.shape();
if D < 2 {
NdArrayBackend::<E>::reshape(tensor, Shape::new([1, 1, shape.dims[0]]))
} else {
let batch_size = batch_size(&shape);
let size0 = shape.dims[D - 2];
let size1 = shape.dims[D - 1];
NdArrayBackend::<E>::reshape(tensor, Shape::new([batch_size, size0, size1]))
}
}
fn batch_size<const D: usize>(shape: &Shape<D>) -> usize {
let mut num_batch = 1;
for i in 0..D - 2 {
num_batch *= shape.dims[i];
}
num_batch
}

View File

@ -1,15 +1,18 @@
use alloc::{vec, vec::Vec};
use super::padding::apply_padding2d;
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use burn_tensor::{ops::TensorOps, Data, Shape};
use burn_tensor::{
ops::{IntTensorOps, TensorOps},
Data, Shape,
};
use libm::ceilf;
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
/// A more optimized version should be used in its place.
pub(crate) fn max_pool2d_with_indexes_naive<E: NdArrayElement>(
pub(crate) fn max_pool2d_with_indexes_naive<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
@ -33,13 +36,14 @@ pub(crate) fn max_pool2d_with_indexes_naive<E: NdArrayElement>(
let [heigth, width] = matrix.shape().dims;
let matrix = NdArrayBackend::reshape(matrix, Shape::new([1, 1, heigth, width]));
let indexes = NdArrayBackend::reshape(indexes, Shape::new([1, 1, heigth, width]));
let indexes =
NdArrayBackend::<E>::int_reshape(indexes, Shape::new([1, 1, heigth, width]));
batch.push(matrix);
batch_indexes.push(indexes);
}
let batch = NdArrayBackend::cat(batch, 1);
let batch_indexes = NdArrayBackend::cat(batch_indexes, 1);
let batch_indexes = NdArrayBackend::<E>::int_cat(batch_indexes, 1);
batches.push(batch);
batches_indexes.push(batch_indexes);
@ -47,11 +51,11 @@ pub(crate) fn max_pool2d_with_indexes_naive<E: NdArrayElement>(
(
NdArrayBackend::cat(batches, 0),
NdArrayBackend::cat(batches_indexes, 0),
NdArrayBackend::<E>::int_cat(batches_indexes, 0),
)
}
pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
pub(crate) fn max_pool2d_backward_naive<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 4>,
_kernel_size: [usize; 2],
_stride: [usize; 2],
@ -66,8 +70,10 @@ pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
output_grad,
Shape::new([batch_size, channels, heigth * width]),
);
let indexes_flatten =
NdArrayBackend::reshape(indexes, Shape::new([batch_size, channels, heigth * width]));
let indexes_flatten = NdArrayBackend::<E>::int_reshape(
indexes,
Shape::new([batch_size, channels, heigth * width]),
);
let mut output_flatten = NdArrayBackend::zeros(
Shape::new([batch_size, channels, heigth_x * width_x]),
&NdArrayDevice::Cpu,
@ -76,9 +82,11 @@ pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
for b in 0..batch_size {
for c in 0..channels {
for i in 0..(heigth * width) {
let index =
NdArrayBackend::index(indexes_flatten.clone(), [b..b + 1, c..c + 1, i..i + 1]);
let index = NdArrayBackend::into_data(index).value[0] as usize;
let index = NdArrayBackend::<E>::int_index(
indexes_flatten.clone(),
[b..b + 1, c..c + 1, i..i + 1],
);
let index = NdArrayBackend::<E>::int_into_data(index).value[0] as usize;
let current_value = NdArrayBackend::index(
output_flatten.clone(),
@ -102,7 +110,7 @@ pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
NdArrayBackend::reshape(output_flatten, x.shape())
}
fn max_pool2d_with_kernel<E: NdArrayElement>(
fn max_pool2d_with_kernel<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 2>,
kernel_size: [usize; 2],
stride: [usize; 2],
@ -117,7 +125,7 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
let mut output =
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
let mut indexes =
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
NdArrayBackend::<E>::int_empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
for i in 0..heigth_new {
for j in 0..width_new {
@ -127,7 +135,7 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
let x_ij = NdArrayBackend::index(x.clone(), [i_x..i_x + k1, j_x..j_x + k2]);
let x_flatten = NdArrayBackend::reshape(x_ij, Shape::new([k1 * k2]));
let index = NdArrayBackend::argmax(x_flatten.clone(), 0);
let index = NdArrayBackend::into_data(index).value[0];
let index = NdArrayBackend::<E>::int_into_data(index).value[0];
let value = NdArrayBackend::into_data(x_flatten).value[index as usize];
let value = NdArrayBackend::from_data(
Data::new(vec![value], Shape::new([1, 1])),
@ -141,12 +149,12 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
let h = (heigth - (2 * p1)) as i64;
let index = ii * h + jj;
let index = NdArrayBackend::from_data(
let index = NdArrayBackend::<E>::int_from_data(
Data::new(vec![index], Shape::new([1, 1])),
&NdArrayDevice::Cpu,
);
indexes = NdArrayBackend::index_assign(indexes, [i..i + 1, j..j + 1], index);
indexes = NdArrayBackend::<E>::int_index_assign(indexes, [i..i + 1, j..j + 1], index);
output = NdArrayBackend::index_assign(output, [i..i + 1, j..j + 1], value);
}
}

View File

@ -6,6 +6,7 @@ mod tensor;
pub(crate) mod conv;
pub(crate) mod macros;
pub(crate) mod matmul;
pub(crate) mod maxpool;
pub(crate) mod padding;

View File

@ -1,6 +1,6 @@
use alloc::vec::Vec;
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use burn_tensor::{ops::*, Shape};
@ -9,7 +9,7 @@ use super::{
maxpool::{max_pool2d_backward_naive, max_pool2d_with_indexes_naive},
};
impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn embedding(
weights: NdArrayTensor<E, 2>,
indexes: NdArrayTensor<i64, 2>,
@ -19,9 +19,10 @@ impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
let mut tensors = Vec::with_capacity(batch_size * seq_length);
for index in NdArrayBackend::reshape(indexes, Shape::new([batch_size * seq_length]))
.array
.iter()
for index in
NdArrayBackend::<E>::int_reshape(indexes, Shape::new([batch_size * seq_length]))
.array
.iter()
{
let index = *index as usize;
tensors.push(NdArrayBackend::index(
@ -46,7 +47,7 @@ impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
NdArrayBackend::reshape(output, Shape::new([batch_size * seq_length, d_model]));
for (index_output, index) in
NdArrayBackend::reshape(indexes, Shape::new([batch_size * seq_length]))
NdArrayBackend::<E>::int_reshape(indexes, Shape::new([batch_size * seq_length]))
.array
.iter()
.enumerate()

View File

@ -1,7 +1,7 @@
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
use burn_tensor::{ops::TensorOps, Shape};
pub(crate) fn apply_padding2d<E: NdArrayElement>(
pub(crate) fn apply_padding2d<E: FloatNdArrayElement>(
x: NdArrayTensor<E, 2>,
padding: [usize; 2],
) -> NdArrayTensor<E, 2> {

View File

@ -3,9 +3,9 @@ use alloc::vec::Vec;
use core::cmp::Ordering;
use core::ops::Range;
use crate::element::FloatNdArrayElement;
// Current crate
use crate::tensor::BatchMatrix;
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use crate::{tensor::NdArrayTensor, NdArrayBackend};
use crate::{NdArrayDevice, SEED};
// Workspace crates
@ -16,9 +16,9 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Sha
// External crates
use libm::{cos, erf, sin, tanh};
use super::{NdArrayMathOps, NdArrayOps};
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn from_data<const D: usize>(data: Data<E, D>, _device: &NdArrayDevice) -> NdArrayTensor<E, D> {
NdArrayTensor::from_data(data)
}
@ -124,11 +124,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
let batch_self = BatchMatrix::from_ndarray(lhs.array.clone(), lhs.shape());
let batch_other = BatchMatrix::from_ndarray(rhs.array.clone(), rhs.shape());
let output = batch_self.matmul(batch_other);
NdArrayTensor::from_bmatrix(output)
matmul(lhs, rhs)
}
fn neg<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
@ -392,7 +388,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
}
}
fn arg<E: NdArrayElement, F, const D: usize>(
fn arg<E: FloatNdArrayElement, F, const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
cmp: F,

View File

@ -0,0 +1,19 @@
use core::cell::UnsafeCell;
/// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439).
pub(crate) struct UnsafeSharedRef<'a, T> {
cell: UnsafeCell<&'a mut T>,
}
unsafe impl<'a, T> Sync for UnsafeSharedRef<'a, T> {}
impl<'a, T> UnsafeSharedRef<'a, T> {
pub fn new(data: &'a mut T) -> Self {
Self {
cell: UnsafeCell::new(data),
}
}
pub unsafe fn get(&self) -> &'a mut T {
unsafe { core::ptr::read(self.cell.get()) }
}
}

View File

@ -1,12 +1,8 @@
use alloc::vec::Vec;
use super::element::NdArrayElement;
use burn_tensor::{Data, Shape};
use ndarray::{s, ArcArray, Array, Axis, Dim, Ix2, Ix3, IxDyn};
use ndarray::{ArcArray, Array, Dim, IxDyn};
#[derive(Debug, Clone)]
#[derive(new, Debug, Clone)]
pub struct NdArrayTensor<E, const D: usize> {
pub array: ArcArray<E, IxDyn>,
}
@ -20,7 +16,7 @@ impl<E, const D: usize> NdArrayTensor<E, D> {
#[cfg(test)]
mod utils {
use super::*;
use crate::NdArrayBackend;
use crate::{element::FloatNdArrayElement, NdArrayBackend};
use burn_tensor::ops::TensorOps;
impl<E, const D: usize> NdArrayTensor<E, D>
@ -29,105 +25,13 @@ mod utils {
{
pub(crate) fn into_data(self) -> Data<E, D>
where
E: NdArrayElement,
E: FloatNdArrayElement,
{
<NdArrayBackend<E> as TensorOps<NdArrayBackend<E>>>::into_data::<D>(self)
}
}
}
#[derive(new)]
pub(crate) struct BatchMatrix<E, const D: usize> {
pub arrays: Vec<ArcArray<E, Ix2>>,
pub shape: Shape<D>,
}
impl<E, const D: usize> BatchMatrix<E, D>
where
E: NdArrayElement,
{
pub fn from_ndarray(array: ArcArray<E, IxDyn>, shape: Shape<D>) -> Self {
let mut arrays = Vec::new();
if D < 2 {
let array = array.reshape((1, shape.dims[0]));
arrays.push(array);
} else {
let batch_size = batch_size(&shape);
let size0 = shape.dims[D - 2];
let size1 = shape.dims[D - 1];
let array_global = array.reshape((batch_size, size0, size1));
for b in 0..batch_size {
let array = array_global.slice(s!(b, .., ..));
let array = array.into_owned().into_shared();
arrays.push(array);
}
}
Self { arrays, shape }
}
pub fn matmul(self, other: BatchMatrix<E, D>) -> Self {
let require_broadcast = self.arrays.len() != other.arrays.len();
if require_broadcast {
return self.matmul_broadcast(other);
}
let self_iter = self.arrays.iter();
let other_iter = other.arrays.iter();
let arrays = self_iter
.zip(other_iter)
.map(|(lhs, rhs)| lhs.dot(rhs))
.map(|output| output.into_shared())
.collect();
let mut shape = self.shape;
shape.dims[D - 1] = other.shape.dims[D - 1];
Self::new(arrays, shape)
}
fn matmul_broadcast(self, other: BatchMatrix<E, D>) -> Self {
let valid_broadcast = self.arrays.len() == 1 || other.arrays.len() == 1;
if !valid_broadcast {
panic!("Invalid broadcast => {:?} , {:?}", self.shape, other.shape);
}
let batch_size = usize::max(self.arrays.len(), other.arrays.len());
let mut arrays = Vec::with_capacity(batch_size);
for batch in 0..batch_size {
let self_tensor = if self.arrays.len() == 1 {
&self.arrays[0]
} else {
&self.arrays[batch]
};
let other_tensor = if other.arrays.len() == 1 {
&other.arrays[0]
} else {
&other.arrays[batch]
};
let tensor = self_tensor.dot(other_tensor);
arrays.push(tensor.into_shared());
}
let mut shape = self.shape;
shape.dims[D - 1] = other.shape.dims[D - 1];
Self::new(arrays, shape)
}
}
fn batch_size<const D: usize>(shape: &Shape<D>) -> usize {
let mut num_batch = 1;
for i in 0..D - 2 {
num_batch *= shape.dims[i];
}
num_batch
}
#[macro_export(local_inner_macros)]
macro_rules! to_typed_dims {
(
@ -145,60 +49,45 @@ macro_rules! to_typed_dims {
}
#[macro_export(local_inner_macros)]
macro_rules! to_nd_array_tensor {
macro_rules! reshape {
(
$n:expr,
$shape:expr,
$array:expr
ty $ty:ty,
n $n:expr,
shape $shape:expr,
array $array:expr
) => {{
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
let array: ndarray::ArcArray<E, Dim<[usize; $n]>> = $array.reshape(dim);
let array = array.into_dyn();
NdArrayTensor { array }
}};
(
bool,
$n:expr,
$shape:expr,
$array:expr
) => {{
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
let array: ndarray::ArcArray<bool, Dim<[usize; $n]>> = $array.reshape(dim);
let array = array.into_dyn();
NdArrayTensor { array }
}};
}
impl<E, const D: usize> NdArrayTensor<E, D>
where
E: Default + Clone,
{
pub(crate) fn from_bmatrix(bmatrix: BatchMatrix<E, D>) -> NdArrayTensor<E, D> {
let shape = bmatrix.shape.clone();
let to_array = |data: BatchMatrix<E, D>| {
let dims = data.shape.dims;
let mut array: Array<E, Ix3> = Array::default((0, dims[D - 2], dims[D - 1]));
for item in data.arrays {
array.push(Axis(0), item.view()).unwrap();
}
array.into_shared()
let safe_into_shape =
$array.is_standard_layout() || $array.raw_view().reversed_axes().is_standard_layout();
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape {
true => $array
.into_shape(dim)
.expect("Safe to change shape without relayout")
.into_shared(),
false => $array.reshape(dim),
};
let array = array.into_dyn();
match D {
1 => to_nd_array_tensor!(1, shape, to_array(bmatrix)),
2 => to_nd_array_tensor!(2, shape, to_array(bmatrix)),
3 => to_nd_array_tensor!(3, shape, to_array(bmatrix)),
4 => to_nd_array_tensor!(4, shape, to_array(bmatrix)),
5 => to_nd_array_tensor!(5, shape, to_array(bmatrix)),
6 => to_nd_array_tensor!(6, shape, to_array(bmatrix)),
_ => panic!(""),
NdArrayTensor::new(array)
}};
(
ty $ty:ty,
shape $shape:expr,
array $array:expr,
d $D:expr
) => {{
match $D {
1 => reshape!(ty $ty, n 1, shape $shape, array $array),
2 => reshape!(ty $ty, n 2, shape $shape, array $array),
3 => reshape!(ty $ty, n 3, shape $shape, array $array),
4 => reshape!(ty $ty, n 4, shape $shape, array $array),
5 => reshape!(ty $ty, n 5, shape $shape, array $array),
6 => reshape!(ty $ty, n 6, shape $shape, array $array),
_ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
}
}
}};
}
impl<E, const D: usize> NdArrayTensor<E, D>
where
E: Default + Clone,
@ -206,16 +95,14 @@ where
pub fn from_data(data: Data<E, D>) -> NdArrayTensor<E, D> {
let shape = data.shape.clone();
let to_array = |data: Data<E, D>| Array::from_iter(data.value.into_iter()).into_shared();
let array = to_array(data);
match D {
1 => to_nd_array_tensor!(1, shape, to_array(data)),
2 => to_nd_array_tensor!(2, shape, to_array(data)),
3 => to_nd_array_tensor!(3, shape, to_array(data)),
4 => to_nd_array_tensor!(4, shape, to_array(data)),
5 => to_nd_array_tensor!(5, shape, to_array(data)),
6 => to_nd_array_tensor!(6, shape, to_array(data)),
_ => panic!(""),
}
reshape!(
ty E,
shape shape,
array array,
d D
)
}
}

View File

@ -32,4 +32,19 @@ mod tests {
Data::from([[[18.0, 28.0], [14.0, 23.0]]])
);
}
#[test]
fn test_matmul_broadcast_1() {
let data_1 = Data::from([[[1.0, 7.0], [2.0, 3.0]]]);
let data_2 = Data::from([[[4.0, 7.0], [2.0, 3.0]], [[4.0, 7.0], [2.0, 3.0]]]);
let tensor_1 = Tensor::<TestBackend, 3>::from_data(data_1);
let tensor_2 = Tensor::<TestBackend, 3>::from_data(data_2);
let tensor_3 = tensor_1.matmul(tensor_2);
assert_eq!(
tensor_3.into_data(),
Data::from([[[18.0, 28.0], [14.0, 23.0]], [[18.0, 28.0], [14.0, 23.0]]])
);
}
}