* Added placeholders for FFT module

* First working implementation of 1D FFT

* Tidied things up a bit

* More tests + placeholders for other backends

* WGPU FFT almost works but not quite yet

* WGPU FFT works

* Tidied up WGPU FFT a bit

* Added 1D ifft default implementation using 1D fft module op

* More tests, added extra buffer to WGPU to not modify tensor in place
This commit is contained in:
Tom Wyllie 2024-04-04 18:30:59 +01:00 committed by GitHub
parent ce898ff899
commit 38ee355245
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
17 changed files with 839 additions and 6 deletions

View File

@ -1030,6 +1030,12 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive<4> {
panic!("Can't differentiate interpolate backward.");
}
fn fft(_x: AutodiffTensor<B, 3>) -> AutodiffTensor<B, 3> {
// Imaginary gradients = not practical. People use FFT for feature
// extraction.
panic!("Can't differentiate FFT.");
}
}
#[derive(Debug)]

View File

@ -266,4 +266,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
) -> FloatTensor<Self, 4> {
panic!("interpolate_backward is not supported by Candle")
}
fn fft(x: FloatTensor<Self, 3>) -> FloatTensor<Self, 3> {
panic!("fft is not supported by Candle")
}
}

View File

@ -1,4 +1,3 @@
use crate::stream::InterpolateBackwardDescription;
use crate::{
client::FusionClient,
stream::{
@ -6,10 +5,11 @@ use crate::{
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription,
AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription,
ConvTranspose2dDescription, InterpolateDescription, MaxPool1dDescription,
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription,
MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
MaxPool2dWithIndicesDescription, Operation, OperationDescription,
ConvTranspose2dDescription, InterpolateBackwardDescription, InterpolateDescription,
MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
MaxPool1dWithIndicesDescription, MaxPool2dDescription,
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, Operation,
OperationDescription,
},
Fusion, FusionBackend, HandleContainer,
};
@ -1054,4 +1054,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
);
out
}
fn fft(_x: FloatTensor<Self, 3>) -> FloatTensor<Self, 3> {
todo!();
}
}

View File

@ -0,0 +1,84 @@
use crate::{
compute::StaticKernel,
element::JitElement,
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
kernel_wgsl,
ops::numeric::empty_device,
tensor::JitTensor,
Runtime,
};
use burn_tensor::Element;
kernel_wgsl!(FFT, "../../template/fft/fft.wgsl");
pub(crate) fn fft<R: Runtime, E: JitElement + Element>(
input: JitTensor<R, E, 3>,
) -> JitTensor<R, E, 3> {
let input = kernel::into_contiguous(input);
let [_, num_samples, complex] = input.shape.dims;
if complex != 2 {
panic!("Last dimension must have size exactly 2 (real, imaginary)");
}
// Power of 2 => only 1 bit set => x & (x - 1) == 0
if num_samples == 0 || (num_samples & (num_samples - 1)) != 0 {
panic!("Fourier transform dimension must have a power of 2 size, perhaps consider zero padding")
};
// Need to use two output buffers as the algorithm writes back and forth
// at each iteration. We could reuse the input buffer but this would
// modify in place which might be undesirable.
let output_1: JitTensor<R, E, 3> = empty_device(
input.client.clone(),
input.device.clone(),
input.shape.clone(),
);
let output_2: JitTensor<R, E, 3> = empty_device(
input.client.clone(),
input.device.clone(),
input.shape.clone(),
);
let num_elems = input.shape.num_elements();
let num_fft_iters = (num_samples as f32).log2() as usize;
for fft_iter in 0..num_fft_iters {
// "Ping pong" buffering
let (x_tensor, x_hat_tensor) = {
if fft_iter == 0 {
(&input, &output_1)
} else if fft_iter % 2 == 0 {
(&output_2, &output_1)
} else {
(&output_1, &output_2)
}
};
let mut info = build_info(&[&x_tensor, &x_hat_tensor]);
info.push(num_fft_iters as u32);
info.push(fft_iter as u32);
let info_handle = input.client.create(bytemuck::cast_slice(&info));
let kernel = StaticKernel::<
KernelSettings<FFT, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
input.client.execute(
Box::new(kernel),
&[&x_tensor.handle, &x_hat_tensor.handle, &info_handle],
);
}
// "Ping pong" buffering
{
if num_fft_iters == 0 {
input
} else if num_fft_iters % 2 == 0 {
output_2
} else {
output_1
}
}
}

View File

@ -0,0 +1,3 @@
mod fft1d;
pub use fft1d::*;

View File

@ -19,6 +19,8 @@ pub use unary::*;
/// Convolution kernels
pub mod conv;
/// Fourier kernels
pub mod fft;
/// Interpolation kernels
pub mod interpolate;
/// Matmul kernels

View File

@ -118,4 +118,8 @@ impl<R: Runtime> ModuleOps<Self> for JitBackend<R> {
) -> FloatTensor<Self, 4> {
kernel::interpolate::interpolate_backward(x, grad, output_size, options)
}
fn fft(x: FloatTensor<Self, 3>) -> FloatTensor<Self, 3> {
kernel::fft::fft(x)
}
}

View File

@ -0,0 +1,144 @@
@group(0)
@binding(0)
var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 32>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
const PI: f32 = 3.141592653589793115997963468544185161590576171875;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
//////////////////////////////////////////////////
// Single FFT stage at a single (x, y) position //
//////////////////////////////////////////////////
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[1];
let input_stride_1 = info[2];
let input_stride_2 = info[3];
let output_stride_0 = info[4];
let output_stride_1 = info[5];
let output_stride_2 = info[6];
let input_shape_0 = info[7];
let input_shape_1 = info[8];
let input_shape_2 = info[9];
let output_shape_0 = info[10];
let output_shape_1 = info[11];
let output_shape_2 = info[12];
let num_fft_iters = info[13];
let fft_iter = info[14];
let is_final_iter: bool = (fft_iter + 1u) == num_fft_iters;
// FFT is done over X dimension, parallelised over Y dimension. It's always
// a 1D transform, but many 1D transforms are done in parallel.
let oy = id / output_stride_0 % output_shape_0;
let ox = id / output_stride_1 % output_shape_1;
let oc = id / output_stride_2 % output_shape_2;
let iy = oy;
// Only ever 0 or 1 (real or imaginary). Arbitrarily choose the real index
// to do both complex calculations.
if (oc >= 1u) {
return;
}
// Number of independent FFTs at this stage (starts at x_width/2, halves each time)
let num_transforms: u32 = input_shape_1 >> (fft_iter + 1u);
// Binary mask for extracting the index of E_k
let even_mask: u32 = num_transforms ^ 0xFFFFu;
// Returns if outside the output dimension
if oy >= output_shape_0 || ox >= output_shape_1 {
return;
}
// Position-dependent FFT Parameters
let ix_even: u32 = ox & even_mask;
let ix_odd: u32 = ix_even + num_transforms;
let exponent: u32 = reverse_bits(ox >> (num_fft_iters - fft_iter), fft_iter);
let negative_instead_of_plus: bool = (ox & num_transforms) > 0u;
// Indices
let i_even_re: u32 = iy * input_stride_0 + ix_even * input_stride_1 + 0u * input_stride_2;
let i_even_im: u32 = iy * input_stride_0 + ix_even * input_stride_1 + 1u * input_stride_2;
let i_odd_re: u32 = iy * input_stride_0 + ix_odd * input_stride_1 + 0u * input_stride_2;
let i_odd_im: u32 = iy * input_stride_0 + ix_odd * input_stride_1 + 1u * input_stride_2;
// Running the FFT algorithm like this results in a bit-reversed ordered
// output. i.e. the element 000, 001, ..., 110, 111 are now sorted if
// they were actually 000, 100, ..., 011, 111. On the last step, undo this
// mapping, by choosing ox differently.
var ox_r = 0u;
if is_final_iter {
ox_r = reverse_bits(ox, num_fft_iters);
} else {
ox_r = ox;
}
let i_out_re: u32 = oy * output_stride_0 + ox_r * output_stride_1 + 0u * output_stride_2;
let i_out_im: u32 = oy * output_stride_0 + ox_r * output_stride_1 + 1u * output_stride_2;
// Here we compute the main computation step for each index.
// See the last two equations of:
// https://en.wikipedia.org/wiki/Cooley-Tukey_FFT_algorithm#The_radix-2_DIT_case
// X_k = E_k + w_k * O_k
// Where w_k is the +/- exp(-2 pi i k / n) term. Note the plus / minus
// is included in the value of the weight.
var pm1: f32 = 1.;
if(negative_instead_of_plus) {
pm1 = -1.;
}
// Width of the FFT at this stage (starts at 2, doubles each time)
let n: u32 = 2u << fft_iter;
let w_k_theta: f32 = - 2. * PI * f32(exponent) / f32(n);
let w_k_re: f32 = pm1 * cos(w_k_theta);
let w_k_im: f32 = pm1 * sin(w_k_theta);
let e_k_re = input[i_even_re];
let e_k_im = input[i_even_im];
let o_k_re = input[i_odd_re];
let o_k_im = input[i_odd_im];
// Note the following:
// Real part of (a + bj)(c + dj) = ac + bd(j*j) = ac - bd
// Imaginary part of (a + bj)(c + dj) = ad(j) + bc(j) = (ad + bc)j
// These are used for w_k * O_k; E_k real and imaginary parts are just added.
output[i_out_re] = e_k_re + w_k_re * o_k_re - w_k_im * o_k_im;
output[i_out_im] = e_k_im + w_k_re * o_k_im + w_k_im * o_k_re;
}
fn reverse_bits(x_: u32, num_bits: u32) -> u32 {
// For input:
// 00000000000000000000000000011011
// num_bits = 1 gives:
// 00000000000000000000000000000001
// num_bits = 2:
// 00000000000000000000000000000011
// num_bits = 3:
// 00000000000000000000000000000110
// num_bits = 4:
// 00000000000000000000000000001101
// etc...
return reverseBits(x_) >> (32u - min(32u, num_bits));
}

View File

@ -0,0 +1,176 @@
use burn_tensor::ElementConversion;
use ndarray::Ix3;
use crate::{iter_range_par, run_par, FloatNdArrayElement, NdArrayTensor, UnsafeSharedRef};
struct SingleIterParams {
/// Nth iteration of the Fast Fourier transform
iteration: usize,
remaining_iterations: usize,
/// Mask used as "mask & sample_id" to extract the required bits
sign_mask: usize,
/// Mask used as "mask & sample_id" to extract the required bits
even_mask: usize,
/// Argument (theta) of complex `z = r e^{i\theta}` such that z^n = 1.
nth_root_of_unity: f64,
}
impl SingleIterParams {
fn new(iteration: usize, required_iterations: usize) -> Self {
let remaining_iterations = required_iterations - iteration;
// Some of these expressions look a little crazy - much easier to see how they
// work if inspecting these numbers as binary strings (most have only 1 bit).
SingleIterParams {
iteration,
remaining_iterations,
sign_mask: (1 << (remaining_iterations - 1)),
even_mask: !(1 << (remaining_iterations - 1)),
nth_root_of_unity: -2. * std::f64::consts::PI / ((2 << iteration) as f64),
}
}
}
pub(crate) fn fft1d<E: FloatNdArrayElement>(input: NdArrayTensor<E, 3>) -> NdArrayTensor<E, 3> {
let [batch_size, num_samples, complex] = input.shape().dims;
// Require complex input - an extra dimension is used that is always size 2, for complex.
if complex != 2 {
panic!(
"The last dimension must have length 2 (real, imaginary). For real inputs, consider
adding an extra dimension, with the imaginary part filled with zeros."
)
}
// This FFT implementation is Cooley-Tukey, only supports power of 2 along transform dim.
if num_samples == 0 || (num_samples & (num_samples - 1)) != 0 {
panic!(
"The dimension that the FFT is computed over must be a power of 2, got {}",
num_samples
);
}
let mut input = input
.array
.into_owned()
.into_dimensionality::<Ix3>()
.unwrap();
let mut output = ndarray::Array3::<E>::zeros((batch_size, num_samples, 2));
let input_unsafe_ref = UnsafeSharedRef::new(&mut input);
let output_unsafe_ref = UnsafeSharedRef::new(&mut output);
let num_fft_iters: usize = f32::log2(num_samples as f32) as usize;
for fft_iter in 0..num_fft_iters {
let params = SingleIterParams::new(fft_iter, num_fft_iters);
let is_last_iter = (fft_iter + 1) == num_fft_iters;
let even_iter = fft_iter % 2 == 0;
unsafe {
run_par!(|| {
iter_range_par!(0, num_samples).for_each(|sample_id| {
let k_even = sample_id & params.even_mask;
let k_odd = k_even + (num_samples >> (1 + params.iteration));
let twiddle: (E, E) = get_twiddle_factor(&params, sample_id);
// Alternate between the array being written to and from.
let (input_ref, output_ref) = match even_iter {
true => (input_unsafe_ref.get(), output_unsafe_ref.get()),
false => (output_unsafe_ref.get(), input_unsafe_ref.get()),
};
for b in 0..batch_size {
/*
Where E_k are even-indexed inputs and O_k are odd-indexed inputs:
X_k = E_k + twiddle_factor * O_k
Taking real and imaginary parts:
Re(X_k) = Re(E_k) + Re(twiddle_factor) * Re(O_k) - Im(twiddle_factor) * Im(O_k)
Im(X_k) = Im( as f32E_k) + Re(twiddle_factor) * Im(O_k) + Im(twiddle_factor) * Re(O_k)
See https://en.wikipedia.org/wiki/Cooley-Tukey_FFT_algorithm#The_radix-2_DIT_case
*/
let e_k = (input_ref[(b, k_even, 0)], input_ref[(b, k_even, 1)]);
let o_k = (input_ref[(b, k_odd, 0)], input_ref[(b, k_odd, 1)]);
let x_k = (
e_k.0 + twiddle.0 * o_k.0 - twiddle.1 * o_k.1,
e_k.1 + twiddle.0 * o_k.1 + twiddle.1 * o_k.0,
);
let out_id = match is_last_iter {
false => sample_id,
true => reverse_bits(sample_id, num_fft_iters),
};
output_ref[(b, out_id, 0)] = x_k.0;
output_ref[(b, out_id, 1)] = x_k.1;
}
});
});
}
}
let output = match num_fft_iters % 2 == 0 {
true => input,
false => output,
};
NdArrayTensor::new(output.into_dyn().into_shared())
}
fn get_twiddle_factor<E: burn_tensor::Element>(
params: &SingleIterParams,
sample_id: usize,
) -> (E, E) {
// Indices are bit reversed at each iteration, but there's also a different
// number of values at each iteration. e.g., for the 3 iterations of an 8
// length transform the mapping from sample_id to k is as follows:
// 0: [0, 0, 0, 0, 1, 1, 1, 1]
// 1: [0, 0, 2, 2, 1, 1, 3, 3]
// 2: [0, 4, 2, 6, 1, 7, 3, 5]
let k = reverse_bits(sample_id >> (params.remaining_iterations), params.iteration) as f64;
/*
`nth_root_of_unity` has the property that
exp(-j * root) = 1.
Raising both sides to the power of any integer k,
exp(-j * root)^k = 1^k
exp(-j * k * root) = 1
The nth root is special because it satisfies this.
*/
let complex_angle = k * params.nth_root_of_unity;
let sign = {
if (sample_id & params.sign_mask) > 0 {
-1.
} else {
1.
}
};
// Euler's formula: exp^(ix) = cos(x) + i sin(x)
// These are both real numbers, but represent coefficients of a + i*b.
(
(sign * f64::cos(complex_angle)).elem::<E>(),
(sign * f64::sin(complex_angle)).elem::<E>(),
)
}
fn reverse_bits(n: usize, no_of_bits: usize) -> usize {
let mut result = 0;
let mut n = n;
for _ in 0..no_of_bits {
result <<= 1;
result |= n & 1;
n >>= 1;
}
result
}

View File

@ -8,6 +8,7 @@ mod tensor;
pub(crate) mod adaptive_avgpool;
pub(crate) mod avgpool;
pub(crate) mod conv;
pub(crate) mod fft;
pub(crate) mod interpolate;
pub(crate) mod macros;
pub(crate) mod matmul;

View File

@ -2,6 +2,7 @@ use super::{
adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
avgpool::{avg_pool2d, avg_pool2d_backward},
conv::{conv2d, conv_transpose2d},
fft::fft1d,
interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
};
@ -130,4 +131,8 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
}
}
}
fn fft(x: NdArrayTensor<E, 3>) -> NdArrayTensor<E, 3> {
fft1d(x)
}
}

View File

@ -344,4 +344,8 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
TchTensor::new(tensor)
}
fn fft(_x: TchTensor<E, 3>) -> TchTensor<E, 3> {
todo!();
}
}

View File

@ -231,3 +231,19 @@ where
{
Tensor::new(B::interpolate(x.primitive, output_size, options))
}
/// Applies a [1D fast fourier transform (FFT)](crate::ops::ModuleOps::fft).
pub fn fft<B>(x: Tensor<B, 3>) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(B::fft(x.primitive))
}
/// Applies an [inverse 1D fast fourier transform (FFT)](crate::ops::ModuleOps::ifft).
pub fn ifft<B>(x: Tensor<B, 3>) -> Tensor<B, 3>
where
B: Backend,
{
Tensor::new(B::ifft(x.primitive))
}

View File

@ -2,7 +2,7 @@ use super::{conv, pool, unfold::unfold4d_using_conv2d};
use crate::{
backend::Backend,
ops::{FloatTensor, IntTensor},
Shape,
ElementConversion, Shape,
};
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
@ -488,4 +488,47 @@ pub trait ModuleOps<B: Backend> {
output_size: [usize; 2],
options: InterpolateOptions,
) -> FloatTensor<B, 4>;
/// One dimensional Fast Fourier Transform (FFT).
///
/// # Shapes
///
/// x: `[batch_size, length, complex]`, where `complex` is exactly 2.
///
/// # Returns
///
/// X: `[batch_size, length, complex]`, where `complex` is exactly 2.
fn fft(x: FloatTensor<B, 3>) -> FloatTensor<B, 3>;
/// One dimensional inverse Fast Fourier Transform (FFT).
///
/// # Shapes
///
/// x: `[batch_size, length, complex]`, where `complex` is exactly 2.
///
/// # Returns
///
/// X: `[batch_size, length, complex]`, where `complex` is exactly 2.
fn ifft(x: FloatTensor<B, 3>) -> FloatTensor<B, 3> {
let [b, s, _] = B::float_shape(&x).dims;
// Conjugate input
let x = B::float_slice_assign(
x.clone(),
[0..b, 0..s, 1..2],
B::float_neg(B::float_slice(x, [0..b, 0..s, 1..2])),
);
// Use forwards 1D FFT of this backend to calculate inverse.
let x = B::fft(x);
// Conjugate output
let x = B::float_slice_assign(
x.clone(),
[0..b, 0..s, 1..2],
B::float_neg(B::float_slice(x, [0..b, 0..s, 1..2])),
);
B::float_div_scalar(x, (s as f32).elem::<B::FloatElem>())
}
}

View File

@ -36,6 +36,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_module_nearest_interpolate!();
burn_tensor::testgen_module_bilinear_interpolate!();
burn_tensor::testgen_module_bicubic_interpolate!();
burn_tensor::testgen_module_fft!();
// test ops
burn_tensor::testgen_add!();

View File

@ -0,0 +1,335 @@
#[burn_tensor_testgen::testgen(module_fft)]
mod tests {
use super::*;
use burn_tensor::module::{fft, ifft};
use burn_tensor::Shape;
fn single_elem() -> (TestTensor<3>, TestTensor<3>) {
(
TestTensor::from([[[1., 0.]]]),
TestTensor::from([[[1., 0.]]]),
)
}
fn delta_1d() -> (TestTensor<3>, TestTensor<3>) {
(
// Delta function -> Flat frequency spectrum
TestTensor::from([[[1., 0.], [0., 0.]], [[0., 1.], [0., 0.]]]),
TestTensor::from([[[1., 0.], [1., 0.]], [[0., 1.], [0., 1.]]]),
)
}
fn simple_1d() -> (TestTensor<3>, TestTensor<3>) {
(
TestTensor::from([[[1., 3.], [2.3, -1.]], [[-1., 0.], [-1., 0.]]]),
TestTensor::from([[[3.3, 2.0], [-1.3, 4.0]], [[-2.0, 0.0], [0.0, 0.0]]]),
)
}
fn even_pow_2_1d() -> (TestTensor<3>, TestTensor<3>) {
// FFT size 2^4 = 16
(
TestTensor::from([[
[-0.2322, -0.8782],
[0.0283, 0.0621],
[0.6704, 0.9188],
[-0.6543, 0.9008],
[0.4139, 0.7885],
[0.5795, 0.8549],
[0.1336, 1.4838],
[0.0967, 0.0182],
[-0.2705, -1.6957],
[1.2414, -0.0782],
[1.0207, -1.2235],
[1.0662, 1.0242],
[0.5177, -0.1858],
[-0.4157, 0.1690],
[-0.5877, 0.9780],
[1.0103, -0.9347],
]]),
TestTensor::from([[
[4.6184, 2.2021],
[1.7495, 2.2396],
[-1.6534, -8.3570],
[4.8833, 1.4184],
[-0.8089, -4.0430],
[-1.7017, 1.2770],
[1.3325, 1.4228],
[-1.3123, 4.6775],
[-1.2865, -1.8301],
[2.5052, 1.3927],
[-6.7483, -2.2867],
[-1.4950, -2.8296],
[-0.8072, -4.2135],
[1.4975, -1.2243],
[1.3319, -3.4857],
[-5.8203, -0.4112],
]]),
)
}
fn odd_pow_2_1d() -> (TestTensor<3>, TestTensor<3>) {
// FFT size 2^5 = 32
(
TestTensor::from([
[
[0.0427, -0.9848],
[0.6244, 0.0441],
[0.2514, 1.2113],
[0.9988, 1.5251],
[-0.0979, 0.3475],
[0.6723, -1.9583],
[0.4330, -0.4186],
[0.2924, -0.1259],
[-0.9504, -0.9404],
[-0.8480, -1.4261],
[1.6096, -1.6685],
[-0.4452, -0.2472],
[-0.7371, 1.9884],
[0.4746, 0.9134],
[0.3894, -0.7379],
[-0.8402, -0.6750],
[-1.2116, -0.0552],
[-0.5223, 0.7814],
[1.6739, -0.8242],
[0.8573, 0.3881],
[-0.0345, 1.4219],
[0.1038, 1.9030],
[-0.3701, 1.0827],
[-0.7380, 1.3959],
[1.2852, -1.1371],
[0.4140, -0.2322],
[-0.0631, 1.0053],
[-0.3737, 0.2743],
[2.1822, 0.7284],
[0.6077, 1.2214],
[-0.3532, -0.6910],
[-0.5390, -0.0049],
],
[
[0.7324, -0.7980],
[0.8606, -1.7983],
[0.4761, 0.5743],
[0.1872, 1.0686],
[0.3460, -0.2946],
[0.9188, -1.6348],
[-0.5201, 0.2241],
[-1.8877, 0.9795],
[0.4466, -0.5693],
[1.8571, 0.6358],
[-1.0278, 0.9778],
[-0.5889, 0.1738],
[-0.2158, -0.3417],
[1.3702, 1.5540],
[0.8480, -1.5970],
[0.9213, -0.8815],
[0.3899, -0.2117],
[1.1022, -0.5712],
[-1.2964, -1.4468],
[-0.0170, -0.6138],
[-0.4706, 0.5001],
[0.0794, -1.2067],
[1.1162, -0.6329],
[0.1651, 0.1031],
[0.3490, -0.2715],
[0.6509, 0.0565],
[1.0147, 0.2431],
[0.8052, -0.9913],
[0.9142, 0.9090],
[1.0906, 0.5651],
[1.5763, 0.5397],
[0.3499, 0.4359],
],
]),
TestTensor::from([
[
[4.7884, 4.1050],
[-5.4979, 0.9704],
[2.2100, 0.8929],
[3.2625, 4.2363],
[-6.3352, -14.4654],
[7.1075, 0.2420],
[2.7235, -4.1024],
[-3.3785, -15.7396],
[-4.3761, 0.0957],
[6.6426, 3.5467],
[2.1965, 6.7494],
[-5.0247, 1.5957],
[-0.1676, -1.6919],
[-1.9747, -4.9014],
[-1.8878, -5.7946],
[-4.8639, -2.0970],
[3.3108, -3.4494],
[1.8428, 3.2496],
[-2.1688, 2.2887],
[1.8962, -0.4687],
[3.0188, -7.4876],
[-1.5869, 5.1769],
[0.1114, 3.9235],
[7.7439, -6.9099],
[-1.8084, 4.7236],
[11.4196, -0.7087],
[-12.0418, 0.5287],
[-3.4558, -2.9853],
[-5.1034, -6.7704],
[-6.3451, 2.8732],
[-3.1719, 3.8135],
[12.2821, -2.9533],
],
[
[12.5437, -4.3209],
[6.6402, 7.5440],
[-1.5249, 1.4718],
[-3.2183, -1.4000],
[1.8995, -0.0458],
[-5.3435, -6.2148],
[-1.4322, -2.0123],
[2.7689, -3.0387],
[-2.3689, -7.9545],
[-13.4452, -4.8216],
[0.6306, 10.9365],
[-0.9051, 0.5564],
[-1.9752, -7.8168],
[2.5935, 0.0319],
[-4.3665, -0.1909],
[7.7858, 1.6211],
[-3.1862, -0.0701],
[1.0940, 5.9505],
[0.8552, -0.3137],
[-1.4151, -1.5680],
[4.4181, 2.5067],
[2.9776, -3.9427],
[-0.6781, 7.6331],
[2.0782, 0.4852],
[2.9787, 8.0347],
[5.6104, -9.2327],
[-0.1015, -9.4777],
[-3.1915, 0.0535],
[1.0344, -5.1372],
[0.2312, 5.2132],
[9.2310, -9.3980],
[1.2195, -0.6194],
],
]),
)
}
fn reshaped_input_1d() -> (TestTensor<3>, TestTensor<3>) {
// "Reshape" might just changes strides, and not alter underlying buffer.
// This test ensures op is correctly using stride info.
let x1 = TestTensor::<1>::from([
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.,
]);
let x3: TestTensor<3> = x1.reshape([2, 4, 2]).swap_dims(0, 2);
let x_hat = TestTensor::from([
[
[8., 0., -8., 0.],
[12., 0., -8., 0.],
[16., 0., -8., 0.],
[20., 0., -8., 0.],
],
[
[10., 0., -8., 0.],
[14., 0., -8., 0.],
[18., 0., -8., 0.],
[22., 0., -8., 0.],
],
]);
(x3, x_hat)
}
#[test]
fn test_fft_single_elem() {
let (x, x_hat) = single_elem();
assert_output(fft(x), x_hat);
}
#[test]
fn test_ifft_single_elem() {
let (x, x_hat) = single_elem();
assert_output(x, ifft(x_hat));
}
#[test]
fn test_fft_delta_1d() {
let (x, x_hat) = delta_1d();
assert_output(fft(x), x_hat);
}
#[test]
fn test_ifft_delta_1d() {
let (x, x_hat) = delta_1d();
assert_output(x, ifft(x_hat));
}
#[test]
fn test_fft_simple_1d() {
let (x, x_hat) = simple_1d();
assert_output(fft(x), x_hat);
}
#[test]
fn test_ifft_simple_1d() {
let (x, x_hat) = simple_1d();
assert_output(x, ifft(x_hat));
}
#[test]
fn test_fft_even_pow_2_1d() {
let (x, x_hat) = even_pow_2_1d();
assert_output(fft(x), x_hat);
}
#[test]
fn test_ifft_even_pow_2_1d() {
let (x, x_hat) = even_pow_2_1d();
assert_output(x, ifft(x_hat));
}
#[test]
fn test_fft_odd_pow_2_1d() {
let (x, x_hat) = odd_pow_2_1d();
assert_output(fft(x), x_hat);
}
#[test]
fn test_ifft_odd_pow_2_1d() {
let (x, x_hat) = odd_pow_2_1d();
assert_output(x, ifft(x_hat));
}
#[test]
fn test_fft_round_trip_precision() {
let shape_x = Shape::new([1, 16, 2]);
let x = TestTensor::from(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
.reshape(shape_x)
.into_data()
.convert(),
);
// 2^10 shape = 9 iterations each way, can accumulate considerable error.
assert_output(x.clone(), ifft(fft(x)));
}
#[test]
#[should_panic]
fn test_invalid_input_non_complex_structure() {
// Last dim must have dimension 2 (real, imaginary)
let x = TestTensor::from([[[0., 0., 0.], [0., 0., 0.]]]);
fft(x);
}
#[test]
#[should_panic]
fn test_invalid_input_non_pow_2() {
let x = TestTensor::from([[[0., 1.], [2., 3.], [4., 5.]]]);
fft(x);
}
fn assert_output(x: TestTensor<3>, y: TestTensor<3>) {
x.to_data().assert_approx_eq(&y.into_data(), 3);
}
}

View File

@ -8,6 +8,7 @@ mod conv1d;
mod conv2d;
mod conv_transpose1d;
mod conv_transpose2d;
mod fft;
mod forward;
mod maxpool1d;
mod maxpool2d;