Introduce autotuning to `conv2d` and `conv_transpose2d` with a new `im2col`/`GEMM` algorithm (#2287)

This commit is contained in:
Genna Wingert 2024-09-23 21:54:50 +02:00 committed by GitHub
parent 2c8514ce7f
commit 97af8c6d28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
35 changed files with 1807 additions and 96 deletions

1
Cargo.lock generated
View File

@ -305,6 +305,7 @@ dependencies = [
"derive-new",
"dirs 5.0.1",
"github-device-flow",
"half",
"indicatif",
"os_info",
"percent-encoding",

View File

@ -11,11 +11,12 @@ version.workspace = true
[features]
# we depend on wgpu and autotune by default because we use the burn-wgpu crate to get system information
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
candle-accelerate = ["burn/candle", "burn/accelerate"]
candle-cpu = ["burn/candle"]
candle-cuda = ["burn/candle-cuda"]
candle-metal = ["burn/candle", "burn/metal"]
candle-accelerate = ["burn/candle", "burn/accelerate"]
cuda-jit = ["burn/cuda-jit"]
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
@ -24,7 +25,6 @@ tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu", "burn/autotune"]
wgpu-fusion = ["wgpu", "burn/fusion"]
cuda-jit = ["burn/cuda-jit"]
[dependencies]
arboard = { workspace = true }
@ -33,11 +33,13 @@ burn-common = { path = "../crates/burn-common", version = "0.15.0" }
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0", optional = true }
clap = { workspace = true }
colored = { workspace = true }
cubecl = { workspace = true, features = ["wgpu"] }
derive-new = { workspace = true }
dirs = { workspace = true }
github-device-flow = { workspace = true }
os_info = { workspace = true }
half = { workspace = true }
indicatif = { workspace = true }
os_info = { workspace = true }
percent-encoding = { workspace = true }
rand = { workspace = true }
reqwest = { workspace = true, features = ["blocking", "json"] }
@ -48,69 +50,68 @@ strum_macros = { workspace = true }
sysinfo = { workspace = true, features = ["serde"] }
wgpu = { workspace = true }
wsl = { workspace = true }
cubecl = { workspace = true, features = ["wgpu"] }
[dev-dependencies]
rstest = { workspace = true }
serial_test = { workspace = true }
[[bench]]
harness = false
name = "unary"
harness = false
[[bench]]
harness = false
name = "binary"
harness = false
[[bench]]
harness = false
name = "max-pool2d"
path = "benches/max_pool2d.rs"
harness = false
[[bench]]
harness = false
name = "conv-transpose2d"
path = "benches/conv_transpose2d.rs"
harness = false
[[bench]]
harness = false
name = "conv-transpose3d"
path = "benches/conv_transpose3d.rs"
harness = false
[[bench]]
harness = false
name = "conv2d"
harness = false
[[bench]]
harness = false
name = "conv3d"
harness = false
[[bench]]
harness = false
name = "matmul"
harness = false
[[bench]]
harness = false
name = "data"
harness = false
[[bench]]
name = "load-record"
harness = false
name = "load-record"
path = "benches/load_record.rs"
[[bench]]
harness = false
name = "custom-gelu"
path = "benches/custom_gelu.rs"
harness = false
[[bench]]
harness = false
name = "resnet50"
path = "benches/resnet.rs"
harness = false
[[bench]]
name = "autodiff"
harness = false
name = "autodiff"
[[bin]]
name = "burnbench"

View File

@ -3,7 +3,6 @@ use std::fs;
use std::path::Path;
use std::process::Command;
const MODELS_DIR: &str = "/tmp/models";
const MODELS_REPO: &str = "https://github.com/tracel-ai/models.git";
// Patch resnet code (remove pretrained feature code)
@ -224,8 +223,10 @@ where
}
fn main() {
let models_dir = std::env::temp_dir().join("models");
let models_dir = models_dir.as_path();
// Checkout ResNet code from models repo
let models_dir = Path::new(MODELS_DIR);
let models_dir = Path::new(models_dir);
if !models_dir.join(".git").exists() {
run("git", |command| {
command
@ -233,7 +234,7 @@ fn main() {
.arg("--depth=1")
.arg("--no-checkout")
.arg(MODELS_REPO)
.arg(MODELS_DIR)
.arg(models_dir)
});
run("git", |command| {
@ -266,10 +267,12 @@ fn main() {
let source_path = models_dir.join("resnet-burn").join("resnet").join("src");
let dest_path = Path::new(&out_dir);
for file in fs::read_dir(source_path).unwrap() {
let source_file = file.unwrap().path();
let dest_file = dest_path.join(source_file.file_name().unwrap());
fs::copy(source_file, dest_file).expect("should copy file successfully");
if let Ok(source_path) = fs::read_dir(source_path) {
for file in source_path {
let source_file = file.unwrap().path();
let dest_file = dest_path.join(source_file.file_name().unwrap());
fs::copy(source_file, dest_file).expect("should copy file successfully");
}
}
// Delete cloned repository contents

View File

@ -131,7 +131,7 @@ macro_rules! bench_on_backend {
{
use burn::backend::cuda_jit::{Cuda, CudaDevice};
bench::<Cuda>(&CudaDevice::default(), feature_name, url, token);
bench::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
}
};
}

View File

@ -11,12 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-jit"
version.workspace = true
[features]
default = ["autotune", "std", "fusion", "cubecl/default"]
std = ["cubecl/std"]
doc = ["default"]
autotune = []
template = []
fusion = ["burn-fusion"]
default = ["autotune", "std", "fusion", "cubecl/default"]
doc = ["default"]
export_tests = [
"burn-tensor-testgen",
"serial_test",
@ -25,32 +22,37 @@ export_tests = [
"burn-ndarray",
"fusion",
]
fusion = ["burn-fusion"]
std = ["cubecl/std"]
template = []
[dependencies]
cubecl = { workspace = true, features = ["linalg"] }
burn-common = { path = "../burn-common", version = "0.15.0" }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = ["cubecl"] }
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [
"cubecl",
] }
cubecl = { workspace = true, features = ["linalg"] }
bytemuck = { workspace = true }
derive-new = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
log = { workspace = true }
num-traits = { workspace = true }
rand = { workspace = true }
spin = { workspace = true }
half = { workspace = true, features = ["bytemuck"] }
# Template
serde = { workspace = true }
text_placeholder = { workspace = true, features = ["struct_context"] }
hashbrown = { workspace = true }
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.15.0", optional = true }
hashbrown = { workspace = true }
# When exporting tests
serial_test = { workspace = true, optional = true }
burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", default-features = false, optional = true }
burn-ndarray = { path = "../burn-ndarray", version = "0.15.0", optional = true }
serial_test = { workspace = true, optional = true }
[package.metadata.docs.rs]
features = ["doc"]

View File

@ -511,14 +511,14 @@ impl TraceBuilder {
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operator::AtomicCompareAndSwap(_op) => {
// Nothing to do.
}
Operator::Magnitude(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operator::AtomicCompareAndSwap(_op) => {
// Nothing to do.
}
},
Operation::Procedure(proc) => {
match proc {

View File

@ -0,0 +1,125 @@
use burn_tensor::{
ops::{ConvOptions, ConvTransposeOptions},
TensorData,
};
use crate::{tensor::JitTensor, FloatElement, IntElement, JitElement, JitRuntime};
#[cfg(feature = "autotune")]
use super::conv2d_autotune;
use super::{
conv2d_direct, conv2d_im2col, conv_transpose2d_autotune, conv_transpose2d_col2im,
conv_transpose2d_direct, implicit_gemm::conv2d_implicit_gemm,
};
/// The strategy to be used when launching a convolution kernel.
pub enum Conv2dStrategy {
/// A simple direct convolution.
Direct,
#[cfg(feature = "autotune")]
/// Using autotune to choose the best kernel based on runtime information.
Autotune,
/// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
Gemm,
/// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and
/// has constraints on tensor shape.
ImplicitGemm,
}
impl Default for Conv2dStrategy {
fn default() -> Self {
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return Conv2dStrategy::Autotune;
// if autotune is disabled, default to the more memory-conservative algorithm
#[cfg(not(feature = "autotune"))]
Conv2dStrategy::Direct
}
}
/// The strategy to be used when launching a conv_transpose kernel.
pub enum ConvTranspose2dStrategy {
/// A simple direct convolution.
Direct,
#[cfg(feature = "autotune")]
/// Using autotune to choose the best kernel based on runtime information.
Autotune,
/// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
Gemm,
}
impl Default for ConvTranspose2dStrategy {
fn default() -> Self {
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return ConvTranspose2dStrategy::Autotune;
// if autotune is disabled, default to the more memory-conservative algorithm
#[cfg(not(feature = "autotune"))]
ConvTranspose2dStrategy::Direct
}
}
/// Perform a 2D convolution with the given strategy
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
///
pub fn conv2d<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvOptions<2>,
strategy: Conv2dStrategy,
) -> JitTensor<R, E, 4> {
match strategy {
Conv2dStrategy::Direct => conv2d_direct::<R, E, I>(input, weight, bias, options),
#[cfg(feature = "autotune")]
Conv2dStrategy::Autotune => conv2d_autotune::<R, E, I>(input, weight, bias, options),
Conv2dStrategy::Gemm => conv2d_im2col::<R, E, I>(input, weight, bias, options),
Conv2dStrategy::ImplicitGemm => {
conv2d_implicit_gemm::<R, E, I>(input, weight, bias, options)
}
}
}
/// Perform a 2D convolution with the given strategy
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
///
pub fn conv_transpose2d<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvTransposeOptions<2>,
strategy: ConvTranspose2dStrategy,
) -> JitTensor<R, E, 4> {
match strategy {
ConvTranspose2dStrategy::Direct => {
conv_transpose2d_direct::<R, E, I>(input, weight, bias, options)
}
#[cfg(feature = "autotune")]
ConvTranspose2dStrategy::Autotune => {
conv_transpose2d_autotune::<R, E, I>(input, weight, bias, options)
}
ConvTranspose2dStrategy::Gemm => {
conv_transpose2d_col2im::<R, E, I>(input, weight, bias, options)
}
}
}
#[allow(unused)]
pub(crate) fn debug_data<R: JitRuntime, E: JitElement, const D: usize>(
tensor: JitTensor<R, E, D>,
) -> TensorData {
let bytes = tensor.client.read(tensor.handle.binding());
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
}

View File

@ -0,0 +1,297 @@
use burn_tensor::{
ops::{conv::calculate_conv_transpose_output_size, ConvTransposeOptions, FloatTensorOps as _},
Shape,
};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use crate::{
kernel::into_contiguous,
ops::{numeric::empty_device, reshape, swap_dims},
tensor::JitTensor,
FloatElement, IntElement, JitBackend, JitRuntime,
};
use super::batches_per_run;
/// Perform a 2D convolution transposition using the GEMM (col2im) algorithm.
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
///
pub fn conv_transpose2d_col2im<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvTransposeOptions<2>,
) -> JitTensor<R, E, 4> {
let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims;
let [batch_size, _, input_h, input_w] = input.shape.dims;
let groups = options.groups;
let input_ch_per_group = input_channels / groups;
let ConvTransposeOptions {
padding: [padding_h, padding_w],
padding_out: [padding_out_h, padding_out_w],
dilation: [dilation_h, dilation_w],
stride: [stride_h, stride_w],
..
} = options.clone();
let im_h = calculate_conv_transpose_output_size(
kernel_h,
stride_h,
padding_h,
padding_out_h,
dilation_h,
input_h,
);
let im_w = calculate_conv_transpose_output_size(
kernel_w,
stride_w,
padding_w,
padding_out_w,
dilation_w,
input_w,
);
let im_channels = im_ch_per_group * groups;
let batches_per_run = batches_per_run(batch_size, input_h, input_w);
let col_shape_0 = im_ch_per_group * kernel_h * kernel_w;
let weight = reshape(
weight.clone(),
Shape::new([groups, input_ch_per_group, col_shape_0]),
);
let weight = into_contiguous(swap_dims(weight, 1, 2));
if batches_per_run != batch_size {
let runs = batch_size / batches_per_run;
let im_shape = Shape::new([runs, batches_per_run, im_channels, im_h, im_w]);
let mut image = empty_device(input.client.clone(), input.device.clone(), im_shape);
let input_shape = Shape::new([runs, batches_per_run, input_channels, input_h, input_w]);
let input = reshape(input, input_shape);
let input_shape_run = Shape::new([batches_per_run, input_channels, input_h, input_w]);
let im_shape_run = Shape::new([1, batches_per_run, im_channels, im_h, im_w]);
for run in 0..runs {
let input = JitBackend::<R, E, I>::float_narrow(input.clone(), 0, run, 1);
let input = reshape(input, input_shape_run.clone());
let image_run = execute::<R, E, I>(
input,
weight.clone(),
bias.clone(),
options.clone(),
im_ch_per_group,
im_h,
im_w,
kernel_h,
kernel_w,
);
let image_run = reshape(image_run, im_shape_run.clone());
image = JitBackend::<R, E, I>::float_slice_assign(
image,
[
run..run + 1,
0..batches_per_run,
0..im_channels,
0..im_h,
0..im_w,
],
image_run,
)
}
reshape(image, Shape::new([batch_size, im_channels, im_h, im_w]))
} else {
execute::<R, E, I>(
input,
weight,
bias,
options,
im_ch_per_group,
im_h,
im_w,
kernel_h,
kernel_w,
)
}
}
#[allow(clippy::too_many_arguments)]
fn execute<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 3>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvTransposeOptions<2>,
im_ch_per_group: usize,
im_h: usize,
im_w: usize,
kernel_h: usize,
kernel_w: usize,
) -> JitTensor<R, E, 4> {
let [batch_size, _, input_h, input_w] = input.shape.dims;
let [groups, col_shape_0, input_ch_per_group] = weight.shape.dims;
let im_channels = im_ch_per_group * groups;
let im_shape = Shape::new([batch_size, im_channels, im_h, im_w]);
let col_shape_1 = batch_size * input_h * input_w;
let input = swap_dims(input, 0, 1);
let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]);
let input = reshape(input, input_shape);
let columns = JitBackend::<R, E, I>::float_matmul(weight, input);
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));
col2im(
columns, bias, im_shape, kernel_h, kernel_w, input_h, input_w, options,
)
}
#[allow(clippy::too_many_arguments)]
fn col2im<R: JitRuntime, E: FloatElement>(
columns: JitTensor<R, E, 2>,
bias: Option<JitTensor<R, E, 1>>,
im_shape: Shape<4>,
kernel_h: usize,
kernel_w: usize,
out_h: usize,
out_w: usize,
options: ConvTransposeOptions<2>,
) -> JitTensor<R, E, 4> {
let [_, col_size_1] = columns.shape.dims;
let columns = into_contiguous(columns);
let has_bias = bias.is_some();
let bias = bias.map(into_contiguous).unwrap_or_else(|| {
empty_device(
columns.client.clone(),
columns.device.clone(),
Shape::new([1]),
)
});
let num_elems = im_shape.num_elements();
let out = empty_device(
columns.client.clone(),
columns.device.clone(),
im_shape.clone(),
);
let vectorization = 1;
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
unsafe {
col2im_kernel::launch_unchecked::<E, R>(
&columns.client,
cube_count,
cube_dim,
columns.as_tensor_arg(vectorization),
bias.as_tensor_arg(vectorization),
out.as_tensor_arg(vectorization),
Col2ImArgsLaunch::new(
ScalarArg::new(out_h as u32),
ScalarArg::new(out_w as u32),
ScalarArg::new(kernel_h as u32),
ScalarArg::new(kernel_w as u32),
ScalarArg::new(options.padding[0] as u32),
ScalarArg::new(options.padding[1] as u32),
ScalarArg::new(options.dilation[0] as u32),
ScalarArg::new(options.dilation[1] as u32),
ScalarArg::new(options.stride[0] as u32),
ScalarArg::new(options.stride[1] as u32),
ScalarArg::new(col_size_1 as u32),
),
has_bias,
)
};
out
}
#[derive(CubeLaunch)]
struct Col2ImArgs {
out_h: u32,
out_w: u32,
kernel_h: u32,
kernel_w: u32,
pad_h: u32,
pad_w: u32,
dilation_h: u32,
dilation_w: u32,
stride_h: u32,
stride_w: u32,
col_size_1: u32,
}
#[cube(launch_unchecked)]
fn col2im_kernel<F: Float>(
columns: &Tensor<F>,
bias: &Tensor<F>,
image: &mut Tensor<F>,
args: &Col2ImArgs,
#[comptime] has_bias: bool,
) {
if ABSOLUTE_POS > image.len() {
return;
}
let _ = bias[0]; // Keep in bind group
let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w;
let im_y = ABSOLUTE_POS / image.stride(2) % image.shape(2) + args.pad_h;
let ch_im = ABSOLUTE_POS / image.stride(1) % image.shape(1);
let batch = ABSOLUTE_POS / image.stride(0);
let kernel_extent_w = (args.kernel_w - 1) * args.dilation_w + 1;
let kernel_extent_h = (args.kernel_h - 1) * args.dilation_h + 1;
let mut val = F::new(0.0);
let x_col_start = if im_x >= kernel_extent_w {
(im_x - kernel_extent_w) / args.stride_w + 1
} else {
0u32
};
let x_col_end = Min::min(im_x / args.stride_w + 1, args.out_w);
let y_col_start = if im_y >= kernel_extent_h {
(im_y - kernel_extent_h) / args.stride_h + 1
} else {
0u32
};
let y_col_end = Min::min(im_y / args.stride_h + 1, args.out_h);
for col_y in y_col_start..y_col_end {
let kernel_y = im_y - col_y * args.stride_h;
for col_x in x_col_start..x_col_end {
let kernel_x = im_x - col_x * args.stride_w;
if kernel_y % args.dilation_h == 0 && kernel_x % args.dilation_w == 0 {
let kernel_y = kernel_y / args.dilation_h;
let kernel_x = kernel_x / args.dilation_w;
let col_pos = ch_im * args.kernel_h * args.kernel_w * args.col_size_1
+ kernel_y * args.kernel_w * args.col_size_1
+ kernel_x * args.col_size_1
+ batch * args.out_h * args.out_w
+ col_y * args.out_w
+ col_x;
val += columns[col_pos];
}
}
}
if has_bias {
image[ABSOLUTE_POS] = val + bias[ch_im];
} else {
image[ABSOLUTE_POS] = val;
}
}

View File

@ -1,9 +1,8 @@
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
Shape,
};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use crate::{
kernel::into_contiguous,
@ -12,7 +11,7 @@ use crate::{
reshape,
},
tensor::JitTensor,
FloatElement, JitRuntime,
FloatElement, IntElement, JitRuntime,
};
#[derive(CubeLaunch)]
@ -23,11 +22,11 @@ struct Conv2dArgs {
dilation_1: u32,
padding_0: u32,
padding_1: u32,
groups: u32,
channels_per_group: u32,
}
#[cube(launch)]
fn conv2d_kernel<F: Float>(
fn direct_conv2d_kernel<F: Float>(
input: &Tensor<F>,
weight: &Tensor<F>,
bias: &Tensor<F>,
@ -50,7 +49,7 @@ fn conv2d_kernel<F: Float>(
let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2);
let ow = ABSOLUTE_POS / output.stride(3) % output.shape(3);
let g = (weight.shape(0) + oc) % args.groups;
let g = oc / args.channels_per_group;
let ic_start = in_channels * g;
let ic_end = ic_start + in_channels;
let mut sum = bias[oc];
@ -114,41 +113,46 @@ fn conv2d_kernel<F: Float>(
output[ABSOLUTE_POS] = sum;
}
pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
/// Perform a 2D convolution using the direct convolution algorithm.
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
///
#[allow(clippy::extra_unused_type_parameters)]
pub fn conv2d_direct<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvOptions<2>,
) -> JitTensor<R, E, 4> {
let input = into_contiguous(input);
let weight = into_contiguous(weight);
let [batch_size, _, in_height, in_width] = input.shape.dims;
let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims;
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
let channels_per_group = out_channels / options.groups;
// Limit loop unrolling factor to 8 or smaller
let kernel_1_unroll = if kernel_1 > 8 {
None
} else {
Some(kernel_1 as u32)
};
let kernel_w_unroll = (kernel_w <= 8).then_some(kernel_w as u32);
let out_0 = calculate_conv_output_size(
kernel_0,
let out_h = calculate_conv_output_size(
kernel_h,
options.stride[0],
options.padding[0],
options.dilation[0],
in_height,
);
let out_1 = calculate_conv_output_size(
kernel_1,
let out_w = calculate_conv_output_size(
kernel_w,
options.stride[1],
options.padding[1],
options.dilation[1],
in_width,
);
let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]);
let input = into_contiguous(input);
let weight = into_contiguous(weight);
let shape_out = Shape::new([batch_size, out_channels, out_h, out_w]);
let output = empty_device(
input.client.clone(),
input.device.clone(),
@ -170,7 +174,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems_output, cube_dim);
conv2d_kernel::launch::<E, R>(
direct_conv2d_kernel::launch::<E, R>(
&input.client,
cube_count,
cube_dim,
@ -185,9 +189,9 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
ScalarArg::new(options.dilation[1] as u32),
ScalarArg::new(options.padding[0] as u32),
ScalarArg::new(options.padding[1] as u32),
ScalarArg::new(options.groups as u32),
ScalarArg::new(channels_per_group as u32),
),
kernel_1_unroll,
kernel_w_unroll,
);
output

View File

@ -0,0 +1,297 @@
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps as _},
Shape,
};
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use crate::{
kernel::into_contiguous,
ops::{numeric::empty_device, reshape, swap_dims},
tensor::JitTensor,
FloatElement, IntElement, JitBackend, JitRuntime,
};
#[derive(CubeLaunch)]
struct Im2ColArgs {
stride_h: u32,
stride_w: u32,
dilation_h: u32,
dilation_w: u32,
padding_h: u32,
padding_w: u32,
kernel_h: u32,
kernel_w: u32,
out_h: u32,
out_w: u32,
col_size_1: u32,
num_elements: u32,
}
#[cube(launch_unchecked)]
fn im2col_kernel<F: Float>(
image: &Tensor<F>,
columns: &mut Tensor<F>,
args: &Im2ColArgs,
#[comptime] kernel_w_unroll: Option<u32>,
#[comptime] has_padding: bool,
) {
// position shape: [in_channels, batch_size, out_h, out_w]
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
let batch_size = image.shape(0);
let height = image.shape(2);
let width = image.shape(3);
let out_h = args.out_h;
let out_w = args.out_w;
if ABSOLUTE_POS > args.num_elements {
return;
}
let out_x = ABSOLUTE_POS % out_w;
let out_y = ABSOLUTE_POS / out_w % out_h;
let batch = ABSOLUTE_POS / (out_w * out_h) % batch_size;
let channel = ABSOLUTE_POS / (out_w * out_h * batch_size) % image.shape(1);
let kernel_w = kernel_w_unroll.unwrap_or(args.kernel_w);
let unroll_w = kernel_w_unroll.is_some();
let image_idx = batch * image.stride(0) + channel * image.stride(1);
let col_idx = channel * args.kernel_h * kernel_w * args.col_size_1
+ batch * out_h * out_w
+ out_y * out_w
+ out_x;
for kernel_y in 0..args.kernel_h {
#[unroll(unroll_w)]
for kernel_x in 0..kernel_w {
let kernel_pos = kernel_y * kernel_w + kernel_x;
let col_pos = col_idx + kernel_pos * args.col_size_1;
if has_padding {
let y = (out_y * args.stride_h + kernel_y * args.dilation_h) as i32
- args.padding_h as i32;
let x = (out_x * args.stride_w + kernel_x * args.dilation_w) as i32
- args.padding_w as i32;
if y >= 0 && x >= 0 && y < height as i32 && x < width as i32 {
let image_ptr = image_idx + y as u32 * width + x as u32;
columns[col_pos] = image[image_ptr];
} else {
columns[col_pos] = F::new(0.0)
};
} else {
let y = out_y * args.stride_h + kernel_y * args.dilation_h;
let x = out_x * args.stride_w + kernel_x * args.dilation_w;
let image_ptr = image_idx + y * image.stride(2) + x * image.stride(3);
columns[col_pos] = image[image_ptr];
}
}
}
}
#[cfg(not(test))]
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> usize {
let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::SUBCUBE_DIM_APPROX);
let max_cube_count = u16::MAX as usize;
let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size);
if max_simultaneous == 0 {
panic!("Image too large to run even one batch at once");
}
(0..=max_simultaneous)
.rev()
.find(|per_run| batch_size % per_run == 0)
.unwrap()
}
#[cfg(test)]
#[allow(unused)]
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> usize {
1
}
fn im2col<R: JitRuntime, E: FloatElement>(
input: JitTensor<R, E, 4>,
options: ConvOptions<2>,
kernel_h: usize,
kernel_w: usize,
out_h: usize,
out_w: usize,
) -> JitTensor<R, E, 2> {
let input = into_contiguous(input);
let [batch_size, in_channels, _, _] = input.shape.dims;
let col_shape_0 = in_channels * kernel_h * kernel_w;
let col_shape_1 = batch_size * out_h * out_w;
let shape_col = Shape::new([col_shape_0, col_shape_1]);
let columns = empty_device(
input.client.clone(),
input.device.clone(),
shape_col.clone(),
);
let num_elems = in_channels * batch_size * out_h * out_w;
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
let kernel_w_unroll = (kernel_w <= 8).then_some(kernel_w as u32);
let vectorization = 1;
unsafe {
im2col_kernel::launch_unchecked::<E, R>(
&input.client,
cube_count,
cube_dim,
input.as_handle_ref().as_tensor_arg(vectorization),
columns.as_handle_ref().as_tensor_arg(vectorization),
Im2ColArgsLaunch::new(
ScalarArg::new(options.stride[0] as u32),
ScalarArg::new(options.stride[1] as u32),
ScalarArg::new(options.dilation[0] as u32),
ScalarArg::new(options.dilation[1] as u32),
ScalarArg::new(options.padding[0] as u32),
ScalarArg::new(options.padding[1] as u32),
ScalarArg::new(kernel_h as u32),
ScalarArg::new(kernel_w as u32),
ScalarArg::new(out_h as u32),
ScalarArg::new(out_w as u32),
ScalarArg::new(col_shape_1 as u32),
ScalarArg::new(num_elems as u32),
),
kernel_w_unroll,
options.padding != [0, 0],
)
};
columns
}
/// Perform a 2D convolution using the GEMM (im2col) algorithm.
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
///
pub fn conv2d_im2col<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvOptions<2>,
) -> JitTensor<R, E, 4> {
let [batch_size, in_channels, in_height, in_width] = input.shape.dims;
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
let out_h = calculate_conv_output_size(
kernel_h,
options.stride[0],
options.padding[0],
options.dilation[0],
in_height,
);
let out_w = calculate_conv_output_size(
kernel_w,
options.stride[1],
options.padding[1],
options.dilation[1],
in_width,
);
if kernel_h == 1 && kernel_w == 1 && in_height == out_h && in_width == out_w {
// Special case for 1x1 kernels (sometimes used to scale the image by a set of weights)
return execute_1x1_kernel::<R, E, I>(input, weight, bias, options);
}
let batches_per_run = batches_per_run(batch_size, out_h, out_w);
let mut out = if batches_per_run != batch_size {
let runs = batch_size / batches_per_run;
let out_shape = Shape::new([runs, out_channels, batches_per_run, out_h, out_w]);
let mut out = empty_device(input.client.clone(), input.device.clone(), out_shape);
let in_shape = Shape::new([runs, batches_per_run, in_channels, in_height, in_width]);
let input = reshape(input, in_shape);
let in_shape_run = Shape::new([batches_per_run, in_channels, in_height, in_width]);
let out_shape_run = Shape::new([1, out_channels, batches_per_run, out_h, out_w]);
for run in 0..runs {
let input = JitBackend::<R, E, I>::float_narrow(input.clone(), 0, run, 1);
let input = reshape(input, in_shape_run.clone());
let run_out = execute::<R, E, I>(input, weight.clone(), options.clone(), out_h, out_w);
let run_out = reshape(run_out, out_shape_run.clone());
out = JitBackend::<R, E, I>::float_slice_assign(
out,
[
run..run + 1,
0..out_channels,
0..batches_per_run,
0..out_h,
0..out_w,
],
run_out,
);
}
let out = swap_dims(out, 1, 2);
reshape(out, Shape::new([batch_size, out_channels, out_h, out_w]))
} else {
let out = execute::<R, E, I>(input, weight, options, out_h, out_w);
let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w]));
swap_dims(out, 0, 1)
};
if let Some(bias) = bias {
let bias = reshape(bias, Shape::new([1, out_channels, 1, 1]));
out = JitBackend::<R, E, I>::float_add(out, bias)
}
out
}
fn execute_1x1_kernel<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvOptions<2>,
) -> JitTensor<R, E, 4> {
let [batch_size, _, height, width] = input.shape.dims;
let [out_channels, in_c_per_grp, _, _] = weight.shape.dims;
let groups = options.groups;
let out_c_per_grp = out_channels / groups;
let input = swap_dims(input, 0, 1); // [CNHW]
let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp]));
let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]);
let input = reshape(input, in_shape);
let out = JitBackend::<R, E, I>::float_matmul(weight, input);
let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width]));
if let Some(bias) = bias {
let bias = reshape(bias, Shape::new([out_channels, 1, 1, 1]));
out = JitBackend::<R, E, I>::float_add(out, bias)
}
swap_dims(out, 0, 1)
}
fn execute<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
options: ConvOptions<2>,
out_h: usize,
out_w: usize,
) -> JitTensor<R, E, 3> {
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
let groups = options.groups;
let columns = im2col(input, options.clone(), kernel_h, kernel_w, out_h, out_w);
let [col_shape_0, col_shape_1] = columns.shape.dims;
let col_shape_0 = col_shape_0 / groups;
let out_c_per_group = out_channels / groups;
let columns = reshape(columns, Shape::new([groups, col_shape_0, col_shape_1]));
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));
JitBackend::<R, E, I>::float_matmul(weight, columns)
}

View File

@ -0,0 +1,562 @@
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps},
Shape,
};
use cmma::{Matrix, MatrixIdent, MatrixLayout};
use cubecl::{
cube,
ir::{Elem, FloatKind},
prelude::*,
Compiler, CubeCount, CubeDim, Feature,
};
use half::f16;
use crate::{
ops::{numeric::empty_device, permute, reshape},
tensor::JitTensor,
FloatElement, IntElement, JitBackend, JitRuntime,
};
/// Perform a 2D convolution using the implicit GEMM algorithm. Requries `cmma` to be available.
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
///
pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
input: JitTensor<R, F, 4>,
weight: JitTensor<R, F, 4>,
bias: Option<JitTensor<R, F, 1>>,
options: ConvOptions<2>,
) -> JitTensor<R, F, 4> {
let [batch_size, in_channels, height, width] = input.shape.dims;
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
let out_h = calculate_conv_output_size(
kernel_h,
options.stride[0],
options.padding[0],
options.dilation[0],
height,
);
let out_w = calculate_conv_output_size(
kernel_w,
options.stride[1],
options.padding[1],
options.dilation[1],
width,
);
if !can_do_implicit_gemm(&input, &weight, &options, out_h, out_w) {
panic!(
"Requirements for implicit GEMM not met:
- CMMA must be available
- `batch_size * out_h * out_w` must be divisible by 16
- `out_channels` must be divisible by 16
- `in_channels * kernel_h * kernel_w` must be divisible by 16
- `groups` must be 1
"
);
}
let out_shape = Shape::new([batch_size, out_h, out_w, out_channels]);
let mut out = empty_device(input.client.clone(), input.device.clone(), out_shape);
// Implicit GEMM matrix size
let gemm_m = (batch_size * out_h * out_w) as u32;
let gemm_n = out_channels as u32;
let gemm_k = in_channels * kernel_h * kernel_w;
let slice_size = kernel_h * kernel_w * in_channels;
let cmma_m = 16;
let cmma_n = 16;
let cmma_k = 16;
let warp_size = 32;
let warps_per_cube = 8;
let cube_dim_x = 128;
let cube_dim_y = 2;
assert!(cube_dim_y * cube_dim_x / warp_size == warps_per_cube);
let settings = GemmSettings {
cmma_m,
cmma_n,
cmma_k,
warp_size,
warps_per_cube,
cube_dim_x,
};
// `CUBE_DIM_X` must be a multiple of `WARP_SIZE`
// 128x2 means we have 8 warps and a cube computes a 32x64 output tile
let cube_dim = CubeDim {
x: cube_dim_x,
y: cube_dim_y,
z: 1,
};
let cube_count_x = gemm_m.div_ceil(cmma_m * cube_dim_x / warp_size);
let cube_count_y = gemm_n.div_ceil(cmma_n * cube_dim_y);
// If div floor == div ceil then the cubes are aligned with the input dimensions
let aligned = gemm_m / (cmma_m * cube_dim_x / warp_size) == cube_count_x
&& gemm_n / (cmma_n * cube_dim_y) == cube_count_y;
let cube_count = CubeCount::Static(cube_count_x, cube_count_y, 1);
unsafe {
implicit_gemm_kernel::launch_unchecked::<F, f16, R>(
&input.client,
cube_count,
cube_dim,
input.as_tensor_arg(1),
weight.as_tensor_arg(1),
out.as_tensor_arg(1),
DimensionsLaunch::new(
ScalarArg::new(gemm_m),
ScalarArg::new(gemm_n),
ScalarArg::new(gemm_k as u32),
ScalarArg::new(slice_size as u32),
ScalarArg::new(out_h as u32),
ScalarArg::new(out_w as u32),
),
ConvArgsLaunch::new(
ScalarArg::new(options.stride[0] as u32),
ScalarArg::new(options.stride[1] as u32),
ScalarArg::new(options.padding[0] as i32),
ScalarArg::new(options.padding[1] as i32),
ScalarArg::new(options.dilation[0] as u32),
ScalarArg::new(options.dilation[1] as u32),
),
settings,
KernelSettings {
kernel_h: kernel_h as u32,
kernel_w: kernel_w as u32,
has_padding: options.padding != [0, 0],
aligned,
},
)
};
if let Some(bias) = bias {
let bias = reshape(bias, Shape::new([1, 1, 1, out_channels]));
out = JitBackend::<R, F, I>::float_add(out, bias);
}
permute(out, [0, 3, 1, 2])
}
#[derive(CubeLaunch)]
struct ConvArgs {
stride_h: u32,
stride_w: u32,
pad_h: i32,
pad_w: i32,
dilation_h: u32,
dilation_w: u32,
}
#[derive(CubeLaunch)]
struct Dimensions {
gemm_m: u32,
gemm_n: u32,
gemm_k: u32,
slice_size: u32,
out_h: u32,
out_w: u32,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
struct GemmSettings {
cmma_m: u32,
cmma_n: u32,
cmma_k: u32,
warp_size: u32,
warps_per_cube: u32,
cube_dim_x: u32,
}
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
struct KernelSettings {
kernel_h: u32,
kernel_w: u32,
has_padding: bool,
aligned: bool,
}
#[derive(Clone, Copy, CubeType)]
struct Positions {
global_m: u32,
global_n: u32,
intra_warp_unit_idx: u32,
cube_linear_warp_idx: u32,
}
#[derive(CubeType)]
struct Matrices<F: Float, FAcc: Float> {
a: Matrix<F>,
b: Matrix<F>,
acc: Matrix<FAcc>,
}
#[allow(clippy::collapsible_else_if)]
#[cube(launch_unchecked)]
fn implicit_gemm_kernel<F: Float, FMat: Float>(
input: &Tensor<F>,
weight: &Tensor<F>,
out: &mut Tensor<F>,
dims: &Dimensions,
args: &ConvArgs,
#[comptime] gemm_settings: GemmSettings,
#[comptime] kernel_settings: KernelSettings,
) {
let GemmSettings {
cmma_m,
cmma_n,
cmma_k,
warps_per_cube,
..
} = gemm_settings;
let cmma_out_tile_size = cmma_m * cmma_n;
let cmma_input_tile_size = cmma_m * cmma_k;
let cmma_filter_tile_size = cmma_k * cmma_n;
let pos = calculate_positions(gemm_settings);
// Shared memory tiles, currently only holds enough data for
// each warp to have its own tile for a single MMA op (8 * 16 * 16 elements)
// conceptually a WARPS_PER_CUBE x (CMMA_M * CMMA_K) matrix
let mut smem_input_tile = SharedMemory::<FMat>::new(cmma_input_tile_size * warps_per_cube);
let mut smem_weight_tile = SharedMemory::<FMat>::new(cmma_filter_tile_size * warps_per_cube);
let input_tile_start = pos.cube_linear_warp_idx * cmma_input_tile_size;
let weight_tile_start = pos.cube_linear_warp_idx * cmma_filter_tile_size;
let input_tile =
smem_input_tile.slice_mut(input_tile_start, input_tile_start + cmma_input_tile_size);
let weight_tile =
smem_weight_tile.slice_mut(weight_tile_start, weight_tile_start + cmma_filter_tile_size);
let out_pos = pos.global_n + pos.global_m * dims.gemm_n;
let out = out.slice_mut(out_pos, out_pos + cmma_out_tile_size);
if kernel_settings.aligned {
execute_gemm(
input,
weight,
out,
input_tile,
weight_tile,
dims,
&pos,
args,
gemm_settings,
kernel_settings,
);
} else {
if pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n {
execute_gemm(
input,
weight,
out,
input_tile,
weight_tile,
dims,
&pos,
args,
gemm_settings,
kernel_settings,
);
}
}
}
#[cube]
fn calculate_positions(#[comptime] gemm_settings: GemmSettings) -> Positions {
let GemmSettings {
cmma_m,
cmma_n,
warp_size,
cube_dim_x,
..
} = gemm_settings;
// Tile using a 2D grid (over the output), each threadblock
// is (128, 2) -> (4,2) = 8 warps -> 32x64 output
let global_warp_m = ABSOLUTE_POS_X / warp_size;
let global_warp_n = ABSOLUTE_POS_Y;
let cube_warp_m = UNIT_POS_X / warp_size;
let cube_warp_n = UNIT_POS_Y;
let num_warps_m = cube_dim_x / warp_size;
let intra_warp_unit_idx = UNIT_POS_X % warp_size; // Thread idx within warp (0 to 31)
let cube_linear_warp_idx = (cube_warp_n * num_warps_m) + cube_warp_m; // Warp idx within a block (0 to WARPS_PER_BLOCK - 1)
Positions {
global_m: global_warp_m * cmma_m,
global_n: global_warp_n * cmma_n,
intra_warp_unit_idx,
cube_linear_warp_idx,
}
}
#[cube]
fn make_matrices<F: Float, FAcc: Float>(
#[comptime] gemm_settings: GemmSettings,
) -> Matrices<F, FAcc> {
let GemmSettings {
cmma_m,
cmma_n,
cmma_k,
..
} = gemm_settings;
let matrices = Matrices::<F, FAcc> {
a: Matrix::<F>::new(
MatrixIdent::A,
cmma_m,
cmma_n,
cmma_k,
MatrixLayout::RowMajor,
),
b: Matrix::<F>::new(
MatrixIdent::B,
cmma_m,
cmma_n,
cmma_k,
MatrixLayout::RowMajor,
),
acc: Matrix::<FAcc>::new(
MatrixIdent::Accumulator,
cmma_m,
cmma_n,
cmma_k,
MatrixLayout::Undefined,
),
};
cmma::fill(&matrices.acc, FAcc::new(0.0));
matrices
}
#[cube]
fn execute_gemm<F: Float, FMat: Float>(
input: &Tensor<F>,
weight: &Tensor<F>,
out: &mut SliceMut<F>,
input_tile: &mut SliceMut<FMat>,
weight_tile: &mut SliceMut<FMat>,
dims: &Dimensions,
pos: &Positions,
args: &ConvArgs,
#[comptime] g_settings: GemmSettings,
#[comptime] k_settings: KernelSettings,
) {
let GemmSettings { cmma_n, cmma_k, .. } = g_settings;
let matrices = make_matrices::<FMat, F>(g_settings);
// Loop over the K-dimension
for k in range_stepped(0, dims.gemm_k, cmma_k) {
// Load into smem...
// Each warp should load the 16x16 tile it's responsible for
// i.e. each thread needs to load 8 elements of input and 8 elements of weight
load_input_tile(
input, args, input_tile, dims, pos, k, g_settings, k_settings,
);
load_weight_tile(weight, weight_tile, pos, k, g_settings, k_settings);
// Run CMMA
cmma::load(&matrices.a, input_tile.as_slice(), cmma_k);
cmma::load(&matrices.b, weight_tile.as_slice(), cmma_n);
cmma::execute::<FMat, FMat, F, F>(&matrices.a, &matrices.b, &matrices.acc, &matrices.acc);
}
cmma::store(out, &matrices.acc, dims.gemm_n, MatrixLayout::RowMajor);
}
#[cube]
fn load_input_tile<F: Float, FMat: Float>(
input: &Tensor<F>,
args: &ConvArgs,
tile: &mut SliceMut<FMat>,
dims: &Dimensions,
pos: &Positions,
k: u32,
#[comptime] gemm_settings: GemmSettings,
#[comptime] kernel_settings: KernelSettings,
) {
let GemmSettings {
cmma_m,
cmma_k,
warp_size,
..
} = gemm_settings;
let KernelSettings {
kernel_h,
kernel_w,
has_padding,
..
} = kernel_settings;
let kernel_size = kernel_h * kernel_w;
let cmma_input_tile_size = cmma_m * cmma_k;
let height = input.shape(2) as i32;
let width = input.shape(3) as i32;
// Row strides in the implicit GEMM matrix
let batch_stride = dims.out_h * dims.out_w;
let y_stride = dims.out_w;
let x_stride = 1;
// Start index within a slice (0 to `kernel_size * channels - 1`) that a half warp (16 units) is responsible for
let slice_start_idx = k % dims.slice_size;
for m in range_stepped(pos.intra_warp_unit_idx, cmma_input_tile_size, warp_size) {
// Compute where in the slice we are starting
// Slices are always `kernel_size * channels` elements wide so we can compute where inside a slice
// we are and also which row the slice is in relative to the start of the CMMA matrix
let rel_slice_row = m / cmma_k; // Relative row (0 - 15)
let abs_slice_row = pos.global_m + rel_slice_row; // Row of the matrix the slice is on
// Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is repsonsible for
let my_slice_idx = (slice_start_idx + (m % cmma_k)) % dims.slice_size;
// Given the row of the matrix that the slice is in, and the index of the thread
// within a slice, want to compute what input element to load...
// first compute coordinates in output space (center of the kernel in MxK matrix A)
let batch = abs_slice_row / batch_stride;
let out_y = (abs_slice_row % batch_stride) / y_stride;
let out_x = ((abs_slice_row % batch_stride) % y_stride) / x_stride;
let channel = my_slice_idx / kernel_size;
let kernel_y = (my_slice_idx / kernel_w) % kernel_h;
let kernel_x = my_slice_idx % kernel_w;
if has_padding {
let y = (out_y * args.stride_h + kernel_y * args.dilation_h) as i32 - args.pad_h;
let x = (out_x * args.stride_w + kernel_x * args.dilation_w) as i32 - args.pad_w;
if x >= 0 && x < width && y >= 0 && y < height {
tile[m] = FMat::cast_from(
input[batch * input.stride(0)
+ y as u32 * input.stride(2)
+ x as u32 * input.stride(3)
+ channel * input.stride(1)],
);
} else {
tile[m] = FMat::new(0.0);
}
} else {
let y = out_y * args.stride_h + kernel_y * args.dilation_h;
let x = out_x * args.stride_w + kernel_x * args.dilation_w;
tile[m] = FMat::cast_from(
input[batch * input.stride(0)
+ y * input.stride(2)
+ x * input.stride(3)
+ channel * input.stride(1)],
);
}
}
}
#[cube]
fn load_weight_tile<F: Float, FMat: Float>(
weight: &Tensor<F>,
tile: &mut SliceMut<FMat>,
pos: &Positions,
k: u32,
#[comptime] gemm_settings: GemmSettings,
#[comptime] kernel_settings: KernelSettings,
) {
let GemmSettings {
cmma_n,
cmma_k,
warp_size,
..
} = gemm_settings;
let KernelSettings {
kernel_h, kernel_w, ..
} = kernel_settings;
let kernel_size = kernel_h * kernel_w;
let cmma_filter_tile_size = cmma_k * cmma_n;
for n in range_stepped(pos.intra_warp_unit_idx, cmma_filter_tile_size, warp_size) {
// Compute where in the slice we are starting
let rel_slice_row = n / cmma_k; // Relative row (0 - 15)
let abs_slice_row = k + rel_slice_row; // Row of the matrix the slice is on
let abs_slice_col = pos.global_n + (n % 16); // Row of the matrix the slice is on
// Given the row of the matrix that the slice is in, and the index of the unit
// within a slice, want to compute what weight element to load...
let out_channel = abs_slice_col;
let in_channel = abs_slice_row / kernel_size;
let kernel_y = (abs_slice_row % kernel_size) / kernel_h;
let kernel_x = abs_slice_row % kernel_w;
tile[n] = FMat::cast_from(
weight[out_channel * weight.stride(0)
+ in_channel * weight.stride(1)
+ kernel_y * weight.stride(2)
+ kernel_x * weight.stride(3)],
);
}
}
pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
input: &JitTensor<R, E, 4>,
weight: &JitTensor<R, E, 4>,
options: &ConvOptions<2>,
out_h: usize,
out_w: usize,
) -> bool {
let [batch_size, in_channels, _, _] = input.shape.dims;
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
let cmma_m = 16;
let cmma_n = 16;
let cmma_k = 16;
let warps_per_cube = 8;
let gemm_m = batch_size * out_h * out_w;
let gemm_n = out_channels;
let gemm_k = in_channels * kernel_h * kernel_w;
let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();
cmma_available::<R>(&input.device)
&& <R::Compiler as Compiler>::max_shared_memory_size() >= smem_size
&& gemm_m % 16 == 0
&& gemm_n % 16 == 0
&& gemm_k % 16 == 0
&& options.groups == 1
}
fn cmma_available<R: JitRuntime>(device: &R::JitDevice) -> bool {
R::client(device).features().enabled(Feature::Cmma {
a: Elem::Float(FloatKind::F16),
b: Elem::Float(FloatKind::F16),
c: Elem::Float(FloatKind::F32),
m: 16,
k: 16,
n: 16,
})
}

View File

@ -0,0 +1,18 @@
mod base;
mod col2im;
mod direct;
mod im2col;
mod implicit_gemm;
mod transpose_direct;
#[cfg(feature = "autotune")]
mod tune;
pub use base::*;
pub use col2im::*;
pub use direct::*;
pub use im2col::*;
pub use implicit_gemm::*;
pub use transpose_direct::*;
#[cfg(feature = "autotune")]
pub use tune::*;

View File

@ -14,9 +14,9 @@ use crate::{
reshape,
},
tensor::JitTensor,
JitRuntime,
IntElement, JitRuntime,
};
use burn_tensor::{ops::ConvTransposeOptions, Element, Shape};
use burn_tensor::{ops::ConvTransposeOptions, Shape};
#[derive(new)]
struct Conv2dTransposeEagerKernel<R, E> {
@ -364,7 +364,15 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
}
}
pub(crate) fn conv_transpose2d<R: JitRuntime, E: JitElement + Element>(
/// Perform a 2D convolution transposition using the direct algorithm.
///
/// * `input` - The input feature map
/// * `weight` - The weights (filter) applied to each kernel
/// * `bias` - The bias added to each channel
/// * `options` - The options to use for the convolution
///
#[allow(clippy::extra_unused_type_parameters)]
pub fn conv_transpose2d_direct<R: JitRuntime, E: JitElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,

View File

@ -0,0 +1,152 @@
use burn_tensor::{
ops::{conv::calculate_conv_output_size, ConvOptions},
ElementConversion, Shape,
};
use cubecl::{
tune,
tune::{local_tuner, tune_with, LocalTuner},
AutotuneKey,
};
use serde::{Deserialize, Serialize};
use crate::{
kernel::{
conv::{can_do_implicit_gemm, conv2d_direct, conv2d_im2col, conv2d_implicit_gemm},
prng::random_uniform,
},
tensor::JitTensor,
FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId,
};
/// Executes autotune on conv2d operations
pub fn conv2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weights: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvOptions<2>,
) -> JitTensor<R, E, 4> {
let client = input.client.clone();
static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();
TUNER.execute(
&JitTuneId::new::<R>(&input.device),
&client,
Box::new(Conv2dOperations::<R, E, I>::new(
input, weights, bias, options,
)),
)
}
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of matmul versions
pub struct Conv2dAutotuneKey {
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
#[autotune(anchor)]
pub in_channels: usize,
#[autotune(anchor)]
pub out_channels: usize,
#[autotune(anchor)]
pub height: usize,
#[autotune(anchor)]
pub width: usize,
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
}
#[tune(
operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm),
create_key = create_key,
should_run = should_run
)]
pub fn conv2d_operations<R: JitRuntime, E: FloatElement, I: IntElement>(
key: JitAutotuneKey,
input: JitTensor<R, E, 4>,
weights: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvOptions<2>,
) -> JitTensor<R, E, 4> {
let device = &input.device;
let key = match key {
JitAutotuneKey::Conv2d(key) => key,
_ => unreachable!(),
};
let random_bounds: (E, E) = ((-1.0).elem::<E>(), (1.0).elem::<E>());
let input_shape = Shape::new([key.batch_size, key.in_channels, key.height, key.width]);
let input = random_uniform(input_shape, device, random_bounds.0, random_bounds.1);
let c_per_grp = key.in_channels / key.groups;
let [kernel_h, kernel_w] = key.kernel_size;
let weight_shape = Shape::new([key.out_channels, c_per_grp, kernel_h, kernel_w]);
let weights = random_uniform(weight_shape, device, random_bounds.0, random_bounds.1);
let bias_shape = Shape::new([key.out_channels]);
let bias = key
.has_bias
.then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1));
tune_with!(input, weights, bias, options)
}
fn should_run<R: JitRuntime, F: FloatElement, I: IntElement>(
op: &Conv2dOperations<R, F, I>,
_key: &JitAutotuneKey,
index: usize,
) -> bool {
match index {
2 => {
let [_, _, height, width] = op.input.shape.dims;
let [_, _, kernel_h, kernel_w] = op.weights.shape.dims;
let o = &op.options;
let out_h = calculate_conv_output_size(
kernel_h,
o.stride[0],
o.padding[0],
o.dilation[0],
height,
);
let out_w = calculate_conv_output_size(
kernel_w,
o.stride[1],
o.padding[1],
o.dilation[1],
width,
);
can_do_implicit_gemm(&op.input, &op.weights, &op.options, out_h, out_w)
}
_ => true,
}
}
fn create_key<R: JitRuntime, E: FloatElement>(
input: &JitTensor<R, E, 4>,
weights: &JitTensor<R, E, 4>,
bias: &Option<JitTensor<R, E, 1>>,
options: &ConvOptions<2>,
) -> JitAutotuneKey {
let [batch_size, in_channels, height, width] = input.shape.dims;
let [out_channels, _, kernel_h, kernel_w] = weights.shape.dims;
let ConvOptions {
stride,
padding,
dilation,
groups,
} = options.clone();
JitAutotuneKey::Conv2d(Conv2dAutotuneKey::new(
[kernel_h, kernel_w],
stride,
padding,
dilation,
groups,
in_channels,
out_channels,
height,
width,
batch_size,
bias.is_some(),
))
}

View File

@ -0,0 +1,117 @@
use burn_tensor::{ops::ConvTransposeOptions, ElementConversion, Shape};
use cubecl::{
tune,
tune::{local_tuner, tune_with, LocalTuner},
AutotuneKey,
};
use serde::{Deserialize, Serialize};
use crate::{
kernel::{
conv::{conv_transpose2d_col2im, conv_transpose2d_direct},
prng::random_uniform,
},
tensor::JitTensor,
FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId,
};
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
/// Autotune key representative of matmul versions
pub struct ConvTranspose2dAutotuneKey {
pub kernel_size: [usize; 2],
pub stride: [usize; 2],
pub padding: [usize; 2],
pub padding_out: [usize; 2],
pub dilation: [usize; 2],
pub groups: usize,
#[autotune(anchor)]
pub in_channels: usize,
#[autotune(anchor)]
pub out_channels: usize,
#[autotune(anchor)]
pub height: usize,
#[autotune(anchor)]
pub width: usize,
#[autotune(anchor)]
pub batch_size: usize,
pub has_bias: bool,
}
/// Executes autotune on conv2d operations
pub fn conv_transpose2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
weights: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvTransposeOptions<2>,
) -> JitTensor<R, E, 4> {
let client = input.client.clone();
static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();
TUNER.execute(
&JitTuneId::new::<R>(&input.device),
&client,
Box::new(ConvTranspose2dOperations::<R, E, I>::new(
input, weights, bias, options,
)),
)
}
#[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key)]
pub fn conv_transpose2d_operations<R: JitRuntime, E: FloatElement, I: IntElement>(
key: JitAutotuneKey,
input: JitTensor<R, E, 4>,
weights: JitTensor<R, E, 4>,
bias: Option<JitTensor<R, E, 1>>,
options: ConvTransposeOptions<2>,
) -> JitTensor<R, E, 4> {
let key = match key {
JitAutotuneKey::ConvTranspose2d(key) => key,
_ => unreachable!(),
};
let device = &input.device;
let random_bounds: (E, E) = ((-1.0).elem::<E>(), (1.0).elem::<E>());
let input_shape = Shape::new([key.batch_size, key.in_channels, key.height, key.width]);
let input = random_uniform(input_shape, device, random_bounds.0, random_bounds.1);
let c_per_grp = key.in_channels / key.groups;
let [kernel_h, kernel_w] = key.kernel_size;
let weight_shape = Shape::new([key.out_channels, c_per_grp, kernel_h, kernel_w]);
let weights = random_uniform(weight_shape, device, random_bounds.0, random_bounds.1);
let bias_shape = Shape::new([key.out_channels]);
let bias = key
.has_bias
.then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1));
tune_with!(input, weights, bias, options)
}
fn create_key<R: JitRuntime, E: FloatElement>(
input: &JitTensor<R, E, 4>,
weights: &JitTensor<R, E, 4>,
bias: &Option<JitTensor<R, E, 1>>,
options: &ConvTransposeOptions<2>,
) -> JitAutotuneKey {
let [batch_size, in_channels, height, width] = input.shape.dims;
let [out_channels, _, kernel_h, kernel_w] = weights.shape.dims;
let ConvTransposeOptions {
stride,
padding,
dilation,
groups,
padding_out,
} = options.clone();
JitAutotuneKey::ConvTranspose2d(ConvTranspose2dAutotuneKey::new(
[kernel_h, kernel_w],
stride,
padding,
padding_out,
dilation,
groups,
in_channels,
out_channels,
height,
width,
batch_size,
bias.is_some(),
))
}

View File

@ -0,0 +1,5 @@
mod conv2d;
mod conv_transpose2d;
pub use conv2d::*;
pub use conv_transpose2d::*;

View File

@ -1,13 +1,13 @@
mod conv2d;
mod conv3d;
mod conv_transpose2d;
mod conv_transpose3d;
mod deform_conv2d;
mod deform_conv_transpose2d;
pub(crate) use conv2d::*;
pub(crate) use conv3d::*;
pub(crate) use conv_transpose2d::*;
pub(crate) use conv_transpose3d::*;
pub(crate) use deform_conv2d::*;
pub(crate) use deform_conv_transpose2d::*;
pub use conv2d::{conv2d, conv_transpose2d, Conv2dStrategy, ConvTranspose2dStrategy};

View File

@ -16,7 +16,7 @@ pub enum MatmulStrategy {
grid_y: usize,
},
#[cfg(feature = "autotune")]
/// Using autotune to chose the best kernel based on runtime information.
/// Using autotune to choose the best kernel based on runtime information.
Autotune,
/// Cube implementation of matmul.
Cube,

View File

@ -1,4 +1,10 @@
use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
use crate::{
kernel::{
self,
conv::{Conv2dStrategy, ConvTranspose2dStrategy},
},
FloatElement, IntElement, JitBackend, JitRuntime,
};
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
@ -17,7 +23,7 @@ where
bias: Option<FloatTensor<Self, 1>>,
options: ConvOptions<2>,
) -> FloatTensor<Self, 4> {
kernel::conv::conv2d(x, weight, bias, options)
kernel::conv::conv2d::<R, F, I>(x, weight, bias, options, Conv2dStrategy::default())
}
fn deform_conv2d(
@ -58,7 +64,13 @@ where
bias: Option<FloatTensor<Self, 1>>,
options: ConvTransposeOptions<2>,
) -> FloatTensor<Self, 4> {
kernel::conv::conv_transpose2d(x, weight, bias, options)
kernel::conv::conv_transpose2d::<R, F, I>(
x,
weight,
bias,
options,
ConvTranspose2dStrategy::default(),
)
}
fn conv_transpose3d(

View File

@ -6,7 +6,7 @@ mod tests {
};
#[test]
fn avg_pool2d_should_work_with_multiple_invocations() {
fn avg_pool2d_should_match_reference_backend() {
let tensor = Tensor::<TestBackend, 4>::random(
[32, 32, 32, 32],
Distribution::Default,
@ -29,7 +29,7 @@ mod tests {
}
#[test]
fn avg_pool2d_backward_should_work_with_multiple_invocations() {
fn avg_pool2d_backward_should_match_reference_backend() {
TestBackend::seed(0);
ReferenceBackend::seed(0);
let tensor = Tensor::<TestBackend, 4>::random(

View File

@ -4,12 +4,12 @@ mod tests {
use burn_tensor::{backend::Backend, Distribution, Tensor};
#[test]
fn cat_should_support_multiple_invocations_dim0() {
fn cat_should_match_reference_backend_dim0() {
test_same_as_reference([6, 256], 2, 0);
}
#[test]
fn cat_should_support_multiple_invocations_dim1() {
fn cat_should_match_reference_backend_dim1() {
test_same_as_reference([6, 256], 2, 1);
}

View File

@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{module, Distribution, Tensor};
#[test]
fn conv2d_should_work_with_multiple_invocations() {
fn conv2d_should_match_reference_backend() {
let test_device = Default::default();
let input =
Tensor::<TestBackend, 4>::random([6, 16, 32, 32], Distribution::Default, &test_device);
@ -26,4 +26,28 @@ mod tests {
.into_data()
.assert_approx_eq(&output_ref.into_data(), 3);
}
#[test]
fn conv2d_should_match_reference_backend_implicit() {
let test_device = Default::default();
let input =
Tensor::<TestBackend, 4>::random([4, 16, 6, 6], Distribution::Default, &test_device);
let weight =
Tensor::<TestBackend, 4>::random([16, 16, 3, 3], Distribution::Default, &test_device);
let bias = Tensor::<TestBackend, 1>::random([16], Distribution::Default, &test_device);
let ref_device = Default::default();
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);
let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);
let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);
let options = burn_tensor::ops::ConvOptions::new([1, 1], [2, 2], [1, 1], 1);
let output = module::conv2d(input, weight, Some(bias), options.clone());
let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options);
output
.into_data()
.assert_approx_eq(&output_ref.into_data(), 1);
}
}

View File

@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{module, Distribution, Tensor};
#[test]
fn conv3d_should_work_with_multiple_invocations() {
fn conv3d_should_match_reference_backend() {
let test_device = Default::default();
let input = Tensor::<TestBackend, 5>::random(
[6, 16, 32, 32, 32],

View File

@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{backend::Backend, module, Distribution, Tensor};
#[test]
fn conv_transpose2d_should_work_with_multiple_invocations() {
fn conv_transpose2d_should_match_reference_backend() {
TestBackend::seed(0);
let height = 8;

View File

@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{backend::Backend, module, Distribution, Tensor};
#[test]
fn conv_transpose3d_should_work_with_multiple_invocations() {
fn conv_transpose3d_should_match_reference_backend() {
TestBackend::seed(0);
let depth = 8;

View File

@ -5,7 +5,7 @@ mod tests {
use burn_tensor::{Bool, Distribution, Tensor, TensorPrimitive};
#[test]
fn mask_fill_should_work_with_multiple_invocations() {
fn mask_fill_should_match_reference_backend() {
let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill();
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_fill(
@ -22,7 +22,7 @@ mod tests {
}
#[test]
fn mask_fill_inplace_should_work_with_multiple_invocations() {
fn mask_fill_inplace_should_match_reference_backend() {
let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill();
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_fill(

View File

@ -5,7 +5,7 @@ mod tests {
use burn_tensor::{backend::Backend, Bool, Distribution, Tensor, TensorPrimitive};
#[test]
fn mask_where_should_work_with_multiple_invocations() {
fn mask_where_should_match_reference_backend() {
let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();
let actual = tensor.mask_where(mask, value);
@ -16,7 +16,7 @@ mod tests {
.assert_approx_eq(&actual.into_data(), 3);
}
#[test]
fn mask_where_inplace_lhs_should_work_with_multiple_invocations() {
fn mask_where_inplace_lhs_should_match_reference_backend() {
let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_where(
@ -33,7 +33,7 @@ mod tests {
}
#[test]
fn mask_where_inplace_rhs_should_work_with_multiple_invocation() {
fn mask_where_inplace_rhs_should_match_reference_backend() {
let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_where(

View File

@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{module, Distribution, Tensor};
#[test]
pub fn max_pool2d_should_work_with_multiple_invocations() {
pub fn max_pool2d_should_match_reference_backends() {
let tensor = Tensor::<TestBackend, 4>::random(
[32, 32, 32, 32],
Distribution::Default,
@ -26,7 +26,7 @@ mod tests {
}
#[test]
pub fn max_pool2d_with_indices_should_work_with_multiple_invocations() {
pub fn max_pool2d_with_indices_should_match_reference_backend() {
let tensor = Tensor::<TestBackend, 4>::random(
[32, 32, 32, 32],
Distribution::Default,

View File

@ -4,7 +4,7 @@ mod tests {
use burn_tensor::{module, ops::ModuleOps, Distribution, Tensor, TensorPrimitive};
#[test]
pub fn max_pool2d_with_indices_backward_should_work_with_multiple_invocations() {
pub fn max_pool2d_with_indices_backward_should_match_reference_backend() {
let test_device = Default::default();
let tensor =
Tensor::<TestBackend, 4>::random([32, 32, 32, 32], Distribution::Default, &test_device);

View File

@ -10,7 +10,7 @@ mod reduction {
};
#[test]
fn reduction_sum_dim_should_work_with_multiple_invocations() {
fn reduction_sum_dim_should_match_reference_backend() {
let tensor =
Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default, &Default::default());
let tensor_ref =
@ -33,7 +33,7 @@ mod reduction {
}
#[test]
fn reduction_prod_dim_should_work_with_multiple_invocations() {
fn reduction_prod_dim_should_match_reference_backend() {
let tensor =
Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default, &Default::default());
let tensor_ref =
@ -56,7 +56,7 @@ mod reduction {
}
#[test]
fn reduction_argmin_dim_should_work_with_multiple_invocations() {
fn reduction_argmin_dim_should_match_reference_backend() {
let tensor =
Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default, &Default::default());
let tensor_ref =
@ -75,7 +75,7 @@ mod reduction {
}
#[test]
fn reduction_argmax_dim_should_work_with_multiple_invocations() {
fn reduction_argmax_dim_should_match_reference_backend() {
let tensor =
Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default, &Default::default());
let tensor_ref =
@ -290,7 +290,7 @@ mod reduction {
}
#[test]
fn reduction_sum_should_work_with_multiple_invocations() {
fn reduction_sum_should_match_reference_backend() {
let tensor =
Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default, &Default::default());
let tensor_ref =
@ -306,7 +306,7 @@ mod reduction {
}
#[test]
fn reduction_prod_should_work_with_multiple_invocations() {
fn reduction_prod_should_match_reference_backend() {
let tensor =
Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default, &Default::default());
let tensor_ref =

View File

@ -1,4 +1,8 @@
use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey};
use crate::kernel::{
conv::{Conv2dAutotuneKey, ConvTranspose2dAutotuneKey},
matmul::MatmulAutotuneKey,
reduce::ReduceAutotuneKey,
};
use cubecl::tune::AutotuneKey;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
@ -13,6 +17,10 @@ pub enum JitAutotuneKey {
Matmul(MatmulAutotuneKey),
/// Key for reduce dim operations
ReduceDim(ReduceAutotuneKey),
/// Key for convolution operations
Conv2d(Conv2dAutotuneKey),
/// Key for transpose convolution operations
ConvTranspose2d(ConvTranspose2dAutotuneKey),
#[cfg(any(feature = "fusion", test))]
/// Key for fused element wise operations.
FusionElemWise(FusionElemWiseAutotuneKey),
@ -25,6 +33,8 @@ impl Display for JitAutotuneKey {
JitAutotuneKey::ReduceDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
#[cfg(any(feature = "fusion", test))]
JitAutotuneKey::FusionElemWise(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f),
JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f),
}
}
}

View File

@ -109,6 +109,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
let [stride_height, stride_width] = options.stride;
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims;
let channels_per_group = out_channels / options.groups;
let out_height = calculate_conv_output_size(
kernel_height,
@ -141,7 +142,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
|(k, mut output)| {
let b = k / out_channels;
let oc = k % out_channels;
let g = k % options.groups;
let g = oc / channels_per_group;
for ic in (in_channels * g)..(in_channels * (g + 1)) {
let weight_ic = ic - (g * in_channels);

View File

@ -43,7 +43,7 @@ pub fn calculate_conv_transpose_output_size(
dilation: usize,
size_in: usize,
) -> usize {
(size_in - 1) * stride + dilation * (kernel_size - 1) + padding_out - 2 * padding + 1
(size_in - 1) * stride + (dilation * (kernel_size - 1) + 1) + padding_out - 2 * padding
}
/// Calculate the expected output size when doing a pooling operation.

View File

@ -69,6 +69,49 @@ mod tests {
]]));
}
#[test]
fn test_conv2d_groups_multiple_channels() {
let test = Conv2dTestCase {
batch_size: 1,
channels_in: 4,
channels_out: 4,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 2,
height: 5,
width: 5,
};
test.assert_output(TestTensor::from([[
[
[4035., 4188., 4341.],
[4800., 4953., 5106.],
[5565., 5718., 5871.],
],
[
[10030., 10507., 10984.],
[12415., 12892., 13369.],
[14800., 15277., 15754.],
],
[
[56075., 56876., 57677.],
[60080., 60881., 61682.],
[64085., 64886., 65687.],
],
[
[78270., 79395., 80520.],
[83895., 85020., 86145.],
[89520., 90645., 91770.],
],
]]));
}
#[test]
fn test_conv2d_complex() {
let test = Conv2dTestCase {

View File

@ -28,6 +28,7 @@ mod tests {
test.assert_output(TestTensor::from([[[[5.0, 11.0], [23.0, 29.0]]]]));
}
#[test]
fn test_conv_transpose2d_simple_2() {
let test = ConvTranspose2dTestCase {
@ -71,6 +72,34 @@ mod tests {
]]));
}
#[test]
fn test_conv_transpose2d_simple_3() {
let test = ConvTranspose2dTestCase {
batch_size: 1,
channels_in: 1,
channels_out: 1,
kernel_size_1: 2,
kernel_size_2: 2,
padding_1: 0,
padding_2: 0,
padding_out_1: 0,
padding_out_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
groups: 1,
height: 2,
width: 2,
};
test.assert_output(TestTensor::from([[[
[0.0, 0.0, 1.0],
[0.0, 4.0, 6.0],
[4.0, 12.0, 9.0],
]]]));
}
#[test]
fn test_conv_transpose2d_stride_2() {
let test = ConvTranspose2dTestCase {