From 831335ac2e5c93cf9bb01f0d8421ddc44e163ee7 Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Mon, 13 Nov 2023 07:20:50 -0500 Subject: [PATCH] Perf/wgpu/reduce dim (#943) * new reduce half working * surprisingly working * good on elongated matrix, bad on balanced ones * working and clean * autotune not tested, tests fail at non contiguous * fixed * autotune tested * mean dim * some fixes * clippy --- burn-wgpu/Cargo.toml | 4 + burn-wgpu/benches/reduction.rs | 108 +++++++++++ burn-wgpu/src/compute/tune_key.rs | 8 +- burn-wgpu/src/kernel/mod.rs | 4 +- burn-wgpu/src/kernel/reduce/base.rs | 22 +++ burn-wgpu/src/kernel/reduce/mod.rs | 9 + .../src/kernel/{ => reduce}/reduction.rs | 56 +++--- .../kernel/reduce/reduction_shared_memory.rs | 170 ++++++++++++++++++ burn-wgpu/src/kernel/reduce/tune/base.rs | 27 +++ burn-wgpu/src/kernel/reduce/tune/key.rs | 49 +++++ burn-wgpu/src/kernel/reduce/tune/mean_dim.rs | 112 ++++++++++++ burn-wgpu/src/kernel/reduce/tune/mod.rs | 9 + burn-wgpu/src/kernel/reduce/tune/sum_dim.rs | 112 ++++++++++++ burn-wgpu/src/ops/float_ops.rs | 33 +++- burn-wgpu/src/ops/int_ops.rs | 14 +- .../reduction/reduce_dim_shared_memory.wgsl | 92 ++++++++++ 16 files changed, 790 insertions(+), 39 deletions(-) create mode 100644 burn-wgpu/benches/reduction.rs create mode 100644 burn-wgpu/src/kernel/reduce/base.rs create mode 100644 burn-wgpu/src/kernel/reduce/mod.rs rename burn-wgpu/src/kernel/{ => reduce}/reduction.rs (84%) create mode 100644 burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs create mode 100644 burn-wgpu/src/kernel/reduce/tune/base.rs create mode 100644 burn-wgpu/src/kernel/reduce/tune/key.rs create mode 100644 burn-wgpu/src/kernel/reduce/tune/mean_dim.rs create mode 100644 burn-wgpu/src/kernel/reduce/tune/mod.rs create mode 100644 burn-wgpu/src/kernel/reduce/tune/sum_dim.rs create mode 100644 burn-wgpu/src/template/reduction/reduce_dim_shared_memory.wgsl diff --git a/burn-wgpu/Cargo.toml b/burn-wgpu/Cargo.toml index f45f23392..4917e2423 100644 --- a/burn-wgpu/Cargo.toml +++ b/burn-wgpu/Cargo.toml @@ -57,3 +57,7 @@ serial_test = "2.0.0" [[bench]] name = "matmul" harness = false + +[[bench]] +name = "reduction" +harness = false diff --git a/burn-wgpu/benches/reduction.rs b/burn-wgpu/benches/reduction.rs new file mode 100644 index 000000000..7eac3440a --- /dev/null +++ b/burn-wgpu/benches/reduction.rs @@ -0,0 +1,108 @@ +use burn_common::benchmark::{run_benchmark, Benchmark}; +use burn_tensor::backend::Backend; +use burn_tensor::{Distribution, Shape, Tensor}; +use burn_wgpu::kernel::reduce::{init_reduce_output, sum_dim, sum_dim_shared_memory}; +use burn_wgpu::WgpuDevice; +use burn_wgpu::{AutoGraphicsApi, Wgpu}; +use derive_new::new; +use std::marker::PhantomData; + +use burn_wgpu::GraphicsApi; + +type WTensor = Tensor, D>; + +#[derive(new)] +struct ReduceBenchmark { + shape: Shape, + dim: usize, + num_repeats: usize, + device: B::Device, + reduce: PhantomData, +} + +trait ReduceFunction { + fn run(input: WTensor, dim: usize) -> WTensor; +} + +impl Benchmark for ReduceBenchmark, F, D> +where + F: ReduceFunction, + G: GraphicsApi, +{ + type Args = WTensor; + + fn name(&self) -> String { + format!( + "{:?} {:?} dim={:?}", + std::any::type_name::(), + self.shape.dims, + self.dim + ) + } + + fn num_samples(&self) -> usize { + 10 + } + + fn execute(&self, input: Self::Args) { + for _ in 0..self.num_repeats { + F::run(input.clone(), self.dim); + } + } + + fn prepare(&self) -> Self::Args { + WTensor::random_device(self.shape.clone(), Distribution::Default, &self.device) + } + + fn sync(&self) { + Wgpu::::sync(&self.device) + } +} + +macro_rules! bench_reduce { + ($benchmark:ident, $reduce_name:ident, $func:expr) => { + struct $reduce_name {} + impl ReduceFunction for $reduce_name { + fn run(input: WTensor, dim: usize) -> WTensor { + let input = input.into_primitive(); + let output = init_reduce_output(&input, dim); + Tensor::from_primitive($func(input, output, dim)) + } + } + type $benchmark = + ReduceBenchmark, $reduce_name, D>; + }; +} + +bench_reduce!(SumDimBenchmark, SumDim, sum_dim); +bench_reduce!( + SumDimSharedMemoryBenchmark, + SumDimSharedMemory, + sum_dim_shared_memory +); + +#[allow(dead_code)] +/// Runs the benchmarks for wgpu matmul implementations +pub fn bench(device: &WgpuDevice) { + let num_repeats = 3; + let shape = Shape::new([50, 8000, 50]); + let dim = 1; + + macro_rules! run_reduce_benchmark { + ($benchmark:ident) => { + run_benchmark($benchmark::new( + shape.clone(), + dim, + num_repeats, + device.clone(), + )); + }; + } + + run_reduce_benchmark!(SumDimSharedMemoryBenchmark); + run_reduce_benchmark!(SumDimBenchmark); +} + +fn main() { + bench(&WgpuDevice::BestAvailable) +} diff --git a/burn-wgpu/src/compute/tune_key.rs b/burn-wgpu/src/compute/tune_key.rs index 2345dc2d8..2b2ce2501 100644 --- a/burn-wgpu/src/compute/tune_key.rs +++ b/burn-wgpu/src/compute/tune_key.rs @@ -2,19 +2,25 @@ use std::fmt::Display; use burn_compute::tune::AutotuneKey; -use crate::kernel::matmul::MatmulAutotuneKey; +use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey}; #[derive(Hash, Eq, PartialEq, Debug, Clone)] /// Key for all autotune-enabled operations pub enum WgpuAutotuneKey { /// Key for matmul operation Matmul(MatmulAutotuneKey), + /// Key for sum_dim operations + SumDim(ReduceAutotuneKey), + /// Key for mean_dim operations + MeanDim(ReduceAutotuneKey), } impl Display for WgpuAutotuneKey { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f), + WgpuAutotuneKey::SumDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + WgpuAutotuneKey::MeanDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), } } } diff --git a/burn-wgpu/src/kernel/mod.rs b/burn-wgpu/src/kernel/mod.rs index 59d182df3..e30f100ca 100644 --- a/burn-wgpu/src/kernel/mod.rs +++ b/burn-wgpu/src/kernel/mod.rs @@ -6,7 +6,6 @@ mod clamp; mod comparison; mod index; mod mask; -mod reduction; mod source; mod unary; mod unary_scalar; @@ -26,10 +25,11 @@ pub mod matmul; pub mod pool; /// Pseudo-random number generator kernels pub mod prng; +/// Reduction algorithms +pub mod reduce; pub(crate) use cat::*; pub(crate) use clamp::*; pub(crate) use comparison::*; pub(crate) use index::*; pub(crate) use mask::*; -pub(crate) use reduction::*; diff --git a/burn-wgpu/src/kernel/reduce/base.rs b/burn-wgpu/src/kernel/reduce/base.rs new file mode 100644 index 000000000..0f5836960 --- /dev/null +++ b/burn-wgpu/src/kernel/reduce/base.rs @@ -0,0 +1,22 @@ +use crate::{element::WgpuElement, tensor::WgpuTensor}; + +/// Creates an empty output tensor with reduce output shape +pub fn init_reduce_output( + input: &WgpuTensor, + reduce_dim: usize, +) -> WgpuTensor { + let mut shape_out = input.shape.clone(); + shape_out.dims[reduce_dim] = 1; + + // Create output handle + let num_elems_output = shape_out.num_elements(); + let handle = input + .client + .empty(num_elems_output * core::mem::size_of::()); + WgpuTensor::new( + input.client.clone(), + input.device.clone(), + shape_out.clone(), + handle, + ) +} diff --git a/burn-wgpu/src/kernel/reduce/mod.rs b/burn-wgpu/src/kernel/reduce/mod.rs new file mode 100644 index 000000000..a6bca6e3a --- /dev/null +++ b/burn-wgpu/src/kernel/reduce/mod.rs @@ -0,0 +1,9 @@ +mod base; +mod reduction; +mod reduction_shared_memory; +mod tune; + +pub use base::*; +pub use reduction::*; +pub use reduction_shared_memory::*; +pub use tune::*; diff --git a/burn-wgpu/src/kernel/reduction.rs b/burn-wgpu/src/kernel/reduce/reduction.rs similarity index 84% rename from burn-wgpu/src/kernel/reduction.rs rename to burn-wgpu/src/kernel/reduce/reduction.rs index e6a81af56..432f67882 100644 --- a/burn-wgpu/src/kernel/reduction.rs +++ b/burn-wgpu/src/kernel/reduce/reduction.rs @@ -1,18 +1,26 @@ -use super::{build_info, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT}; use crate::{ - compute::StaticKernel, element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, + compute::StaticKernel, + element::WgpuElement, + kernel::{ + build_info, elemwise_workgroup, KernelSettings, SourceTemplate, StaticKernelSource, + WORKGROUP_DEFAULT, + }, + kernel_wgsl, tensor::WgpuTensor, }; use burn_tensor::Shape; -kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl"); -kernel_wgsl!(ReductionDimRaw, "../template/reduction/reduce_dim.wgsl"); -kernel_wgsl!(ReductionArgsRaw, "../template/reduction/args.wgsl"); +kernel_wgsl!( + RecursiveSumRaw, + "../../template/reduction/recursive_sum.wgsl" +); +kernel_wgsl!(ReductionDimRaw, "../../template/reduction/reduce_dim.wgsl"); +kernel_wgsl!(ReductionArgsRaw, "../../template/reduction/args.wgsl"); -pub struct ArgsMax; -pub struct ArgsMin; -pub struct SumDim; -pub struct MeanDim; +pub(crate) struct ArgsMax; +pub(crate) struct ArgsMin; +pub(crate) struct SumDim; +pub(crate) struct MeanDim; impl StaticKernelSource for SumDim { fn source() -> SourceTemplate { @@ -79,37 +87,29 @@ pub fn sum(input: WgpuTensor) -> WgpuTenso /// Execute the sum dim kernel. pub fn sum_dim( input: WgpuTensor, + output: WgpuTensor, dim: usize, ) -> WgpuTensor { - reduction_dim::(input, dim) + reduction_dim::(input, output, dim) } /// Execute the mean dim kernel. pub fn mean_dim( input: WgpuTensor, + output: WgpuTensor, dim: usize, ) -> WgpuTensor { - reduction_dim::(input, dim) + reduction_dim::(input, output, dim) } fn reduction_dim( input: WgpuTensor, + output: WgpuTensor, dim: usize, ) -> WgpuTensor { - let mut shape_out = input.shape.clone(); - shape_out.dims[dim] = 1; - let num_elems = shape_out.num_elements(); - let handle = input.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new( - input.client.clone(), - input.device.clone(), - shape_out, - handle, - ); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP_DEFAULT), + elemwise_workgroup(output.shape.num_elements(), WORKGROUP_DEFAULT), ); let mut info = build_info(&[&input, &output]); @@ -174,7 +174,10 @@ fn reduction_args_dim::random([6, 1024], Distribution::Default); let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 1; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); let val = Tensor::::from_primitive(reduction_dim::( tensor.into_primitive(), - 1, + output, + reduce_dim, )); let val_ref = tensor_ref.sum_dim(1); diff --git a/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs b/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs new file mode 100644 index 000000000..4d4fb43e3 --- /dev/null +++ b/burn-wgpu/src/kernel/reduce/reduction_shared_memory.rs @@ -0,0 +1,170 @@ +use crate::{ + compute::{StaticKernel, WorkGroup}, + element::WgpuElement, + kernel::{build_info, KernelSettings, SourceTemplate, StaticKernelSource, WORKGROUP_DEFAULT}, + kernel_wgsl, + tensor::WgpuTensor, +}; + +kernel_wgsl!( + ReductionDimSharedMemoryRaw, + "../../template/reduction/reduce_dim_shared_memory.wgsl" +); + +pub(crate) struct SumDimSharedMemory; +pub(crate) struct MeanDimSharedMemory; + +impl StaticKernelSource for SumDimSharedMemory { + fn source() -> SourceTemplate { + ReductionDimSharedMemoryRaw::source() + .register( + "shared_size", + (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), + ) + .register("initial", 0.0.to_string()) + .register("update", "shared_memory[local_id] += value; ") + .register("assign", "output[output_position] = final_value; ") + } +} + +impl StaticKernelSource for MeanDimSharedMemory { + fn source() -> SourceTemplate { + ReductionDimSharedMemoryRaw::source() + .register( + "shared_size", + (WORKGROUP_DEFAULT * WORKGROUP_DEFAULT).to_string(), + ) + .register("initial", 0.0.to_string()) + .register("update", "shared_memory[local_id] += value; ") + .add_template( + "fn mean_dim(sum: {{ elem }}, dim: u32) -> {{ elem }} { + return sum / {{ elem }}(dim); + }", + ) + .register( + "assign", + "output[output_position] = mean_dim(final_value, shape_input_dim_reduce);", + ) + } +} + +/// Execute the sum dim kernel leveraging shared memory +/// Probably more efficient on tensors where the dimension to reduced +/// is much larger than the others +pub fn sum_dim_shared_memory( + input: WgpuTensor, + output: WgpuTensor, + dim: usize, +) -> WgpuTensor { + reduction_dim_shared_memory::(input, output, dim) +} + +/// Execute the mean dim kernel leveraging shared memory +/// Probably more efficient on tensors where the dimension to reduced +/// is much larger than the others +pub fn mean_dim_shared_memory( + input: WgpuTensor, + output: WgpuTensor, + dim: usize, +) -> WgpuTensor { + reduction_dim_shared_memory::(input, output, dim) +} + +fn reduction_dim_shared_memory( + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, +) -> WgpuTensor { + let num_elems_output = output.shape.num_elements(); + let n_workgroups_x = f32::ceil(f32::sqrt(num_elems_output as f32)); + let n_workgroups_y = f32::ceil(num_elems_output as f32 / n_workgroups_x); + let grid = WorkGroup::new(n_workgroups_x as u32, n_workgroups_y as u32, 1); + + let kernel = + StaticKernel::>::new( + grid, + ); + + // Build info + let mut info = build_info(&[&input, &output]); + + // Reduce groups are elements that are aligned along the reduce dim + let reduce_group_size = input.shape.dims[reduce_dim]; + let n_invocation_per_workgroup = WORKGROUP_DEFAULT * WORKGROUP_DEFAULT; + let n_reduce_elements_per_thread = + f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32; + + // Add dimension of reduction and how many reduce elements are treated per thread + info.push(reduce_dim as u32); + info.push(n_reduce_elements_per_thread); + + let info_handle = input.client.create(bytemuck::cast_slice(&info)); + + input.client.execute( + Box::new(kernel), + &[&input.handle, &output.handle, &info_handle], + ); + + output +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + kernel::reduce::init_reduce_output, + tests::{ReferenceBackend, TestBackend}, + }; + use burn_tensor::{Distribution, Tensor}; + + #[test] + fn reduction_sum_dim_shared_memory_small() { + let tensor = Tensor::::random([700], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 0; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(sum_dim_shared_memory( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(reduce_dim); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_sum_dim_shared_memory_medium() { + let tensor = Tensor::::random([6, 1024], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 1; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(sum_dim_shared_memory( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(reduce_dim); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } + + #[test] + fn reduction_sum_dim_shared_memory_large() { + let tensor = Tensor::::random([4, 1024, 50], Distribution::Default); + let tensor_ref = Tensor::::from_data(tensor.to_data()); + let reduce_dim = 2; + let output = init_reduce_output(&tensor.clone().into_primitive(), reduce_dim); + + let val = Tensor::::from_primitive(sum_dim_shared_memory( + tensor.into_primitive(), + output, + reduce_dim, + )); + let val_ref = tensor_ref.sum_dim(reduce_dim); + + val_ref.into_data().assert_approx_eq(&val.into_data(), 3); + } +} diff --git a/burn-wgpu/src/kernel/reduce/tune/base.rs b/burn-wgpu/src/kernel/reduce/tune/base.rs new file mode 100644 index 000000000..d52bf37dc --- /dev/null +++ b/burn-wgpu/src/kernel/reduce/tune/base.rs @@ -0,0 +1,27 @@ +#[macro_export] +/// Generate an autotune operation for a reduce kernel +macro_rules! reduce_tune_ops { + ($name:ident, $func:expr) => { + #[derive(new)] + pub(crate) struct $name { + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, + } + + impl AutotuneOperation for $name { + fn execute(self: Box) { + #[allow(clippy::redundant_closure_call)] + $func(self.input, self.output, self.reduce_dim); + } + + fn clone(&self) -> Box { + Box::new(Self { + input: self.input.clone(), + output: self.output.clone(), + reduce_dim: self.reduce_dim.clone(), + }) + } + } + }; +} diff --git a/burn-wgpu/src/kernel/reduce/tune/key.rs b/burn-wgpu/src/kernel/reduce/tune/key.rs new file mode 100644 index 000000000..db5e4b21b --- /dev/null +++ b/burn-wgpu/src/kernel/reduce/tune/key.rs @@ -0,0 +1,49 @@ +use std::{cmp::min, fmt::Display}; + +use burn_tensor::Shape; + +#[derive(Hash, Eq, PartialEq, Debug, Clone)] +/// Autotune key representative of reduce versions +pub struct ReduceAutotuneKey { + reduce_dim_length: usize, + others_product: usize, +} + +impl Display for ReduceAutotuneKey { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + f.write_str( + format!( + "Reduce - reduce_dim_length: {:?} others_product: {:?}", + self.reduce_dim_length, self.others_product + ) + .as_str(), + ) + } +} + +impl ReduceAutotuneKey { + /// Create a reduce autotune key from the input shape and reduce dim + pub fn new(shape: &Shape, reduce_dim: usize) -> Self { + let reduce_dim_length = shape.dims[reduce_dim]; + let mut others_product = 1; + for d in 0..D { + if d != reduce_dim { + others_product *= shape.dims[d] + } + } + Self { + reduce_dim_length: anchor(reduce_dim_length, None), + others_product: anchor(others_product, None), + } + } +} + +fn anchor(x: usize, max: Option) -> usize { + let exp = f32::ceil(f32::log2(x as f32)) as u32; + let power_of_2 = 2_u32.pow(exp) as usize; + if let Some(max) = max { + min(power_of_2, max) + } else { + power_of_2 + } +} diff --git a/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs b/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs new file mode 100644 index 000000000..a19fc7cf3 --- /dev/null +++ b/burn-wgpu/src/kernel/reduce/tune/mean_dim.rs @@ -0,0 +1,112 @@ +use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; +use burn_tensor::{Element, ElementConversion}; + +use crate::{ + compute::WgpuAutotuneKey, + element::WgpuElement, + kernel::{ + prng::random_like_uniform, + reduce::{init_reduce_output, mean_dim, mean_dim_shared_memory}, + }, + ops::numeric::empty_device, + reduce_tune_ops, + tensor::WgpuTensor, +}; + +use super::ReduceAutotuneKey; + +/// Set of mean_dim implementations available for autotune +/// Autotune key is given by concatenating the closest upper power of 2 of +/// dim to reduce, and product of others +pub struct MeanDimAutotuneOperationSet { + key: WgpuAutotuneKey, + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, +} +impl MeanDimAutotuneOperationSet { + fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { + Self { + key: WgpuAutotuneKey::MeanDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), + input, + output, + reduce_dim, + } + } +} + +impl AutotuneOperationSet + for MeanDimAutotuneOperationSet +{ + fn key(&self) -> WgpuAutotuneKey { + self.key.clone() + } + + fn autotunables(&self) -> Vec> { + let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); + let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); + + let output = empty_device( + self.output.client.clone(), + self.output.device.clone(), + self.output.shape.clone(), + ); + + vec![ + Box::new(MeanDimAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + Box::new(MeanDimSharedMemoryAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + // Warning: since AutotuneOperationSet shares his key with SumDimAutotuneOperationSet + // we must make sure the order here is correlated with SumDim + match fastest_index { + 0 => Box::new(MeanDimAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + 1 => Box::new(MeanDimSharedMemoryAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + _ => panic!("Fastest index is out of bound"), + } + } +} + +/// Executes autotune on mean_dim operation +pub fn mean_dim_autotune( + input: WgpuTensor, + reduce_dim: usize, +) -> WgpuTensor { + let client = input.client.clone(); + + let output = init_reduce_output(&input, reduce_dim); + + let operation_set = Box::new(MeanDimAutotuneOperationSet::::new( + input, + output.clone(), + reduce_dim, + )); + + client.execute_autotune(operation_set); + + output +} + +// Probably better on balanced tensor shapes +reduce_tune_ops!(MeanDimAutotune, mean_dim); + +// Probably better on tensors large along reduce dim +reduce_tune_ops!(MeanDimSharedMemoryAutotune, mean_dim_shared_memory); diff --git a/burn-wgpu/src/kernel/reduce/tune/mod.rs b/burn-wgpu/src/kernel/reduce/tune/mod.rs new file mode 100644 index 000000000..fed0dbb8b --- /dev/null +++ b/burn-wgpu/src/kernel/reduce/tune/mod.rs @@ -0,0 +1,9 @@ +mod base; +mod key; +mod mean_dim; +mod sum_dim; + +pub use base::*; +pub use key::*; +pub use mean_dim::*; +pub use sum_dim::*; diff --git a/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs b/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs new file mode 100644 index 000000000..a5831d701 --- /dev/null +++ b/burn-wgpu/src/kernel/reduce/tune/sum_dim.rs @@ -0,0 +1,112 @@ +use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet}; +use burn_tensor::{Element, ElementConversion}; + +use crate::{ + compute::WgpuAutotuneKey, + element::WgpuElement, + kernel::{ + prng::random_like_uniform, + reduce::{init_reduce_output, sum_dim, sum_dim_shared_memory}, + }, + ops::numeric::empty_device, + reduce_tune_ops, + tensor::WgpuTensor, +}; + +use super::ReduceAutotuneKey; + +/// Set of sum_dim implementations available for autotune +/// Autotune key is given by concatenating the closest upper power of 2 of +/// dim to reduce, and product of others +pub struct SumDimAutotuneOperationSet { + key: WgpuAutotuneKey, + input: WgpuTensor, + output: WgpuTensor, + reduce_dim: usize, +} +impl SumDimAutotuneOperationSet { + fn new(input: WgpuTensor, output: WgpuTensor, reduce_dim: usize) -> Self { + Self { + key: WgpuAutotuneKey::SumDim(ReduceAutotuneKey::new(&input.shape, reduce_dim)), + input, + output, + reduce_dim, + } + } +} + +impl AutotuneOperationSet + for SumDimAutotuneOperationSet +{ + fn key(&self) -> WgpuAutotuneKey { + self.key.clone() + } + + fn autotunables(&self) -> Vec> { + let random_bounds: (E, E) = ((-10.0).elem::(), (10.0).elem::()); + let input = random_like_uniform(&self.input, random_bounds.0, random_bounds.1); + + let output = empty_device( + self.output.client.clone(), + self.output.device.clone(), + self.output.shape.clone(), + ); + + vec![ + Box::new(SumDimAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + Box::new(SumDimSharedMemoryAutotune::::new( + input.clone(), + output.clone(), + self.reduce_dim, + )), + ] + } + + fn fastest(self: Box, fastest_index: usize) -> Box { + // Warning: since AutotuneOperationSet shares his key with MeanDimAutotuneOperationSet + // we must make sure the order here is correlated with MeanDim + match fastest_index { + 0 => Box::new(SumDimAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + 1 => Box::new(SumDimSharedMemoryAutotune::::new( + self.input, + self.output, + self.reduce_dim, + )), + _ => panic!("Fastest index is out of bound"), + } + } +} + +/// Executes autotune on sum_dim operation +pub fn sum_dim_autotune( + input: WgpuTensor, + reduce_dim: usize, +) -> WgpuTensor { + let client = input.client.clone(); + + let output = init_reduce_output(&input, reduce_dim); + + let operation_set = Box::new(SumDimAutotuneOperationSet::::new( + input, + output.clone(), + reduce_dim, + )); + + client.execute_autotune(operation_set); + + output +} + +// Probably better on balanced tensor shapes +reduce_tune_ops!(SumDimAutotune, sum_dim); + +// Probably better on tensors large along reduce dim +reduce_tune_ops!(SumDimSharedMemoryAutotune, sum_dim_shared_memory); diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs index 0ea2eaeff..4dbd3c722 100644 --- a/burn-wgpu/src/ops/float_ops.rs +++ b/burn-wgpu/src/ops/float_ops.rs @@ -6,8 +6,11 @@ use crate::kernel::matmul::matmul_autotune; #[cfg(not(feature = "autotune"))] use crate::kernel::matmul::vec4::matmul_tiling_2d_vec4; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; +#[cfg(not(feature = "autotune"))] +use crate::kernel::reduce::init_reduce_output; use crate::kernel::{ - self, unary_default, unary_inplace_default, unary_scalar_default, unary_scalar_inplace_default, + self, reduce, unary_default, unary_inplace_default, unary_scalar_default, + unary_scalar_inplace_default, }; use crate::{unary, unary_inplace, unary_scalar, FloatElement, GraphicsApi, IntElement, Wgpu}; use crate::{unary_scalar_inplace, WgpuDevice}; @@ -311,15 +314,33 @@ where } fn sum(tensor: FloatTensor) -> FloatTensor { - kernel::sum(tensor) + reduce::sum(tensor) } fn sum_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - kernel::sum_dim(tensor, dim) + #[cfg(feature = "autotune")] + { + reduce::sum_dim_autotune(tensor, dim) + } + + #[cfg(not(feature = "autotune"))] + { + let output = init_reduce_output(&tensor, dim); + reduce::sum_dim(tensor, output, dim) + } } fn mean_dim(tensor: FloatTensor, dim: usize) -> FloatTensor { - kernel::mean_dim(tensor, dim) + #[cfg(feature = "autotune")] + { + reduce::mean_dim_autotune(tensor, dim) + } + + #[cfg(not(feature = "autotune"))] + { + let output = init_reduce_output(&tensor, dim); + reduce::mean_dim(tensor, output, dim) + } } fn to_full_precision( @@ -457,11 +478,11 @@ where } fn argmax(tensor: FloatTensor, dim: usize) -> IntTensor { - kernel::argmax(tensor, dim) + reduce::argmax(tensor, dim) } fn argmin(tensor: FloatTensor, dim: usize) -> IntTensor { - kernel::argmin(tensor, dim) + reduce::argmin(tensor, dim) } fn into_int(tensor: FloatTensor) -> IntTensor { diff --git a/burn-wgpu/src/ops/int_ops.rs b/burn-wgpu/src/ops/int_ops.rs index 3f28017c2..bbef2dd6a 100644 --- a/burn-wgpu/src/ops/int_ops.rs +++ b/burn-wgpu/src/ops/int_ops.rs @@ -1,4 +1,6 @@ use super::numeric; + +use crate::kernel::reduce::{self, init_reduce_output}; use crate::kernel::{unary_default, unary_inplace_default}; use crate::{ element::{FloatElement, IntElement}, @@ -257,23 +259,25 @@ where } fn int_sum(tensor: IntTensor) -> IntTensor { - kernel::sum(tensor) + kernel::reduce::sum(tensor) } fn int_sum_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::sum_dim(tensor, dim) + let output = init_reduce_output(&tensor, dim); + reduce::sum_dim(tensor, output, dim) } fn int_mean_dim(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::mean_dim(tensor, dim) + let output = init_reduce_output(&tensor, dim); + reduce::mean_dim(tensor, output, dim) } fn int_argmax(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::argmax(tensor, dim) + kernel::reduce::argmax(tensor, dim) } fn int_argmin(tensor: IntTensor, dim: usize) -> IntTensor { - kernel::argmin(tensor, dim) + kernel::reduce::argmin(tensor, dim) } fn int_clamp_min( diff --git a/burn-wgpu/src/template/reduction/reduce_dim_shared_memory.wgsl b/burn-wgpu/src/template/reduction/reduce_dim_shared_memory.wgsl new file mode 100644 index 000000000..39e1c648d --- /dev/null +++ b/burn-wgpu/src/template/reduction/reduce_dim_shared_memory.wgsl @@ -0,0 +1,92 @@ +@group(0) +@binding(0) +var input: array<{{ elem }}>; + +@group(0) +@binding(1) +var output: array<{{ elem }}>; + +@group(0) +@binding(2) +var info: array; + +var shared_memory: array<{{ elem }}, {{ shared_size }}>; + +const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u; + +@compute +@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) +fn main( + @builtin(local_invocation_id) local_invocation_id: vec3, + @builtin(num_workgroups) num_workgroups: vec3, + @builtin(workgroup_id) workgroup_id: vec3, +) { + let workgroup_size_x = {{ workgroup_size_x }}u; + let workgroup_size_y = {{ workgroup_size_y }}u; + + // To determine which reduce_group (not position, but absolute id) + let reduce_group_id = workgroup_id.y * num_workgroups.x + workgroup_id.x; + + // nth thread in the workgroup + let local_id = local_invocation_id.y * workgroup_size_x + local_invocation_id.x; + + // rank of the tensors + let rank: u32 = info[0]; + // dimension on which to reduce (in 0..rank) + let dim_reduce = info[4u * rank + 1u]; + // threads are responsible of how many inputs in one reduce_group + let n_input_values_per_thread = info[4u * rank + 2u]; + + let stride_input_dim_reduce = info[dim_reduce + 1u]; + let shape_input_dim_reduce = info[dim_reduce + 1u + 2u * rank]; + var n_threads = workgroup_size_x * workgroup_size_y; + + var index_offset: u32 = 0u; + + for (var i: u32 = 0u; i < rank; i++) { + let stride_input = info[i + 1u]; + let stride_output = info[i + 1u + rank]; + let shape_output = info[i + 1u + 3u * rank]; + + let num_block = reduce_group_id / stride_output % shape_output; + index_offset += num_block * stride_input; + } + + // Ensure shared memory starts at 0 + shared_memory[local_id] = {{ elem }}(0); + + for (var i = 0u; i < n_input_values_per_thread; i++) { + let nth = local_id + i * n_threads; + if nth < shape_input_dim_reduce { + let current_position = index_offset + nth * stride_input_dim_reduce; + let value = input[current_position]; + + {{ update }} + } + } + + workgroupBarrier(); + + let reduce_factor = 2u; + while n_threads > 1u { + n_threads /= reduce_factor; + + if local_id < n_threads { + for (var i = 1u; i < reduce_factor; i++) { + let read_position = local_id + i * n_threads; + let value = shared_memory[read_position]; + + {{ update }} + } + } + + workgroupBarrier(); + } + + if local_id == 0u { + let output_position = reduce_group_id; + let final_value = shared_memory[0u]; + + {{ assign }} + } +}