diff --git a/Cargo.lock b/Cargo.lock index 9c9172817..9839a3497 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -305,6 +305,7 @@ dependencies = [ "derive-new", "dirs 5.0.1", "github-device-flow", + "half", "indicatif", "os_info", "percent-encoding", diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index aab974ad8..14e1512c4 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -11,11 +11,12 @@ version.workspace = true [features] # we depend on wgpu and autotune by default because we use the burn-wgpu crate to get system information -default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"] +candle-accelerate = ["burn/candle", "burn/accelerate"] candle-cpu = ["burn/candle"] candle-cuda = ["burn/candle-cuda"] candle-metal = ["burn/candle", "burn/metal"] -candle-accelerate = ["burn/candle", "burn/accelerate"] +cuda-jit = ["burn/cuda-jit"] +default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"] ndarray = ["burn/ndarray"] ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"] ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"] @@ -24,7 +25,6 @@ tch-cpu = ["burn/tch"] tch-gpu = ["burn/tch"] wgpu = ["burn/wgpu", "burn/autotune"] wgpu-fusion = ["wgpu", "burn/fusion"] -cuda-jit = ["burn/cuda-jit"] [dependencies] arboard = { workspace = true } @@ -33,11 +33,13 @@ burn-common = { path = "../crates/burn-common", version = "0.15.0" } burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0", optional = true } clap = { workspace = true } colored = { workspace = true } +cubecl = { workspace = true, features = ["wgpu"] } derive-new = { workspace = true } dirs = { workspace = true } github-device-flow = { workspace = true } -os_info = { workspace = true } +half = { workspace = true } indicatif = { workspace = true } +os_info = { workspace = true } percent-encoding = { workspace = true } rand = { workspace = true } reqwest = { workspace = true, features = ["blocking", "json"] } @@ -48,69 +50,68 @@ strum_macros = { workspace = true } sysinfo = { workspace = true, features = ["serde"] } wgpu = { workspace = true } wsl = { workspace = true } -cubecl = { workspace = true, features = ["wgpu"] } [dev-dependencies] rstest = { workspace = true } serial_test = { workspace = true } [[bench]] +harness = false name = "unary" -harness = false [[bench]] +harness = false name = "binary" -harness = false [[bench]] +harness = false name = "max-pool2d" path = "benches/max_pool2d.rs" -harness = false [[bench]] +harness = false name = "conv-transpose2d" path = "benches/conv_transpose2d.rs" -harness = false [[bench]] +harness = false name = "conv-transpose3d" path = "benches/conv_transpose3d.rs" -harness = false [[bench]] +harness = false name = "conv2d" -harness = false [[bench]] +harness = false name = "conv3d" -harness = false [[bench]] +harness = false name = "matmul" -harness = false [[bench]] +harness = false name = "data" -harness = false [[bench]] -name = "load-record" harness = false +name = "load-record" path = "benches/load_record.rs" [[bench]] +harness = false name = "custom-gelu" path = "benches/custom_gelu.rs" -harness = false [[bench]] +harness = false name = "resnet50" path = "benches/resnet.rs" -harness = false [[bench]] -name = "autodiff" harness = false +name = "autodiff" [[bin]] name = "burnbench" diff --git a/backend-comparison/build.rs b/backend-comparison/build.rs index 7ecb328e4..f25bcd5a6 100644 --- a/backend-comparison/build.rs +++ b/backend-comparison/build.rs @@ -3,7 +3,6 @@ use std::fs; use std::path::Path; use std::process::Command; -const MODELS_DIR: &str = "/tmp/models"; const MODELS_REPO: &str = "https://github.com/tracel-ai/models.git"; // Patch resnet code (remove pretrained feature code) @@ -224,8 +223,10 @@ where } fn main() { + let models_dir = std::env::temp_dir().join("models"); + let models_dir = models_dir.as_path(); // Checkout ResNet code from models repo - let models_dir = Path::new(MODELS_DIR); + let models_dir = Path::new(models_dir); if !models_dir.join(".git").exists() { run("git", |command| { command @@ -233,7 +234,7 @@ fn main() { .arg("--depth=1") .arg("--no-checkout") .arg(MODELS_REPO) - .arg(MODELS_DIR) + .arg(models_dir) }); run("git", |command| { @@ -266,10 +267,12 @@ fn main() { let source_path = models_dir.join("resnet-burn").join("resnet").join("src"); let dest_path = Path::new(&out_dir); - for file in fs::read_dir(source_path).unwrap() { - let source_file = file.unwrap().path(); - let dest_file = dest_path.join(source_file.file_name().unwrap()); - fs::copy(source_file, dest_file).expect("should copy file successfully"); + if let Ok(source_path) = fs::read_dir(source_path) { + for file in source_path { + let source_file = file.unwrap().path(); + let dest_file = dest_path.join(source_file.file_name().unwrap()); + fs::copy(source_file, dest_file).expect("should copy file successfully"); + } } // Delete cloned repository contents diff --git a/backend-comparison/src/lib.rs b/backend-comparison/src/lib.rs index 0f371fcc0..5f9fb23fa 100644 --- a/backend-comparison/src/lib.rs +++ b/backend-comparison/src/lib.rs @@ -131,7 +131,7 @@ macro_rules! bench_on_backend { { use burn::backend::cuda_jit::{Cuda, CudaDevice}; - bench::(&CudaDevice::default(), feature_name, url, token); + bench::>(&CudaDevice::default(), feature_name, url, token); } }; } diff --git a/crates/burn-jit/Cargo.toml b/crates/burn-jit/Cargo.toml index 42cb1b442..c7538e1d9 100644 --- a/crates/burn-jit/Cargo.toml +++ b/crates/burn-jit/Cargo.toml @@ -11,12 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-jit" version.workspace = true [features] -default = ["autotune", "std", "fusion", "cubecl/default"] -std = ["cubecl/std"] -doc = ["default"] autotune = [] -template = [] -fusion = ["burn-fusion"] +default = ["autotune", "std", "fusion", "cubecl/default"] +doc = ["default"] export_tests = [ "burn-tensor-testgen", "serial_test", @@ -25,32 +22,37 @@ export_tests = [ "burn-ndarray", "fusion", ] +fusion = ["burn-fusion"] +std = ["cubecl/std"] +template = [] [dependencies] -cubecl = { workspace = true, features = ["linalg"] } burn-common = { path = "../burn-common", version = "0.15.0" } -burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = ["cubecl"] } burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true } +burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [ + "cubecl", +] } +cubecl = { workspace = true, features = ["linalg"] } bytemuck = { workspace = true } derive-new = { workspace = true } +half = { workspace = true, features = ["bytemuck"] } log = { workspace = true } num-traits = { workspace = true } rand = { workspace = true } spin = { workspace = true } -half = { workspace = true, features = ["bytemuck"] } # Template serde = { workspace = true } text_placeholder = { workspace = true, features = ["struct_context"] } -hashbrown = { workspace = true } burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.15.0", optional = true } +hashbrown = { workspace = true } # When exporting tests -serial_test = { workspace = true, optional = true } burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", default-features = false, optional = true } burn-ndarray = { path = "../burn-ndarray", version = "0.15.0", optional = true } +serial_test = { workspace = true, optional = true } [package.metadata.docs.rs] features = ["doc"] diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index dc416d9ec..9e99f8fe0 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -511,14 +511,14 @@ impl TraceBuilder { &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), - Operator::AtomicCompareAndSwap(_op) => { - // Nothing to do. - } Operator::Magnitude(op) => mark_unary( op, &mut local_tensor_ids_input, &mut local_tensor_ids_output, ), + Operator::AtomicCompareAndSwap(_op) => { + // Nothing to do. + } }, Operation::Procedure(proc) => { match proc { diff --git a/crates/burn-jit/src/kernel/conv/conv2d/base.rs b/crates/burn-jit/src/kernel/conv/conv2d/base.rs new file mode 100644 index 000000000..90e8adf5d --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/base.rs @@ -0,0 +1,125 @@ +use burn_tensor::{ + ops::{ConvOptions, ConvTransposeOptions}, + TensorData, +}; + +use crate::{tensor::JitTensor, FloatElement, IntElement, JitElement, JitRuntime}; + +#[cfg(feature = "autotune")] +use super::conv2d_autotune; +use super::{ + conv2d_direct, conv2d_im2col, conv_transpose2d_autotune, conv_transpose2d_col2im, + conv_transpose2d_direct, implicit_gemm::conv2d_implicit_gemm, +}; + +/// The strategy to be used when launching a convolution kernel. +pub enum Conv2dStrategy { + /// A simple direct convolution. + Direct, + #[cfg(feature = "autotune")] + /// Using autotune to choose the best kernel based on runtime information. + Autotune, + /// GEMM (im2col) based implementation of convolution. Significantly increased memory usage. + Gemm, + /// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and + /// has constraints on tensor shape. + ImplicitGemm, +} + +impl Default for Conv2dStrategy { + fn default() -> Self { + // if autotune is enabled, default to autotune + #[cfg(feature = "autotune")] + return Conv2dStrategy::Autotune; + + // if autotune is disabled, default to the more memory-conservative algorithm + #[cfg(not(feature = "autotune"))] + Conv2dStrategy::Direct + } +} + +/// The strategy to be used when launching a conv_transpose kernel. +pub enum ConvTranspose2dStrategy { + /// A simple direct convolution. + Direct, + #[cfg(feature = "autotune")] + /// Using autotune to choose the best kernel based on runtime information. + Autotune, + /// GEMM (im2col) based implementation of convolution. Significantly increased memory usage. + Gemm, +} + +impl Default for ConvTranspose2dStrategy { + fn default() -> Self { + // if autotune is enabled, default to autotune + #[cfg(feature = "autotune")] + return ConvTranspose2dStrategy::Autotune; + + // if autotune is disabled, default to the more memory-conservative algorithm + #[cfg(not(feature = "autotune"))] + ConvTranspose2dStrategy::Direct + } +} + +/// Perform a 2D convolution with the given strategy +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. +/// +pub fn conv2d( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, + strategy: Conv2dStrategy, +) -> JitTensor { + match strategy { + Conv2dStrategy::Direct => conv2d_direct::(input, weight, bias, options), + #[cfg(feature = "autotune")] + Conv2dStrategy::Autotune => conv2d_autotune::(input, weight, bias, options), + Conv2dStrategy::Gemm => conv2d_im2col::(input, weight, bias, options), + Conv2dStrategy::ImplicitGemm => { + conv2d_implicit_gemm::(input, weight, bias, options) + } + } +} + +/// Perform a 2D convolution with the given strategy +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option. +/// +pub fn conv_transpose2d( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + strategy: ConvTranspose2dStrategy, +) -> JitTensor { + match strategy { + ConvTranspose2dStrategy::Direct => { + conv_transpose2d_direct::(input, weight, bias, options) + } + #[cfg(feature = "autotune")] + ConvTranspose2dStrategy::Autotune => { + conv_transpose2d_autotune::(input, weight, bias, options) + } + ConvTranspose2dStrategy::Gemm => { + conv_transpose2d_col2im::(input, weight, bias, options) + } + } +} + +#[allow(unused)] +pub(crate) fn debug_data( + tensor: JitTensor, +) -> TensorData { + let bytes = tensor.client.read(tensor.handle.binding()); + TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape) +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs new file mode 100644 index 000000000..7b1181341 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/col2im.rs @@ -0,0 +1,297 @@ +use burn_tensor::{ + ops::{conv::calculate_conv_transpose_output_size, ConvTransposeOptions, FloatTensorOps as _}, + Shape, +}; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; + +use crate::{ + kernel::into_contiguous, + ops::{numeric::empty_device, reshape, swap_dims}, + tensor::JitTensor, + FloatElement, IntElement, JitBackend, JitRuntime, +}; + +use super::batches_per_run; + +/// Perform a 2D convolution transposition using the GEMM (col2im) algorithm. +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// +pub fn conv_transpose2d_col2im( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvTransposeOptions<2>, +) -> JitTensor { + let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims; + let [batch_size, _, input_h, input_w] = input.shape.dims; + let groups = options.groups; + let input_ch_per_group = input_channels / groups; + let ConvTransposeOptions { + padding: [padding_h, padding_w], + padding_out: [padding_out_h, padding_out_w], + dilation: [dilation_h, dilation_w], + stride: [stride_h, stride_w], + .. + } = options.clone(); + + let im_h = calculate_conv_transpose_output_size( + kernel_h, + stride_h, + padding_h, + padding_out_h, + dilation_h, + input_h, + ); + let im_w = calculate_conv_transpose_output_size( + kernel_w, + stride_w, + padding_w, + padding_out_w, + dilation_w, + input_w, + ); + let im_channels = im_ch_per_group * groups; + + let batches_per_run = batches_per_run(batch_size, input_h, input_w); + let col_shape_0 = im_ch_per_group * kernel_h * kernel_w; + + let weight = reshape( + weight.clone(), + Shape::new([groups, input_ch_per_group, col_shape_0]), + ); + let weight = into_contiguous(swap_dims(weight, 1, 2)); + + if batches_per_run != batch_size { + let runs = batch_size / batches_per_run; + + let im_shape = Shape::new([runs, batches_per_run, im_channels, im_h, im_w]); + let mut image = empty_device(input.client.clone(), input.device.clone(), im_shape); + + let input_shape = Shape::new([runs, batches_per_run, input_channels, input_h, input_w]); + let input = reshape(input, input_shape); + let input_shape_run = Shape::new([batches_per_run, input_channels, input_h, input_w]); + let im_shape_run = Shape::new([1, batches_per_run, im_channels, im_h, im_w]); + + for run in 0..runs { + let input = JitBackend::::float_narrow(input.clone(), 0, run, 1); + let input = reshape(input, input_shape_run.clone()); + let image_run = execute::( + input, + weight.clone(), + bias.clone(), + options.clone(), + im_ch_per_group, + im_h, + im_w, + kernel_h, + kernel_w, + ); + let image_run = reshape(image_run, im_shape_run.clone()); + image = JitBackend::::float_slice_assign( + image, + [ + run..run + 1, + 0..batches_per_run, + 0..im_channels, + 0..im_h, + 0..im_w, + ], + image_run, + ) + } + reshape(image, Shape::new([batch_size, im_channels, im_h, im_w])) + } else { + execute::( + input, + weight, + bias, + options, + im_ch_per_group, + im_h, + im_w, + kernel_h, + kernel_w, + ) + } +} + +#[allow(clippy::too_many_arguments)] +fn execute( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvTransposeOptions<2>, + im_ch_per_group: usize, + im_h: usize, + im_w: usize, + kernel_h: usize, + kernel_w: usize, +) -> JitTensor { + let [batch_size, _, input_h, input_w] = input.shape.dims; + let [groups, col_shape_0, input_ch_per_group] = weight.shape.dims; + + let im_channels = im_ch_per_group * groups; + + let im_shape = Shape::new([batch_size, im_channels, im_h, im_w]); + + let col_shape_1 = batch_size * input_h * input_w; + + let input = swap_dims(input, 0, 1); + let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]); + let input = reshape(input, input_shape); + + let columns = JitBackend::::float_matmul(weight, input); + let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); + + col2im( + columns, bias, im_shape, kernel_h, kernel_w, input_h, input_w, options, + ) +} + +#[allow(clippy::too_many_arguments)] +fn col2im( + columns: JitTensor, + bias: Option>, + im_shape: Shape<4>, + kernel_h: usize, + kernel_w: usize, + out_h: usize, + out_w: usize, + options: ConvTransposeOptions<2>, +) -> JitTensor { + let [_, col_size_1] = columns.shape.dims; + + let columns = into_contiguous(columns); + let has_bias = bias.is_some(); + let bias = bias.map(into_contiguous).unwrap_or_else(|| { + empty_device( + columns.client.clone(), + columns.device.clone(), + Shape::new([1]), + ) + }); + + let num_elems = im_shape.num_elements(); + let out = empty_device( + columns.client.clone(), + columns.device.clone(), + im_shape.clone(), + ); + + let vectorization = 1; + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); + + unsafe { + col2im_kernel::launch_unchecked::( + &columns.client, + cube_count, + cube_dim, + columns.as_tensor_arg(vectorization), + bias.as_tensor_arg(vectorization), + out.as_tensor_arg(vectorization), + Col2ImArgsLaunch::new( + ScalarArg::new(out_h as u32), + ScalarArg::new(out_w as u32), + ScalarArg::new(kernel_h as u32), + ScalarArg::new(kernel_w as u32), + ScalarArg::new(options.padding[0] as u32), + ScalarArg::new(options.padding[1] as u32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(col_size_1 as u32), + ), + has_bias, + ) + }; + + out +} + +#[derive(CubeLaunch)] +struct Col2ImArgs { + out_h: u32, + out_w: u32, + + kernel_h: u32, + kernel_w: u32, + + pad_h: u32, + pad_w: u32, + dilation_h: u32, + dilation_w: u32, + stride_h: u32, + stride_w: u32, + + col_size_1: u32, +} + +#[cube(launch_unchecked)] +fn col2im_kernel( + columns: &Tensor, + bias: &Tensor, + image: &mut Tensor, + args: &Col2ImArgs, + #[comptime] has_bias: bool, +) { + if ABSOLUTE_POS > image.len() { + return; + } + + let _ = bias[0]; // Keep in bind group + + let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w; + let im_y = ABSOLUTE_POS / image.stride(2) % image.shape(2) + args.pad_h; + let ch_im = ABSOLUTE_POS / image.stride(1) % image.shape(1); + let batch = ABSOLUTE_POS / image.stride(0); + + let kernel_extent_w = (args.kernel_w - 1) * args.dilation_w + 1; + let kernel_extent_h = (args.kernel_h - 1) * args.dilation_h + 1; + + let mut val = F::new(0.0); + + let x_col_start = if im_x >= kernel_extent_w { + (im_x - kernel_extent_w) / args.stride_w + 1 + } else { + 0u32 + }; + let x_col_end = Min::min(im_x / args.stride_w + 1, args.out_w); + let y_col_start = if im_y >= kernel_extent_h { + (im_y - kernel_extent_h) / args.stride_h + 1 + } else { + 0u32 + }; + let y_col_end = Min::min(im_y / args.stride_h + 1, args.out_h); + + for col_y in y_col_start..y_col_end { + let kernel_y = im_y - col_y * args.stride_h; + for col_x in x_col_start..x_col_end { + let kernel_x = im_x - col_x * args.stride_w; + + if kernel_y % args.dilation_h == 0 && kernel_x % args.dilation_w == 0 { + let kernel_y = kernel_y / args.dilation_h; + let kernel_x = kernel_x / args.dilation_w; + + let col_pos = ch_im * args.kernel_h * args.kernel_w * args.col_size_1 + + kernel_y * args.kernel_w * args.col_size_1 + + kernel_x * args.col_size_1 + + batch * args.out_h * args.out_w + + col_y * args.out_w + + col_x; + val += columns[col_pos]; + } + } + } + + if has_bias { + image[ABSOLUTE_POS] = val + bias[ch_im]; + } else { + image[ABSOLUTE_POS] = val; + } +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs similarity index 82% rename from crates/burn-jit/src/kernel/conv/conv2d.rs rename to crates/burn-jit/src/kernel/conv/conv2d/direct.rs index c984008c1..ff8653554 100644 --- a/crates/burn-jit/src/kernel/conv/conv2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/direct.rs @@ -1,9 +1,8 @@ -use cubecl::{calculate_cube_count_elemwise, prelude::*}; - use burn_tensor::{ ops::{conv::calculate_conv_output_size, ConvOptions}, Shape, }; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; use crate::{ kernel::into_contiguous, @@ -12,7 +11,7 @@ use crate::{ reshape, }, tensor::JitTensor, - FloatElement, JitRuntime, + FloatElement, IntElement, JitRuntime, }; #[derive(CubeLaunch)] @@ -23,11 +22,11 @@ struct Conv2dArgs { dilation_1: u32, padding_0: u32, padding_1: u32, - groups: u32, + channels_per_group: u32, } #[cube(launch)] -fn conv2d_kernel( +fn direct_conv2d_kernel( input: &Tensor, weight: &Tensor, bias: &Tensor, @@ -50,7 +49,7 @@ fn conv2d_kernel( let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2); let ow = ABSOLUTE_POS / output.stride(3) % output.shape(3); - let g = (weight.shape(0) + oc) % args.groups; + let g = oc / args.channels_per_group; let ic_start = in_channels * g; let ic_end = ic_start + in_channels; let mut sum = bias[oc]; @@ -114,41 +113,46 @@ fn conv2d_kernel( output[ABSOLUTE_POS] = sum; } -pub(crate) fn conv2d( +/// Perform a 2D convolution using the direct convolution algorithm. +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// +#[allow(clippy::extra_unused_type_parameters)] +pub fn conv2d_direct( input: JitTensor, weight: JitTensor, bias: Option>, options: ConvOptions<2>, ) -> JitTensor { - let input = into_contiguous(input); - let weight = into_contiguous(weight); let [batch_size, _, in_height, in_width] = input.shape.dims; - let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims; + let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims; + let channels_per_group = out_channels / options.groups; // Limit loop unrolling factor to 8 or smaller - let kernel_1_unroll = if kernel_1 > 8 { - None - } else { - Some(kernel_1 as u32) - }; + let kernel_w_unroll = (kernel_w <= 8).then_some(kernel_w as u32); - let out_0 = calculate_conv_output_size( - kernel_0, + let out_h = calculate_conv_output_size( + kernel_h, options.stride[0], options.padding[0], options.dilation[0], in_height, ); - let out_1 = calculate_conv_output_size( - kernel_1, + let out_w = calculate_conv_output_size( + kernel_w, options.stride[1], options.padding[1], options.dilation[1], in_width, ); - let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]); + let input = into_contiguous(input); + let weight = into_contiguous(weight); + let shape_out = Shape::new([batch_size, out_channels, out_h, out_w]); let output = empty_device( input.client.clone(), input.device.clone(), @@ -170,7 +174,7 @@ pub(crate) fn conv2d( let cube_dim = CubeDim::default(); let cube_count = calculate_cube_count_elemwise(num_elems_output, cube_dim); - conv2d_kernel::launch::( + direct_conv2d_kernel::launch::( &input.client, cube_count, cube_dim, @@ -185,9 +189,9 @@ pub(crate) fn conv2d( ScalarArg::new(options.dilation[1] as u32), ScalarArg::new(options.padding[0] as u32), ScalarArg::new(options.padding[1] as u32), - ScalarArg::new(options.groups as u32), + ScalarArg::new(channels_per_group as u32), ), - kernel_1_unroll, + kernel_w_unroll, ); output diff --git a/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs new file mode 100644 index 000000000..19b7bcf1d --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/im2col.rs @@ -0,0 +1,297 @@ +use burn_tensor::{ + ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps as _}, + Shape, +}; +use cubecl::{calculate_cube_count_elemwise, prelude::*}; + +use crate::{ + kernel::into_contiguous, + ops::{numeric::empty_device, reshape, swap_dims}, + tensor::JitTensor, + FloatElement, IntElement, JitBackend, JitRuntime, +}; + +#[derive(CubeLaunch)] +struct Im2ColArgs { + stride_h: u32, + stride_w: u32, + dilation_h: u32, + dilation_w: u32, + padding_h: u32, + padding_w: u32, + + kernel_h: u32, + kernel_w: u32, + out_h: u32, + out_w: u32, + + col_size_1: u32, + num_elements: u32, +} + +#[cube(launch_unchecked)] +fn im2col_kernel( + image: &Tensor, + columns: &mut Tensor, + args: &Im2ColArgs, + #[comptime] kernel_w_unroll: Option, + #[comptime] has_padding: bool, +) { + // position shape: [in_channels, batch_size, out_h, out_w] + // columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] + + let batch_size = image.shape(0); + let height = image.shape(2); + let width = image.shape(3); + + let out_h = args.out_h; + let out_w = args.out_w; + + if ABSOLUTE_POS > args.num_elements { + return; + } + + let out_x = ABSOLUTE_POS % out_w; + let out_y = ABSOLUTE_POS / out_w % out_h; + let batch = ABSOLUTE_POS / (out_w * out_h) % batch_size; + let channel = ABSOLUTE_POS / (out_w * out_h * batch_size) % image.shape(1); + + let kernel_w = kernel_w_unroll.unwrap_or(args.kernel_w); + let unroll_w = kernel_w_unroll.is_some(); + + let image_idx = batch * image.stride(0) + channel * image.stride(1); + let col_idx = channel * args.kernel_h * kernel_w * args.col_size_1 + + batch * out_h * out_w + + out_y * out_w + + out_x; + + for kernel_y in 0..args.kernel_h { + #[unroll(unroll_w)] + for kernel_x in 0..kernel_w { + let kernel_pos = kernel_y * kernel_w + kernel_x; + let col_pos = col_idx + kernel_pos * args.col_size_1; + + if has_padding { + let y = (out_y * args.stride_h + kernel_y * args.dilation_h) as i32 + - args.padding_h as i32; + let x = (out_x * args.stride_w + kernel_x * args.dilation_w) as i32 + - args.padding_w as i32; + if y >= 0 && x >= 0 && y < height as i32 && x < width as i32 { + let image_ptr = image_idx + y as u32 * width + x as u32; + columns[col_pos] = image[image_ptr]; + } else { + columns[col_pos] = F::new(0.0) + }; + } else { + let y = out_y * args.stride_h + kernel_y * args.dilation_h; + let x = out_x * args.stride_w + kernel_x * args.dilation_w; + let image_ptr = image_idx + y * image.stride(2) + x * image.stride(3); + columns[col_pos] = image[image_ptr]; + } + } + } +} + +#[cfg(not(test))] +pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> usize { + let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::SUBCUBE_DIM_APPROX); + let max_cube_count = u16::MAX as usize; + let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size); + if max_simultaneous == 0 { + panic!("Image too large to run even one batch at once"); + } + (0..=max_simultaneous) + .rev() + .find(|per_run| batch_size % per_run == 0) + .unwrap() +} + +#[cfg(test)] +#[allow(unused)] +pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> usize { + 1 +} + +fn im2col( + input: JitTensor, + options: ConvOptions<2>, + kernel_h: usize, + kernel_w: usize, + out_h: usize, + out_w: usize, +) -> JitTensor { + let input = into_contiguous(input); + let [batch_size, in_channels, _, _] = input.shape.dims; + + let col_shape_0 = in_channels * kernel_h * kernel_w; + let col_shape_1 = batch_size * out_h * out_w; + let shape_col = Shape::new([col_shape_0, col_shape_1]); + let columns = empty_device( + input.client.clone(), + input.device.clone(), + shape_col.clone(), + ); + + let num_elems = in_channels * batch_size * out_h * out_w; + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim); + + let kernel_w_unroll = (kernel_w <= 8).then_some(kernel_w as u32); + + let vectorization = 1; + + unsafe { + im2col_kernel::launch_unchecked::( + &input.client, + cube_count, + cube_dim, + input.as_handle_ref().as_tensor_arg(vectorization), + columns.as_handle_ref().as_tensor_arg(vectorization), + Im2ColArgsLaunch::new( + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(options.padding[0] as u32), + ScalarArg::new(options.padding[1] as u32), + ScalarArg::new(kernel_h as u32), + ScalarArg::new(kernel_w as u32), + ScalarArg::new(out_h as u32), + ScalarArg::new(out_w as u32), + ScalarArg::new(col_shape_1 as u32), + ScalarArg::new(num_elems as u32), + ), + kernel_w_unroll, + options.padding != [0, 0], + ) + }; + + columns +} + +/// Perform a 2D convolution using the GEMM (im2col) algorithm. +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// +pub fn conv2d_im2col( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let [batch_size, in_channels, in_height, in_width] = input.shape.dims; + let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims; + + let out_h = calculate_conv_output_size( + kernel_h, + options.stride[0], + options.padding[0], + options.dilation[0], + in_height, + ); + let out_w = calculate_conv_output_size( + kernel_w, + options.stride[1], + options.padding[1], + options.dilation[1], + in_width, + ); + + if kernel_h == 1 && kernel_w == 1 && in_height == out_h && in_width == out_w { + // Special case for 1x1 kernels (sometimes used to scale the image by a set of weights) + return execute_1x1_kernel::(input, weight, bias, options); + } + + let batches_per_run = batches_per_run(batch_size, out_h, out_w); + + let mut out = if batches_per_run != batch_size { + let runs = batch_size / batches_per_run; + let out_shape = Shape::new([runs, out_channels, batches_per_run, out_h, out_w]); + let mut out = empty_device(input.client.clone(), input.device.clone(), out_shape); + let in_shape = Shape::new([runs, batches_per_run, in_channels, in_height, in_width]); + let input = reshape(input, in_shape); + let in_shape_run = Shape::new([batches_per_run, in_channels, in_height, in_width]); + let out_shape_run = Shape::new([1, out_channels, batches_per_run, out_h, out_w]); + + for run in 0..runs { + let input = JitBackend::::float_narrow(input.clone(), 0, run, 1); + let input = reshape(input, in_shape_run.clone()); + let run_out = execute::(input, weight.clone(), options.clone(), out_h, out_w); + let run_out = reshape(run_out, out_shape_run.clone()); + out = JitBackend::::float_slice_assign( + out, + [ + run..run + 1, + 0..out_channels, + 0..batches_per_run, + 0..out_h, + 0..out_w, + ], + run_out, + ); + } + let out = swap_dims(out, 1, 2); + reshape(out, Shape::new([batch_size, out_channels, out_h, out_w])) + } else { + let out = execute::(input, weight, options, out_h, out_w); + let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); + swap_dims(out, 0, 1) + }; + + if let Some(bias) = bias { + let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); + out = JitBackend::::float_add(out, bias) + } + out +} + +fn execute_1x1_kernel( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let [batch_size, _, height, width] = input.shape.dims; + let [out_channels, in_c_per_grp, _, _] = weight.shape.dims; + let groups = options.groups; + let out_c_per_grp = out_channels / groups; + + let input = swap_dims(input, 0, 1); // [CNHW] + + let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp])); + let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]); + let input = reshape(input, in_shape); + let out = JitBackend::::float_matmul(weight, input); + let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width])); + + if let Some(bias) = bias { + let bias = reshape(bias, Shape::new([out_channels, 1, 1, 1])); + out = JitBackend::::float_add(out, bias) + } + + swap_dims(out, 0, 1) +} + +fn execute( + input: JitTensor, + weight: JitTensor, + options: ConvOptions<2>, + out_h: usize, + out_w: usize, +) -> JitTensor { + let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims; + let groups = options.groups; + + let columns = im2col(input, options.clone(), kernel_h, kernel_w, out_h, out_w); + let [col_shape_0, col_shape_1] = columns.shape.dims; + let col_shape_0 = col_shape_0 / groups; + let out_c_per_group = out_channels / groups; + + let columns = reshape(columns, Shape::new([groups, col_shape_0, col_shape_1])); + let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); + + JitBackend::::float_matmul(weight, columns) +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs new file mode 100644 index 000000000..0dd7d090c --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/implicit_gemm.rs @@ -0,0 +1,562 @@ +use burn_tensor::{ + ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps}, + Shape, +}; +use cmma::{Matrix, MatrixIdent, MatrixLayout}; +use cubecl::{ + cube, + ir::{Elem, FloatKind}, + prelude::*, + Compiler, CubeCount, CubeDim, Feature, +}; +use half::f16; + +use crate::{ + ops::{numeric::empty_device, permute, reshape}, + tensor::JitTensor, + FloatElement, IntElement, JitBackend, JitRuntime, +}; + +/// Perform a 2D convolution using the implicit GEMM algorithm. Requries `cmma` to be available. +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// +pub fn conv2d_implicit_gemm( + input: JitTensor, + weight: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let [batch_size, in_channels, height, width] = input.shape.dims; + let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims; + + let out_h = calculate_conv_output_size( + kernel_h, + options.stride[0], + options.padding[0], + options.dilation[0], + height, + ); + let out_w = calculate_conv_output_size( + kernel_w, + options.stride[1], + options.padding[1], + options.dilation[1], + width, + ); + + if !can_do_implicit_gemm(&input, &weight, &options, out_h, out_w) { + panic!( + "Requirements for implicit GEMM not met: +- CMMA must be available +- `batch_size * out_h * out_w` must be divisible by 16 +- `out_channels` must be divisible by 16 +- `in_channels * kernel_h * kernel_w` must be divisible by 16 +- `groups` must be 1 + " + ); + } + + let out_shape = Shape::new([batch_size, out_h, out_w, out_channels]); + let mut out = empty_device(input.client.clone(), input.device.clone(), out_shape); + + // Implicit GEMM matrix size + let gemm_m = (batch_size * out_h * out_w) as u32; + let gemm_n = out_channels as u32; + let gemm_k = in_channels * kernel_h * kernel_w; + let slice_size = kernel_h * kernel_w * in_channels; + + let cmma_m = 16; + let cmma_n = 16; + let cmma_k = 16; + + let warp_size = 32; + let warps_per_cube = 8; + + let cube_dim_x = 128; + let cube_dim_y = 2; + + assert!(cube_dim_y * cube_dim_x / warp_size == warps_per_cube); + + let settings = GemmSettings { + cmma_m, + cmma_n, + cmma_k, + warp_size, + warps_per_cube, + cube_dim_x, + }; + + // `CUBE_DIM_X` must be a multiple of `WARP_SIZE` + // 128x2 means we have 8 warps and a cube computes a 32x64 output tile + let cube_dim = CubeDim { + x: cube_dim_x, + y: cube_dim_y, + z: 1, + }; + + let cube_count_x = gemm_m.div_ceil(cmma_m * cube_dim_x / warp_size); + let cube_count_y = gemm_n.div_ceil(cmma_n * cube_dim_y); + + // If div floor == div ceil then the cubes are aligned with the input dimensions + let aligned = gemm_m / (cmma_m * cube_dim_x / warp_size) == cube_count_x + && gemm_n / (cmma_n * cube_dim_y) == cube_count_y; + + let cube_count = CubeCount::Static(cube_count_x, cube_count_y, 1); + + unsafe { + implicit_gemm_kernel::launch_unchecked::( + &input.client, + cube_count, + cube_dim, + input.as_tensor_arg(1), + weight.as_tensor_arg(1), + out.as_tensor_arg(1), + DimensionsLaunch::new( + ScalarArg::new(gemm_m), + ScalarArg::new(gemm_n), + ScalarArg::new(gemm_k as u32), + ScalarArg::new(slice_size as u32), + ScalarArg::new(out_h as u32), + ScalarArg::new(out_w as u32), + ), + ConvArgsLaunch::new( + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(options.padding[0] as i32), + ScalarArg::new(options.padding[1] as i32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ), + settings, + KernelSettings { + kernel_h: kernel_h as u32, + kernel_w: kernel_w as u32, + has_padding: options.padding != [0, 0], + aligned, + }, + ) + }; + + if let Some(bias) = bias { + let bias = reshape(bias, Shape::new([1, 1, 1, out_channels])); + out = JitBackend::::float_add(out, bias); + } + + permute(out, [0, 3, 1, 2]) +} + +#[derive(CubeLaunch)] +struct ConvArgs { + stride_h: u32, + stride_w: u32, + pad_h: i32, + pad_w: i32, + dilation_h: u32, + dilation_w: u32, +} + +#[derive(CubeLaunch)] +struct Dimensions { + gemm_m: u32, + gemm_n: u32, + gemm_k: u32, + slice_size: u32, + + out_h: u32, + out_w: u32, +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +struct GemmSettings { + cmma_m: u32, + cmma_n: u32, + cmma_k: u32, + + warp_size: u32, + warps_per_cube: u32, + + cube_dim_x: u32, +} + +#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)] +struct KernelSettings { + kernel_h: u32, + kernel_w: u32, + has_padding: bool, + aligned: bool, +} + +#[derive(Clone, Copy, CubeType)] +struct Positions { + global_m: u32, + global_n: u32, + + intra_warp_unit_idx: u32, + cube_linear_warp_idx: u32, +} + +#[derive(CubeType)] +struct Matrices { + a: Matrix, + b: Matrix, + acc: Matrix, +} + +#[allow(clippy::collapsible_else_if)] +#[cube(launch_unchecked)] +fn implicit_gemm_kernel( + input: &Tensor, + weight: &Tensor, + out: &mut Tensor, + dims: &Dimensions, + args: &ConvArgs, + #[comptime] gemm_settings: GemmSettings, + #[comptime] kernel_settings: KernelSettings, +) { + let GemmSettings { + cmma_m, + cmma_n, + cmma_k, + warps_per_cube, + .. + } = gemm_settings; + + let cmma_out_tile_size = cmma_m * cmma_n; + let cmma_input_tile_size = cmma_m * cmma_k; + let cmma_filter_tile_size = cmma_k * cmma_n; + + let pos = calculate_positions(gemm_settings); + + // Shared memory tiles, currently only holds enough data for + // each warp to have its own tile for a single MMA op (8 * 16 * 16 elements) + // conceptually a WARPS_PER_CUBE x (CMMA_M * CMMA_K) matrix + let mut smem_input_tile = SharedMemory::::new(cmma_input_tile_size * warps_per_cube); + let mut smem_weight_tile = SharedMemory::::new(cmma_filter_tile_size * warps_per_cube); + + let input_tile_start = pos.cube_linear_warp_idx * cmma_input_tile_size; + let weight_tile_start = pos.cube_linear_warp_idx * cmma_filter_tile_size; + let input_tile = + smem_input_tile.slice_mut(input_tile_start, input_tile_start + cmma_input_tile_size); + let weight_tile = + smem_weight_tile.slice_mut(weight_tile_start, weight_tile_start + cmma_filter_tile_size); + + let out_pos = pos.global_n + pos.global_m * dims.gemm_n; + let out = out.slice_mut(out_pos, out_pos + cmma_out_tile_size); + + if kernel_settings.aligned { + execute_gemm( + input, + weight, + out, + input_tile, + weight_tile, + dims, + &pos, + args, + gemm_settings, + kernel_settings, + ); + } else { + if pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n { + execute_gemm( + input, + weight, + out, + input_tile, + weight_tile, + dims, + &pos, + args, + gemm_settings, + kernel_settings, + ); + } + } +} + +#[cube] +fn calculate_positions(#[comptime] gemm_settings: GemmSettings) -> Positions { + let GemmSettings { + cmma_m, + cmma_n, + warp_size, + cube_dim_x, + .. + } = gemm_settings; + + // Tile using a 2D grid (over the output), each threadblock + // is (128, 2) -> (4,2) = 8 warps -> 32x64 output + let global_warp_m = ABSOLUTE_POS_X / warp_size; + let global_warp_n = ABSOLUTE_POS_Y; + let cube_warp_m = UNIT_POS_X / warp_size; + let cube_warp_n = UNIT_POS_Y; + let num_warps_m = cube_dim_x / warp_size; + let intra_warp_unit_idx = UNIT_POS_X % warp_size; // Thread idx within warp (0 to 31) + let cube_linear_warp_idx = (cube_warp_n * num_warps_m) + cube_warp_m; // Warp idx within a block (0 to WARPS_PER_BLOCK - 1) + + Positions { + global_m: global_warp_m * cmma_m, + global_n: global_warp_n * cmma_n, + intra_warp_unit_idx, + cube_linear_warp_idx, + } +} + +#[cube] +fn make_matrices( + #[comptime] gemm_settings: GemmSettings, +) -> Matrices { + let GemmSettings { + cmma_m, + cmma_n, + cmma_k, + .. + } = gemm_settings; + + let matrices = Matrices:: { + a: Matrix::::new( + MatrixIdent::A, + cmma_m, + cmma_n, + cmma_k, + MatrixLayout::RowMajor, + ), + b: Matrix::::new( + MatrixIdent::B, + cmma_m, + cmma_n, + cmma_k, + MatrixLayout::RowMajor, + ), + acc: Matrix::::new( + MatrixIdent::Accumulator, + cmma_m, + cmma_n, + cmma_k, + MatrixLayout::Undefined, + ), + }; + + cmma::fill(&matrices.acc, FAcc::new(0.0)); + + matrices +} + +#[cube] +fn execute_gemm( + input: &Tensor, + weight: &Tensor, + out: &mut SliceMut, + input_tile: &mut SliceMut, + weight_tile: &mut SliceMut, + dims: &Dimensions, + pos: &Positions, + args: &ConvArgs, + #[comptime] g_settings: GemmSettings, + #[comptime] k_settings: KernelSettings, +) { + let GemmSettings { cmma_n, cmma_k, .. } = g_settings; + + let matrices = make_matrices::(g_settings); + + // Loop over the K-dimension + for k in range_stepped(0, dims.gemm_k, cmma_k) { + // Load into smem... + // Each warp should load the 16x16 tile it's responsible for + // i.e. each thread needs to load 8 elements of input and 8 elements of weight + + load_input_tile( + input, args, input_tile, dims, pos, k, g_settings, k_settings, + ); + + load_weight_tile(weight, weight_tile, pos, k, g_settings, k_settings); + + // Run CMMA + cmma::load(&matrices.a, input_tile.as_slice(), cmma_k); + cmma::load(&matrices.b, weight_tile.as_slice(), cmma_n); + + cmma::execute::(&matrices.a, &matrices.b, &matrices.acc, &matrices.acc); + } + + cmma::store(out, &matrices.acc, dims.gemm_n, MatrixLayout::RowMajor); +} + +#[cube] +fn load_input_tile( + input: &Tensor, + args: &ConvArgs, + tile: &mut SliceMut, + dims: &Dimensions, + pos: &Positions, + k: u32, + #[comptime] gemm_settings: GemmSettings, + #[comptime] kernel_settings: KernelSettings, +) { + let GemmSettings { + cmma_m, + cmma_k, + warp_size, + .. + } = gemm_settings; + + let KernelSettings { + kernel_h, + kernel_w, + has_padding, + .. + } = kernel_settings; + + let kernel_size = kernel_h * kernel_w; + let cmma_input_tile_size = cmma_m * cmma_k; + + let height = input.shape(2) as i32; + let width = input.shape(3) as i32; + + // Row strides in the implicit GEMM matrix + let batch_stride = dims.out_h * dims.out_w; + let y_stride = dims.out_w; + let x_stride = 1; + + // Start index within a slice (0 to `kernel_size * channels - 1`) that a half warp (16 units) is responsible for + let slice_start_idx = k % dims.slice_size; + + for m in range_stepped(pos.intra_warp_unit_idx, cmma_input_tile_size, warp_size) { + // Compute where in the slice we are starting + + // Slices are always `kernel_size * channels` elements wide so we can compute where inside a slice + // we are and also which row the slice is in relative to the start of the CMMA matrix + + let rel_slice_row = m / cmma_k; // Relative row (0 - 15) + let abs_slice_row = pos.global_m + rel_slice_row; // Row of the matrix the slice is on + + // Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is repsonsible for + let my_slice_idx = (slice_start_idx + (m % cmma_k)) % dims.slice_size; + + // Given the row of the matrix that the slice is in, and the index of the thread + // within a slice, want to compute what input element to load... + // first compute coordinates in output space (center of the kernel in MxK matrix A) + let batch = abs_slice_row / batch_stride; + let out_y = (abs_slice_row % batch_stride) / y_stride; + let out_x = ((abs_slice_row % batch_stride) % y_stride) / x_stride; + + let channel = my_slice_idx / kernel_size; + + let kernel_y = (my_slice_idx / kernel_w) % kernel_h; + let kernel_x = my_slice_idx % kernel_w; + + if has_padding { + let y = (out_y * args.stride_h + kernel_y * args.dilation_h) as i32 - args.pad_h; + let x = (out_x * args.stride_w + kernel_x * args.dilation_w) as i32 - args.pad_w; + + if x >= 0 && x < width && y >= 0 && y < height { + tile[m] = FMat::cast_from( + input[batch * input.stride(0) + + y as u32 * input.stride(2) + + x as u32 * input.stride(3) + + channel * input.stride(1)], + ); + } else { + tile[m] = FMat::new(0.0); + } + } else { + let y = out_y * args.stride_h + kernel_y * args.dilation_h; + let x = out_x * args.stride_w + kernel_x * args.dilation_w; + + tile[m] = FMat::cast_from( + input[batch * input.stride(0) + + y * input.stride(2) + + x * input.stride(3) + + channel * input.stride(1)], + ); + } + } +} + +#[cube] +fn load_weight_tile( + weight: &Tensor, + tile: &mut SliceMut, + pos: &Positions, + k: u32, + #[comptime] gemm_settings: GemmSettings, + #[comptime] kernel_settings: KernelSettings, +) { + let GemmSettings { + cmma_n, + cmma_k, + warp_size, + .. + } = gemm_settings; + + let KernelSettings { + kernel_h, kernel_w, .. + } = kernel_settings; + + let kernel_size = kernel_h * kernel_w; + let cmma_filter_tile_size = cmma_k * cmma_n; + + for n in range_stepped(pos.intra_warp_unit_idx, cmma_filter_tile_size, warp_size) { + // Compute where in the slice we are starting + let rel_slice_row = n / cmma_k; // Relative row (0 - 15) + let abs_slice_row = k + rel_slice_row; // Row of the matrix the slice is on + let abs_slice_col = pos.global_n + (n % 16); // Row of the matrix the slice is on + + // Given the row of the matrix that the slice is in, and the index of the unit + // within a slice, want to compute what weight element to load... + let out_channel = abs_slice_col; + let in_channel = abs_slice_row / kernel_size; + let kernel_y = (abs_slice_row % kernel_size) / kernel_h; + let kernel_x = abs_slice_row % kernel_w; + + tile[n] = FMat::cast_from( + weight[out_channel * weight.stride(0) + + in_channel * weight.stride(1) + + kernel_y * weight.stride(2) + + kernel_x * weight.stride(3)], + ); + } +} + +pub(crate) fn can_do_implicit_gemm( + input: &JitTensor, + weight: &JitTensor, + options: &ConvOptions<2>, + out_h: usize, + out_w: usize, +) -> bool { + let [batch_size, in_channels, _, _] = input.shape.dims; + let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims; + + let cmma_m = 16; + let cmma_n = 16; + let cmma_k = 16; + let warps_per_cube = 8; + + let gemm_m = batch_size * out_h * out_w; + let gemm_n = out_channels; + let gemm_k = in_channels * kernel_h * kernel_w; + + let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::(); + + cmma_available::(&input.device) + && ::max_shared_memory_size() >= smem_size + && gemm_m % 16 == 0 + && gemm_n % 16 == 0 + && gemm_k % 16 == 0 + && options.groups == 1 +} + +fn cmma_available(device: &R::JitDevice) -> bool { + R::client(device).features().enabled(Feature::Cmma { + a: Elem::Float(FloatKind::F16), + b: Elem::Float(FloatKind::F16), + c: Elem::Float(FloatKind::F32), + m: 16, + k: 16, + n: 16, + }) +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/mod.rs new file mode 100644 index 000000000..40775fe48 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/mod.rs @@ -0,0 +1,18 @@ +mod base; +mod col2im; +mod direct; +mod im2col; +mod implicit_gemm; +mod transpose_direct; + +#[cfg(feature = "autotune")] +mod tune; + +pub use base::*; +pub use col2im::*; +pub use direct::*; +pub use im2col::*; +pub use implicit_gemm::*; +pub use transpose_direct::*; +#[cfg(feature = "autotune")] +pub use tune::*; diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs similarity index 97% rename from crates/burn-jit/src/kernel/conv/conv_transpose2d.rs rename to crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs index 480aceef7..499a1cf3a 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv2d/transpose_direct.rs @@ -14,9 +14,9 @@ use crate::{ reshape, }, tensor::JitTensor, - JitRuntime, + IntElement, JitRuntime, }; -use burn_tensor::{ops::ConvTransposeOptions, Element, Shape}; +use burn_tensor::{ops::ConvTransposeOptions, Shape}; #[derive(new)] struct Conv2dTransposeEagerKernel { @@ -364,7 +364,15 @@ impl Kernel for Conv2dTransposeEagerKernel { } } -pub(crate) fn conv_transpose2d( +/// Perform a 2D convolution transposition using the direct algorithm. +/// +/// * `input` - The input feature map +/// * `weight` - The weights (filter) applied to each kernel +/// * `bias` - The bias added to each channel +/// * `options` - The options to use for the convolution +/// +#[allow(clippy::extra_unused_type_parameters)] +pub fn conv_transpose2d_direct( input: JitTensor, weight: JitTensor, bias: Option>, diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs new file mode 100644 index 000000000..99de01d77 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv2d.rs @@ -0,0 +1,152 @@ +use burn_tensor::{ + ops::{conv::calculate_conv_output_size, ConvOptions}, + ElementConversion, Shape, +}; +use cubecl::{ + tune, + tune::{local_tuner, tune_with, LocalTuner}, + AutotuneKey, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + kernel::{ + conv::{can_do_implicit_gemm, conv2d_direct, conv2d_im2col, conv2d_implicit_gemm}, + prng::random_uniform, + }, + tensor::JitTensor, + FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId, +}; + +/// Executes autotune on conv2d operations +pub fn conv2d_autotune( + input: JitTensor, + weights: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let client = input.client.clone(); + + static TUNER: LocalTuner = local_tuner!(); + + TUNER.execute( + &JitTuneId::new::(&input.device), + &client, + Box::new(Conv2dOperations::::new( + input, weights, bias, options, + )), + ) +} + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +/// Autotune key representative of matmul versions +pub struct Conv2dAutotuneKey { + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub dilation: [usize; 2], + pub groups: usize, + #[autotune(anchor)] + pub in_channels: usize, + #[autotune(anchor)] + pub out_channels: usize, + #[autotune(anchor)] + pub height: usize, + #[autotune(anchor)] + pub width: usize, + #[autotune(anchor)] + pub batch_size: usize, + pub has_bias: bool, +} + +#[tune( + operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm), + create_key = create_key, + should_run = should_run +)] +pub fn conv2d_operations( + key: JitAutotuneKey, + input: JitTensor, + weights: JitTensor, + bias: Option>, + options: ConvOptions<2>, +) -> JitTensor { + let device = &input.device; + let key = match key { + JitAutotuneKey::Conv2d(key) => key, + _ => unreachable!(), + }; + + let random_bounds: (E, E) = ((-1.0).elem::(), (1.0).elem::()); + let input_shape = Shape::new([key.batch_size, key.in_channels, key.height, key.width]); + let input = random_uniform(input_shape, device, random_bounds.0, random_bounds.1); + let c_per_grp = key.in_channels / key.groups; + let [kernel_h, kernel_w] = key.kernel_size; + let weight_shape = Shape::new([key.out_channels, c_per_grp, kernel_h, kernel_w]); + let weights = random_uniform(weight_shape, device, random_bounds.0, random_bounds.1); + let bias_shape = Shape::new([key.out_channels]); + let bias = key + .has_bias + .then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1)); + + tune_with!(input, weights, bias, options) +} + +fn should_run( + op: &Conv2dOperations, + _key: &JitAutotuneKey, + index: usize, +) -> bool { + match index { + 2 => { + let [_, _, height, width] = op.input.shape.dims; + let [_, _, kernel_h, kernel_w] = op.weights.shape.dims; + let o = &op.options; + let out_h = calculate_conv_output_size( + kernel_h, + o.stride[0], + o.padding[0], + o.dilation[0], + height, + ); + let out_w = calculate_conv_output_size( + kernel_w, + o.stride[1], + o.padding[1], + o.dilation[1], + width, + ); + can_do_implicit_gemm(&op.input, &op.weights, &op.options, out_h, out_w) + } + _ => true, + } +} + +fn create_key( + input: &JitTensor, + weights: &JitTensor, + bias: &Option>, + options: &ConvOptions<2>, +) -> JitAutotuneKey { + let [batch_size, in_channels, height, width] = input.shape.dims; + let [out_channels, _, kernel_h, kernel_w] = weights.shape.dims; + let ConvOptions { + stride, + padding, + dilation, + groups, + } = options.clone(); + JitAutotuneKey::Conv2d(Conv2dAutotuneKey::new( + [kernel_h, kernel_w], + stride, + padding, + dilation, + groups, + in_channels, + out_channels, + height, + width, + batch_size, + bias.is_some(), + )) +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs new file mode 100644 index 000000000..2ee3396a1 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/conv_transpose2d.rs @@ -0,0 +1,117 @@ +use burn_tensor::{ops::ConvTransposeOptions, ElementConversion, Shape}; +use cubecl::{ + tune, + tune::{local_tuner, tune_with, LocalTuner}, + AutotuneKey, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + kernel::{ + conv::{conv_transpose2d_col2im, conv_transpose2d_direct}, + prng::random_uniform, + }, + tensor::JitTensor, + FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId, +}; + +#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)] +/// Autotune key representative of matmul versions +pub struct ConvTranspose2dAutotuneKey { + pub kernel_size: [usize; 2], + pub stride: [usize; 2], + pub padding: [usize; 2], + pub padding_out: [usize; 2], + pub dilation: [usize; 2], + pub groups: usize, + #[autotune(anchor)] + pub in_channels: usize, + #[autotune(anchor)] + pub out_channels: usize, + #[autotune(anchor)] + pub height: usize, + #[autotune(anchor)] + pub width: usize, + #[autotune(anchor)] + pub batch_size: usize, + pub has_bias: bool, +} + +/// Executes autotune on conv2d operations +pub fn conv_transpose2d_autotune( + input: JitTensor, + weights: JitTensor, + bias: Option>, + options: ConvTransposeOptions<2>, +) -> JitTensor { + let client = input.client.clone(); + + static TUNER: LocalTuner = local_tuner!(); + + TUNER.execute( + &JitTuneId::new::(&input.device), + &client, + Box::new(ConvTranspose2dOperations::::new( + input, weights, bias, options, + )), + ) +} + +#[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key)] +pub fn conv_transpose2d_operations( + key: JitAutotuneKey, + input: JitTensor, + weights: JitTensor, + bias: Option>, + options: ConvTransposeOptions<2>, +) -> JitTensor { + let key = match key { + JitAutotuneKey::ConvTranspose2d(key) => key, + _ => unreachable!(), + }; + let device = &input.device; + + let random_bounds: (E, E) = ((-1.0).elem::(), (1.0).elem::()); + let input_shape = Shape::new([key.batch_size, key.in_channels, key.height, key.width]); + let input = random_uniform(input_shape, device, random_bounds.0, random_bounds.1); + let c_per_grp = key.in_channels / key.groups; + let [kernel_h, kernel_w] = key.kernel_size; + let weight_shape = Shape::new([key.out_channels, c_per_grp, kernel_h, kernel_w]); + let weights = random_uniform(weight_shape, device, random_bounds.0, random_bounds.1); + let bias_shape = Shape::new([key.out_channels]); + let bias = key + .has_bias + .then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1)); + tune_with!(input, weights, bias, options) +} + +fn create_key( + input: &JitTensor, + weights: &JitTensor, + bias: &Option>, + options: &ConvTransposeOptions<2>, +) -> JitAutotuneKey { + let [batch_size, in_channels, height, width] = input.shape.dims; + let [out_channels, _, kernel_h, kernel_w] = weights.shape.dims; + let ConvTransposeOptions { + stride, + padding, + dilation, + groups, + padding_out, + } = options.clone(); + JitAutotuneKey::ConvTranspose2d(ConvTranspose2dAutotuneKey::new( + [kernel_h, kernel_w], + stride, + padding, + padding_out, + dilation, + groups, + in_channels, + out_channels, + height, + width, + batch_size, + bias.is_some(), + )) +} diff --git a/crates/burn-jit/src/kernel/conv/conv2d/tune/mod.rs b/crates/burn-jit/src/kernel/conv/conv2d/tune/mod.rs new file mode 100644 index 000000000..c65e221c6 --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/conv2d/tune/mod.rs @@ -0,0 +1,5 @@ +mod conv2d; +mod conv_transpose2d; + +pub use conv2d::*; +pub use conv_transpose2d::*; diff --git a/crates/burn-jit/src/kernel/conv/mod.rs b/crates/burn-jit/src/kernel/conv/mod.rs index 2c60c07bf..5ed7aa570 100644 --- a/crates/burn-jit/src/kernel/conv/mod.rs +++ b/crates/burn-jit/src/kernel/conv/mod.rs @@ -1,13 +1,13 @@ mod conv2d; mod conv3d; -mod conv_transpose2d; mod conv_transpose3d; mod deform_conv2d; mod deform_conv_transpose2d; pub(crate) use conv2d::*; pub(crate) use conv3d::*; -pub(crate) use conv_transpose2d::*; pub(crate) use conv_transpose3d::*; pub(crate) use deform_conv2d::*; pub(crate) use deform_conv_transpose2d::*; + +pub use conv2d::{conv2d, conv_transpose2d, Conv2dStrategy, ConvTranspose2dStrategy}; diff --git a/crates/burn-jit/src/kernel/matmul/base.rs b/crates/burn-jit/src/kernel/matmul/base.rs index 61d2dd2ba..f2c493e36 100644 --- a/crates/burn-jit/src/kernel/matmul/base.rs +++ b/crates/burn-jit/src/kernel/matmul/base.rs @@ -16,7 +16,7 @@ pub enum MatmulStrategy { grid_y: usize, }, #[cfg(feature = "autotune")] - /// Using autotune to chose the best kernel based on runtime information. + /// Using autotune to choose the best kernel based on runtime information. Autotune, /// Cube implementation of matmul. Cube, diff --git a/crates/burn-jit/src/ops/module_ops.rs b/crates/burn-jit/src/ops/module_ops.rs index 48fc1db49..eb9d29999 100644 --- a/crates/burn-jit/src/ops/module_ops.rs +++ b/crates/burn-jit/src/ops/module_ops.rs @@ -1,4 +1,10 @@ -use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime}; +use crate::{ + kernel::{ + self, + conv::{Conv2dStrategy, ConvTranspose2dStrategy}, + }, + FloatElement, IntElement, JitBackend, JitRuntime, +}; use burn_tensor::ops::{ ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, @@ -17,7 +23,7 @@ where bias: Option>, options: ConvOptions<2>, ) -> FloatTensor { - kernel::conv::conv2d(x, weight, bias, options) + kernel::conv::conv2d::(x, weight, bias, options, Conv2dStrategy::default()) } fn deform_conv2d( @@ -58,7 +64,13 @@ where bias: Option>, options: ConvTransposeOptions<2>, ) -> FloatTensor { - kernel::conv::conv_transpose2d(x, weight, bias, options) + kernel::conv::conv_transpose2d::( + x, + weight, + bias, + options, + ConvTranspose2dStrategy::default(), + ) } fn conv_transpose3d( diff --git a/crates/burn-jit/src/tests/avg_pool2d.rs b/crates/burn-jit/src/tests/avg_pool2d.rs index 4e4a394da..54e7c2492 100644 --- a/crates/burn-jit/src/tests/avg_pool2d.rs +++ b/crates/burn-jit/src/tests/avg_pool2d.rs @@ -6,7 +6,7 @@ mod tests { }; #[test] - fn avg_pool2d_should_work_with_multiple_invocations() { + fn avg_pool2d_should_match_reference_backend() { let tensor = Tensor::::random( [32, 32, 32, 32], Distribution::Default, @@ -29,7 +29,7 @@ mod tests { } #[test] - fn avg_pool2d_backward_should_work_with_multiple_invocations() { + fn avg_pool2d_backward_should_match_reference_backend() { TestBackend::seed(0); ReferenceBackend::seed(0); let tensor = Tensor::::random( diff --git a/crates/burn-jit/src/tests/cat.rs b/crates/burn-jit/src/tests/cat.rs index 8d81e6180..8a5d4be54 100644 --- a/crates/burn-jit/src/tests/cat.rs +++ b/crates/burn-jit/src/tests/cat.rs @@ -4,12 +4,12 @@ mod tests { use burn_tensor::{backend::Backend, Distribution, Tensor}; #[test] - fn cat_should_support_multiple_invocations_dim0() { + fn cat_should_match_reference_backend_dim0() { test_same_as_reference([6, 256], 2, 0); } #[test] - fn cat_should_support_multiple_invocations_dim1() { + fn cat_should_match_reference_backend_dim1() { test_same_as_reference([6, 256], 2, 1); } diff --git a/crates/burn-jit/src/tests/conv2d.rs b/crates/burn-jit/src/tests/conv2d.rs index 91a4c5c64..e64d337e9 100644 --- a/crates/burn-jit/src/tests/conv2d.rs +++ b/crates/burn-jit/src/tests/conv2d.rs @@ -4,7 +4,7 @@ mod tests { use burn_tensor::{module, Distribution, Tensor}; #[test] - fn conv2d_should_work_with_multiple_invocations() { + fn conv2d_should_match_reference_backend() { let test_device = Default::default(); let input = Tensor::::random([6, 16, 32, 32], Distribution::Default, &test_device); @@ -26,4 +26,28 @@ mod tests { .into_data() .assert_approx_eq(&output_ref.into_data(), 3); } + + #[test] + fn conv2d_should_match_reference_backend_implicit() { + let test_device = Default::default(); + let input = + Tensor::::random([4, 16, 6, 6], Distribution::Default, &test_device); + let weight = + Tensor::::random([16, 16, 3, 3], Distribution::Default, &test_device); + let bias = Tensor::::random([16], Distribution::Default, &test_device); + let ref_device = Default::default(); + + let input_ref = Tensor::::from_data(input.to_data(), &ref_device); + let weight_ref = Tensor::::from_data(weight.to_data(), &ref_device); + let bias_ref = Tensor::::from_data(bias.to_data(), &ref_device); + + let options = burn_tensor::ops::ConvOptions::new([1, 1], [2, 2], [1, 1], 1); + + let output = module::conv2d(input, weight, Some(bias), options.clone()); + let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options); + + output + .into_data() + .assert_approx_eq(&output_ref.into_data(), 1); + } } diff --git a/crates/burn-jit/src/tests/conv3d.rs b/crates/burn-jit/src/tests/conv3d.rs index 99abeba22..d6d7badaf 100644 --- a/crates/burn-jit/src/tests/conv3d.rs +++ b/crates/burn-jit/src/tests/conv3d.rs @@ -4,7 +4,7 @@ mod tests { use burn_tensor::{module, Distribution, Tensor}; #[test] - fn conv3d_should_work_with_multiple_invocations() { + fn conv3d_should_match_reference_backend() { let test_device = Default::default(); let input = Tensor::::random( [6, 16, 32, 32, 32], diff --git a/crates/burn-jit/src/tests/conv_transpose2d.rs b/crates/burn-jit/src/tests/conv_transpose2d.rs index bf44ea694..00db587a6 100644 --- a/crates/burn-jit/src/tests/conv_transpose2d.rs +++ b/crates/burn-jit/src/tests/conv_transpose2d.rs @@ -4,7 +4,7 @@ mod tests { use burn_tensor::{backend::Backend, module, Distribution, Tensor}; #[test] - fn conv_transpose2d_should_work_with_multiple_invocations() { + fn conv_transpose2d_should_match_reference_backend() { TestBackend::seed(0); let height = 8; diff --git a/crates/burn-jit/src/tests/conv_transpose3d.rs b/crates/burn-jit/src/tests/conv_transpose3d.rs index 1a582d58c..52362be9f 100644 --- a/crates/burn-jit/src/tests/conv_transpose3d.rs +++ b/crates/burn-jit/src/tests/conv_transpose3d.rs @@ -4,7 +4,7 @@ mod tests { use burn_tensor::{backend::Backend, module, Distribution, Tensor}; #[test] - fn conv_transpose3d_should_work_with_multiple_invocations() { + fn conv_transpose3d_should_match_reference_backend() { TestBackend::seed(0); let depth = 8; diff --git a/crates/burn-jit/src/tests/mask_fill.rs b/crates/burn-jit/src/tests/mask_fill.rs index e4a874cf0..b13a4d29c 100644 --- a/crates/burn-jit/src/tests/mask_fill.rs +++ b/crates/burn-jit/src/tests/mask_fill.rs @@ -5,7 +5,7 @@ mod tests { use burn_tensor::{Bool, Distribution, Tensor, TensorPrimitive}; #[test] - fn mask_fill_should_work_with_multiple_invocations() { + fn mask_fill_should_match_reference_backend() { let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill( @@ -22,7 +22,7 @@ mod tests { } #[test] - fn mask_fill_inplace_should_work_with_multiple_invocations() { + fn mask_fill_inplace_should_match_reference_backend() { let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill(); let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_fill( diff --git a/crates/burn-jit/src/tests/mask_where.rs b/crates/burn-jit/src/tests/mask_where.rs index 82c62fc93..e49c3ef0a 100644 --- a/crates/burn-jit/src/tests/mask_where.rs +++ b/crates/burn-jit/src/tests/mask_where.rs @@ -5,7 +5,7 @@ mod tests { use burn_tensor::{backend::Backend, Bool, Distribution, Tensor, TensorPrimitive}; #[test] - fn mask_where_should_work_with_multiple_invocations() { + fn mask_where_should_match_reference_backend() { let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); let actual = tensor.mask_where(mask, value); @@ -16,7 +16,7 @@ mod tests { .assert_approx_eq(&actual.into_data(), 3); } #[test] - fn mask_where_inplace_lhs_should_work_with_multiple_invocations() { + fn mask_where_inplace_lhs_should_match_reference_backend() { let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_where( @@ -33,7 +33,7 @@ mod tests { } #[test] - fn mask_where_inplace_rhs_should_work_with_multiple_invocation() { + fn mask_where_inplace_rhs_should_match_reference_backend() { let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where(); let actual = Tensor::::from_primitive(TensorPrimitive::Float(mask_where( diff --git a/crates/burn-jit/src/tests/max_pool2d.rs b/crates/burn-jit/src/tests/max_pool2d.rs index 950bee528..7cba862db 100644 --- a/crates/burn-jit/src/tests/max_pool2d.rs +++ b/crates/burn-jit/src/tests/max_pool2d.rs @@ -4,7 +4,7 @@ mod tests { use burn_tensor::{module, Distribution, Tensor}; #[test] - pub fn max_pool2d_should_work_with_multiple_invocations() { + pub fn max_pool2d_should_match_reference_backends() { let tensor = Tensor::::random( [32, 32, 32, 32], Distribution::Default, @@ -26,7 +26,7 @@ mod tests { } #[test] - pub fn max_pool2d_with_indices_should_work_with_multiple_invocations() { + pub fn max_pool2d_with_indices_should_match_reference_backend() { let tensor = Tensor::::random( [32, 32, 32, 32], Distribution::Default, diff --git a/crates/burn-jit/src/tests/max_pool2d_backward.rs b/crates/burn-jit/src/tests/max_pool2d_backward.rs index 1419f40c5..daf60c48e 100644 --- a/crates/burn-jit/src/tests/max_pool2d_backward.rs +++ b/crates/burn-jit/src/tests/max_pool2d_backward.rs @@ -4,7 +4,7 @@ mod tests { use burn_tensor::{module, ops::ModuleOps, Distribution, Tensor, TensorPrimitive}; #[test] - pub fn max_pool2d_with_indices_backward_should_work_with_multiple_invocations() { + pub fn max_pool2d_with_indices_backward_should_match_reference_backend() { let test_device = Default::default(); let tensor = Tensor::::random([32, 32, 32, 32], Distribution::Default, &test_device); diff --git a/crates/burn-jit/src/tests/reduce.rs b/crates/burn-jit/src/tests/reduce.rs index dc6ff8331..f00d34896 100644 --- a/crates/burn-jit/src/tests/reduce.rs +++ b/crates/burn-jit/src/tests/reduce.rs @@ -10,7 +10,7 @@ mod reduction { }; #[test] - fn reduction_sum_dim_should_work_with_multiple_invocations() { + fn reduction_sum_dim_should_match_reference_backend() { let tensor = Tensor::::random([6, 1024], Distribution::Default, &Default::default()); let tensor_ref = @@ -33,7 +33,7 @@ mod reduction { } #[test] - fn reduction_prod_dim_should_work_with_multiple_invocations() { + fn reduction_prod_dim_should_match_reference_backend() { let tensor = Tensor::::random([6, 1024], Distribution::Default, &Default::default()); let tensor_ref = @@ -56,7 +56,7 @@ mod reduction { } #[test] - fn reduction_argmin_dim_should_work_with_multiple_invocations() { + fn reduction_argmin_dim_should_match_reference_backend() { let tensor = Tensor::::random([6, 1024], Distribution::Default, &Default::default()); let tensor_ref = @@ -75,7 +75,7 @@ mod reduction { } #[test] - fn reduction_argmax_dim_should_work_with_multiple_invocations() { + fn reduction_argmax_dim_should_match_reference_backend() { let tensor = Tensor::::random([6, 1024], Distribution::Default, &Default::default()); let tensor_ref = @@ -290,7 +290,7 @@ mod reduction { } #[test] - fn reduction_sum_should_work_with_multiple_invocations() { + fn reduction_sum_should_match_reference_backend() { let tensor = Tensor::::random([6, 256], Distribution::Default, &Default::default()); let tensor_ref = @@ -306,7 +306,7 @@ mod reduction { } #[test] - fn reduction_prod_should_work_with_multiple_invocations() { + fn reduction_prod_should_match_reference_backend() { let tensor = Tensor::::random([6, 256], Distribution::Default, &Default::default()); let tensor_ref = diff --git a/crates/burn-jit/src/tune_key.rs b/crates/burn-jit/src/tune_key.rs index ebc1e9202..e0ea918e2 100644 --- a/crates/burn-jit/src/tune_key.rs +++ b/crates/burn-jit/src/tune_key.rs @@ -1,4 +1,8 @@ -use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey}; +use crate::kernel::{ + conv::{Conv2dAutotuneKey, ConvTranspose2dAutotuneKey}, + matmul::MatmulAutotuneKey, + reduce::ReduceAutotuneKey, +}; use cubecl::tune::AutotuneKey; use serde::{Deserialize, Serialize}; use std::fmt::Display; @@ -13,6 +17,10 @@ pub enum JitAutotuneKey { Matmul(MatmulAutotuneKey), /// Key for reduce dim operations ReduceDim(ReduceAutotuneKey), + /// Key for convolution operations + Conv2d(Conv2dAutotuneKey), + /// Key for transpose convolution operations + ConvTranspose2d(ConvTranspose2dAutotuneKey), #[cfg(any(feature = "fusion", test))] /// Key for fused element wise operations. FusionElemWise(FusionElemWiseAutotuneKey), @@ -25,6 +33,8 @@ impl Display for JitAutotuneKey { JitAutotuneKey::ReduceDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), #[cfg(any(feature = "fusion", test))] JitAutotuneKey::FusionElemWise(reduce_key) => std::fmt::Display::fmt(&reduce_key, f), + JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), + JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f), } } } diff --git a/crates/burn-ndarray/src/ops/conv.rs b/crates/burn-ndarray/src/ops/conv.rs index 50bd319c9..2788f0c7d 100644 --- a/crates/burn-ndarray/src/ops/conv.rs +++ b/crates/burn-ndarray/src/ops/conv.rs @@ -109,6 +109,7 @@ pub(crate) fn conv2d( let [stride_height, stride_width] = options.stride; let [batch_size, _in_channels, in_height, in_width] = x.shape().dims; let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims; + let channels_per_group = out_channels / options.groups; let out_height = calculate_conv_output_size( kernel_height, @@ -141,7 +142,7 @@ pub(crate) fn conv2d( |(k, mut output)| { let b = k / out_channels; let oc = k % out_channels; - let g = k % options.groups; + let g = oc / channels_per_group; for ic in (in_channels * g)..(in_channels * (g + 1)) { let weight_ic = ic - (g * in_channels); diff --git a/crates/burn-tensor/src/tensor/ops/modules/conv.rs b/crates/burn-tensor/src/tensor/ops/modules/conv.rs index ba801aade..3c33805a9 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/conv.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/conv.rs @@ -43,7 +43,7 @@ pub fn calculate_conv_transpose_output_size( dilation: usize, size_in: usize, ) -> usize { - (size_in - 1) * stride + dilation * (kernel_size - 1) + padding_out - 2 * padding + 1 + (size_in - 1) * stride + (dilation * (kernel_size - 1) + 1) + padding_out - 2 * padding } /// Calculate the expected output size when doing a pooling operation. diff --git a/crates/burn-tensor/src/tests/module/conv2d.rs b/crates/burn-tensor/src/tests/module/conv2d.rs index 5c14eff65..aa8fdb826 100644 --- a/crates/burn-tensor/src/tests/module/conv2d.rs +++ b/crates/burn-tensor/src/tests/module/conv2d.rs @@ -69,6 +69,49 @@ mod tests { ]])); } + #[test] + fn test_conv2d_groups_multiple_channels() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 4, + channels_out: 4, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 2, + height: 5, + width: 5, + }; + + test.assert_output(TestTensor::from([[ + [ + [4035., 4188., 4341.], + [4800., 4953., 5106.], + [5565., 5718., 5871.], + ], + [ + [10030., 10507., 10984.], + [12415., 12892., 13369.], + [14800., 15277., 15754.], + ], + [ + [56075., 56876., 57677.], + [60080., 60881., 61682.], + [64085., 64886., 65687.], + ], + [ + [78270., 79395., 80520.], + [83895., 85020., 86145.], + [89520., 90645., 91770.], + ], + ]])); + } + #[test] fn test_conv2d_complex() { let test = Conv2dTestCase { diff --git a/crates/burn-tensor/src/tests/module/conv_transpose2d.rs b/crates/burn-tensor/src/tests/module/conv_transpose2d.rs index a0da6e61d..677e733c5 100644 --- a/crates/burn-tensor/src/tests/module/conv_transpose2d.rs +++ b/crates/burn-tensor/src/tests/module/conv_transpose2d.rs @@ -28,6 +28,7 @@ mod tests { test.assert_output(TestTensor::from([[[[5.0, 11.0], [23.0, 29.0]]]])); } + #[test] fn test_conv_transpose2d_simple_2() { let test = ConvTranspose2dTestCase { @@ -71,6 +72,34 @@ mod tests { ]])); } + #[test] + fn test_conv_transpose2d_simple_3() { + let test = ConvTranspose2dTestCase { + batch_size: 1, + channels_in: 1, + channels_out: 1, + kernel_size_1: 2, + kernel_size_2: 2, + padding_1: 0, + padding_2: 0, + padding_out_1: 0, + padding_out_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + height: 2, + width: 2, + }; + + test.assert_output(TestTensor::from([[[ + [0.0, 0.0, 1.0], + [0.0, 4.0, 6.0], + [4.0, 12.0, 9.0], + ]]])); + } + #[test] fn test_conv_transpose2d_stride_2() { let test = ConvTranspose2dTestCase {