mirror of https://github.com/tracel-ai/burn.git
Add more ops (#387)
This commit is contained in:
parent
ecc67c58f9
commit
c1e1e38a79
|
@ -1,7 +1,7 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
|
||||
use crate::{backend::Backend, tensor::Shape, Data};
|
||||
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
|
||||
|
||||
/// Int Tensor API for basic and numeric operations, see [tensor](crate::Tensor)
|
||||
/// for documentation on each function.
|
||||
|
@ -172,7 +172,9 @@ pub trait IntTensorOps<B: Backend> {
|
|||
lhs: B::IntTensorPrimitive<D>,
|
||||
rhs: B::IntElem,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
fn int_neg<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<D>;
|
||||
fn int_neg<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<D> {
|
||||
Self::int_mul_scalar(tensor, (-1.0).elem::<B::IntElem>())
|
||||
}
|
||||
fn int_zeros<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
|
||||
fn int_ones<const D: usize>(shape: Shape<D>, device: &B::Device) -> B::IntTensorPrimitive<D>;
|
||||
fn int_sum<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1>;
|
||||
|
|
|
@ -32,6 +32,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
burn_tensor::testgen_log!();
|
||||
burn_tensor::testgen_sqrt!();
|
||||
burn_tensor::testgen_log1p!();
|
||||
burn_tensor::testgen_index!();
|
||||
burn_tensor::testgen_gather_scatter!();
|
||||
|
|
|
@ -13,4 +13,18 @@ mod tests {
|
|||
let data_expected = Data::from([[0.0000, 0.8427, 0.9953], [1.0000, 1.0000, 1.0000]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_erf_ops_with_negative_number() {
|
||||
let data = Data::from([[-0.056, -0.043, -0.089], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.erf().into_data();
|
||||
|
||||
let data_expected = Data::from([
|
||||
[-0.06312324, -0.048490416, -0.10016122],
|
||||
[1.0000, 1.0000, 1.0000],
|
||||
]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -21,6 +21,7 @@ mod powf;
|
|||
mod repeat;
|
||||
mod reshape;
|
||||
mod sin;
|
||||
mod sqrt;
|
||||
mod sub;
|
||||
mod tanh;
|
||||
mod transpose;
|
||||
|
|
|
@ -0,0 +1,16 @@
|
|||
#[burn_tensor_testgen::testgen(sqrt)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_sqrt_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.sqrt().into_data();
|
||||
|
||||
let data_expected = Data::from([[0.0, 1.0, 1.41421], [1.73205, 2.0, 2.2360]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
|
@ -38,6 +38,28 @@ macro_rules! unary {
|
|||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr,
|
||||
include $file:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
$crate::kernel_wgsl!(Include, $file);
|
||||
|
||||
let source = $crate::kernel::UnaryRaw::generate().to_string();
|
||||
let body = format!("output[global_id.x] = {}(input[global_id.x]);", $func);
|
||||
let source = source.replace("BODY", &body);
|
||||
let included: &str = Include::generate().as_ref();
|
||||
|
||||
format!("{}\n{}", included, source)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
|
@ -73,6 +95,28 @@ macro_rules! unary_inplace {
|
|||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr,
|
||||
include $file:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::KernelGenerator for $struct {
|
||||
type Source = String;
|
||||
|
||||
fn generate() -> Self::Source {
|
||||
$crate::kernel_wgsl!(Include, $file);
|
||||
|
||||
let source = $crate::kernel::UnaryInplaceRaw::generate().to_string();
|
||||
let body = format!("input[global_id.x] = {}(input[global_id.x]);", $func);
|
||||
let source = source.replace("BODY", &body);
|
||||
let included: &str = Include::generate().as_ref();
|
||||
|
||||
format!("{}\n{}", included, source)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub fn unary<K: KernelGenerator, E: WGPUElement, const D: usize>(
|
||||
|
|
|
@ -34,6 +34,12 @@ mod tests {
|
|||
burn_tensor::testgen_powf!();
|
||||
burn_tensor::testgen_exp!();
|
||||
burn_tensor::testgen_log!();
|
||||
burn_tensor::testgen_log1p!();
|
||||
burn_tensor::testgen_sqrt!();
|
||||
burn_tensor::testgen_cos!();
|
||||
burn_tensor::testgen_sin!();
|
||||
burn_tensor::testgen_tanh!();
|
||||
burn_tensor::testgen_erf!();
|
||||
burn_tensor::testgen_relu!();
|
||||
burn_tensor::testgen_matmul!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
|
@ -50,6 +56,11 @@ mod tests {
|
|||
burn_autodiff::testgen_ad_powf!();
|
||||
burn_autodiff::testgen_ad_exp!();
|
||||
burn_autodiff::testgen_ad_log!();
|
||||
burn_autodiff::testgen_ad_log1p!();
|
||||
burn_autodiff::testgen_ad_sqrt!();
|
||||
burn_autodiff::testgen_ad_cos!();
|
||||
burn_autodiff::testgen_ad_sin!();
|
||||
burn_autodiff::testgen_ad_tanh!();
|
||||
burn_autodiff::testgen_ad_matmul!();
|
||||
burn_autodiff::testgen_ad_reshape!();
|
||||
burn_autodiff::testgen_ad_transpose!();
|
||||
|
|
|
@ -48,10 +48,8 @@ where
|
|||
todo!()
|
||||
}
|
||||
|
||||
fn bool_device<const D: usize>(
|
||||
_tensor: &<WGPUBackend<G, F, I> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::Device {
|
||||
todo!()
|
||||
fn bool_device<const D: usize>(tensor: &BoolTensor<Self, D>) -> Device<Self> {
|
||||
tensor.context.device.clone()
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
|
|
|
@ -70,56 +70,64 @@ where
|
|||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::add(lhs, rhs)
|
||||
NumericOps::<G>::add(lhs, rhs)
|
||||
}
|
||||
|
||||
fn add_scalar<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::add_scalar(lhs, rhs)
|
||||
NumericOps::<G>::add_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
|
||||
NumericOps::<G>::zeros(shape, device)
|
||||
}
|
||||
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> FloatTensor<Self, D> {
|
||||
NumericOps::<G>::ones(shape, device)
|
||||
}
|
||||
|
||||
fn sub<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::sub(lhs, rhs)
|
||||
NumericOps::<G>::sub(lhs, rhs)
|
||||
}
|
||||
|
||||
fn sub_scalar<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::sub_scalar(lhs, rhs)
|
||||
NumericOps::<G>::sub_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn mul<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::mul(lhs, rhs)
|
||||
NumericOps::<G>::mul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn mul_scalar<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::mul_scalar(lhs, rhs)
|
||||
NumericOps::<G>::mul_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn div<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::div(lhs, rhs)
|
||||
NumericOps::<G>::div(lhs, rhs)
|
||||
}
|
||||
|
||||
fn div_scalar<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
NumericOps::div_scalar(lhs, rhs)
|
||||
NumericOps::<G>::div_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn matmul<const D: usize>(
|
||||
|
@ -343,10 +351,15 @@ where
|
|||
unary::<Log, F, D>(tensor)
|
||||
}
|
||||
|
||||
fn log1p<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Log1p, body "output[global_id.x] = log(1.0 + input[global_id.x]);");
|
||||
unary_inplace!(Log1pInplace, body "input[global_id.x] = log(1.0 + input[global_id.x]);");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace::<Log1pInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary::<Log1p, F, D>(tensor)
|
||||
}
|
||||
|
||||
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
|
||||
|
@ -360,34 +373,59 @@ where
|
|||
unary_scalar::<Powf, F, D>(lhs, rhs.elem())
|
||||
}
|
||||
|
||||
fn sqrt<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Sqrt, func "sqrt");
|
||||
unary_inplace!(SqrtInplace, func "sqrt");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace::<SqrtInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary::<Sqrt, F, D>(tensor)
|
||||
}
|
||||
|
||||
fn cos<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Cos, func "cos");
|
||||
unary_inplace!(CosInplace, func "cos");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace::<CosInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary::<Cos, F, D>(tensor)
|
||||
}
|
||||
|
||||
fn sin<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Sin, func "sin");
|
||||
unary_inplace!(SinInplace, func "sin");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace::<SinInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary::<Sin, F, D>(tensor)
|
||||
}
|
||||
|
||||
fn tanh<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Tanh, func "tanh");
|
||||
unary_inplace!(TanhInplace, func "tanh");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace::<TanhInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary::<Tanh, F, D>(tensor)
|
||||
}
|
||||
|
||||
fn erf<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D> {
|
||||
todo!()
|
||||
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Erf, func "erf", include "../template/erf.wgsl");
|
||||
unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace::<ErfInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary::<Erf, F, D>(tensor)
|
||||
}
|
||||
|
||||
fn cat<const D: usize>(
|
||||
|
|
|
@ -17,10 +17,8 @@ where
|
|||
BaseOps::<G>::empty(shape, device)
|
||||
}
|
||||
|
||||
fn int_shape<const D: usize>(
|
||||
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> Shape<D> {
|
||||
todo!()
|
||||
fn int_shape<const D: usize>(tensor: &IntTensor<Self, D>) -> Shape<D> {
|
||||
tensor.shape.clone()
|
||||
}
|
||||
|
||||
fn int_into_data<const D: usize>(tensor: IntTensor<Self, D>) -> Data<I, D> {
|
||||
|
@ -34,10 +32,8 @@ where
|
|||
BaseOps::<G>::from_data(data, device)
|
||||
}
|
||||
|
||||
fn int_device<const D: usize>(
|
||||
_tensor: &<WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::Device {
|
||||
todo!()
|
||||
fn int_device<const D: usize>(tensor: &IntTensor<Self, D>) -> Device<Self> {
|
||||
tensor.context.device.clone()
|
||||
}
|
||||
|
||||
fn int_to_device<const D: usize>(
|
||||
|
@ -200,76 +196,64 @@ where
|
|||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::add::<I, D>(lhs, rhs)
|
||||
NumericOps::<G>::add::<I, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_add_scalar<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::add_scalar(lhs, rhs)
|
||||
NumericOps::<G>::add_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_sub<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::sub(lhs, rhs)
|
||||
NumericOps::<G>::sub(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_sub_scalar<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::sub_scalar(lhs, rhs)
|
||||
NumericOps::<G>::sub_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::mul(lhs, rhs)
|
||||
NumericOps::<G>::mul(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_mul_scalar<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::mul_scalar(lhs, rhs)
|
||||
NumericOps::<G>::mul_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntTensor<Self, D>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::div(lhs, rhs)
|
||||
NumericOps::<G>::div(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_div_scalar<const D: usize>(
|
||||
lhs: IntTensor<Self, D>,
|
||||
rhs: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
NumericOps::div_scalar(lhs, rhs)
|
||||
NumericOps::<G>::div_scalar(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_neg<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
fn int_zeros<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::zeros(shape, device)
|
||||
}
|
||||
|
||||
fn int_zeros<const D: usize>(
|
||||
_shape: Shape<D>,
|
||||
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn int_ones<const D: usize>(
|
||||
_shape: Shape<D>,
|
||||
_device: &<WGPUBackend<G, F, I> as Backend>::Device,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
fn int_ones<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::ones(shape, device)
|
||||
}
|
||||
|
||||
fn int_sum<const D: usize>(
|
||||
|
|
|
@ -1,12 +1,46 @@
|
|||
use std::marker::PhantomData;
|
||||
use std::sync::Arc;
|
||||
|
||||
use burn_tensor::{Element, ElementConversion, Shape};
|
||||
|
||||
use crate::kernel::{binary_elemwise, binary_elemwise_inplace, unary_scalar, unary_scalar_inplace};
|
||||
use crate::pool::get_context;
|
||||
use crate::{
|
||||
binary_elemwise, binary_elemwise_inplace, element::WGPUElement, tensor::WgpuTensor,
|
||||
unary_scalar, unary_scalar_inplace,
|
||||
};
|
||||
use crate::{GraphicsApi, WgpuDevice};
|
||||
|
||||
pub struct NumericOps;
|
||||
pub struct NumericOps<G: GraphicsApi> {
|
||||
_g: PhantomData<G>,
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi> NumericOps<G> {
|
||||
pub fn zeros<E: WGPUElement, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &WgpuDevice,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let context = get_context::<G>(device);
|
||||
|
||||
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||
|
||||
WgpuTensor::new(context, shape, Arc::new(buffer))
|
||||
}
|
||||
|
||||
pub fn ones<E: WGPUElement + Element, const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &WgpuDevice,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let context = get_context::<G>(device);
|
||||
|
||||
let buffer = context.create_buffer(shape.num_elements() * core::mem::size_of::<E>());
|
||||
|
||||
Self::add_scalar(
|
||||
WgpuTensor::new(context, shape, Arc::new(buffer)),
|
||||
1i32.elem::<E>(),
|
||||
)
|
||||
}
|
||||
|
||||
impl NumericOps {
|
||||
pub fn add<E: WGPUElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
|
||||
///
|
||||
/// > (maximum error: 1.5×10−7)
|
||||
/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x).
|
||||
fn erf_positive(x: elem) -> elem {
|
||||
let p = 0.3275911;
|
||||
let a1 = 0.254829592;
|
||||
let a2 = -0.284496736;
|
||||
let a3 = 1.421413741;
|
||||
let a4 = -1.453152027;
|
||||
let a5 = 1.061405429;
|
||||
|
||||
let t = 1.0 / (1.0 + p * abs(x));
|
||||
let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
|
||||
|
||||
return 1.0 - (tmp * t * exp(-x * x));
|
||||
}
|
||||
|
||||
fn erf(x: elem) -> elem {
|
||||
if (x < 0.0) {
|
||||
return -1.0 * erf_positive(-1.0 * x);
|
||||
}
|
||||
|
||||
return erf_positive(x);
|
||||
}
|
Loading…
Reference in New Issue