Add more ops (#387)

This commit is contained in:
Nathaniel Simard 2023-06-06 12:21:20 -04:00 committed by GitHub
parent ecc67c58f9
commit c1e1e38a79
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 240 additions and 72 deletions

View File

@ -1,7 +1,7 @@
use alloc::vec::Vec;
use core::ops::Range;
use crate::{backend::Backend, tensor::Shape, Data};
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
/// for documentation on each function.
@ -172,7 +172,9 @@ pub trait IntTensorOps<B: Backend> {
lhs: B::IntTensorPrimitive<D>,
rhs: B::IntElem,
) -> B::IntTensorPrimitive<D>;
fn int_neg<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<D>;
fn int_neg<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<D> {
Self::int_mul_scalar(tensor, (-1.0).elem::<B::IntElem>())
}
fn int_zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
fn int_ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
fn int_sum<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1>;

View File

@ -32,6 +32,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_erf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_log!();
burn_tensor::testgen_sqrt!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_index!();
burn_tensor::testgen_gather_scatter!();

View File

@ -13,4 +13,18 @@ mod tests {
let data_expected = Data::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
#[test]
fn should_support_erf_ops_with_negative_number() {
let data = Data::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.erf().into_data();
let data_expected = Data::from([
[-0.06312324, -0.048490416, -0.10016122],
[1.0000, 1.0000, 1.0000],
]);
data_expected.assert_approx_eq(&data_actual, 3);
}
}

View File

@ -21,6 +21,7 @@ mod powf;
mod repeat;
mod reshape;
mod sin;
mod sqrt;
mod sub;
mod tanh;
mod transpose;

View File

@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(sqrt)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_support_sqrt_ops() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.sqrt().into_data();
let data_expected = Data::from([[0.0, 1.0, 1.41421], [1.73205, 2.0, 2.2360]]);
data_expected.assert_approx_eq(&data_actual, 3);
}
}

View File

@ -38,6 +38,28 @@ macro_rules! unary {
}
}
};
(
$struct:ident,
func $func:expr,
include $file:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
$crate::kernel_wgsl!(Include, $file);
let source = $crate::kernel::UnaryRaw::generate().to_string();
let body = format!("output[global_id.x] = {}(input[global_id.x]);", $func);
let source = source.replace("BODY", &body);
let included: &str = Include::generate().as_ref();
format!("{}\n{}", included, source)
}
}
};
}
#[macro_export]
@ -73,6 +95,28 @@ macro_rules! unary_inplace {
}
}
};
(
$struct:ident,
func $func:expr,
include $file:expr
) => {
pub struct $struct;
impl $crate::kernel::KernelGenerator for $struct {
type Source = String;
fn generate() -> Self::Source {
$crate::kernel_wgsl!(Include, $file);
let source = $crate::kernel::UnaryInplaceRaw::generate().to_string();
let body = format!("input[global_id.x] = {}(input[global_id.x]);", $func);
let source = source.replace("BODY", &body);
let included: &str = Include::generate().as_ref();
format!("{}\n{}", included, source)
}
}
};
}
pub fn unary<K: KernelGenerator, E: WGPUElement, const D: usize>(

View File

@ -34,6 +34,12 @@ mod tests {
burn_tensor::testgen_powf!();
burn_tensor::testgen_exp!();
burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_sqrt!();
burn_tensor::testgen_cos!();
burn_tensor::testgen_sin!();
burn_tensor::testgen_tanh!();
burn_tensor::testgen_erf!();
burn_tensor::testgen_relu!();
burn_tensor::testgen_matmul!();
burn_tensor::testgen_reshape!();
@ -50,6 +56,11 @@ mod tests {
burn_autodiff::testgen_ad_powf!();
burn_autodiff::testgen_ad_exp!();
burn_autodiff::testgen_ad_log!();
burn_autodiff::testgen_ad_log1p!();
burn_autodiff::testgen_ad_sqrt!();
burn_autodiff::testgen_ad_cos!();
burn_autodiff::testgen_ad_sin!();
burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_matmul!();
burn_autodiff::testgen_ad_reshape!();
burn_autodiff::testgen_ad_transpose!();

View File

@ -48,10 +48,8 @@ where
todo!()
}
fn bool_device<const D: usize>(
_tensor: &<WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::Device {
todo!()
fn bool_device<const D: usize>(tensor: &BoolTensor<Self, D>) -> Device<Self> {
tensor.context.device.clone()
}
fn bool_to_device<const D: usize>(

View File

@ -70,56 +70,64 @@ where
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
NumericOps::add(lhs, rhs)
NumericOps::<G>::add(lhs, rhs)
}
fn add_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
NumericOps::add_scalar(lhs, rhs)
NumericOps::<G>::add_scalar(lhs, rhs)
}
fn zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
NumericOps::<G>::zeros(shape, device)
}
fn ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
NumericOps::<G>::ones(shape, device)
}
fn sub<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
NumericOps::sub(lhs, rhs)
NumericOps::<G>::sub(lhs, rhs)
}
fn sub_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
NumericOps::sub_scalar(lhs, rhs)
NumericOps::<G>::sub_scalar(lhs, rhs)
}
fn mul<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
NumericOps::mul(lhs, rhs)
NumericOps::<G>::mul(lhs, rhs)
}
fn mul_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
NumericOps::mul_scalar(lhs, rhs)
NumericOps::<G>::mul_scalar(lhs, rhs)
}
fn div<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatTensor<Self, D>,
) -> FloatTensor<Self, D> {
NumericOps::div(lhs, rhs)
NumericOps::<G>::div(lhs, rhs)
}
fn div_scalar<const D: usize>(
lhs: FloatTensor<Self, D>,
rhs: FloatElem<Self>,
) -> FloatTensor<Self, D> {
NumericOps::div_scalar(lhs, rhs)
NumericOps::<G>::div_scalar(lhs, rhs)
}
fn matmul<const D: usize>(
@ -343,10 +351,15 @@ where
unary::<Log, F, D>(tensor)
}
fn log1p<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Log1p, body "output[global_id.x] = log(1.0 + input[global_id.x]);");
unary_inplace!(Log1pInplace, body "input[global_id.x] = log(1.0 + input[global_id.x]);");
if tensor.can_mut() {
return unary_inplace::<Log1pInplace, F, D>(tensor);
}
unary::<Log1p, F, D>(tensor)
}
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
@ -360,34 +373,59 @@ where
unary_scalar::<Powf, F, D>(lhs, rhs.elem())
}
fn sqrt<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Sqrt, func "sqrt");
unary_inplace!(SqrtInplace, func "sqrt");
if tensor.can_mut() {
return unary_inplace::<SqrtInplace, F, D>(tensor);
}
unary::<Sqrt, F, D>(tensor)
}
fn cos<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Cos, func "cos");
unary_inplace!(CosInplace, func "cos");
if tensor.can_mut() {
return unary_inplace::<CosInplace, F, D>(tensor);
}
unary::<Cos, F, D>(tensor)
}
fn sin<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Sin, func "sin");
unary_inplace!(SinInplace, func "sin");
if tensor.can_mut() {
return unary_inplace::<SinInplace, F, D>(tensor);
}
unary::<Sin, F, D>(tensor)
}
fn tanh<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Tanh, func "tanh");
unary_inplace!(TanhInplace, func "tanh");
if tensor.can_mut() {
return unary_inplace::<TanhInplace, F, D>(tensor);
}
unary::<Tanh, F, D>(tensor)
}
fn erf<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
todo!()
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(Erf, func "erf", include "../template/erf.wgsl");
unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl");
if tensor.can_mut() {
return unary_inplace::<ErfInplace, F, D>(tensor);
}
unary::<Erf, F, D>(tensor)
}
fn cat<const D: usize>(

View File

@ -17,10 +17,8 @@ where
BaseOps::<G>::empty(shape, device)
}
fn int_shape<const D: usize>(
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> Shape<D> {
todo!()
fn int_shape<const D: usize>(tensor: &IntTensor<Self, D>) -> Shape<D> {
tensor.shape.clone()
}
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Data<I, D> {
@ -34,10 +32,8 @@ where
BaseOps::<G>::from_data(data, device)
}
fn int_device<const D: usize>(
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::Device {
todo!()
fn int_device<const D: usize>(tensor: &IntTensor<Self, D>) -> Device<Self> {
tensor.context.device.clone()
}
fn int_to_device<const D: usize>(
@ -200,76 +196,64 @@ where
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
NumericOps::add::<I, D>(lhs, rhs)
NumericOps::<G>::add::<I, D>(lhs, rhs)
}
fn int_add_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
NumericOps::add_scalar(lhs, rhs)
NumericOps::<G>::add_scalar(lhs, rhs)
}
fn int_sub<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
NumericOps::sub(lhs, rhs)
NumericOps::<G>::sub(lhs, rhs)
}
fn int_sub_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
NumericOps::sub_scalar(lhs, rhs)
NumericOps::<G>::sub_scalar(lhs, rhs)
}
fn int_mul<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
NumericOps::mul(lhs, rhs)
NumericOps::<G>::mul(lhs, rhs)
}
fn int_mul_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
NumericOps::mul_scalar(lhs, rhs)
NumericOps::<G>::mul_scalar(lhs, rhs)
}
fn int_div<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntTensor<Self, D>,
) -> IntTensor<Self, D> {
NumericOps::div(lhs, rhs)
NumericOps::<G>::div(lhs, rhs)
}
fn int_div_scalar<const D: usize>(
lhs: IntTensor<Self, D>,
rhs: IntElem<Self>,
) -> IntTensor<Self, D> {
NumericOps::div_scalar(lhs, rhs)
NumericOps::<G>::div_scalar(lhs, rhs)
}
fn int_neg<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
fn int_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
NumericOps::<G>::zeros(shape, device)
}
fn int_zeros<const D: usize>(
_shape: Shape<D>,
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
}
fn int_ones<const D: usize>(
_shape: Shape<D>,
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
fn int_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
NumericOps::<G>::ones(shape, device)
}
fn int_sum<const D: usize>(

View File

@ -1,12 +1,46 @@
use std::marker::PhantomData;
use std::sync::Arc;
use burn_tensor::{Element, ElementConversion, Shape};
use crate::kernel::{binary_elemwise, binary_elemwise_inplace, unary_scalar, unary_scalar_inplace};
use crate::pool::get_context;
use crate::{
binary_elemwise, binary_elemwise_inplace, element::WGPUElement, tensor::WgpuTensor,
unary_scalar, unary_scalar_inplace,
};
use crate::{GraphicsApi, WgpuDevice};
pub struct NumericOps;
pub struct NumericOps<G: GraphicsApi> {
_g: PhantomData<G>,
}
impl<G: GraphicsApi> NumericOps<G> {
pub fn zeros<E: WGPUElement, const D: usize>(
shape: Shape<D>,
device: &WgpuDevice,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
WgpuTensor::new(context, shape, Arc::new(buffer))
}
pub fn ones<E: WGPUElement + Element, const D: usize>(
shape: Shape<D>,
device: &WgpuDevice,
) -> WgpuTensor<E, D> {
let context = get_context::<G>(device);
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
Self::add_scalar(
WgpuTensor::new(context, shape, Arc::new(buffer)),
1i32.elem::<E>(),
)
}
impl NumericOps {
pub fn add<E: WGPUElement, const D: usize>(
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,

View File

@ -0,0 +1,25 @@
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
///
/// > (maximum error: 1.5×107)
/// > All of these approximations are valid for x 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = erf(x).
fn erf_positive(x: elem) -> elem {
let p = 0.3275911;
let a1 = 0.254829592;
let a2 = -0.284496736;
let a3 = 1.421413741;
let a4 = -1.453152027;
let a5 = 1.061405429;
let t = 1.0 / (1.0 + p * abs(x));
let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
return 1.0 - (tmp * t * exp(-x * x));
}
fn erf(x: elem) -> elem {
if (x < 0.0) {
return -1.0 * erf_positive(-1.0 * x);
}
return erf_positive(x);
}