mirror of https://github.com/tracel-ai/burn.git
draft: Perf/ndarray matmul (#214)
This commit is contained in:
parent
d8e5b3fed1
commit
a2ec774c37
|
@ -13,7 +13,6 @@ pub fn backward<B: Backend, const D: usize>(root: ADTensor<B, D>) -> Gradients {
|
||||||
|
|
||||||
fn build_tape(root: NodeRef, graph: Graph) -> Vec<Vec<StepBoxed>> {
|
fn build_tape(root: NodeRef, graph: Graph) -> Vec<Vec<StepBoxed>> {
|
||||||
let mut tape = (0..root.order)
|
let mut tape = (0..root.order)
|
||||||
.into_iter()
|
|
||||||
.map(|_| Vec::with_capacity(1))
|
.map(|_| Vec::with_capacity(1))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
|
|
@ -18,9 +18,11 @@ std = [
|
||||||
"burn-tensor/std",
|
"burn-tensor/std",
|
||||||
"burn-common/std",
|
"burn-common/std",
|
||||||
"burn-autodiff",
|
"burn-autodiff",
|
||||||
|
"rayon",
|
||||||
"ndarray/std",
|
"ndarray/std",
|
||||||
"ndarray/rayon",
|
"ndarray/rayon",
|
||||||
"ndarray/matrixmultiply-threading",
|
"matrixmultiply/std",
|
||||||
|
"matrixmultiply/threading",
|
||||||
]
|
]
|
||||||
|
|
||||||
blas-accelerate = ["ndarray/blas", "blas-src/accelerate"] # Accelerate framework (macOS only)
|
blas-accelerate = ["ndarray/blas", "blas-src/accelerate"] # Accelerate framework (macOS only)
|
||||||
|
@ -40,7 +42,12 @@ burn-autodiff = {path = "../burn-autodiff", features = ["export_tests"], optiona
|
||||||
burn-common = {path = "../burn-common", default-features = false}
|
burn-common = {path = "../burn-common", default-features = false}
|
||||||
burn-tensor = {path = "../burn-tensor", default-features = false, features = ["export_tests"]}
|
burn-tensor = {path = "../burn-tensor", default-features = false, features = ["export_tests"]}
|
||||||
|
|
||||||
|
|
||||||
|
matrixmultiply = {version = "0.3.2", default-features = false}
|
||||||
|
rayon = {version= "1.6", optional = true}
|
||||||
|
|
||||||
blas-src = {version = "0.8.0", default-features = false, optional = true}# no-std compatible
|
blas-src = {version = "0.8.0", default-features = false, optional = true}# no-std compatible
|
||||||
|
|
||||||
derive-new = {workspace = true}
|
derive-new = {workspace = true}
|
||||||
libm = {workspace = true}
|
libm = {workspace = true}
|
||||||
ndarray = {workspace = true}
|
ndarray = {workspace = true}
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use alloc::string::String;
|
use alloc::string::String;
|
||||||
use core::marker::PhantomData;
|
use core::marker::PhantomData;
|
||||||
|
|
||||||
use crate::element::NdArrayElement;
|
use crate::element::FloatNdArrayElement;
|
||||||
use crate::NdArrayTensor;
|
use crate::NdArrayTensor;
|
||||||
|
|
||||||
use burn_tensor::backend::Backend;
|
use burn_tensor::backend::Backend;
|
||||||
|
@ -32,7 +32,7 @@ pub struct NdArrayBackend<E> {
|
||||||
phantom: PhantomData<E>,
|
phantom: PhantomData<E>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E: NdArrayElement> Backend for NdArrayBackend<E> {
|
impl<E: FloatNdArrayElement> Backend for NdArrayBackend<E> {
|
||||||
type Device = NdArrayDevice;
|
type Device = NdArrayDevice;
|
||||||
type FullPrecisionElem = f32;
|
type FullPrecisionElem = f32;
|
||||||
type FullPrecisionBackend = NdArrayBackend<f32>;
|
type FullPrecisionBackend = NdArrayBackend<f32>;
|
||||||
|
|
|
@ -1,6 +1,13 @@
|
||||||
use burn_tensor::Element;
|
use burn_tensor::Element;
|
||||||
use libm::{exp, log, log1p, pow, sqrt};
|
use libm::{exp, log, log1p, pow, sqrt};
|
||||||
use libm::{expf, log1pf, logf, powf, sqrtf};
|
use libm::{expf, log1pf, logf, powf, sqrtf};
|
||||||
|
use ndarray::LinalgScalar;
|
||||||
|
|
||||||
|
pub(crate) trait FloatNdArrayElement: NdArrayElement + LinalgScalar
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
pub(crate) trait NdArrayElement:
|
pub(crate) trait NdArrayElement:
|
||||||
Element
|
Element
|
||||||
|
@ -21,6 +28,7 @@ pub(crate) trait ExpElement {
|
||||||
fn sqrt_elem(self) -> Self;
|
fn sqrt_elem(self) -> Self;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FloatNdArrayElement for f64 {}
|
||||||
impl NdArrayElement for f64 {}
|
impl NdArrayElement for f64 {}
|
||||||
impl ExpElement for f64 {
|
impl ExpElement for f64 {
|
||||||
fn exp_elem(self) -> Self {
|
fn exp_elem(self) -> Self {
|
||||||
|
@ -44,24 +52,30 @@ impl ExpElement for f64 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl FloatNdArrayElement for f32 {}
|
||||||
impl NdArrayElement for f32 {}
|
impl NdArrayElement for f32 {}
|
||||||
impl ExpElement for f32 {
|
impl ExpElement for f32 {
|
||||||
|
#[inline(always)]
|
||||||
fn exp_elem(self) -> Self {
|
fn exp_elem(self) -> Self {
|
||||||
expf(self)
|
expf(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
fn log_elem(self) -> Self {
|
fn log_elem(self) -> Self {
|
||||||
logf(self)
|
logf(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
fn log1p_elem(self) -> Self {
|
fn log1p_elem(self) -> Self {
|
||||||
log1pf(self)
|
log1pf(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
fn pow_elem(self, value: f32) -> Self {
|
fn pow_elem(self, value: f32) -> Self {
|
||||||
powf(self, value)
|
powf(self, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline(always)]
|
||||||
fn sqrt_elem(self) -> Self {
|
fn sqrt_elem(self) -> Self {
|
||||||
sqrtf(self)
|
sqrtf(self)
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,9 +13,11 @@ extern crate blas_src;
|
||||||
mod backend;
|
mod backend;
|
||||||
mod element;
|
mod element;
|
||||||
mod ops;
|
mod ops;
|
||||||
|
mod sharing;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
|
|
||||||
pub use backend::*;
|
pub use backend::*;
|
||||||
|
pub(crate) use sharing::*;
|
||||||
pub(crate) use tensor::*;
|
pub(crate) use tensor::*;
|
||||||
|
|
||||||
extern crate alloc;
|
extern crate alloc;
|
||||||
|
|
|
@ -10,7 +10,7 @@ use ndarray::SliceInfoElem;
|
||||||
|
|
||||||
use crate::element::NdArrayElement;
|
use crate::element::NdArrayElement;
|
||||||
use crate::ops::macros::{keepdim, mean_dim, sum_dim};
|
use crate::ops::macros::{keepdim, mean_dim, sum_dim};
|
||||||
use crate::{tensor::NdArrayTensor, to_nd_array_tensor};
|
use crate::{reshape, tensor::NdArrayTensor};
|
||||||
|
|
||||||
pub struct NdArrayOps<E> {
|
pub struct NdArrayOps<E> {
|
||||||
e: PhantomData<E>,
|
e: PhantomData<E>,
|
||||||
|
@ -51,15 +51,12 @@ where
|
||||||
tensor: NdArrayTensor<E, D1>,
|
tensor: NdArrayTensor<E, D1>,
|
||||||
shape: Shape<D2>,
|
shape: Shape<D2>,
|
||||||
) -> NdArrayTensor<E, D2> {
|
) -> NdArrayTensor<E, D2> {
|
||||||
match D2 {
|
reshape!(
|
||||||
1 => to_nd_array_tensor!(1, shape, tensor.array),
|
ty E,
|
||||||
2 => to_nd_array_tensor!(2, shape, tensor.array),
|
shape shape,
|
||||||
3 => to_nd_array_tensor!(3, shape, tensor.array),
|
array tensor.array,
|
||||||
4 => to_nd_array_tensor!(4, shape, tensor.array),
|
d D2
|
||||||
5 => to_nd_array_tensor!(5, shape, tensor.array),
|
)
|
||||||
6 => to_nd_array_tensor!(6, shape, tensor.array),
|
|
||||||
_ => panic!("NdArrayTensor support only 6 dimensions."),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cat<const D: usize>(
|
pub fn cat<const D: usize>(
|
||||||
|
|
|
@ -1,20 +1,20 @@
|
||||||
// Language
|
// Language
|
||||||
use alloc::vec;
|
use alloc::vec;
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
use burn_tensor::ops::BoolTensorOps;
|
use burn_tensor::ops::{BoolTensorOps, IntTensorOps};
|
||||||
use core::ops::Range;
|
use core::ops::Range;
|
||||||
|
|
||||||
// Current crate
|
// Current crate
|
||||||
use crate::element::NdArrayElement;
|
use crate::element::FloatNdArrayElement;
|
||||||
use crate::NdArrayDevice;
|
use crate::NdArrayDevice;
|
||||||
use crate::{tensor::NdArrayTensor, NdArrayBackend};
|
use crate::{tensor::NdArrayTensor, NdArrayBackend};
|
||||||
|
|
||||||
// Workspace crates
|
// Workspace crates
|
||||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Shape};
|
use burn_tensor::{backend::Backend, Data, Shape};
|
||||||
|
|
||||||
use super::NdArrayOps;
|
use super::NdArrayOps;
|
||||||
|
|
||||||
impl<E: NdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
impl<E: FloatNdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
fn bool_from_data<const D: usize>(
|
fn bool_from_data<const D: usize>(
|
||||||
data: Data<bool, D>,
|
data: Data<bool, D>,
|
||||||
_device: &NdArrayDevice,
|
_device: &NdArrayDevice,
|
||||||
|
@ -68,7 +68,7 @@ impl<E: NdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||||
) -> NdArrayTensor<i64, D> {
|
) -> NdArrayTensor<i64, D> {
|
||||||
let data = Self::bool_into_data(tensor);
|
let data = Self::bool_into_data(tensor);
|
||||||
NdArrayBackend::<i64>::from_data(data.convert(), &NdArrayDevice::Cpu)
|
NdArrayBackend::<E>::int_from_data(data.convert(), &NdArrayDevice::Cpu)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bool_device<const D: usize>(
|
fn bool_device<const D: usize>(
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
|
|
||||||
use super::padding::apply_padding2d;
|
use super::padding::apply_padding2d;
|
||||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
||||||
|
|
||||||
use burn_tensor::{ops::TensorOps, Shape};
|
use burn_tensor::{ops::TensorOps, Shape};
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ use libm::ceilf;
|
||||||
|
|
||||||
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
|
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
|
||||||
/// A more optimized version should be used in its place.
|
/// A more optimized version should be used in its place.
|
||||||
pub(crate) fn conv2d_naive<E: NdArrayElement>(
|
pub(crate) fn conv2d_naive<E: FloatNdArrayElement>(
|
||||||
x: NdArrayTensor<E, 4>,
|
x: NdArrayTensor<E, 4>,
|
||||||
weight: NdArrayTensor<E, 4>,
|
weight: NdArrayTensor<E, 4>,
|
||||||
bias: Option<NdArrayTensor<E, 1>>,
|
bias: Option<NdArrayTensor<E, 1>>,
|
||||||
|
@ -35,7 +35,7 @@ pub(crate) fn conv2d_naive<E: NdArrayElement>(
|
||||||
NdArrayBackend::cat(results, 0)
|
NdArrayBackend::cat(results, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn conv2d_naive_no_batch_size<E: NdArrayElement>(
|
pub(crate) fn conv2d_naive_no_batch_size<E: FloatNdArrayElement>(
|
||||||
x: NdArrayTensor<E, 3>,
|
x: NdArrayTensor<E, 3>,
|
||||||
weight: NdArrayTensor<E, 4>,
|
weight: NdArrayTensor<E, 4>,
|
||||||
bias: Option<NdArrayTensor<E, 1>>,
|
bias: Option<NdArrayTensor<E, 1>>,
|
||||||
|
@ -81,7 +81,7 @@ pub(crate) fn conv2d_naive_no_batch_size<E: NdArrayElement>(
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conv2d_with_kernel<E: NdArrayElement>(
|
fn conv2d_with_kernel<E: FloatNdArrayElement>(
|
||||||
x: NdArrayTensor<E, 2>,
|
x: NdArrayTensor<E, 2>,
|
||||||
kernel: NdArrayTensor<E, 2>,
|
kernel: NdArrayTensor<E, 2>,
|
||||||
stride: [usize; 2],
|
stride: [usize; 2],
|
||||||
|
|
|
@ -5,7 +5,7 @@ use burn_tensor::ops::IntTensorOps;
|
||||||
use core::ops::Range;
|
use core::ops::Range;
|
||||||
|
|
||||||
// Current crate
|
// Current crate
|
||||||
use crate::element::NdArrayElement;
|
use crate::element::FloatNdArrayElement;
|
||||||
use crate::NdArrayDevice;
|
use crate::NdArrayDevice;
|
||||||
use crate::{tensor::NdArrayTensor, NdArrayBackend};
|
use crate::{tensor::NdArrayTensor, NdArrayBackend};
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ use burn_tensor::{backend::Backend, Data, Shape};
|
||||||
|
|
||||||
use super::{NdArrayMathOps, NdArrayOps};
|
use super::{NdArrayMathOps, NdArrayOps};
|
||||||
|
|
||||||
impl<E: NdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
fn int_from_data<const D: usize>(
|
fn int_from_data<const D: usize>(
|
||||||
data: Data<i64, D>,
|
data: Data<i64, D>,
|
||||||
_device: &NdArrayDevice,
|
_device: &NdArrayDevice,
|
||||||
|
|
|
@ -0,0 +1,126 @@
|
||||||
|
use crate::UnsafeSharedRef;
|
||||||
|
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||||
|
use burn_tensor::ElementConversion;
|
||||||
|
use burn_tensor::{ops::TensorOps, Shape};
|
||||||
|
use ndarray::s;
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
pub(crate) fn matmul<E, const D: usize>(
|
||||||
|
lhs: NdArrayTensor<E, D>,
|
||||||
|
rhs: NdArrayTensor<E, D>,
|
||||||
|
) -> NdArrayTensor<E, D>
|
||||||
|
where
|
||||||
|
E: FloatNdArrayElement,
|
||||||
|
{
|
||||||
|
let shape_ori_lhs = lhs.shape();
|
||||||
|
let shape_ori_rhs = rhs.shape();
|
||||||
|
|
||||||
|
let lhs = reshape(lhs);
|
||||||
|
let rhs = reshape(rhs);
|
||||||
|
|
||||||
|
let [batch_size_lhs, m, _] = lhs.shape().dims;
|
||||||
|
let [batch_size_rhs, _, n] = rhs.shape().dims;
|
||||||
|
|
||||||
|
let mut shape_out = match batch_size_lhs > batch_size_rhs {
|
||||||
|
true => shape_ori_lhs,
|
||||||
|
false => shape_ori_rhs,
|
||||||
|
};
|
||||||
|
shape_out.dims[D - 2] = m;
|
||||||
|
shape_out.dims[D - 1] = n;
|
||||||
|
|
||||||
|
let out = general_matmul(lhs, rhs);
|
||||||
|
|
||||||
|
NdArrayBackend::<E>::reshape(out, shape_out)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn general_matmul<E: FloatNdArrayElement>(
|
||||||
|
lhs: NdArrayTensor<E, 3>,
|
||||||
|
rhs: NdArrayTensor<E, 3>,
|
||||||
|
) -> NdArrayTensor<E, 3> {
|
||||||
|
let run = || {
|
||||||
|
let [batch_size_lhs, m, _] = lhs.shape().dims;
|
||||||
|
let [batch_size_rhs, k, n] = rhs.shape().dims;
|
||||||
|
let batch_size = usize::max(batch_size_rhs, batch_size_lhs);
|
||||||
|
|
||||||
|
if batch_size_lhs > batch_size && batch_size_lhs != 1 {
|
||||||
|
panic!("Broadcast on multiple dimensions is not yet supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
if batch_size_rhs > batch_size && batch_size_rhs != 1 {
|
||||||
|
panic!("Broadcast on multiple dimensions is not yet supported");
|
||||||
|
}
|
||||||
|
|
||||||
|
let alpha: E = 1.0.elem();
|
||||||
|
let beta: E = 0.0.elem();
|
||||||
|
|
||||||
|
let mut out_array = ndarray::Array3::<E>::zeros((batch_size, m, n));
|
||||||
|
let unsafe_shared_out_array = UnsafeSharedRef::new(&mut out_array);
|
||||||
|
|
||||||
|
let lhs_array = lhs.array.into_shape((batch_size_lhs, m, k)).unwrap();
|
||||||
|
let rhs_array = rhs.array.into_shape((batch_size_rhs, k, n)).unwrap();
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
let iter = (0..batch_size).into_par_iter();
|
||||||
|
#[cfg(not(feature = "std"))]
|
||||||
|
let iter = 0..batch_size;
|
||||||
|
|
||||||
|
iter.for_each(|b| {
|
||||||
|
let lhs_slice = match batch_size_lhs == 1 {
|
||||||
|
true => lhs_array.slice(s!(0, .., ..)),
|
||||||
|
false => lhs_array.slice(s!(b, .., ..)),
|
||||||
|
};
|
||||||
|
let rhs_slice = match batch_size_rhs == 1 {
|
||||||
|
true => rhs_array.slice(s!(0, .., ..)),
|
||||||
|
false => rhs_array.slice(s!(b, .., ..)),
|
||||||
|
};
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let mut out_slice = unsafe_shared_out_array.get().slice_mut(s!(b, .., ..));
|
||||||
|
|
||||||
|
ndarray::linalg::general_mat_mul(
|
||||||
|
alpha,
|
||||||
|
&lhs_slice,
|
||||||
|
&rhs_slice,
|
||||||
|
beta,
|
||||||
|
&mut out_slice,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
NdArrayTensor::new(out_array.into_shared().into_dyn())
|
||||||
|
};
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
let output = rayon::scope(|_| run());
|
||||||
|
#[cfg(not(feature = "std"))]
|
||||||
|
let output = run();
|
||||||
|
|
||||||
|
output
|
||||||
|
}
|
||||||
|
|
||||||
|
fn reshape<E: FloatNdArrayElement, const D: usize>(
|
||||||
|
tensor: NdArrayTensor<E, D>,
|
||||||
|
) -> NdArrayTensor<E, 3> {
|
||||||
|
let shape = tensor.shape();
|
||||||
|
|
||||||
|
if D < 2 {
|
||||||
|
NdArrayBackend::<E>::reshape(tensor, Shape::new([1, 1, shape.dims[0]]))
|
||||||
|
} else {
|
||||||
|
let batch_size = batch_size(&shape);
|
||||||
|
let size0 = shape.dims[D - 2];
|
||||||
|
let size1 = shape.dims[D - 1];
|
||||||
|
|
||||||
|
NdArrayBackend::<E>::reshape(tensor, Shape::new([batch_size, size0, size1]))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn batch_size<const D: usize>(shape: &Shape<D>) -> usize {
|
||||||
|
let mut num_batch = 1;
|
||||||
|
for i in 0..D - 2 {
|
||||||
|
num_batch *= shape.dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
num_batch
|
||||||
|
}
|
|
@ -1,15 +1,18 @@
|
||||||
use alloc::{vec, vec::Vec};
|
use alloc::{vec, vec::Vec};
|
||||||
|
|
||||||
use super::padding::apply_padding2d;
|
use super::padding::apply_padding2d;
|
||||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
||||||
|
|
||||||
use burn_tensor::{ops::TensorOps, Data, Shape};
|
use burn_tensor::{
|
||||||
|
ops::{IntTensorOps, TensorOps},
|
||||||
|
Data, Shape,
|
||||||
|
};
|
||||||
|
|
||||||
use libm::ceilf;
|
use libm::ceilf;
|
||||||
|
|
||||||
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
|
/// This method is not the most efficient, but it serves as a basic implementation that is easy to understand.
|
||||||
/// A more optimized version should be used in its place.
|
/// A more optimized version should be used in its place.
|
||||||
pub(crate) fn max_pool2d_with_indexes_naive<E: NdArrayElement>(
|
pub(crate) fn max_pool2d_with_indexes_naive<E: FloatNdArrayElement>(
|
||||||
x: NdArrayTensor<E, 4>,
|
x: NdArrayTensor<E, 4>,
|
||||||
kernel_size: [usize; 2],
|
kernel_size: [usize; 2],
|
||||||
stride: [usize; 2],
|
stride: [usize; 2],
|
||||||
|
@ -33,13 +36,14 @@ pub(crate) fn max_pool2d_with_indexes_naive<E: NdArrayElement>(
|
||||||
let [heigth, width] = matrix.shape().dims;
|
let [heigth, width] = matrix.shape().dims;
|
||||||
|
|
||||||
let matrix = NdArrayBackend::reshape(matrix, Shape::new([1, 1, heigth, width]));
|
let matrix = NdArrayBackend::reshape(matrix, Shape::new([1, 1, heigth, width]));
|
||||||
let indexes = NdArrayBackend::reshape(indexes, Shape::new([1, 1, heigth, width]));
|
let indexes =
|
||||||
|
NdArrayBackend::<E>::int_reshape(indexes, Shape::new([1, 1, heigth, width]));
|
||||||
|
|
||||||
batch.push(matrix);
|
batch.push(matrix);
|
||||||
batch_indexes.push(indexes);
|
batch_indexes.push(indexes);
|
||||||
}
|
}
|
||||||
let batch = NdArrayBackend::cat(batch, 1);
|
let batch = NdArrayBackend::cat(batch, 1);
|
||||||
let batch_indexes = NdArrayBackend::cat(batch_indexes, 1);
|
let batch_indexes = NdArrayBackend::<E>::int_cat(batch_indexes, 1);
|
||||||
|
|
||||||
batches.push(batch);
|
batches.push(batch);
|
||||||
batches_indexes.push(batch_indexes);
|
batches_indexes.push(batch_indexes);
|
||||||
|
@ -47,11 +51,11 @@ pub(crate) fn max_pool2d_with_indexes_naive<E: NdArrayElement>(
|
||||||
|
|
||||||
(
|
(
|
||||||
NdArrayBackend::cat(batches, 0),
|
NdArrayBackend::cat(batches, 0),
|
||||||
NdArrayBackend::cat(batches_indexes, 0),
|
NdArrayBackend::<E>::int_cat(batches_indexes, 0),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
|
pub(crate) fn max_pool2d_backward_naive<E: FloatNdArrayElement>(
|
||||||
x: NdArrayTensor<E, 4>,
|
x: NdArrayTensor<E, 4>,
|
||||||
_kernel_size: [usize; 2],
|
_kernel_size: [usize; 2],
|
||||||
_stride: [usize; 2],
|
_stride: [usize; 2],
|
||||||
|
@ -66,8 +70,10 @@ pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
|
||||||
output_grad,
|
output_grad,
|
||||||
Shape::new([batch_size, channels, heigth * width]),
|
Shape::new([batch_size, channels, heigth * width]),
|
||||||
);
|
);
|
||||||
let indexes_flatten =
|
let indexes_flatten = NdArrayBackend::<E>::int_reshape(
|
||||||
NdArrayBackend::reshape(indexes, Shape::new([batch_size, channels, heigth * width]));
|
indexes,
|
||||||
|
Shape::new([batch_size, channels, heigth * width]),
|
||||||
|
);
|
||||||
let mut output_flatten = NdArrayBackend::zeros(
|
let mut output_flatten = NdArrayBackend::zeros(
|
||||||
Shape::new([batch_size, channels, heigth_x * width_x]),
|
Shape::new([batch_size, channels, heigth_x * width_x]),
|
||||||
&NdArrayDevice::Cpu,
|
&NdArrayDevice::Cpu,
|
||||||
|
@ -76,9 +82,11 @@ pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
|
||||||
for b in 0..batch_size {
|
for b in 0..batch_size {
|
||||||
for c in 0..channels {
|
for c in 0..channels {
|
||||||
for i in 0..(heigth * width) {
|
for i in 0..(heigth * width) {
|
||||||
let index =
|
let index = NdArrayBackend::<E>::int_index(
|
||||||
NdArrayBackend::index(indexes_flatten.clone(), [b..b + 1, c..c + 1, i..i + 1]);
|
indexes_flatten.clone(),
|
||||||
let index = NdArrayBackend::into_data(index).value[0] as usize;
|
[b..b + 1, c..c + 1, i..i + 1],
|
||||||
|
);
|
||||||
|
let index = NdArrayBackend::<E>::int_into_data(index).value[0] as usize;
|
||||||
|
|
||||||
let current_value = NdArrayBackend::index(
|
let current_value = NdArrayBackend::index(
|
||||||
output_flatten.clone(),
|
output_flatten.clone(),
|
||||||
|
@ -102,7 +110,7 @@ pub(crate) fn max_pool2d_backward_naive<E: NdArrayElement>(
|
||||||
NdArrayBackend::reshape(output_flatten, x.shape())
|
NdArrayBackend::reshape(output_flatten, x.shape())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn max_pool2d_with_kernel<E: NdArrayElement>(
|
fn max_pool2d_with_kernel<E: FloatNdArrayElement>(
|
||||||
x: NdArrayTensor<E, 2>,
|
x: NdArrayTensor<E, 2>,
|
||||||
kernel_size: [usize; 2],
|
kernel_size: [usize; 2],
|
||||||
stride: [usize; 2],
|
stride: [usize; 2],
|
||||||
|
@ -117,7 +125,7 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
|
||||||
let mut output =
|
let mut output =
|
||||||
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
|
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
|
||||||
let mut indexes =
|
let mut indexes =
|
||||||
NdArrayBackend::empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
|
NdArrayBackend::<E>::int_empty(Shape::new([heigth_new, width_new]), &NdArrayDevice::Cpu);
|
||||||
|
|
||||||
for i in 0..heigth_new {
|
for i in 0..heigth_new {
|
||||||
for j in 0..width_new {
|
for j in 0..width_new {
|
||||||
|
@ -127,7 +135,7 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
|
||||||
let x_ij = NdArrayBackend::index(x.clone(), [i_x..i_x + k1, j_x..j_x + k2]);
|
let x_ij = NdArrayBackend::index(x.clone(), [i_x..i_x + k1, j_x..j_x + k2]);
|
||||||
let x_flatten = NdArrayBackend::reshape(x_ij, Shape::new([k1 * k2]));
|
let x_flatten = NdArrayBackend::reshape(x_ij, Shape::new([k1 * k2]));
|
||||||
let index = NdArrayBackend::argmax(x_flatten.clone(), 0);
|
let index = NdArrayBackend::argmax(x_flatten.clone(), 0);
|
||||||
let index = NdArrayBackend::into_data(index).value[0];
|
let index = NdArrayBackend::<E>::int_into_data(index).value[0];
|
||||||
let value = NdArrayBackend::into_data(x_flatten).value[index as usize];
|
let value = NdArrayBackend::into_data(x_flatten).value[index as usize];
|
||||||
let value = NdArrayBackend::from_data(
|
let value = NdArrayBackend::from_data(
|
||||||
Data::new(vec![value], Shape::new([1, 1])),
|
Data::new(vec![value], Shape::new([1, 1])),
|
||||||
|
@ -141,12 +149,12 @@ fn max_pool2d_with_kernel<E: NdArrayElement>(
|
||||||
let h = (heigth - (2 * p1)) as i64;
|
let h = (heigth - (2 * p1)) as i64;
|
||||||
let index = ii * h + jj;
|
let index = ii * h + jj;
|
||||||
|
|
||||||
let index = NdArrayBackend::from_data(
|
let index = NdArrayBackend::<E>::int_from_data(
|
||||||
Data::new(vec![index], Shape::new([1, 1])),
|
Data::new(vec![index], Shape::new([1, 1])),
|
||||||
&NdArrayDevice::Cpu,
|
&NdArrayDevice::Cpu,
|
||||||
);
|
);
|
||||||
|
|
||||||
indexes = NdArrayBackend::index_assign(indexes, [i..i + 1, j..j + 1], index);
|
indexes = NdArrayBackend::<E>::int_index_assign(indexes, [i..i + 1, j..j + 1], index);
|
||||||
output = NdArrayBackend::index_assign(output, [i..i + 1, j..j + 1], value);
|
output = NdArrayBackend::index_assign(output, [i..i + 1, j..j + 1], value);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ mod tensor;
|
||||||
|
|
||||||
pub(crate) mod conv;
|
pub(crate) mod conv;
|
||||||
pub(crate) mod macros;
|
pub(crate) mod macros;
|
||||||
|
pub(crate) mod matmul;
|
||||||
pub(crate) mod maxpool;
|
pub(crate) mod maxpool;
|
||||||
pub(crate) mod padding;
|
pub(crate) mod padding;
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
|
|
||||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
||||||
|
|
||||||
use burn_tensor::{ops::*, Shape};
|
use burn_tensor::{ops::*, Shape};
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ use super::{
|
||||||
maxpool::{max_pool2d_backward_naive, max_pool2d_with_indexes_naive},
|
maxpool::{max_pool2d_backward_naive, max_pool2d_with_indexes_naive},
|
||||||
};
|
};
|
||||||
|
|
||||||
impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
impl<E: FloatNdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
fn embedding(
|
fn embedding(
|
||||||
weights: NdArrayTensor<E, 2>,
|
weights: NdArrayTensor<E, 2>,
|
||||||
indexes: NdArrayTensor<i64, 2>,
|
indexes: NdArrayTensor<i64, 2>,
|
||||||
|
@ -19,7 +19,8 @@ impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
|
|
||||||
let mut tensors = Vec::with_capacity(batch_size * seq_length);
|
let mut tensors = Vec::with_capacity(batch_size * seq_length);
|
||||||
|
|
||||||
for index in NdArrayBackend::reshape(indexes, Shape::new([batch_size * seq_length]))
|
for index in
|
||||||
|
NdArrayBackend::<E>::int_reshape(indexes, Shape::new([batch_size * seq_length]))
|
||||||
.array
|
.array
|
||||||
.iter()
|
.iter()
|
||||||
{
|
{
|
||||||
|
@ -46,7 +47,7 @@ impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
NdArrayBackend::reshape(output, Shape::new([batch_size * seq_length, d_model]));
|
NdArrayBackend::reshape(output, Shape::new([batch_size * seq_length, d_model]));
|
||||||
|
|
||||||
for (index_output, index) in
|
for (index_output, index) in
|
||||||
NdArrayBackend::reshape(indexes, Shape::new([batch_size * seq_length]))
|
NdArrayBackend::<E>::int_reshape(indexes, Shape::new([batch_size * seq_length]))
|
||||||
.array
|
.array
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
use crate::{element::FloatNdArrayElement, tensor::NdArrayTensor, NdArrayBackend, NdArrayDevice};
|
||||||
use burn_tensor::{ops::TensorOps, Shape};
|
use burn_tensor::{ops::TensorOps, Shape};
|
||||||
|
|
||||||
pub(crate) fn apply_padding2d<E: NdArrayElement>(
|
pub(crate) fn apply_padding2d<E: FloatNdArrayElement>(
|
||||||
x: NdArrayTensor<E, 2>,
|
x: NdArrayTensor<E, 2>,
|
||||||
padding: [usize; 2],
|
padding: [usize; 2],
|
||||||
) -> NdArrayTensor<E, 2> {
|
) -> NdArrayTensor<E, 2> {
|
||||||
|
|
|
@ -3,9 +3,9 @@ use alloc::vec::Vec;
|
||||||
use core::cmp::Ordering;
|
use core::cmp::Ordering;
|
||||||
use core::ops::Range;
|
use core::ops::Range;
|
||||||
|
|
||||||
|
use crate::element::FloatNdArrayElement;
|
||||||
// Current crate
|
// Current crate
|
||||||
use crate::tensor::BatchMatrix;
|
use crate::{tensor::NdArrayTensor, NdArrayBackend};
|
||||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
|
||||||
use crate::{NdArrayDevice, SEED};
|
use crate::{NdArrayDevice, SEED};
|
||||||
|
|
||||||
// Workspace crates
|
// Workspace crates
|
||||||
|
@ -16,9 +16,9 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Sha
|
||||||
// External crates
|
// External crates
|
||||||
use libm::{cos, erf, sin, tanh};
|
use libm::{cos, erf, sin, tanh};
|
||||||
|
|
||||||
use super::{NdArrayMathOps, NdArrayOps};
|
use super::{matmul::matmul, NdArrayMathOps, NdArrayOps};
|
||||||
|
|
||||||
impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
fn from_data<const D: usize>(data: Data<E, D>, _device: &NdArrayDevice) -> NdArrayTensor<E, D> {
|
fn from_data<const D: usize>(data: Data<E, D>, _device: &NdArrayDevice) -> NdArrayTensor<E, D> {
|
||||||
NdArrayTensor::from_data(data)
|
NdArrayTensor::from_data(data)
|
||||||
}
|
}
|
||||||
|
@ -124,11 +124,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
lhs: NdArrayTensor<E, D>,
|
lhs: NdArrayTensor<E, D>,
|
||||||
rhs: NdArrayTensor<E, D>,
|
rhs: NdArrayTensor<E, D>,
|
||||||
) -> NdArrayTensor<E, D> {
|
) -> NdArrayTensor<E, D> {
|
||||||
let batch_self = BatchMatrix::from_ndarray(lhs.array.clone(), lhs.shape());
|
matmul(lhs, rhs)
|
||||||
let batch_other = BatchMatrix::from_ndarray(rhs.array.clone(), rhs.shape());
|
|
||||||
let output = batch_self.matmul(batch_other);
|
|
||||||
|
|
||||||
NdArrayTensor::from_bmatrix(output)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn neg<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
fn neg<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||||
|
@ -392,7 +388,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn arg<E: NdArrayElement, F, const D: usize>(
|
fn arg<E: FloatNdArrayElement, F, const D: usize>(
|
||||||
tensor: NdArrayTensor<E, D>,
|
tensor: NdArrayTensor<E, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
cmp: F,
|
cmp: F,
|
||||||
|
|
|
@ -0,0 +1,19 @@
|
||||||
|
use core::cell::UnsafeCell;
|
||||||
|
|
||||||
|
/// Similar to `SyncUnsafeCell` see [Rust issues](https://github.com/rust-lang/rust/issues/95439).
|
||||||
|
pub(crate) struct UnsafeSharedRef<'a, T> {
|
||||||
|
cell: UnsafeCell<&'a mut T>,
|
||||||
|
}
|
||||||
|
|
||||||
|
unsafe impl<'a, T> Sync for UnsafeSharedRef<'a, T> {}
|
||||||
|
|
||||||
|
impl<'a, T> UnsafeSharedRef<'a, T> {
|
||||||
|
pub fn new(data: &'a mut T) -> Self {
|
||||||
|
Self {
|
||||||
|
cell: UnsafeCell::new(data),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub unsafe fn get(&self) -> &'a mut T {
|
||||||
|
unsafe { core::ptr::read(self.cell.get()) }
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,12 +1,8 @@
|
||||||
use alloc::vec::Vec;
|
|
||||||
|
|
||||||
use super::element::NdArrayElement;
|
|
||||||
|
|
||||||
use burn_tensor::{Data, Shape};
|
use burn_tensor::{Data, Shape};
|
||||||
|
|
||||||
use ndarray::{s, ArcArray, Array, Axis, Dim, Ix2, Ix3, IxDyn};
|
use ndarray::{ArcArray, Array, Dim, IxDyn};
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(new, Debug, Clone)]
|
||||||
pub struct NdArrayTensor<E, const D: usize> {
|
pub struct NdArrayTensor<E, const D: usize> {
|
||||||
pub array: ArcArray<E, IxDyn>,
|
pub array: ArcArray<E, IxDyn>,
|
||||||
}
|
}
|
||||||
|
@ -20,7 +16,7 @@ impl<E, const D: usize> NdArrayTensor<E, D> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod utils {
|
mod utils {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::NdArrayBackend;
|
use crate::{element::FloatNdArrayElement, NdArrayBackend};
|
||||||
use burn_tensor::ops::TensorOps;
|
use burn_tensor::ops::TensorOps;
|
||||||
|
|
||||||
impl<E, const D: usize> NdArrayTensor<E, D>
|
impl<E, const D: usize> NdArrayTensor<E, D>
|
||||||
|
@ -29,105 +25,13 @@ mod utils {
|
||||||
{
|
{
|
||||||
pub(crate) fn into_data(self) -> Data<E, D>
|
pub(crate) fn into_data(self) -> Data<E, D>
|
||||||
where
|
where
|
||||||
E: NdArrayElement,
|
E: FloatNdArrayElement,
|
||||||
{
|
{
|
||||||
<NdArrayBackend<E> as TensorOps<NdArrayBackend<E>>>::into_data::<D>(self)
|
<NdArrayBackend<E> as TensorOps<NdArrayBackend<E>>>::into_data::<D>(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(new)]
|
|
||||||
pub(crate) struct BatchMatrix<E, const D: usize> {
|
|
||||||
pub arrays: Vec<ArcArray<E, Ix2>>,
|
|
||||||
pub shape: Shape<D>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E, const D: usize> BatchMatrix<E, D>
|
|
||||||
where
|
|
||||||
E: NdArrayElement,
|
|
||||||
{
|
|
||||||
pub fn from_ndarray(array: ArcArray<E, IxDyn>, shape: Shape<D>) -> Self {
|
|
||||||
let mut arrays = Vec::new();
|
|
||||||
if D < 2 {
|
|
||||||
let array = array.reshape((1, shape.dims[0]));
|
|
||||||
arrays.push(array);
|
|
||||||
} else {
|
|
||||||
let batch_size = batch_size(&shape);
|
|
||||||
let size0 = shape.dims[D - 2];
|
|
||||||
let size1 = shape.dims[D - 1];
|
|
||||||
let array_global = array.reshape((batch_size, size0, size1));
|
|
||||||
for b in 0..batch_size {
|
|
||||||
let array = array_global.slice(s!(b, .., ..));
|
|
||||||
let array = array.into_owned().into_shared();
|
|
||||||
arrays.push(array);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Self { arrays, shape }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn matmul(self, other: BatchMatrix<E, D>) -> Self {
|
|
||||||
let require_broadcast = self.arrays.len() != other.arrays.len();
|
|
||||||
if require_broadcast {
|
|
||||||
return self.matmul_broadcast(other);
|
|
||||||
}
|
|
||||||
|
|
||||||
let self_iter = self.arrays.iter();
|
|
||||||
let other_iter = other.arrays.iter();
|
|
||||||
|
|
||||||
let arrays = self_iter
|
|
||||||
.zip(other_iter)
|
|
||||||
.map(|(lhs, rhs)| lhs.dot(rhs))
|
|
||||||
.map(|output| output.into_shared())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut shape = self.shape;
|
|
||||||
shape.dims[D - 1] = other.shape.dims[D - 1];
|
|
||||||
|
|
||||||
Self::new(arrays, shape)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn matmul_broadcast(self, other: BatchMatrix<E, D>) -> Self {
|
|
||||||
let valid_broadcast = self.arrays.len() == 1 || other.arrays.len() == 1;
|
|
||||||
if !valid_broadcast {
|
|
||||||
panic!("Invalid broadcast => {:?} , {:?}", self.shape, other.shape);
|
|
||||||
}
|
|
||||||
let batch_size = usize::max(self.arrays.len(), other.arrays.len());
|
|
||||||
let mut arrays = Vec::with_capacity(batch_size);
|
|
||||||
|
|
||||||
for batch in 0..batch_size {
|
|
||||||
let self_tensor = if self.arrays.len() == 1 {
|
|
||||||
&self.arrays[0]
|
|
||||||
} else {
|
|
||||||
&self.arrays[batch]
|
|
||||||
};
|
|
||||||
|
|
||||||
let other_tensor = if other.arrays.len() == 1 {
|
|
||||||
&other.arrays[0]
|
|
||||||
} else {
|
|
||||||
&other.arrays[batch]
|
|
||||||
};
|
|
||||||
|
|
||||||
let tensor = self_tensor.dot(other_tensor);
|
|
||||||
arrays.push(tensor.into_shared());
|
|
||||||
}
|
|
||||||
|
|
||||||
let mut shape = self.shape;
|
|
||||||
shape.dims[D - 1] = other.shape.dims[D - 1];
|
|
||||||
|
|
||||||
Self::new(arrays, shape)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
fn batch_size<const D: usize>(shape: &Shape<D>) -> usize {
|
|
||||||
let mut num_batch = 1;
|
|
||||||
for i in 0..D - 2 {
|
|
||||||
num_batch *= shape.dims[i];
|
|
||||||
}
|
|
||||||
|
|
||||||
num_batch
|
|
||||||
}
|
|
||||||
|
|
||||||
#[macro_export(local_inner_macros)]
|
#[macro_export(local_inner_macros)]
|
||||||
macro_rules! to_typed_dims {
|
macro_rules! to_typed_dims {
|
||||||
(
|
(
|
||||||
|
@ -145,60 +49,45 @@ macro_rules! to_typed_dims {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[macro_export(local_inner_macros)]
|
#[macro_export(local_inner_macros)]
|
||||||
macro_rules! to_nd_array_tensor {
|
macro_rules! reshape {
|
||||||
(
|
(
|
||||||
$n:expr,
|
ty $ty:ty,
|
||||||
$shape:expr,
|
n $n:expr,
|
||||||
$array:expr
|
shape $shape:expr,
|
||||||
|
array $array:expr
|
||||||
) => {{
|
) => {{
|
||||||
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
|
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
|
||||||
let array: ndarray::ArcArray<E, Dim<[usize; $n]>> = $array.reshape(dim);
|
let safe_into_shape =
|
||||||
let array = array.into_dyn();
|
$array.is_standard_layout() || $array.raw_view().reversed_axes().is_standard_layout();
|
||||||
|
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape {
|
||||||
NdArrayTensor { array }
|
true => $array
|
||||||
}};
|
.into_shape(dim)
|
||||||
(
|
.expect("Safe to change shape without relayout")
|
||||||
bool,
|
.into_shared(),
|
||||||
$n:expr,
|
false => $array.reshape(dim),
|
||||||
$shape:expr,
|
|
||||||
$array:expr
|
|
||||||
) => {{
|
|
||||||
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
|
|
||||||
let array: ndarray::ArcArray<bool, Dim<[usize; $n]>> = $array.reshape(dim);
|
|
||||||
let array = array.into_dyn();
|
|
||||||
|
|
||||||
NdArrayTensor { array }
|
|
||||||
}};
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E, const D: usize> NdArrayTensor<E, D>
|
|
||||||
where
|
|
||||||
E: Default + Clone,
|
|
||||||
{
|
|
||||||
pub(crate) fn from_bmatrix(bmatrix: BatchMatrix<E, D>) -> NdArrayTensor<E, D> {
|
|
||||||
let shape = bmatrix.shape.clone();
|
|
||||||
let to_array = |data: BatchMatrix<E, D>| {
|
|
||||||
let dims = data.shape.dims;
|
|
||||||
let mut array: Array<E, Ix3> = Array::default((0, dims[D - 2], dims[D - 1]));
|
|
||||||
|
|
||||||
for item in data.arrays {
|
|
||||||
array.push(Axis(0), item.view()).unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
array.into_shared()
|
|
||||||
};
|
};
|
||||||
|
let array = array.into_dyn();
|
||||||
|
|
||||||
match D {
|
NdArrayTensor::new(array)
|
||||||
1 => to_nd_array_tensor!(1, shape, to_array(bmatrix)),
|
}};
|
||||||
2 => to_nd_array_tensor!(2, shape, to_array(bmatrix)),
|
(
|
||||||
3 => to_nd_array_tensor!(3, shape, to_array(bmatrix)),
|
ty $ty:ty,
|
||||||
4 => to_nd_array_tensor!(4, shape, to_array(bmatrix)),
|
shape $shape:expr,
|
||||||
5 => to_nd_array_tensor!(5, shape, to_array(bmatrix)),
|
array $array:expr,
|
||||||
6 => to_nd_array_tensor!(6, shape, to_array(bmatrix)),
|
d $D:expr
|
||||||
_ => panic!(""),
|
) => {{
|
||||||
}
|
match $D {
|
||||||
|
1 => reshape!(ty $ty, n 1, shape $shape, array $array),
|
||||||
|
2 => reshape!(ty $ty, n 2, shape $shape, array $array),
|
||||||
|
3 => reshape!(ty $ty, n 3, shape $shape, array $array),
|
||||||
|
4 => reshape!(ty $ty, n 4, shape $shape, array $array),
|
||||||
|
5 => reshape!(ty $ty, n 5, shape $shape, array $array),
|
||||||
|
6 => reshape!(ty $ty, n 6, shape $shape, array $array),
|
||||||
|
_ => core::panic!("NdArray supports arrays up to 6 dimensions, received: {}", $D),
|
||||||
}
|
}
|
||||||
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<E, const D: usize> NdArrayTensor<E, D>
|
impl<E, const D: usize> NdArrayTensor<E, D>
|
||||||
where
|
where
|
||||||
E: Default + Clone,
|
E: Default + Clone,
|
||||||
|
@ -206,16 +95,14 @@ where
|
||||||
pub fn from_data(data: Data<E, D>) -> NdArrayTensor<E, D> {
|
pub fn from_data(data: Data<E, D>) -> NdArrayTensor<E, D> {
|
||||||
let shape = data.shape.clone();
|
let shape = data.shape.clone();
|
||||||
let to_array = |data: Data<E, D>| Array::from_iter(data.value.into_iter()).into_shared();
|
let to_array = |data: Data<E, D>| Array::from_iter(data.value.into_iter()).into_shared();
|
||||||
|
let array = to_array(data);
|
||||||
|
|
||||||
match D {
|
reshape!(
|
||||||
1 => to_nd_array_tensor!(1, shape, to_array(data)),
|
ty E,
|
||||||
2 => to_nd_array_tensor!(2, shape, to_array(data)),
|
shape shape,
|
||||||
3 => to_nd_array_tensor!(3, shape, to_array(data)),
|
array array,
|
||||||
4 => to_nd_array_tensor!(4, shape, to_array(data)),
|
d D
|
||||||
5 => to_nd_array_tensor!(5, shape, to_array(data)),
|
)
|
||||||
6 => to_nd_array_tensor!(6, shape, to_array(data)),
|
|
||||||
_ => panic!(""),
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -32,4 +32,19 @@ mod tests {
|
||||||
Data::from([[[18.0, 28.0], [14.0, 23.0]]])
|
Data::from([[[18.0, 28.0], [14.0, 23.0]]])
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_matmul_broadcast_1() {
|
||||||
|
let data_1 = Data::from([[[1.0, 7.0], [2.0, 3.0]]]);
|
||||||
|
let data_2 = Data::from([[[4.0, 7.0], [2.0, 3.0]], [[4.0, 7.0], [2.0, 3.0]]]);
|
||||||
|
let tensor_1 = Tensor::<TestBackend, 3>::from_data(data_1);
|
||||||
|
let tensor_2 = Tensor::<TestBackend, 3>::from_data(data_2);
|
||||||
|
|
||||||
|
let tensor_3 = tensor_1.matmul(tensor_2);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
tensor_3.into_data(),
|
||||||
|
Data::from([[[18.0, 28.0], [14.0, 23.0]], [[18.0, 28.0], [14.0, 23.0]]])
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue