mirror of https://github.com/tracel-ai/burn.git
Fft (#1574)
* Added placeholders for FFT module * First working implementation of 1D FFT * Tidied things up a bit * More tests + placeholders for other backends * WGPU FFT almost works but not quite yet * WGPU FFT works * Tidied up WGPU FFT a bit * Added 1D ifft default implementation using 1D fft module op * More tests, added extra buffer to WGPU to not modify tensor in place
This commit is contained in:
parent
ce898ff899
commit
38ee355245
|
@ -1030,6 +1030,12 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
) -> <Autodiff<B> as Backend>::FloatTensorPrimitive<4> {
|
||||
panic!("Can't differentiate interpolate backward.");
|
||||
}
|
||||
|
||||
fn fft(_x: AutodiffTensor<B, 3>) -> AutodiffTensor<B, 3> {
|
||||
// Imaginary gradients = not practical. People use FFT for feature
|
||||
// extraction.
|
||||
panic!("Can't differentiate FFT.");
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
|
|
|
@ -266,4 +266,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
|
|||
) -> FloatTensor<Self, 4> {
|
||||
panic!("interpolate_backward is not supported by Candle")
|
||||
}
|
||||
|
||||
fn fft(x: FloatTensor<Self, 3>) -> FloatTensor<Self, 3> {
|
||||
panic!("fft is not supported by Candle")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::stream::InterpolateBackwardDescription;
|
||||
use crate::{
|
||||
client::FusionClient,
|
||||
stream::{
|
||||
|
@ -6,10 +5,11 @@ use crate::{
|
|||
AdaptiveAvgPool2dBackwardDescription, AdaptiveAvgPool2dDescription,
|
||||
AvgPool1dBackwardDescription, AvgPool1dDescription, AvgPool2dBackwardDescription,
|
||||
AvgPool2dDescription, Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription,
|
||||
ConvTranspose2dDescription, InterpolateDescription, MaxPool1dDescription,
|
||||
MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription,
|
||||
MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
|
||||
MaxPool2dWithIndicesDescription, Operation, OperationDescription,
|
||||
ConvTranspose2dDescription, InterpolateBackwardDescription, InterpolateDescription,
|
||||
MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
|
||||
MaxPool1dWithIndicesDescription, MaxPool2dDescription,
|
||||
MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription, Operation,
|
||||
OperationDescription,
|
||||
},
|
||||
Fusion, FusionBackend, HandleContainer,
|
||||
};
|
||||
|
@ -1054,4 +1054,8 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
);
|
||||
out
|
||||
}
|
||||
|
||||
fn fft(_x: FloatTensor<Self, 3>) -> FloatTensor<Self, 3> {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
use crate::{
|
||||
compute::StaticKernel,
|
||||
element::JitElement,
|
||||
kernel::{self, build_info, elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use burn_tensor::Element;
|
||||
|
||||
kernel_wgsl!(FFT, "../../template/fft/fft.wgsl");
|
||||
|
||||
pub(crate) fn fft<R: Runtime, E: JitElement + Element>(
|
||||
input: JitTensor<R, E, 3>,
|
||||
) -> JitTensor<R, E, 3> {
|
||||
let input = kernel::into_contiguous(input);
|
||||
|
||||
let [_, num_samples, complex] = input.shape.dims;
|
||||
|
||||
if complex != 2 {
|
||||
panic!("Last dimension must have size exactly 2 (real, imaginary)");
|
||||
}
|
||||
|
||||
// Power of 2 => only 1 bit set => x & (x - 1) == 0
|
||||
if num_samples == 0 || (num_samples & (num_samples - 1)) != 0 {
|
||||
panic!("Fourier transform dimension must have a power of 2 size, perhaps consider zero padding")
|
||||
};
|
||||
|
||||
// Need to use two output buffers as the algorithm writes back and forth
|
||||
// at each iteration. We could reuse the input buffer but this would
|
||||
// modify in place which might be undesirable.
|
||||
let output_1: JitTensor<R, E, 3> = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
input.shape.clone(),
|
||||
);
|
||||
let output_2: JitTensor<R, E, 3> = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
input.shape.clone(),
|
||||
);
|
||||
|
||||
let num_elems = input.shape.num_elements();
|
||||
let num_fft_iters = (num_samples as f32).log2() as usize;
|
||||
|
||||
for fft_iter in 0..num_fft_iters {
|
||||
// "Ping pong" buffering
|
||||
let (x_tensor, x_hat_tensor) = {
|
||||
if fft_iter == 0 {
|
||||
(&input, &output_1)
|
||||
} else if fft_iter % 2 == 0 {
|
||||
(&output_2, &output_1)
|
||||
} else {
|
||||
(&output_1, &output_2)
|
||||
}
|
||||
};
|
||||
|
||||
let mut info = build_info(&[&x_tensor, &x_hat_tensor]);
|
||||
info.push(num_fft_iters as u32);
|
||||
info.push(fft_iter as u32);
|
||||
|
||||
let info_handle = input.client.create(bytemuck::cast_slice(&info));
|
||||
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<FFT, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&x_tensor.handle, &x_hat_tensor.handle, &info_handle],
|
||||
);
|
||||
}
|
||||
|
||||
// "Ping pong" buffering
|
||||
{
|
||||
if num_fft_iters == 0 {
|
||||
input
|
||||
} else if num_fft_iters % 2 == 0 {
|
||||
output_2
|
||||
} else {
|
||||
output_1
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,3 @@
|
|||
mod fft1d;
|
||||
|
||||
pub use fft1d::*;
|
|
@ -19,6 +19,8 @@ pub use unary::*;
|
|||
|
||||
/// Convolution kernels
|
||||
pub mod conv;
|
||||
/// Fourier kernels
|
||||
pub mod fft;
|
||||
/// Interpolation kernels
|
||||
pub mod interpolate;
|
||||
/// Matmul kernels
|
||||
|
|
|
@ -118,4 +118,8 @@ impl<R: Runtime> ModuleOps<Self> for JitBackend<R> {
|
|||
) -> FloatTensor<Self, 4> {
|
||||
kernel::interpolate::interpolate_backward(x, grad, output_size, options)
|
||||
}
|
||||
|
||||
fn fft(x: FloatTensor<Self, 3>) -> FloatTensor<Self, 3> {
|
||||
kernel::fft::fft(x)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,144 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32, 32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
const PI: f32 = 3.141592653589793115997963468544185161590576171875;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_index) local_idx: u32,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
//////////////////////////////////////////////////
|
||||
// Single FFT stage at a single (x, y) position //
|
||||
//////////////////////////////////////////////////
|
||||
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
|
||||
let input_stride_0 = info[1];
|
||||
let input_stride_1 = info[2];
|
||||
let input_stride_2 = info[3];
|
||||
let output_stride_0 = info[4];
|
||||
let output_stride_1 = info[5];
|
||||
let output_stride_2 = info[6];
|
||||
|
||||
let input_shape_0 = info[7];
|
||||
let input_shape_1 = info[8];
|
||||
let input_shape_2 = info[9];
|
||||
let output_shape_0 = info[10];
|
||||
let output_shape_1 = info[11];
|
||||
let output_shape_2 = info[12];
|
||||
|
||||
let num_fft_iters = info[13];
|
||||
let fft_iter = info[14];
|
||||
let is_final_iter: bool = (fft_iter + 1u) == num_fft_iters;
|
||||
|
||||
// FFT is done over X dimension, parallelised over Y dimension. It's always
|
||||
// a 1D transform, but many 1D transforms are done in parallel.
|
||||
let oy = id / output_stride_0 % output_shape_0;
|
||||
let ox = id / output_stride_1 % output_shape_1;
|
||||
let oc = id / output_stride_2 % output_shape_2;
|
||||
let iy = oy;
|
||||
|
||||
// Only ever 0 or 1 (real or imaginary). Arbitrarily choose the real index
|
||||
// to do both complex calculations.
|
||||
if (oc >= 1u) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Number of independent FFTs at this stage (starts at x_width/2, halves each time)
|
||||
let num_transforms: u32 = input_shape_1 >> (fft_iter + 1u);
|
||||
|
||||
// Binary mask for extracting the index of E_k
|
||||
let even_mask: u32 = num_transforms ^ 0xFFFFu;
|
||||
|
||||
// Returns if outside the output dimension
|
||||
if oy >= output_shape_0 || ox >= output_shape_1 {
|
||||
return;
|
||||
}
|
||||
|
||||
// Position-dependent FFT Parameters
|
||||
let ix_even: u32 = ox & even_mask;
|
||||
let ix_odd: u32 = ix_even + num_transforms;
|
||||
let exponent: u32 = reverse_bits(ox >> (num_fft_iters - fft_iter), fft_iter);
|
||||
let negative_instead_of_plus: bool = (ox & num_transforms) > 0u;
|
||||
|
||||
// Indices
|
||||
let i_even_re: u32 = iy * input_stride_0 + ix_even * input_stride_1 + 0u * input_stride_2;
|
||||
let i_even_im: u32 = iy * input_stride_0 + ix_even * input_stride_1 + 1u * input_stride_2;
|
||||
|
||||
let i_odd_re: u32 = iy * input_stride_0 + ix_odd * input_stride_1 + 0u * input_stride_2;
|
||||
let i_odd_im: u32 = iy * input_stride_0 + ix_odd * input_stride_1 + 1u * input_stride_2;
|
||||
|
||||
// Running the FFT algorithm like this results in a bit-reversed ordered
|
||||
// output. i.e. the element 000, 001, ..., 110, 111 are now sorted if
|
||||
// they were actually 000, 100, ..., 011, 111. On the last step, undo this
|
||||
// mapping, by choosing ox differently.
|
||||
|
||||
var ox_r = 0u;
|
||||
if is_final_iter {
|
||||
ox_r = reverse_bits(ox, num_fft_iters);
|
||||
} else {
|
||||
ox_r = ox;
|
||||
}
|
||||
|
||||
let i_out_re: u32 = oy * output_stride_0 + ox_r * output_stride_1 + 0u * output_stride_2;
|
||||
let i_out_im: u32 = oy * output_stride_0 + ox_r * output_stride_1 + 1u * output_stride_2;
|
||||
|
||||
// Here we compute the main computation step for each index.
|
||||
// See the last two equations of:
|
||||
// https://en.wikipedia.org/wiki/Cooley-Tukey_FFT_algorithm#The_radix-2_DIT_case
|
||||
// X_k = E_k + w_k * O_k
|
||||
// Where w_k is the +/- exp(-2 pi i k / n) term. Note the plus / minus
|
||||
// is included in the value of the weight.
|
||||
|
||||
var pm1: f32 = 1.;
|
||||
if(negative_instead_of_plus) {
|
||||
pm1 = -1.;
|
||||
}
|
||||
|
||||
// Width of the FFT at this stage (starts at 2, doubles each time)
|
||||
let n: u32 = 2u << fft_iter;
|
||||
let w_k_theta: f32 = - 2. * PI * f32(exponent) / f32(n);
|
||||
let w_k_re: f32 = pm1 * cos(w_k_theta);
|
||||
let w_k_im: f32 = pm1 * sin(w_k_theta);
|
||||
|
||||
let e_k_re = input[i_even_re];
|
||||
let e_k_im = input[i_even_im];
|
||||
let o_k_re = input[i_odd_re];
|
||||
let o_k_im = input[i_odd_im];
|
||||
|
||||
// Note the following:
|
||||
// Real part of (a + bj)(c + dj) = ac + bd(j*j) = ac - bd
|
||||
// Imaginary part of (a + bj)(c + dj) = ad(j) + bc(j) = (ad + bc)j
|
||||
// These are used for w_k * O_k; E_k real and imaginary parts are just added.
|
||||
output[i_out_re] = e_k_re + w_k_re * o_k_re - w_k_im * o_k_im;
|
||||
output[i_out_im] = e_k_im + w_k_re * o_k_im + w_k_im * o_k_re;
|
||||
}
|
||||
|
||||
fn reverse_bits(x_: u32, num_bits: u32) -> u32 {
|
||||
// For input:
|
||||
// 00000000000000000000000000011011
|
||||
// num_bits = 1 gives:
|
||||
// 00000000000000000000000000000001
|
||||
// num_bits = 2:
|
||||
// 00000000000000000000000000000011
|
||||
// num_bits = 3:
|
||||
// 00000000000000000000000000000110
|
||||
// num_bits = 4:
|
||||
// 00000000000000000000000000001101
|
||||
// etc...
|
||||
return reverseBits(x_) >> (32u - min(32u, num_bits));
|
||||
}
|
|
@ -0,0 +1,176 @@
|
|||
use burn_tensor::ElementConversion;
|
||||
use ndarray::Ix3;
|
||||
|
||||
use crate::{iter_range_par, run_par, FloatNdArrayElement, NdArrayTensor, UnsafeSharedRef};
|
||||
|
||||
struct SingleIterParams {
|
||||
/// Nth iteration of the Fast Fourier transform
|
||||
iteration: usize,
|
||||
remaining_iterations: usize,
|
||||
/// Mask used as "mask & sample_id" to extract the required bits
|
||||
sign_mask: usize,
|
||||
/// Mask used as "mask & sample_id" to extract the required bits
|
||||
even_mask: usize,
|
||||
/// Argument (theta) of complex `z = r e^{i\theta}` such that z^n = 1.
|
||||
nth_root_of_unity: f64,
|
||||
}
|
||||
|
||||
impl SingleIterParams {
|
||||
fn new(iteration: usize, required_iterations: usize) -> Self {
|
||||
let remaining_iterations = required_iterations - iteration;
|
||||
|
||||
// Some of these expressions look a little crazy - much easier to see how they
|
||||
// work if inspecting these numbers as binary strings (most have only 1 bit).
|
||||
SingleIterParams {
|
||||
iteration,
|
||||
remaining_iterations,
|
||||
sign_mask: (1 << (remaining_iterations - 1)),
|
||||
even_mask: !(1 << (remaining_iterations - 1)),
|
||||
nth_root_of_unity: -2. * std::f64::consts::PI / ((2 << iteration) as f64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn fft1d<E: FloatNdArrayElement>(input: NdArrayTensor<E, 3>) -> NdArrayTensor<E, 3> {
|
||||
let [batch_size, num_samples, complex] = input.shape().dims;
|
||||
|
||||
// Require complex input - an extra dimension is used that is always size 2, for complex.
|
||||
if complex != 2 {
|
||||
panic!(
|
||||
"The last dimension must have length 2 (real, imaginary). For real inputs, consider
|
||||
adding an extra dimension, with the imaginary part filled with zeros."
|
||||
)
|
||||
}
|
||||
|
||||
// This FFT implementation is Cooley-Tukey, only supports power of 2 along transform dim.
|
||||
if num_samples == 0 || (num_samples & (num_samples - 1)) != 0 {
|
||||
panic!(
|
||||
"The dimension that the FFT is computed over must be a power of 2, got {}",
|
||||
num_samples
|
||||
);
|
||||
}
|
||||
|
||||
let mut input = input
|
||||
.array
|
||||
.into_owned()
|
||||
.into_dimensionality::<Ix3>()
|
||||
.unwrap();
|
||||
let mut output = ndarray::Array3::<E>::zeros((batch_size, num_samples, 2));
|
||||
|
||||
let input_unsafe_ref = UnsafeSharedRef::new(&mut input);
|
||||
let output_unsafe_ref = UnsafeSharedRef::new(&mut output);
|
||||
|
||||
let num_fft_iters: usize = f32::log2(num_samples as f32) as usize;
|
||||
|
||||
for fft_iter in 0..num_fft_iters {
|
||||
let params = SingleIterParams::new(fft_iter, num_fft_iters);
|
||||
|
||||
let is_last_iter = (fft_iter + 1) == num_fft_iters;
|
||||
let even_iter = fft_iter % 2 == 0;
|
||||
|
||||
unsafe {
|
||||
run_par!(|| {
|
||||
iter_range_par!(0, num_samples).for_each(|sample_id| {
|
||||
let k_even = sample_id & params.even_mask;
|
||||
let k_odd = k_even + (num_samples >> (1 + params.iteration));
|
||||
let twiddle: (E, E) = get_twiddle_factor(¶ms, sample_id);
|
||||
|
||||
// Alternate between the array being written to and from.
|
||||
let (input_ref, output_ref) = match even_iter {
|
||||
true => (input_unsafe_ref.get(), output_unsafe_ref.get()),
|
||||
false => (output_unsafe_ref.get(), input_unsafe_ref.get()),
|
||||
};
|
||||
|
||||
for b in 0..batch_size {
|
||||
/*
|
||||
Where E_k are even-indexed inputs and O_k are odd-indexed inputs:
|
||||
X_k = E_k + twiddle_factor * O_k
|
||||
|
||||
Taking real and imaginary parts:
|
||||
Re(X_k) = Re(E_k) + Re(twiddle_factor) * Re(O_k) - Im(twiddle_factor) * Im(O_k)
|
||||
Im(X_k) = Im( as f32E_k) + Re(twiddle_factor) * Im(O_k) + Im(twiddle_factor) * Re(O_k)
|
||||
|
||||
See https://en.wikipedia.org/wiki/Cooley-Tukey_FFT_algorithm#The_radix-2_DIT_case
|
||||
*/
|
||||
|
||||
let e_k = (input_ref[(b, k_even, 0)], input_ref[(b, k_even, 1)]);
|
||||
let o_k = (input_ref[(b, k_odd, 0)], input_ref[(b, k_odd, 1)]);
|
||||
|
||||
let x_k = (
|
||||
e_k.0 + twiddle.0 * o_k.0 - twiddle.1 * o_k.1,
|
||||
e_k.1 + twiddle.0 * o_k.1 + twiddle.1 * o_k.0,
|
||||
);
|
||||
|
||||
let out_id = match is_last_iter {
|
||||
false => sample_id,
|
||||
true => reverse_bits(sample_id, num_fft_iters),
|
||||
};
|
||||
|
||||
output_ref[(b, out_id, 0)] = x_k.0;
|
||||
output_ref[(b, out_id, 1)] = x_k.1;
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let output = match num_fft_iters % 2 == 0 {
|
||||
true => input,
|
||||
false => output,
|
||||
};
|
||||
|
||||
NdArrayTensor::new(output.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
fn get_twiddle_factor<E: burn_tensor::Element>(
|
||||
params: &SingleIterParams,
|
||||
sample_id: usize,
|
||||
) -> (E, E) {
|
||||
// Indices are bit reversed at each iteration, but there's also a different
|
||||
// number of values at each iteration. e.g., for the 3 iterations of an 8
|
||||
// length transform the mapping from sample_id to k is as follows:
|
||||
// 0: [0, 0, 0, 0, 1, 1, 1, 1]
|
||||
// 1: [0, 0, 2, 2, 1, 1, 3, 3]
|
||||
// 2: [0, 4, 2, 6, 1, 7, 3, 5]
|
||||
let k = reverse_bits(sample_id >> (params.remaining_iterations), params.iteration) as f64;
|
||||
|
||||
/*
|
||||
`nth_root_of_unity` has the property that
|
||||
|
||||
exp(-j * root) = 1.
|
||||
|
||||
Raising both sides to the power of any integer k,
|
||||
|
||||
exp(-j * root)^k = 1^k
|
||||
exp(-j * k * root) = 1
|
||||
|
||||
The nth root is special because it satisfies this.
|
||||
*/
|
||||
let complex_angle = k * params.nth_root_of_unity;
|
||||
let sign = {
|
||||
if (sample_id & params.sign_mask) > 0 {
|
||||
-1.
|
||||
} else {
|
||||
1.
|
||||
}
|
||||
};
|
||||
|
||||
// Euler's formula: exp^(ix) = cos(x) + i sin(x)
|
||||
// These are both real numbers, but represent coefficients of a + i*b.
|
||||
(
|
||||
(sign * f64::cos(complex_angle)).elem::<E>(),
|
||||
(sign * f64::sin(complex_angle)).elem::<E>(),
|
||||
)
|
||||
}
|
||||
|
||||
fn reverse_bits(n: usize, no_of_bits: usize) -> usize {
|
||||
let mut result = 0;
|
||||
let mut n = n;
|
||||
|
||||
for _ in 0..no_of_bits {
|
||||
result <<= 1;
|
||||
result |= n & 1;
|
||||
n >>= 1;
|
||||
}
|
||||
result
|
||||
}
|
|
@ -8,6 +8,7 @@ mod tensor;
|
|||
pub(crate) mod adaptive_avgpool;
|
||||
pub(crate) mod avgpool;
|
||||
pub(crate) mod conv;
|
||||
pub(crate) mod fft;
|
||||
pub(crate) mod interpolate;
|
||||
pub(crate) mod macros;
|
||||
pub(crate) mod matmul;
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::{
|
|||
adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
|
||||
avgpool::{avg_pool2d, avg_pool2d_backward},
|
||||
conv::{conv2d, conv_transpose2d},
|
||||
fft::fft1d,
|
||||
interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
|
||||
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
|
||||
};
|
||||
|
@ -130,4 +131,8 @@ impl<E: FloatNdArrayElement> ModuleOps<Self> for NdArray<E> {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn fft(x: NdArrayTensor<E, 3>) -> NdArrayTensor<E, 3> {
|
||||
fft1d(x)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -344,4 +344,8 @@ impl<E: TchElement> ModuleOps<Self> for LibTorch<E> {
|
|||
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn fft(_x: TchTensor<E, 3>) -> TchTensor<E, 3> {
|
||||
todo!();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -231,3 +231,19 @@ where
|
|||
{
|
||||
Tensor::new(B::interpolate(x.primitive, output_size, options))
|
||||
}
|
||||
|
||||
/// Applies a [1D fast fourier transform (FFT)](crate::ops::ModuleOps::fft).
|
||||
pub fn fft<B>(x: Tensor<B, 3>) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(B::fft(x.primitive))
|
||||
}
|
||||
|
||||
/// Applies an [inverse 1D fast fourier transform (FFT)](crate::ops::ModuleOps::ifft).
|
||||
pub fn ifft<B>(x: Tensor<B, 3>) -> Tensor<B, 3>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(B::ifft(x.primitive))
|
||||
}
|
||||
|
|
|
@ -2,7 +2,7 @@ use super::{conv, pool, unfold::unfold4d_using_conv2d};
|
|||
use crate::{
|
||||
backend::Backend,
|
||||
ops::{FloatTensor, IntTensor},
|
||||
Shape,
|
||||
ElementConversion, Shape,
|
||||
};
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
|
||||
|
@ -488,4 +488,47 @@ pub trait ModuleOps<B: Backend> {
|
|||
output_size: [usize; 2],
|
||||
options: InterpolateOptions,
|
||||
) -> FloatTensor<B, 4>;
|
||||
|
||||
/// One dimensional Fast Fourier Transform (FFT).
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// x: `[batch_size, length, complex]`, where `complex` is exactly 2.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// X: `[batch_size, length, complex]`, where `complex` is exactly 2.
|
||||
fn fft(x: FloatTensor<B, 3>) -> FloatTensor<B, 3>;
|
||||
|
||||
/// One dimensional inverse Fast Fourier Transform (FFT).
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// x: `[batch_size, length, complex]`, where `complex` is exactly 2.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// X: `[batch_size, length, complex]`, where `complex` is exactly 2.
|
||||
fn ifft(x: FloatTensor<B, 3>) -> FloatTensor<B, 3> {
|
||||
let [b, s, _] = B::float_shape(&x).dims;
|
||||
|
||||
// Conjugate input
|
||||
let x = B::float_slice_assign(
|
||||
x.clone(),
|
||||
[0..b, 0..s, 1..2],
|
||||
B::float_neg(B::float_slice(x, [0..b, 0..s, 1..2])),
|
||||
);
|
||||
|
||||
// Use forwards 1D FFT of this backend to calculate inverse.
|
||||
let x = B::fft(x);
|
||||
|
||||
// Conjugate output
|
||||
let x = B::float_slice_assign(
|
||||
x.clone(),
|
||||
[0..b, 0..s, 1..2],
|
||||
B::float_neg(B::float_slice(x, [0..b, 0..s, 1..2])),
|
||||
);
|
||||
|
||||
B::float_div_scalar(x, (s as f32).elem::<B::FloatElem>())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -36,6 +36,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_module_nearest_interpolate!();
|
||||
burn_tensor::testgen_module_bilinear_interpolate!();
|
||||
burn_tensor::testgen_module_bicubic_interpolate!();
|
||||
burn_tensor::testgen_module_fft!();
|
||||
|
||||
// test ops
|
||||
burn_tensor::testgen_add!();
|
||||
|
|
|
@ -0,0 +1,335 @@
|
|||
#[burn_tensor_testgen::testgen(module_fft)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::module::{fft, ifft};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
fn single_elem() -> (TestTensor<3>, TestTensor<3>) {
|
||||
(
|
||||
TestTensor::from([[[1., 0.]]]),
|
||||
TestTensor::from([[[1., 0.]]]),
|
||||
)
|
||||
}
|
||||
|
||||
fn delta_1d() -> (TestTensor<3>, TestTensor<3>) {
|
||||
(
|
||||
// Delta function -> Flat frequency spectrum
|
||||
TestTensor::from([[[1., 0.], [0., 0.]], [[0., 1.], [0., 0.]]]),
|
||||
TestTensor::from([[[1., 0.], [1., 0.]], [[0., 1.], [0., 1.]]]),
|
||||
)
|
||||
}
|
||||
|
||||
fn simple_1d() -> (TestTensor<3>, TestTensor<3>) {
|
||||
(
|
||||
TestTensor::from([[[1., 3.], [2.3, -1.]], [[-1., 0.], [-1., 0.]]]),
|
||||
TestTensor::from([[[3.3, 2.0], [-1.3, 4.0]], [[-2.0, 0.0], [0.0, 0.0]]]),
|
||||
)
|
||||
}
|
||||
|
||||
fn even_pow_2_1d() -> (TestTensor<3>, TestTensor<3>) {
|
||||
// FFT size 2^4 = 16
|
||||
(
|
||||
TestTensor::from([[
|
||||
[-0.2322, -0.8782],
|
||||
[0.0283, 0.0621],
|
||||
[0.6704, 0.9188],
|
||||
[-0.6543, 0.9008],
|
||||
[0.4139, 0.7885],
|
||||
[0.5795, 0.8549],
|
||||
[0.1336, 1.4838],
|
||||
[0.0967, 0.0182],
|
||||
[-0.2705, -1.6957],
|
||||
[1.2414, -0.0782],
|
||||
[1.0207, -1.2235],
|
||||
[1.0662, 1.0242],
|
||||
[0.5177, -0.1858],
|
||||
[-0.4157, 0.1690],
|
||||
[-0.5877, 0.9780],
|
||||
[1.0103, -0.9347],
|
||||
]]),
|
||||
TestTensor::from([[
|
||||
[4.6184, 2.2021],
|
||||
[1.7495, 2.2396],
|
||||
[-1.6534, -8.3570],
|
||||
[4.8833, 1.4184],
|
||||
[-0.8089, -4.0430],
|
||||
[-1.7017, 1.2770],
|
||||
[1.3325, 1.4228],
|
||||
[-1.3123, 4.6775],
|
||||
[-1.2865, -1.8301],
|
||||
[2.5052, 1.3927],
|
||||
[-6.7483, -2.2867],
|
||||
[-1.4950, -2.8296],
|
||||
[-0.8072, -4.2135],
|
||||
[1.4975, -1.2243],
|
||||
[1.3319, -3.4857],
|
||||
[-5.8203, -0.4112],
|
||||
]]),
|
||||
)
|
||||
}
|
||||
|
||||
fn odd_pow_2_1d() -> (TestTensor<3>, TestTensor<3>) {
|
||||
// FFT size 2^5 = 32
|
||||
(
|
||||
TestTensor::from([
|
||||
[
|
||||
[0.0427, -0.9848],
|
||||
[0.6244, 0.0441],
|
||||
[0.2514, 1.2113],
|
||||
[0.9988, 1.5251],
|
||||
[-0.0979, 0.3475],
|
||||
[0.6723, -1.9583],
|
||||
[0.4330, -0.4186],
|
||||
[0.2924, -0.1259],
|
||||
[-0.9504, -0.9404],
|
||||
[-0.8480, -1.4261],
|
||||
[1.6096, -1.6685],
|
||||
[-0.4452, -0.2472],
|
||||
[-0.7371, 1.9884],
|
||||
[0.4746, 0.9134],
|
||||
[0.3894, -0.7379],
|
||||
[-0.8402, -0.6750],
|
||||
[-1.2116, -0.0552],
|
||||
[-0.5223, 0.7814],
|
||||
[1.6739, -0.8242],
|
||||
[0.8573, 0.3881],
|
||||
[-0.0345, 1.4219],
|
||||
[0.1038, 1.9030],
|
||||
[-0.3701, 1.0827],
|
||||
[-0.7380, 1.3959],
|
||||
[1.2852, -1.1371],
|
||||
[0.4140, -0.2322],
|
||||
[-0.0631, 1.0053],
|
||||
[-0.3737, 0.2743],
|
||||
[2.1822, 0.7284],
|
||||
[0.6077, 1.2214],
|
||||
[-0.3532, -0.6910],
|
||||
[-0.5390, -0.0049],
|
||||
],
|
||||
[
|
||||
[0.7324, -0.7980],
|
||||
[0.8606, -1.7983],
|
||||
[0.4761, 0.5743],
|
||||
[0.1872, 1.0686],
|
||||
[0.3460, -0.2946],
|
||||
[0.9188, -1.6348],
|
||||
[-0.5201, 0.2241],
|
||||
[-1.8877, 0.9795],
|
||||
[0.4466, -0.5693],
|
||||
[1.8571, 0.6358],
|
||||
[-1.0278, 0.9778],
|
||||
[-0.5889, 0.1738],
|
||||
[-0.2158, -0.3417],
|
||||
[1.3702, 1.5540],
|
||||
[0.8480, -1.5970],
|
||||
[0.9213, -0.8815],
|
||||
[0.3899, -0.2117],
|
||||
[1.1022, -0.5712],
|
||||
[-1.2964, -1.4468],
|
||||
[-0.0170, -0.6138],
|
||||
[-0.4706, 0.5001],
|
||||
[0.0794, -1.2067],
|
||||
[1.1162, -0.6329],
|
||||
[0.1651, 0.1031],
|
||||
[0.3490, -0.2715],
|
||||
[0.6509, 0.0565],
|
||||
[1.0147, 0.2431],
|
||||
[0.8052, -0.9913],
|
||||
[0.9142, 0.9090],
|
||||
[1.0906, 0.5651],
|
||||
[1.5763, 0.5397],
|
||||
[0.3499, 0.4359],
|
||||
],
|
||||
]),
|
||||
TestTensor::from([
|
||||
[
|
||||
[4.7884, 4.1050],
|
||||
[-5.4979, 0.9704],
|
||||
[2.2100, 0.8929],
|
||||
[3.2625, 4.2363],
|
||||
[-6.3352, -14.4654],
|
||||
[7.1075, 0.2420],
|
||||
[2.7235, -4.1024],
|
||||
[-3.3785, -15.7396],
|
||||
[-4.3761, 0.0957],
|
||||
[6.6426, 3.5467],
|
||||
[2.1965, 6.7494],
|
||||
[-5.0247, 1.5957],
|
||||
[-0.1676, -1.6919],
|
||||
[-1.9747, -4.9014],
|
||||
[-1.8878, -5.7946],
|
||||
[-4.8639, -2.0970],
|
||||
[3.3108, -3.4494],
|
||||
[1.8428, 3.2496],
|
||||
[-2.1688, 2.2887],
|
||||
[1.8962, -0.4687],
|
||||
[3.0188, -7.4876],
|
||||
[-1.5869, 5.1769],
|
||||
[0.1114, 3.9235],
|
||||
[7.7439, -6.9099],
|
||||
[-1.8084, 4.7236],
|
||||
[11.4196, -0.7087],
|
||||
[-12.0418, 0.5287],
|
||||
[-3.4558, -2.9853],
|
||||
[-5.1034, -6.7704],
|
||||
[-6.3451, 2.8732],
|
||||
[-3.1719, 3.8135],
|
||||
[12.2821, -2.9533],
|
||||
],
|
||||
[
|
||||
[12.5437, -4.3209],
|
||||
[6.6402, 7.5440],
|
||||
[-1.5249, 1.4718],
|
||||
[-3.2183, -1.4000],
|
||||
[1.8995, -0.0458],
|
||||
[-5.3435, -6.2148],
|
||||
[-1.4322, -2.0123],
|
||||
[2.7689, -3.0387],
|
||||
[-2.3689, -7.9545],
|
||||
[-13.4452, -4.8216],
|
||||
[0.6306, 10.9365],
|
||||
[-0.9051, 0.5564],
|
||||
[-1.9752, -7.8168],
|
||||
[2.5935, 0.0319],
|
||||
[-4.3665, -0.1909],
|
||||
[7.7858, 1.6211],
|
||||
[-3.1862, -0.0701],
|
||||
[1.0940, 5.9505],
|
||||
[0.8552, -0.3137],
|
||||
[-1.4151, -1.5680],
|
||||
[4.4181, 2.5067],
|
||||
[2.9776, -3.9427],
|
||||
[-0.6781, 7.6331],
|
||||
[2.0782, 0.4852],
|
||||
[2.9787, 8.0347],
|
||||
[5.6104, -9.2327],
|
||||
[-0.1015, -9.4777],
|
||||
[-3.1915, 0.0535],
|
||||
[1.0344, -5.1372],
|
||||
[0.2312, 5.2132],
|
||||
[9.2310, -9.3980],
|
||||
[1.2195, -0.6194],
|
||||
],
|
||||
]),
|
||||
)
|
||||
}
|
||||
|
||||
fn reshaped_input_1d() -> (TestTensor<3>, TestTensor<3>) {
|
||||
// "Reshape" might just changes strides, and not alter underlying buffer.
|
||||
// This test ensures op is correctly using stride info.
|
||||
let x1 = TestTensor::<1>::from([
|
||||
0., 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15.,
|
||||
]);
|
||||
let x3: TestTensor<3> = x1.reshape([2, 4, 2]).swap_dims(0, 2);
|
||||
|
||||
let x_hat = TestTensor::from([
|
||||
[
|
||||
[8., 0., -8., 0.],
|
||||
[12., 0., -8., 0.],
|
||||
[16., 0., -8., 0.],
|
||||
[20., 0., -8., 0.],
|
||||
],
|
||||
[
|
||||
[10., 0., -8., 0.],
|
||||
[14., 0., -8., 0.],
|
||||
[18., 0., -8., 0.],
|
||||
[22., 0., -8., 0.],
|
||||
],
|
||||
]);
|
||||
|
||||
(x3, x_hat)
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_single_elem() {
|
||||
let (x, x_hat) = single_elem();
|
||||
assert_output(fft(x), x_hat);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft_single_elem() {
|
||||
let (x, x_hat) = single_elem();
|
||||
assert_output(x, ifft(x_hat));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_delta_1d() {
|
||||
let (x, x_hat) = delta_1d();
|
||||
assert_output(fft(x), x_hat);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft_delta_1d() {
|
||||
let (x, x_hat) = delta_1d();
|
||||
assert_output(x, ifft(x_hat));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_simple_1d() {
|
||||
let (x, x_hat) = simple_1d();
|
||||
assert_output(fft(x), x_hat);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft_simple_1d() {
|
||||
let (x, x_hat) = simple_1d();
|
||||
assert_output(x, ifft(x_hat));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_even_pow_2_1d() {
|
||||
let (x, x_hat) = even_pow_2_1d();
|
||||
assert_output(fft(x), x_hat);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft_even_pow_2_1d() {
|
||||
let (x, x_hat) = even_pow_2_1d();
|
||||
assert_output(x, ifft(x_hat));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_odd_pow_2_1d() {
|
||||
let (x, x_hat) = odd_pow_2_1d();
|
||||
assert_output(fft(x), x_hat);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ifft_odd_pow_2_1d() {
|
||||
let (x, x_hat) = odd_pow_2_1d();
|
||||
assert_output(x, ifft(x_hat));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_fft_round_trip_precision() {
|
||||
let shape_x = Shape::new([1, 16, 2]);
|
||||
let x = TestTensor::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &Default::default())
|
||||
.reshape(shape_x)
|
||||
.into_data()
|
||||
.convert(),
|
||||
);
|
||||
|
||||
// 2^10 shape = 9 iterations each way, can accumulate considerable error.
|
||||
assert_output(x.clone(), ifft(fft(x)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_invalid_input_non_complex_structure() {
|
||||
// Last dim must have dimension 2 (real, imaginary)
|
||||
let x = TestTensor::from([[[0., 0., 0.], [0., 0., 0.]]]);
|
||||
fft(x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_invalid_input_non_pow_2() {
|
||||
let x = TestTensor::from([[[0., 1.], [2., 3.], [4., 5.]]]);
|
||||
fft(x);
|
||||
}
|
||||
|
||||
fn assert_output(x: TestTensor<3>, y: TestTensor<3>) {
|
||||
x.to_data().assert_approx_eq(&y.into_data(), 3);
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@ mod conv1d;
|
|||
mod conv2d;
|
||||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod fft;
|
||||
mod forward;
|
||||
mod maxpool1d;
|
||||
mod maxpool2d;
|
||||
|
|
Loading…
Reference in New Issue