mirror of https://github.com/tracel-ai/burn.git
Introduce autotuning to `conv2d` and `conv_transpose2d` with a new `im2col`/`GEMM` algorithm (#2287)
This commit is contained in:
parent
2c8514ce7f
commit
97af8c6d28
|
@ -305,6 +305,7 @@ dependencies = [
|
|||
"derive-new",
|
||||
"dirs 5.0.1",
|
||||
"github-device-flow",
|
||||
"half",
|
||||
"indicatif",
|
||||
"os_info",
|
||||
"percent-encoding",
|
||||
|
|
|
@ -11,11 +11,12 @@ version.workspace = true
|
|||
|
||||
[features]
|
||||
# we depend on wgpu and autotune by default because we use the burn-wgpu crate to get system information
|
||||
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
|
||||
candle-accelerate = ["burn/candle", "burn/accelerate"]
|
||||
candle-cpu = ["burn/candle"]
|
||||
candle-cuda = ["burn/candle-cuda"]
|
||||
candle-metal = ["burn/candle", "burn/metal"]
|
||||
candle-accelerate = ["burn/candle", "burn/accelerate"]
|
||||
cuda-jit = ["burn/cuda-jit"]
|
||||
default = ["burn/std", "burn/autodiff", "burn/wgpu", "burn/autotune"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
|
||||
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
|
||||
|
@ -24,7 +25,6 @@ tch-cpu = ["burn/tch"]
|
|||
tch-gpu = ["burn/tch"]
|
||||
wgpu = ["burn/wgpu", "burn/autotune"]
|
||||
wgpu-fusion = ["wgpu", "burn/fusion"]
|
||||
cuda-jit = ["burn/cuda-jit"]
|
||||
|
||||
[dependencies]
|
||||
arboard = { workspace = true }
|
||||
|
@ -33,11 +33,13 @@ burn-common = { path = "../crates/burn-common", version = "0.15.0" }
|
|||
burn-wgpu = { path = "../crates/burn-wgpu", default-features = false, version = "0.15.0", optional = true }
|
||||
clap = { workspace = true }
|
||||
colored = { workspace = true }
|
||||
cubecl = { workspace = true, features = ["wgpu"] }
|
||||
derive-new = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
github-device-flow = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
half = { workspace = true }
|
||||
indicatif = { workspace = true }
|
||||
os_info = { workspace = true }
|
||||
percent-encoding = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["blocking", "json"] }
|
||||
|
@ -48,69 +50,68 @@ strum_macros = { workspace = true }
|
|||
sysinfo = { workspace = true, features = ["serde"] }
|
||||
wgpu = { workspace = true }
|
||||
wsl = { workspace = true }
|
||||
cubecl = { workspace = true, features = ["wgpu"] }
|
||||
|
||||
[dev-dependencies]
|
||||
rstest = { workspace = true }
|
||||
serial_test = { workspace = true }
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "unary"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "binary"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "max-pool2d"
|
||||
path = "benches/max_pool2d.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "conv-transpose2d"
|
||||
path = "benches/conv_transpose2d.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "conv-transpose3d"
|
||||
path = "benches/conv_transpose3d.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "conv2d"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "conv3d"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "matmul"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "data"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "load-record"
|
||||
harness = false
|
||||
name = "load-record"
|
||||
path = "benches/load_record.rs"
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "custom-gelu"
|
||||
path = "benches/custom_gelu.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
harness = false
|
||||
name = "resnet50"
|
||||
path = "benches/resnet.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "autodiff"
|
||||
harness = false
|
||||
name = "autodiff"
|
||||
|
||||
[[bin]]
|
||||
name = "burnbench"
|
||||
|
|
|
@ -3,7 +3,6 @@ use std::fs;
|
|||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
const MODELS_DIR: &str = "/tmp/models";
|
||||
const MODELS_REPO: &str = "https://github.com/tracel-ai/models.git";
|
||||
|
||||
// Patch resnet code (remove pretrained feature code)
|
||||
|
@ -224,8 +223,10 @@ where
|
|||
}
|
||||
|
||||
fn main() {
|
||||
let models_dir = std::env::temp_dir().join("models");
|
||||
let models_dir = models_dir.as_path();
|
||||
// Checkout ResNet code from models repo
|
||||
let models_dir = Path::new(MODELS_DIR);
|
||||
let models_dir = Path::new(models_dir);
|
||||
if !models_dir.join(".git").exists() {
|
||||
run("git", |command| {
|
||||
command
|
||||
|
@ -233,7 +234,7 @@ fn main() {
|
|||
.arg("--depth=1")
|
||||
.arg("--no-checkout")
|
||||
.arg(MODELS_REPO)
|
||||
.arg(MODELS_DIR)
|
||||
.arg(models_dir)
|
||||
});
|
||||
|
||||
run("git", |command| {
|
||||
|
@ -266,10 +267,12 @@ fn main() {
|
|||
let source_path = models_dir.join("resnet-burn").join("resnet").join("src");
|
||||
let dest_path = Path::new(&out_dir);
|
||||
|
||||
for file in fs::read_dir(source_path).unwrap() {
|
||||
let source_file = file.unwrap().path();
|
||||
let dest_file = dest_path.join(source_file.file_name().unwrap());
|
||||
fs::copy(source_file, dest_file).expect("should copy file successfully");
|
||||
if let Ok(source_path) = fs::read_dir(source_path) {
|
||||
for file in source_path {
|
||||
let source_file = file.unwrap().path();
|
||||
let dest_file = dest_path.join(source_file.file_name().unwrap());
|
||||
fs::copy(source_file, dest_file).expect("should copy file successfully");
|
||||
}
|
||||
}
|
||||
|
||||
// Delete cloned repository contents
|
||||
|
|
|
@ -131,7 +131,7 @@ macro_rules! bench_on_backend {
|
|||
{
|
||||
use burn::backend::cuda_jit::{Cuda, CudaDevice};
|
||||
|
||||
bench::<Cuda>(&CudaDevice::default(), feature_name, url, token);
|
||||
bench::<Cuda<half::f16>>(&CudaDevice::default(), feature_name, url, token);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -11,12 +11,9 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-jit"
|
|||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["autotune", "std", "fusion", "cubecl/default"]
|
||||
std = ["cubecl/std"]
|
||||
doc = ["default"]
|
||||
autotune = []
|
||||
template = []
|
||||
fusion = ["burn-fusion"]
|
||||
default = ["autotune", "std", "fusion", "cubecl/default"]
|
||||
doc = ["default"]
|
||||
export_tests = [
|
||||
"burn-tensor-testgen",
|
||||
"serial_test",
|
||||
|
@ -25,32 +22,37 @@ export_tests = [
|
|||
"burn-ndarray",
|
||||
"fusion",
|
||||
]
|
||||
fusion = ["burn-fusion"]
|
||||
std = ["cubecl/std"]
|
||||
template = []
|
||||
|
||||
[dependencies]
|
||||
cubecl = { workspace = true, features = ["linalg"] }
|
||||
burn-common = { path = "../burn-common", version = "0.15.0" }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = ["cubecl"] }
|
||||
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [
|
||||
"cubecl",
|
||||
] }
|
||||
cubecl = { workspace = true, features = ["linalg"] }
|
||||
|
||||
bytemuck = { workspace = true }
|
||||
derive-new = { workspace = true }
|
||||
half = { workspace = true, features = ["bytemuck"] }
|
||||
log = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
rand = { workspace = true }
|
||||
spin = { workspace = true }
|
||||
half = { workspace = true, features = ["bytemuck"] }
|
||||
|
||||
# Template
|
||||
serde = { workspace = true }
|
||||
text_placeholder = { workspace = true, features = ["struct_context"] }
|
||||
|
||||
hashbrown = { workspace = true }
|
||||
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.15.0", optional = true }
|
||||
hashbrown = { workspace = true }
|
||||
|
||||
# When exporting tests
|
||||
serial_test = { workspace = true, optional = true }
|
||||
burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", default-features = false, optional = true }
|
||||
burn-ndarray = { path = "../burn-ndarray", version = "0.15.0", optional = true }
|
||||
serial_test = { workspace = true, optional = true }
|
||||
|
||||
[package.metadata.docs.rs]
|
||||
features = ["doc"]
|
||||
|
|
|
@ -511,14 +511,14 @@ impl TraceBuilder {
|
|||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operator::AtomicCompareAndSwap(_op) => {
|
||||
// Nothing to do.
|
||||
}
|
||||
Operator::Magnitude(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operator::AtomicCompareAndSwap(_op) => {
|
||||
// Nothing to do.
|
||||
}
|
||||
},
|
||||
Operation::Procedure(proc) => {
|
||||
match proc {
|
||||
|
|
|
@ -0,0 +1,125 @@
|
|||
use burn_tensor::{
|
||||
ops::{ConvOptions, ConvTransposeOptions},
|
||||
TensorData,
|
||||
};
|
||||
|
||||
use crate::{tensor::JitTensor, FloatElement, IntElement, JitElement, JitRuntime};
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
use super::conv2d_autotune;
|
||||
use super::{
|
||||
conv2d_direct, conv2d_im2col, conv_transpose2d_autotune, conv_transpose2d_col2im,
|
||||
conv_transpose2d_direct, implicit_gemm::conv2d_implicit_gemm,
|
||||
};
|
||||
|
||||
/// The strategy to be used when launching a convolution kernel.
|
||||
pub enum Conv2dStrategy {
|
||||
/// A simple direct convolution.
|
||||
Direct,
|
||||
#[cfg(feature = "autotune")]
|
||||
/// Using autotune to choose the best kernel based on runtime information.
|
||||
Autotune,
|
||||
/// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
|
||||
Gemm,
|
||||
/// Implicit GEMM implementation of convolution. Lower memory usage but requires CMMA and
|
||||
/// has constraints on tensor shape.
|
||||
ImplicitGemm,
|
||||
}
|
||||
|
||||
impl Default for Conv2dStrategy {
|
||||
fn default() -> Self {
|
||||
// if autotune is enabled, default to autotune
|
||||
#[cfg(feature = "autotune")]
|
||||
return Conv2dStrategy::Autotune;
|
||||
|
||||
// if autotune is disabled, default to the more memory-conservative algorithm
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
Conv2dStrategy::Direct
|
||||
}
|
||||
}
|
||||
|
||||
/// The strategy to be used when launching a conv_transpose kernel.
|
||||
pub enum ConvTranspose2dStrategy {
|
||||
/// A simple direct convolution.
|
||||
Direct,
|
||||
#[cfg(feature = "autotune")]
|
||||
/// Using autotune to choose the best kernel based on runtime information.
|
||||
Autotune,
|
||||
/// GEMM (im2col) based implementation of convolution. Significantly increased memory usage.
|
||||
Gemm,
|
||||
}
|
||||
|
||||
impl Default for ConvTranspose2dStrategy {
|
||||
fn default() -> Self {
|
||||
// if autotune is enabled, default to autotune
|
||||
#[cfg(feature = "autotune")]
|
||||
return ConvTranspose2dStrategy::Autotune;
|
||||
|
||||
// if autotune is disabled, default to the more memory-conservative algorithm
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
ConvTranspose2dStrategy::Direct
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a 2D convolution with the given strategy
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
|
||||
///
|
||||
pub fn conv2d<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
strategy: Conv2dStrategy,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
match strategy {
|
||||
Conv2dStrategy::Direct => conv2d_direct::<R, E, I>(input, weight, bias, options),
|
||||
#[cfg(feature = "autotune")]
|
||||
Conv2dStrategy::Autotune => conv2d_autotune::<R, E, I>(input, weight, bias, options),
|
||||
Conv2dStrategy::Gemm => conv2d_im2col::<R, E, I>(input, weight, bias, options),
|
||||
Conv2dStrategy::ImplicitGemm => {
|
||||
conv2d_implicit_gemm::<R, E, I>(input, weight, bias, options)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a 2D convolution with the given strategy
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
/// * `strategy` - The convolution algorithm to use. Autotune will pick the fastest available option.
|
||||
///
|
||||
pub fn conv_transpose2d<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
strategy: ConvTranspose2dStrategy,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
match strategy {
|
||||
ConvTranspose2dStrategy::Direct => {
|
||||
conv_transpose2d_direct::<R, E, I>(input, weight, bias, options)
|
||||
}
|
||||
#[cfg(feature = "autotune")]
|
||||
ConvTranspose2dStrategy::Autotune => {
|
||||
conv_transpose2d_autotune::<R, E, I>(input, weight, bias, options)
|
||||
}
|
||||
ConvTranspose2dStrategy::Gemm => {
|
||||
conv_transpose2d_col2im::<R, E, I>(input, weight, bias, options)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub(crate) fn debug_data<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
tensor: JitTensor<R, E, D>,
|
||||
) -> TensorData {
|
||||
let bytes = tensor.client.read(tensor.handle.binding());
|
||||
TensorData::new(E::from_bytes(&bytes).to_vec(), tensor.shape)
|
||||
}
|
|
@ -0,0 +1,297 @@
|
|||
use burn_tensor::{
|
||||
ops::{conv::calculate_conv_transpose_output_size, ConvTransposeOptions, FloatTensorOps as _},
|
||||
Shape,
|
||||
};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use crate::{
|
||||
kernel::into_contiguous,
|
||||
ops::{numeric::empty_device, reshape, swap_dims},
|
||||
tensor::JitTensor,
|
||||
FloatElement, IntElement, JitBackend, JitRuntime,
|
||||
};
|
||||
|
||||
use super::batches_per_run;
|
||||
|
||||
/// Perform a 2D convolution transposition using the GEMM (col2im) algorithm.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
///
|
||||
pub fn conv_transpose2d_col2im<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let [input_channels, im_ch_per_group, kernel_h, kernel_w] = weight.shape.dims;
|
||||
let [batch_size, _, input_h, input_w] = input.shape.dims;
|
||||
let groups = options.groups;
|
||||
let input_ch_per_group = input_channels / groups;
|
||||
let ConvTransposeOptions {
|
||||
padding: [padding_h, padding_w],
|
||||
padding_out: [padding_out_h, padding_out_w],
|
||||
dilation: [dilation_h, dilation_w],
|
||||
stride: [stride_h, stride_w],
|
||||
..
|
||||
} = options.clone();
|
||||
|
||||
let im_h = calculate_conv_transpose_output_size(
|
||||
kernel_h,
|
||||
stride_h,
|
||||
padding_h,
|
||||
padding_out_h,
|
||||
dilation_h,
|
||||
input_h,
|
||||
);
|
||||
let im_w = calculate_conv_transpose_output_size(
|
||||
kernel_w,
|
||||
stride_w,
|
||||
padding_w,
|
||||
padding_out_w,
|
||||
dilation_w,
|
||||
input_w,
|
||||
);
|
||||
let im_channels = im_ch_per_group * groups;
|
||||
|
||||
let batches_per_run = batches_per_run(batch_size, input_h, input_w);
|
||||
let col_shape_0 = im_ch_per_group * kernel_h * kernel_w;
|
||||
|
||||
let weight = reshape(
|
||||
weight.clone(),
|
||||
Shape::new([groups, input_ch_per_group, col_shape_0]),
|
||||
);
|
||||
let weight = into_contiguous(swap_dims(weight, 1, 2));
|
||||
|
||||
if batches_per_run != batch_size {
|
||||
let runs = batch_size / batches_per_run;
|
||||
|
||||
let im_shape = Shape::new([runs, batches_per_run, im_channels, im_h, im_w]);
|
||||
let mut image = empty_device(input.client.clone(), input.device.clone(), im_shape);
|
||||
|
||||
let input_shape = Shape::new([runs, batches_per_run, input_channels, input_h, input_w]);
|
||||
let input = reshape(input, input_shape);
|
||||
let input_shape_run = Shape::new([batches_per_run, input_channels, input_h, input_w]);
|
||||
let im_shape_run = Shape::new([1, batches_per_run, im_channels, im_h, im_w]);
|
||||
|
||||
for run in 0..runs {
|
||||
let input = JitBackend::<R, E, I>::float_narrow(input.clone(), 0, run, 1);
|
||||
let input = reshape(input, input_shape_run.clone());
|
||||
let image_run = execute::<R, E, I>(
|
||||
input,
|
||||
weight.clone(),
|
||||
bias.clone(),
|
||||
options.clone(),
|
||||
im_ch_per_group,
|
||||
im_h,
|
||||
im_w,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
);
|
||||
let image_run = reshape(image_run, im_shape_run.clone());
|
||||
image = JitBackend::<R, E, I>::float_slice_assign(
|
||||
image,
|
||||
[
|
||||
run..run + 1,
|
||||
0..batches_per_run,
|
||||
0..im_channels,
|
||||
0..im_h,
|
||||
0..im_w,
|
||||
],
|
||||
image_run,
|
||||
)
|
||||
}
|
||||
reshape(image, Shape::new([batch_size, im_channels, im_h, im_w]))
|
||||
} else {
|
||||
execute::<R, E, I>(
|
||||
input,
|
||||
weight,
|
||||
bias,
|
||||
options,
|
||||
im_ch_per_group,
|
||||
im_h,
|
||||
im_w,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn execute<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 3>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
im_ch_per_group: usize,
|
||||
im_h: usize,
|
||||
im_w: usize,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let [batch_size, _, input_h, input_w] = input.shape.dims;
|
||||
let [groups, col_shape_0, input_ch_per_group] = weight.shape.dims;
|
||||
|
||||
let im_channels = im_ch_per_group * groups;
|
||||
|
||||
let im_shape = Shape::new([batch_size, im_channels, im_h, im_w]);
|
||||
|
||||
let col_shape_1 = batch_size * input_h * input_w;
|
||||
|
||||
let input = swap_dims(input, 0, 1);
|
||||
let input_shape = Shape::new([groups, input_ch_per_group, col_shape_1]);
|
||||
let input = reshape(input, input_shape);
|
||||
|
||||
let columns = JitBackend::<R, E, I>::float_matmul(weight, input);
|
||||
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));
|
||||
|
||||
col2im(
|
||||
columns, bias, im_shape, kernel_h, kernel_w, input_h, input_w, options,
|
||||
)
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn col2im<R: JitRuntime, E: FloatElement>(
|
||||
columns: JitTensor<R, E, 2>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
im_shape: Shape<4>,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
out_h: usize,
|
||||
out_w: usize,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let [_, col_size_1] = columns.shape.dims;
|
||||
|
||||
let columns = into_contiguous(columns);
|
||||
let has_bias = bias.is_some();
|
||||
let bias = bias.map(into_contiguous).unwrap_or_else(|| {
|
||||
empty_device(
|
||||
columns.client.clone(),
|
||||
columns.device.clone(),
|
||||
Shape::new([1]),
|
||||
)
|
||||
});
|
||||
|
||||
let num_elems = im_shape.num_elements();
|
||||
let out = empty_device(
|
||||
columns.client.clone(),
|
||||
columns.device.clone(),
|
||||
im_shape.clone(),
|
||||
);
|
||||
|
||||
let vectorization = 1;
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
|
||||
|
||||
unsafe {
|
||||
col2im_kernel::launch_unchecked::<E, R>(
|
||||
&columns.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
columns.as_tensor_arg(vectorization),
|
||||
bias.as_tensor_arg(vectorization),
|
||||
out.as_tensor_arg(vectorization),
|
||||
Col2ImArgsLaunch::new(
|
||||
ScalarArg::new(out_h as u32),
|
||||
ScalarArg::new(out_w as u32),
|
||||
ScalarArg::new(kernel_h as u32),
|
||||
ScalarArg::new(kernel_w as u32),
|
||||
ScalarArg::new(options.padding[0] as u32),
|
||||
ScalarArg::new(options.padding[1] as u32),
|
||||
ScalarArg::new(options.dilation[0] as u32),
|
||||
ScalarArg::new(options.dilation[1] as u32),
|
||||
ScalarArg::new(options.stride[0] as u32),
|
||||
ScalarArg::new(options.stride[1] as u32),
|
||||
ScalarArg::new(col_size_1 as u32),
|
||||
),
|
||||
has_bias,
|
||||
)
|
||||
};
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
struct Col2ImArgs {
|
||||
out_h: u32,
|
||||
out_w: u32,
|
||||
|
||||
kernel_h: u32,
|
||||
kernel_w: u32,
|
||||
|
||||
pad_h: u32,
|
||||
pad_w: u32,
|
||||
dilation_h: u32,
|
||||
dilation_w: u32,
|
||||
stride_h: u32,
|
||||
stride_w: u32,
|
||||
|
||||
col_size_1: u32,
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn col2im_kernel<F: Float>(
|
||||
columns: &Tensor<F>,
|
||||
bias: &Tensor<F>,
|
||||
image: &mut Tensor<F>,
|
||||
args: &Col2ImArgs,
|
||||
#[comptime] has_bias: bool,
|
||||
) {
|
||||
if ABSOLUTE_POS > image.len() {
|
||||
return;
|
||||
}
|
||||
|
||||
let _ = bias[0]; // Keep in bind group
|
||||
|
||||
let im_x = ABSOLUTE_POS % image.shape(3) + args.pad_w;
|
||||
let im_y = ABSOLUTE_POS / image.stride(2) % image.shape(2) + args.pad_h;
|
||||
let ch_im = ABSOLUTE_POS / image.stride(1) % image.shape(1);
|
||||
let batch = ABSOLUTE_POS / image.stride(0);
|
||||
|
||||
let kernel_extent_w = (args.kernel_w - 1) * args.dilation_w + 1;
|
||||
let kernel_extent_h = (args.kernel_h - 1) * args.dilation_h + 1;
|
||||
|
||||
let mut val = F::new(0.0);
|
||||
|
||||
let x_col_start = if im_x >= kernel_extent_w {
|
||||
(im_x - kernel_extent_w) / args.stride_w + 1
|
||||
} else {
|
||||
0u32
|
||||
};
|
||||
let x_col_end = Min::min(im_x / args.stride_w + 1, args.out_w);
|
||||
let y_col_start = if im_y >= kernel_extent_h {
|
||||
(im_y - kernel_extent_h) / args.stride_h + 1
|
||||
} else {
|
||||
0u32
|
||||
};
|
||||
let y_col_end = Min::min(im_y / args.stride_h + 1, args.out_h);
|
||||
|
||||
for col_y in y_col_start..y_col_end {
|
||||
let kernel_y = im_y - col_y * args.stride_h;
|
||||
for col_x in x_col_start..x_col_end {
|
||||
let kernel_x = im_x - col_x * args.stride_w;
|
||||
|
||||
if kernel_y % args.dilation_h == 0 && kernel_x % args.dilation_w == 0 {
|
||||
let kernel_y = kernel_y / args.dilation_h;
|
||||
let kernel_x = kernel_x / args.dilation_w;
|
||||
|
||||
let col_pos = ch_im * args.kernel_h * args.kernel_w * args.col_size_1
|
||||
+ kernel_y * args.kernel_w * args.col_size_1
|
||||
+ kernel_x * args.col_size_1
|
||||
+ batch * args.out_h * args.out_w
|
||||
+ col_y * args.out_w
|
||||
+ col_x;
|
||||
val += columns[col_pos];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if has_bias {
|
||||
image[ABSOLUTE_POS] = val + bias[ch_im];
|
||||
} else {
|
||||
image[ABSOLUTE_POS] = val;
|
||||
}
|
||||
}
|
|
@ -1,9 +1,8 @@
|
|||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use burn_tensor::{
|
||||
ops::{conv::calculate_conv_output_size, ConvOptions},
|
||||
Shape,
|
||||
};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use crate::{
|
||||
kernel::into_contiguous,
|
||||
|
@ -12,7 +11,7 @@ use crate::{
|
|||
reshape,
|
||||
},
|
||||
tensor::JitTensor,
|
||||
FloatElement, JitRuntime,
|
||||
FloatElement, IntElement, JitRuntime,
|
||||
};
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
|
@ -23,11 +22,11 @@ struct Conv2dArgs {
|
|||
dilation_1: u32,
|
||||
padding_0: u32,
|
||||
padding_1: u32,
|
||||
groups: u32,
|
||||
channels_per_group: u32,
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn conv2d_kernel<F: Float>(
|
||||
fn direct_conv2d_kernel<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
weight: &Tensor<F>,
|
||||
bias: &Tensor<F>,
|
||||
|
@ -50,7 +49,7 @@ fn conv2d_kernel<F: Float>(
|
|||
let oh = ABSOLUTE_POS / output.stride(2) % output.shape(2);
|
||||
let ow = ABSOLUTE_POS / output.stride(3) % output.shape(3);
|
||||
|
||||
let g = (weight.shape(0) + oc) % args.groups;
|
||||
let g = oc / args.channels_per_group;
|
||||
let ic_start = in_channels * g;
|
||||
let ic_end = ic_start + in_channels;
|
||||
let mut sum = bias[oc];
|
||||
|
@ -114,41 +113,46 @@ fn conv2d_kernel<F: Float>(
|
|||
output[ABSOLUTE_POS] = sum;
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
||||
/// Perform a 2D convolution using the direct convolution algorithm.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
///
|
||||
#[allow(clippy::extra_unused_type_parameters)]
|
||||
pub fn conv2d_direct<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let input = into_contiguous(input);
|
||||
let weight = into_contiguous(weight);
|
||||
let [batch_size, _, in_height, in_width] = input.shape.dims;
|
||||
let [out_channels, _, kernel_0, kernel_1] = weight.shape.dims;
|
||||
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
|
||||
let channels_per_group = out_channels / options.groups;
|
||||
|
||||
// Limit loop unrolling factor to 8 or smaller
|
||||
let kernel_1_unroll = if kernel_1 > 8 {
|
||||
None
|
||||
} else {
|
||||
Some(kernel_1 as u32)
|
||||
};
|
||||
let kernel_w_unroll = (kernel_w <= 8).then_some(kernel_w as u32);
|
||||
|
||||
let out_0 = calculate_conv_output_size(
|
||||
kernel_0,
|
||||
let out_h = calculate_conv_output_size(
|
||||
kernel_h,
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
in_height,
|
||||
);
|
||||
let out_1 = calculate_conv_output_size(
|
||||
kernel_1,
|
||||
let out_w = calculate_conv_output_size(
|
||||
kernel_w,
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
in_width,
|
||||
);
|
||||
|
||||
let shape_out = Shape::new([batch_size, out_channels, out_0, out_1]);
|
||||
let input = into_contiguous(input);
|
||||
let weight = into_contiguous(weight);
|
||||
|
||||
let shape_out = Shape::new([batch_size, out_channels, out_h, out_w]);
|
||||
let output = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
|
@ -170,7 +174,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
|||
let cube_dim = CubeDim::default();
|
||||
let cube_count = calculate_cube_count_elemwise(num_elems_output, cube_dim);
|
||||
|
||||
conv2d_kernel::launch::<E, R>(
|
||||
direct_conv2d_kernel::launch::<E, R>(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -185,9 +189,9 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
|||
ScalarArg::new(options.dilation[1] as u32),
|
||||
ScalarArg::new(options.padding[0] as u32),
|
||||
ScalarArg::new(options.padding[1] as u32),
|
||||
ScalarArg::new(options.groups as u32),
|
||||
ScalarArg::new(channels_per_group as u32),
|
||||
),
|
||||
kernel_1_unroll,
|
||||
kernel_w_unroll,
|
||||
);
|
||||
|
||||
output
|
|
@ -0,0 +1,297 @@
|
|||
use burn_tensor::{
|
||||
ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps as _},
|
||||
Shape,
|
||||
};
|
||||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use crate::{
|
||||
kernel::into_contiguous,
|
||||
ops::{numeric::empty_device, reshape, swap_dims},
|
||||
tensor::JitTensor,
|
||||
FloatElement, IntElement, JitBackend, JitRuntime,
|
||||
};
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
struct Im2ColArgs {
|
||||
stride_h: u32,
|
||||
stride_w: u32,
|
||||
dilation_h: u32,
|
||||
dilation_w: u32,
|
||||
padding_h: u32,
|
||||
padding_w: u32,
|
||||
|
||||
kernel_h: u32,
|
||||
kernel_w: u32,
|
||||
out_h: u32,
|
||||
out_w: u32,
|
||||
|
||||
col_size_1: u32,
|
||||
num_elements: u32,
|
||||
}
|
||||
|
||||
#[cube(launch_unchecked)]
|
||||
fn im2col_kernel<F: Float>(
|
||||
image: &Tensor<F>,
|
||||
columns: &mut Tensor<F>,
|
||||
args: &Im2ColArgs,
|
||||
#[comptime] kernel_w_unroll: Option<u32>,
|
||||
#[comptime] has_padding: bool,
|
||||
) {
|
||||
// position shape: [in_channels, batch_size, out_h, out_w]
|
||||
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
|
||||
|
||||
let batch_size = image.shape(0);
|
||||
let height = image.shape(2);
|
||||
let width = image.shape(3);
|
||||
|
||||
let out_h = args.out_h;
|
||||
let out_w = args.out_w;
|
||||
|
||||
if ABSOLUTE_POS > args.num_elements {
|
||||
return;
|
||||
}
|
||||
|
||||
let out_x = ABSOLUTE_POS % out_w;
|
||||
let out_y = ABSOLUTE_POS / out_w % out_h;
|
||||
let batch = ABSOLUTE_POS / (out_w * out_h) % batch_size;
|
||||
let channel = ABSOLUTE_POS / (out_w * out_h * batch_size) % image.shape(1);
|
||||
|
||||
let kernel_w = kernel_w_unroll.unwrap_or(args.kernel_w);
|
||||
let unroll_w = kernel_w_unroll.is_some();
|
||||
|
||||
let image_idx = batch * image.stride(0) + channel * image.stride(1);
|
||||
let col_idx = channel * args.kernel_h * kernel_w * args.col_size_1
|
||||
+ batch * out_h * out_w
|
||||
+ out_y * out_w
|
||||
+ out_x;
|
||||
|
||||
for kernel_y in 0..args.kernel_h {
|
||||
#[unroll(unroll_w)]
|
||||
for kernel_x in 0..kernel_w {
|
||||
let kernel_pos = kernel_y * kernel_w + kernel_x;
|
||||
let col_pos = col_idx + kernel_pos * args.col_size_1;
|
||||
|
||||
if has_padding {
|
||||
let y = (out_y * args.stride_h + kernel_y * args.dilation_h) as i32
|
||||
- args.padding_h as i32;
|
||||
let x = (out_x * args.stride_w + kernel_x * args.dilation_w) as i32
|
||||
- args.padding_w as i32;
|
||||
if y >= 0 && x >= 0 && y < height as i32 && x < width as i32 {
|
||||
let image_ptr = image_idx + y as u32 * width + x as u32;
|
||||
columns[col_pos] = image[image_ptr];
|
||||
} else {
|
||||
columns[col_pos] = F::new(0.0)
|
||||
};
|
||||
} else {
|
||||
let y = out_y * args.stride_h + kernel_y * args.dilation_h;
|
||||
let x = out_x * args.stride_w + kernel_x * args.dilation_w;
|
||||
let image_ptr = image_idx + y * image.stride(2) + x * image.stride(3);
|
||||
columns[col_pos] = image[image_ptr];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(test))]
|
||||
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> usize {
|
||||
let cube_count_per_batch = (out_h * out_w).div_ceil(cubecl::SUBCUBE_DIM_APPROX);
|
||||
let max_cube_count = u16::MAX as usize;
|
||||
let max_simultaneous = (max_cube_count / cube_count_per_batch).min(batch_size);
|
||||
if max_simultaneous == 0 {
|
||||
panic!("Image too large to run even one batch at once");
|
||||
}
|
||||
(0..=max_simultaneous)
|
||||
.rev()
|
||||
.find(|per_run| batch_size % per_run == 0)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(unused)]
|
||||
pub(crate) fn batches_per_run(batch_size: usize, out_h: usize, out_w: usize) -> usize {
|
||||
1
|
||||
}
|
||||
|
||||
fn im2col<R: JitRuntime, E: FloatElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
options: ConvOptions<2>,
|
||||
kernel_h: usize,
|
||||
kernel_w: usize,
|
||||
out_h: usize,
|
||||
out_w: usize,
|
||||
) -> JitTensor<R, E, 2> {
|
||||
let input = into_contiguous(input);
|
||||
let [batch_size, in_channels, _, _] = input.shape.dims;
|
||||
|
||||
let col_shape_0 = in_channels * kernel_h * kernel_w;
|
||||
let col_shape_1 = batch_size * out_h * out_w;
|
||||
let shape_col = Shape::new([col_shape_0, col_shape_1]);
|
||||
let columns = empty_device(
|
||||
input.client.clone(),
|
||||
input.device.clone(),
|
||||
shape_col.clone(),
|
||||
);
|
||||
|
||||
let num_elems = in_channels * batch_size * out_h * out_w;
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count = calculate_cube_count_elemwise(num_elems, cube_dim);
|
||||
|
||||
let kernel_w_unroll = (kernel_w <= 8).then_some(kernel_w as u32);
|
||||
|
||||
let vectorization = 1;
|
||||
|
||||
unsafe {
|
||||
im2col_kernel::launch_unchecked::<E, R>(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
input.as_handle_ref().as_tensor_arg(vectorization),
|
||||
columns.as_handle_ref().as_tensor_arg(vectorization),
|
||||
Im2ColArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0] as u32),
|
||||
ScalarArg::new(options.stride[1] as u32),
|
||||
ScalarArg::new(options.dilation[0] as u32),
|
||||
ScalarArg::new(options.dilation[1] as u32),
|
||||
ScalarArg::new(options.padding[0] as u32),
|
||||
ScalarArg::new(options.padding[1] as u32),
|
||||
ScalarArg::new(kernel_h as u32),
|
||||
ScalarArg::new(kernel_w as u32),
|
||||
ScalarArg::new(out_h as u32),
|
||||
ScalarArg::new(out_w as u32),
|
||||
ScalarArg::new(col_shape_1 as u32),
|
||||
ScalarArg::new(num_elems as u32),
|
||||
),
|
||||
kernel_w_unroll,
|
||||
options.padding != [0, 0],
|
||||
)
|
||||
};
|
||||
|
||||
columns
|
||||
}
|
||||
|
||||
/// Perform a 2D convolution using the GEMM (im2col) algorithm.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
///
|
||||
pub fn conv2d_im2col<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let [batch_size, in_channels, in_height, in_width] = input.shape.dims;
|
||||
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
|
||||
|
||||
let out_h = calculate_conv_output_size(
|
||||
kernel_h,
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
in_height,
|
||||
);
|
||||
let out_w = calculate_conv_output_size(
|
||||
kernel_w,
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
in_width,
|
||||
);
|
||||
|
||||
if kernel_h == 1 && kernel_w == 1 && in_height == out_h && in_width == out_w {
|
||||
// Special case for 1x1 kernels (sometimes used to scale the image by a set of weights)
|
||||
return execute_1x1_kernel::<R, E, I>(input, weight, bias, options);
|
||||
}
|
||||
|
||||
let batches_per_run = batches_per_run(batch_size, out_h, out_w);
|
||||
|
||||
let mut out = if batches_per_run != batch_size {
|
||||
let runs = batch_size / batches_per_run;
|
||||
let out_shape = Shape::new([runs, out_channels, batches_per_run, out_h, out_w]);
|
||||
let mut out = empty_device(input.client.clone(), input.device.clone(), out_shape);
|
||||
let in_shape = Shape::new([runs, batches_per_run, in_channels, in_height, in_width]);
|
||||
let input = reshape(input, in_shape);
|
||||
let in_shape_run = Shape::new([batches_per_run, in_channels, in_height, in_width]);
|
||||
let out_shape_run = Shape::new([1, out_channels, batches_per_run, out_h, out_w]);
|
||||
|
||||
for run in 0..runs {
|
||||
let input = JitBackend::<R, E, I>::float_narrow(input.clone(), 0, run, 1);
|
||||
let input = reshape(input, in_shape_run.clone());
|
||||
let run_out = execute::<R, E, I>(input, weight.clone(), options.clone(), out_h, out_w);
|
||||
let run_out = reshape(run_out, out_shape_run.clone());
|
||||
out = JitBackend::<R, E, I>::float_slice_assign(
|
||||
out,
|
||||
[
|
||||
run..run + 1,
|
||||
0..out_channels,
|
||||
0..batches_per_run,
|
||||
0..out_h,
|
||||
0..out_w,
|
||||
],
|
||||
run_out,
|
||||
);
|
||||
}
|
||||
let out = swap_dims(out, 1, 2);
|
||||
reshape(out, Shape::new([batch_size, out_channels, out_h, out_w]))
|
||||
} else {
|
||||
let out = execute::<R, E, I>(input, weight, options, out_h, out_w);
|
||||
let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w]));
|
||||
swap_dims(out, 0, 1)
|
||||
};
|
||||
|
||||
if let Some(bias) = bias {
|
||||
let bias = reshape(bias, Shape::new([1, out_channels, 1, 1]));
|
||||
out = JitBackend::<R, E, I>::float_add(out, bias)
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn execute_1x1_kernel<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let [batch_size, _, height, width] = input.shape.dims;
|
||||
let [out_channels, in_c_per_grp, _, _] = weight.shape.dims;
|
||||
let groups = options.groups;
|
||||
let out_c_per_grp = out_channels / groups;
|
||||
|
||||
let input = swap_dims(input, 0, 1); // [CNHW]
|
||||
|
||||
let weight = reshape(weight, Shape::new([groups, out_c_per_grp, in_c_per_grp]));
|
||||
let in_shape = Shape::new([groups, in_c_per_grp, batch_size * height * width]);
|
||||
let input = reshape(input, in_shape);
|
||||
let out = JitBackend::<R, E, I>::float_matmul(weight, input);
|
||||
let mut out = reshape(out, Shape::new([out_channels, batch_size, height, width]));
|
||||
|
||||
if let Some(bias) = bias {
|
||||
let bias = reshape(bias, Shape::new([out_channels, 1, 1, 1]));
|
||||
out = JitBackend::<R, E, I>::float_add(out, bias)
|
||||
}
|
||||
|
||||
swap_dims(out, 0, 1)
|
||||
}
|
||||
|
||||
fn execute<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
options: ConvOptions<2>,
|
||||
out_h: usize,
|
||||
out_w: usize,
|
||||
) -> JitTensor<R, E, 3> {
|
||||
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
|
||||
let groups = options.groups;
|
||||
|
||||
let columns = im2col(input, options.clone(), kernel_h, kernel_w, out_h, out_w);
|
||||
let [col_shape_0, col_shape_1] = columns.shape.dims;
|
||||
let col_shape_0 = col_shape_0 / groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let columns = reshape(columns, Shape::new([groups, col_shape_0, col_shape_1]));
|
||||
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));
|
||||
|
||||
JitBackend::<R, E, I>::float_matmul(weight, columns)
|
||||
}
|
|
@ -0,0 +1,562 @@
|
|||
use burn_tensor::{
|
||||
ops::{conv::calculate_conv_output_size, ConvOptions, FloatTensorOps},
|
||||
Shape,
|
||||
};
|
||||
use cmma::{Matrix, MatrixIdent, MatrixLayout};
|
||||
use cubecl::{
|
||||
cube,
|
||||
ir::{Elem, FloatKind},
|
||||
prelude::*,
|
||||
Compiler, CubeCount, CubeDim, Feature,
|
||||
};
|
||||
use half::f16;
|
||||
|
||||
use crate::{
|
||||
ops::{numeric::empty_device, permute, reshape},
|
||||
tensor::JitTensor,
|
||||
FloatElement, IntElement, JitBackend, JitRuntime,
|
||||
};
|
||||
|
||||
/// Perform a 2D convolution using the implicit GEMM algorithm. Requries `cmma` to be available.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
///
|
||||
pub fn conv2d_implicit_gemm<R: JitRuntime, F: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, F, 4>,
|
||||
weight: JitTensor<R, F, 4>,
|
||||
bias: Option<JitTensor<R, F, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> JitTensor<R, F, 4> {
|
||||
let [batch_size, in_channels, height, width] = input.shape.dims;
|
||||
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
|
||||
|
||||
let out_h = calculate_conv_output_size(
|
||||
kernel_h,
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
height,
|
||||
);
|
||||
let out_w = calculate_conv_output_size(
|
||||
kernel_w,
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
width,
|
||||
);
|
||||
|
||||
if !can_do_implicit_gemm(&input, &weight, &options, out_h, out_w) {
|
||||
panic!(
|
||||
"Requirements for implicit GEMM not met:
|
||||
- CMMA must be available
|
||||
- `batch_size * out_h * out_w` must be divisible by 16
|
||||
- `out_channels` must be divisible by 16
|
||||
- `in_channels * kernel_h * kernel_w` must be divisible by 16
|
||||
- `groups` must be 1
|
||||
"
|
||||
);
|
||||
}
|
||||
|
||||
let out_shape = Shape::new([batch_size, out_h, out_w, out_channels]);
|
||||
let mut out = empty_device(input.client.clone(), input.device.clone(), out_shape);
|
||||
|
||||
// Implicit GEMM matrix size
|
||||
let gemm_m = (batch_size * out_h * out_w) as u32;
|
||||
let gemm_n = out_channels as u32;
|
||||
let gemm_k = in_channels * kernel_h * kernel_w;
|
||||
let slice_size = kernel_h * kernel_w * in_channels;
|
||||
|
||||
let cmma_m = 16;
|
||||
let cmma_n = 16;
|
||||
let cmma_k = 16;
|
||||
|
||||
let warp_size = 32;
|
||||
let warps_per_cube = 8;
|
||||
|
||||
let cube_dim_x = 128;
|
||||
let cube_dim_y = 2;
|
||||
|
||||
assert!(cube_dim_y * cube_dim_x / warp_size == warps_per_cube);
|
||||
|
||||
let settings = GemmSettings {
|
||||
cmma_m,
|
||||
cmma_n,
|
||||
cmma_k,
|
||||
warp_size,
|
||||
warps_per_cube,
|
||||
cube_dim_x,
|
||||
};
|
||||
|
||||
// `CUBE_DIM_X` must be a multiple of `WARP_SIZE`
|
||||
// 128x2 means we have 8 warps and a cube computes a 32x64 output tile
|
||||
let cube_dim = CubeDim {
|
||||
x: cube_dim_x,
|
||||
y: cube_dim_y,
|
||||
z: 1,
|
||||
};
|
||||
|
||||
let cube_count_x = gemm_m.div_ceil(cmma_m * cube_dim_x / warp_size);
|
||||
let cube_count_y = gemm_n.div_ceil(cmma_n * cube_dim_y);
|
||||
|
||||
// If div floor == div ceil then the cubes are aligned with the input dimensions
|
||||
let aligned = gemm_m / (cmma_m * cube_dim_x / warp_size) == cube_count_x
|
||||
&& gemm_n / (cmma_n * cube_dim_y) == cube_count_y;
|
||||
|
||||
let cube_count = CubeCount::Static(cube_count_x, cube_count_y, 1);
|
||||
|
||||
unsafe {
|
||||
implicit_gemm_kernel::launch_unchecked::<F, f16, R>(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
input.as_tensor_arg(1),
|
||||
weight.as_tensor_arg(1),
|
||||
out.as_tensor_arg(1),
|
||||
DimensionsLaunch::new(
|
||||
ScalarArg::new(gemm_m),
|
||||
ScalarArg::new(gemm_n),
|
||||
ScalarArg::new(gemm_k as u32),
|
||||
ScalarArg::new(slice_size as u32),
|
||||
ScalarArg::new(out_h as u32),
|
||||
ScalarArg::new(out_w as u32),
|
||||
),
|
||||
ConvArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0] as u32),
|
||||
ScalarArg::new(options.stride[1] as u32),
|
||||
ScalarArg::new(options.padding[0] as i32),
|
||||
ScalarArg::new(options.padding[1] as i32),
|
||||
ScalarArg::new(options.dilation[0] as u32),
|
||||
ScalarArg::new(options.dilation[1] as u32),
|
||||
),
|
||||
settings,
|
||||
KernelSettings {
|
||||
kernel_h: kernel_h as u32,
|
||||
kernel_w: kernel_w as u32,
|
||||
has_padding: options.padding != [0, 0],
|
||||
aligned,
|
||||
},
|
||||
)
|
||||
};
|
||||
|
||||
if let Some(bias) = bias {
|
||||
let bias = reshape(bias, Shape::new([1, 1, 1, out_channels]));
|
||||
out = JitBackend::<R, F, I>::float_add(out, bias);
|
||||
}
|
||||
|
||||
permute(out, [0, 3, 1, 2])
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
struct ConvArgs {
|
||||
stride_h: u32,
|
||||
stride_w: u32,
|
||||
pad_h: i32,
|
||||
pad_w: i32,
|
||||
dilation_h: u32,
|
||||
dilation_w: u32,
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
struct Dimensions {
|
||||
gemm_m: u32,
|
||||
gemm_n: u32,
|
||||
gemm_k: u32,
|
||||
slice_size: u32,
|
||||
|
||||
out_h: u32,
|
||||
out_w: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
struct GemmSettings {
|
||||
cmma_m: u32,
|
||||
cmma_n: u32,
|
||||
cmma_k: u32,
|
||||
|
||||
warp_size: u32,
|
||||
warps_per_cube: u32,
|
||||
|
||||
cube_dim_x: u32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, Hash, Debug)]
|
||||
struct KernelSettings {
|
||||
kernel_h: u32,
|
||||
kernel_w: u32,
|
||||
has_padding: bool,
|
||||
aligned: bool,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, CubeType)]
|
||||
struct Positions {
|
||||
global_m: u32,
|
||||
global_n: u32,
|
||||
|
||||
intra_warp_unit_idx: u32,
|
||||
cube_linear_warp_idx: u32,
|
||||
}
|
||||
|
||||
#[derive(CubeType)]
|
||||
struct Matrices<F: Float, FAcc: Float> {
|
||||
a: Matrix<F>,
|
||||
b: Matrix<F>,
|
||||
acc: Matrix<FAcc>,
|
||||
}
|
||||
|
||||
#[allow(clippy::collapsible_else_if)]
|
||||
#[cube(launch_unchecked)]
|
||||
fn implicit_gemm_kernel<F: Float, FMat: Float>(
|
||||
input: &Tensor<F>,
|
||||
weight: &Tensor<F>,
|
||||
out: &mut Tensor<F>,
|
||||
dims: &Dimensions,
|
||||
args: &ConvArgs,
|
||||
#[comptime] gemm_settings: GemmSettings,
|
||||
#[comptime] kernel_settings: KernelSettings,
|
||||
) {
|
||||
let GemmSettings {
|
||||
cmma_m,
|
||||
cmma_n,
|
||||
cmma_k,
|
||||
warps_per_cube,
|
||||
..
|
||||
} = gemm_settings;
|
||||
|
||||
let cmma_out_tile_size = cmma_m * cmma_n;
|
||||
let cmma_input_tile_size = cmma_m * cmma_k;
|
||||
let cmma_filter_tile_size = cmma_k * cmma_n;
|
||||
|
||||
let pos = calculate_positions(gemm_settings);
|
||||
|
||||
// Shared memory tiles, currently only holds enough data for
|
||||
// each warp to have its own tile for a single MMA op (8 * 16 * 16 elements)
|
||||
// conceptually a WARPS_PER_CUBE x (CMMA_M * CMMA_K) matrix
|
||||
let mut smem_input_tile = SharedMemory::<FMat>::new(cmma_input_tile_size * warps_per_cube);
|
||||
let mut smem_weight_tile = SharedMemory::<FMat>::new(cmma_filter_tile_size * warps_per_cube);
|
||||
|
||||
let input_tile_start = pos.cube_linear_warp_idx * cmma_input_tile_size;
|
||||
let weight_tile_start = pos.cube_linear_warp_idx * cmma_filter_tile_size;
|
||||
let input_tile =
|
||||
smem_input_tile.slice_mut(input_tile_start, input_tile_start + cmma_input_tile_size);
|
||||
let weight_tile =
|
||||
smem_weight_tile.slice_mut(weight_tile_start, weight_tile_start + cmma_filter_tile_size);
|
||||
|
||||
let out_pos = pos.global_n + pos.global_m * dims.gemm_n;
|
||||
let out = out.slice_mut(out_pos, out_pos + cmma_out_tile_size);
|
||||
|
||||
if kernel_settings.aligned {
|
||||
execute_gemm(
|
||||
input,
|
||||
weight,
|
||||
out,
|
||||
input_tile,
|
||||
weight_tile,
|
||||
dims,
|
||||
&pos,
|
||||
args,
|
||||
gemm_settings,
|
||||
kernel_settings,
|
||||
);
|
||||
} else {
|
||||
if pos.global_m < dims.gemm_m && pos.global_n < dims.gemm_n {
|
||||
execute_gemm(
|
||||
input,
|
||||
weight,
|
||||
out,
|
||||
input_tile,
|
||||
weight_tile,
|
||||
dims,
|
||||
&pos,
|
||||
args,
|
||||
gemm_settings,
|
||||
kernel_settings,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn calculate_positions(#[comptime] gemm_settings: GemmSettings) -> Positions {
|
||||
let GemmSettings {
|
||||
cmma_m,
|
||||
cmma_n,
|
||||
warp_size,
|
||||
cube_dim_x,
|
||||
..
|
||||
} = gemm_settings;
|
||||
|
||||
// Tile using a 2D grid (over the output), each threadblock
|
||||
// is (128, 2) -> (4,2) = 8 warps -> 32x64 output
|
||||
let global_warp_m = ABSOLUTE_POS_X / warp_size;
|
||||
let global_warp_n = ABSOLUTE_POS_Y;
|
||||
let cube_warp_m = UNIT_POS_X / warp_size;
|
||||
let cube_warp_n = UNIT_POS_Y;
|
||||
let num_warps_m = cube_dim_x / warp_size;
|
||||
let intra_warp_unit_idx = UNIT_POS_X % warp_size; // Thread idx within warp (0 to 31)
|
||||
let cube_linear_warp_idx = (cube_warp_n * num_warps_m) + cube_warp_m; // Warp idx within a block (0 to WARPS_PER_BLOCK - 1)
|
||||
|
||||
Positions {
|
||||
global_m: global_warp_m * cmma_m,
|
||||
global_n: global_warp_n * cmma_n,
|
||||
intra_warp_unit_idx,
|
||||
cube_linear_warp_idx,
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn make_matrices<F: Float, FAcc: Float>(
|
||||
#[comptime] gemm_settings: GemmSettings,
|
||||
) -> Matrices<F, FAcc> {
|
||||
let GemmSettings {
|
||||
cmma_m,
|
||||
cmma_n,
|
||||
cmma_k,
|
||||
..
|
||||
} = gemm_settings;
|
||||
|
||||
let matrices = Matrices::<F, FAcc> {
|
||||
a: Matrix::<F>::new(
|
||||
MatrixIdent::A,
|
||||
cmma_m,
|
||||
cmma_n,
|
||||
cmma_k,
|
||||
MatrixLayout::RowMajor,
|
||||
),
|
||||
b: Matrix::<F>::new(
|
||||
MatrixIdent::B,
|
||||
cmma_m,
|
||||
cmma_n,
|
||||
cmma_k,
|
||||
MatrixLayout::RowMajor,
|
||||
),
|
||||
acc: Matrix::<FAcc>::new(
|
||||
MatrixIdent::Accumulator,
|
||||
cmma_m,
|
||||
cmma_n,
|
||||
cmma_k,
|
||||
MatrixLayout::Undefined,
|
||||
),
|
||||
};
|
||||
|
||||
cmma::fill(&matrices.acc, FAcc::new(0.0));
|
||||
|
||||
matrices
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn execute_gemm<F: Float, FMat: Float>(
|
||||
input: &Tensor<F>,
|
||||
weight: &Tensor<F>,
|
||||
out: &mut SliceMut<F>,
|
||||
input_tile: &mut SliceMut<FMat>,
|
||||
weight_tile: &mut SliceMut<FMat>,
|
||||
dims: &Dimensions,
|
||||
pos: &Positions,
|
||||
args: &ConvArgs,
|
||||
#[comptime] g_settings: GemmSettings,
|
||||
#[comptime] k_settings: KernelSettings,
|
||||
) {
|
||||
let GemmSettings { cmma_n, cmma_k, .. } = g_settings;
|
||||
|
||||
let matrices = make_matrices::<FMat, F>(g_settings);
|
||||
|
||||
// Loop over the K-dimension
|
||||
for k in range_stepped(0, dims.gemm_k, cmma_k) {
|
||||
// Load into smem...
|
||||
// Each warp should load the 16x16 tile it's responsible for
|
||||
// i.e. each thread needs to load 8 elements of input and 8 elements of weight
|
||||
|
||||
load_input_tile(
|
||||
input, args, input_tile, dims, pos, k, g_settings, k_settings,
|
||||
);
|
||||
|
||||
load_weight_tile(weight, weight_tile, pos, k, g_settings, k_settings);
|
||||
|
||||
// Run CMMA
|
||||
cmma::load(&matrices.a, input_tile.as_slice(), cmma_k);
|
||||
cmma::load(&matrices.b, weight_tile.as_slice(), cmma_n);
|
||||
|
||||
cmma::execute::<FMat, FMat, F, F>(&matrices.a, &matrices.b, &matrices.acc, &matrices.acc);
|
||||
}
|
||||
|
||||
cmma::store(out, &matrices.acc, dims.gemm_n, MatrixLayout::RowMajor);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn load_input_tile<F: Float, FMat: Float>(
|
||||
input: &Tensor<F>,
|
||||
args: &ConvArgs,
|
||||
tile: &mut SliceMut<FMat>,
|
||||
dims: &Dimensions,
|
||||
pos: &Positions,
|
||||
k: u32,
|
||||
#[comptime] gemm_settings: GemmSettings,
|
||||
#[comptime] kernel_settings: KernelSettings,
|
||||
) {
|
||||
let GemmSettings {
|
||||
cmma_m,
|
||||
cmma_k,
|
||||
warp_size,
|
||||
..
|
||||
} = gemm_settings;
|
||||
|
||||
let KernelSettings {
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
has_padding,
|
||||
..
|
||||
} = kernel_settings;
|
||||
|
||||
let kernel_size = kernel_h * kernel_w;
|
||||
let cmma_input_tile_size = cmma_m * cmma_k;
|
||||
|
||||
let height = input.shape(2) as i32;
|
||||
let width = input.shape(3) as i32;
|
||||
|
||||
// Row strides in the implicit GEMM matrix
|
||||
let batch_stride = dims.out_h * dims.out_w;
|
||||
let y_stride = dims.out_w;
|
||||
let x_stride = 1;
|
||||
|
||||
// Start index within a slice (0 to `kernel_size * channels - 1`) that a half warp (16 units) is responsible for
|
||||
let slice_start_idx = k % dims.slice_size;
|
||||
|
||||
for m in range_stepped(pos.intra_warp_unit_idx, cmma_input_tile_size, warp_size) {
|
||||
// Compute where in the slice we are starting
|
||||
|
||||
// Slices are always `kernel_size * channels` elements wide so we can compute where inside a slice
|
||||
// we are and also which row the slice is in relative to the start of the CMMA matrix
|
||||
|
||||
let rel_slice_row = m / cmma_k; // Relative row (0 - 15)
|
||||
let abs_slice_row = pos.global_m + rel_slice_row; // Row of the matrix the slice is on
|
||||
|
||||
// Actual index within a slice (0 to `kernel_size * channels - 1`) that the thread is repsonsible for
|
||||
let my_slice_idx = (slice_start_idx + (m % cmma_k)) % dims.slice_size;
|
||||
|
||||
// Given the row of the matrix that the slice is in, and the index of the thread
|
||||
// within a slice, want to compute what input element to load...
|
||||
// first compute coordinates in output space (center of the kernel in MxK matrix A)
|
||||
let batch = abs_slice_row / batch_stride;
|
||||
let out_y = (abs_slice_row % batch_stride) / y_stride;
|
||||
let out_x = ((abs_slice_row % batch_stride) % y_stride) / x_stride;
|
||||
|
||||
let channel = my_slice_idx / kernel_size;
|
||||
|
||||
let kernel_y = (my_slice_idx / kernel_w) % kernel_h;
|
||||
let kernel_x = my_slice_idx % kernel_w;
|
||||
|
||||
if has_padding {
|
||||
let y = (out_y * args.stride_h + kernel_y * args.dilation_h) as i32 - args.pad_h;
|
||||
let x = (out_x * args.stride_w + kernel_x * args.dilation_w) as i32 - args.pad_w;
|
||||
|
||||
if x >= 0 && x < width && y >= 0 && y < height {
|
||||
tile[m] = FMat::cast_from(
|
||||
input[batch * input.stride(0)
|
||||
+ y as u32 * input.stride(2)
|
||||
+ x as u32 * input.stride(3)
|
||||
+ channel * input.stride(1)],
|
||||
);
|
||||
} else {
|
||||
tile[m] = FMat::new(0.0);
|
||||
}
|
||||
} else {
|
||||
let y = out_y * args.stride_h + kernel_y * args.dilation_h;
|
||||
let x = out_x * args.stride_w + kernel_x * args.dilation_w;
|
||||
|
||||
tile[m] = FMat::cast_from(
|
||||
input[batch * input.stride(0)
|
||||
+ y * input.stride(2)
|
||||
+ x * input.stride(3)
|
||||
+ channel * input.stride(1)],
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn load_weight_tile<F: Float, FMat: Float>(
|
||||
weight: &Tensor<F>,
|
||||
tile: &mut SliceMut<FMat>,
|
||||
pos: &Positions,
|
||||
k: u32,
|
||||
#[comptime] gemm_settings: GemmSettings,
|
||||
#[comptime] kernel_settings: KernelSettings,
|
||||
) {
|
||||
let GemmSettings {
|
||||
cmma_n,
|
||||
cmma_k,
|
||||
warp_size,
|
||||
..
|
||||
} = gemm_settings;
|
||||
|
||||
let KernelSettings {
|
||||
kernel_h, kernel_w, ..
|
||||
} = kernel_settings;
|
||||
|
||||
let kernel_size = kernel_h * kernel_w;
|
||||
let cmma_filter_tile_size = cmma_k * cmma_n;
|
||||
|
||||
for n in range_stepped(pos.intra_warp_unit_idx, cmma_filter_tile_size, warp_size) {
|
||||
// Compute where in the slice we are starting
|
||||
let rel_slice_row = n / cmma_k; // Relative row (0 - 15)
|
||||
let abs_slice_row = k + rel_slice_row; // Row of the matrix the slice is on
|
||||
let abs_slice_col = pos.global_n + (n % 16); // Row of the matrix the slice is on
|
||||
|
||||
// Given the row of the matrix that the slice is in, and the index of the unit
|
||||
// within a slice, want to compute what weight element to load...
|
||||
let out_channel = abs_slice_col;
|
||||
let in_channel = abs_slice_row / kernel_size;
|
||||
let kernel_y = (abs_slice_row % kernel_size) / kernel_h;
|
||||
let kernel_x = abs_slice_row % kernel_w;
|
||||
|
||||
tile[n] = FMat::cast_from(
|
||||
weight[out_channel * weight.stride(0)
|
||||
+ in_channel * weight.stride(1)
|
||||
+ kernel_y * weight.stride(2)
|
||||
+ kernel_x * weight.stride(3)],
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn can_do_implicit_gemm<R: JitRuntime, E: FloatElement>(
|
||||
input: &JitTensor<R, E, 4>,
|
||||
weight: &JitTensor<R, E, 4>,
|
||||
options: &ConvOptions<2>,
|
||||
out_h: usize,
|
||||
out_w: usize,
|
||||
) -> bool {
|
||||
let [batch_size, in_channels, _, _] = input.shape.dims;
|
||||
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
|
||||
|
||||
let cmma_m = 16;
|
||||
let cmma_n = 16;
|
||||
let cmma_k = 16;
|
||||
let warps_per_cube = 8;
|
||||
|
||||
let gemm_m = batch_size * out_h * out_w;
|
||||
let gemm_n = out_channels;
|
||||
let gemm_k = in_channels * kernel_h * kernel_w;
|
||||
|
||||
let smem_size = ((cmma_m + cmma_n) * cmma_k * warps_per_cube) as usize * size_of::<f16>();
|
||||
|
||||
cmma_available::<R>(&input.device)
|
||||
&& <R::Compiler as Compiler>::max_shared_memory_size() >= smem_size
|
||||
&& gemm_m % 16 == 0
|
||||
&& gemm_n % 16 == 0
|
||||
&& gemm_k % 16 == 0
|
||||
&& options.groups == 1
|
||||
}
|
||||
|
||||
fn cmma_available<R: JitRuntime>(device: &R::JitDevice) -> bool {
|
||||
R::client(device).features().enabled(Feature::Cmma {
|
||||
a: Elem::Float(FloatKind::F16),
|
||||
b: Elem::Float(FloatKind::F16),
|
||||
c: Elem::Float(FloatKind::F32),
|
||||
m: 16,
|
||||
k: 16,
|
||||
n: 16,
|
||||
})
|
||||
}
|
|
@ -0,0 +1,18 @@
|
|||
mod base;
|
||||
mod col2im;
|
||||
mod direct;
|
||||
mod im2col;
|
||||
mod implicit_gemm;
|
||||
mod transpose_direct;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
mod tune;
|
||||
|
||||
pub use base::*;
|
||||
pub use col2im::*;
|
||||
pub use direct::*;
|
||||
pub use im2col::*;
|
||||
pub use implicit_gemm::*;
|
||||
pub use transpose_direct::*;
|
||||
#[cfg(feature = "autotune")]
|
||||
pub use tune::*;
|
|
@ -14,9 +14,9 @@ use crate::{
|
|||
reshape,
|
||||
},
|
||||
tensor::JitTensor,
|
||||
JitRuntime,
|
||||
IntElement, JitRuntime,
|
||||
};
|
||||
use burn_tensor::{ops::ConvTransposeOptions, Element, Shape};
|
||||
use burn_tensor::{ops::ConvTransposeOptions, Shape};
|
||||
|
||||
#[derive(new)]
|
||||
struct Conv2dTransposeEagerKernel<R, E> {
|
||||
|
@ -364,7 +364,15 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn conv_transpose2d<R: JitRuntime, E: JitElement + Element>(
|
||||
/// Perform a 2D convolution transposition using the direct algorithm.
|
||||
///
|
||||
/// * `input` - The input feature map
|
||||
/// * `weight` - The weights (filter) applied to each kernel
|
||||
/// * `bias` - The bias added to each channel
|
||||
/// * `options` - The options to use for the convolution
|
||||
///
|
||||
#[allow(clippy::extra_unused_type_parameters)]
|
||||
pub fn conv_transpose2d_direct<R: JitRuntime, E: JitElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
|
@ -0,0 +1,152 @@
|
|||
use burn_tensor::{
|
||||
ops::{conv::calculate_conv_output_size, ConvOptions},
|
||||
ElementConversion, Shape,
|
||||
};
|
||||
use cubecl::{
|
||||
tune,
|
||||
tune::{local_tuner, tune_with, LocalTuner},
|
||||
AutotuneKey,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
kernel::{
|
||||
conv::{can_do_implicit_gemm, conv2d_direct, conv2d_im2col, conv2d_implicit_gemm},
|
||||
prng::random_uniform,
|
||||
},
|
||||
tensor::JitTensor,
|
||||
FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId,
|
||||
};
|
||||
|
||||
/// Executes autotune on conv2d operations
|
||||
pub fn conv2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weights: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let client = input.client.clone();
|
||||
|
||||
static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();
|
||||
|
||||
TUNER.execute(
|
||||
&JitTuneId::new::<R>(&input.device),
|
||||
&client,
|
||||
Box::new(Conv2dOperations::<R, E, I>::new(
|
||||
input, weights, bias, options,
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
|
||||
/// Autotune key representative of matmul versions
|
||||
pub struct Conv2dAutotuneKey {
|
||||
pub kernel_size: [usize; 2],
|
||||
pub stride: [usize; 2],
|
||||
pub padding: [usize; 2],
|
||||
pub dilation: [usize; 2],
|
||||
pub groups: usize,
|
||||
#[autotune(anchor)]
|
||||
pub in_channels: usize,
|
||||
#[autotune(anchor)]
|
||||
pub out_channels: usize,
|
||||
#[autotune(anchor)]
|
||||
pub height: usize,
|
||||
#[autotune(anchor)]
|
||||
pub width: usize,
|
||||
#[autotune(anchor)]
|
||||
pub batch_size: usize,
|
||||
pub has_bias: bool,
|
||||
}
|
||||
|
||||
#[tune(
|
||||
operations(conv2d_direct, conv2d_im2col, conv2d_implicit_gemm),
|
||||
create_key = create_key,
|
||||
should_run = should_run
|
||||
)]
|
||||
pub fn conv2d_operations<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
key: JitAutotuneKey,
|
||||
input: JitTensor<R, E, 4>,
|
||||
weights: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let device = &input.device;
|
||||
let key = match key {
|
||||
JitAutotuneKey::Conv2d(key) => key,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let random_bounds: (E, E) = ((-1.0).elem::<E>(), (1.0).elem::<E>());
|
||||
let input_shape = Shape::new([key.batch_size, key.in_channels, key.height, key.width]);
|
||||
let input = random_uniform(input_shape, device, random_bounds.0, random_bounds.1);
|
||||
let c_per_grp = key.in_channels / key.groups;
|
||||
let [kernel_h, kernel_w] = key.kernel_size;
|
||||
let weight_shape = Shape::new([key.out_channels, c_per_grp, kernel_h, kernel_w]);
|
||||
let weights = random_uniform(weight_shape, device, random_bounds.0, random_bounds.1);
|
||||
let bias_shape = Shape::new([key.out_channels]);
|
||||
let bias = key
|
||||
.has_bias
|
||||
.then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1));
|
||||
|
||||
tune_with!(input, weights, bias, options)
|
||||
}
|
||||
|
||||
fn should_run<R: JitRuntime, F: FloatElement, I: IntElement>(
|
||||
op: &Conv2dOperations<R, F, I>,
|
||||
_key: &JitAutotuneKey,
|
||||
index: usize,
|
||||
) -> bool {
|
||||
match index {
|
||||
2 => {
|
||||
let [_, _, height, width] = op.input.shape.dims;
|
||||
let [_, _, kernel_h, kernel_w] = op.weights.shape.dims;
|
||||
let o = &op.options;
|
||||
let out_h = calculate_conv_output_size(
|
||||
kernel_h,
|
||||
o.stride[0],
|
||||
o.padding[0],
|
||||
o.dilation[0],
|
||||
height,
|
||||
);
|
||||
let out_w = calculate_conv_output_size(
|
||||
kernel_w,
|
||||
o.stride[1],
|
||||
o.padding[1],
|
||||
o.dilation[1],
|
||||
width,
|
||||
);
|
||||
can_do_implicit_gemm(&op.input, &op.weights, &op.options, out_h, out_w)
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
fn create_key<R: JitRuntime, E: FloatElement>(
|
||||
input: &JitTensor<R, E, 4>,
|
||||
weights: &JitTensor<R, E, 4>,
|
||||
bias: &Option<JitTensor<R, E, 1>>,
|
||||
options: &ConvOptions<2>,
|
||||
) -> JitAutotuneKey {
|
||||
let [batch_size, in_channels, height, width] = input.shape.dims;
|
||||
let [out_channels, _, kernel_h, kernel_w] = weights.shape.dims;
|
||||
let ConvOptions {
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
} = options.clone();
|
||||
JitAutotuneKey::Conv2d(Conv2dAutotuneKey::new(
|
||||
[kernel_h, kernel_w],
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
in_channels,
|
||||
out_channels,
|
||||
height,
|
||||
width,
|
||||
batch_size,
|
||||
bias.is_some(),
|
||||
))
|
||||
}
|
|
@ -0,0 +1,117 @@
|
|||
use burn_tensor::{ops::ConvTransposeOptions, ElementConversion, Shape};
|
||||
use cubecl::{
|
||||
tune,
|
||||
tune::{local_tuner, tune_with, LocalTuner},
|
||||
AutotuneKey,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
kernel::{
|
||||
conv::{conv_transpose2d_col2im, conv_transpose2d_direct},
|
||||
prng::random_uniform,
|
||||
},
|
||||
tensor::JitTensor,
|
||||
FloatElement, IntElement, JitAutotuneKey, JitRuntime, JitTuneId,
|
||||
};
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
|
||||
/// Autotune key representative of matmul versions
|
||||
pub struct ConvTranspose2dAutotuneKey {
|
||||
pub kernel_size: [usize; 2],
|
||||
pub stride: [usize; 2],
|
||||
pub padding: [usize; 2],
|
||||
pub padding_out: [usize; 2],
|
||||
pub dilation: [usize; 2],
|
||||
pub groups: usize,
|
||||
#[autotune(anchor)]
|
||||
pub in_channels: usize,
|
||||
#[autotune(anchor)]
|
||||
pub out_channels: usize,
|
||||
#[autotune(anchor)]
|
||||
pub height: usize,
|
||||
#[autotune(anchor)]
|
||||
pub width: usize,
|
||||
#[autotune(anchor)]
|
||||
pub batch_size: usize,
|
||||
pub has_bias: bool,
|
||||
}
|
||||
|
||||
/// Executes autotune on conv2d operations
|
||||
pub fn conv_transpose2d_autotune<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
weights: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let client = input.client.clone();
|
||||
|
||||
static TUNER: LocalTuner<JitAutotuneKey, JitTuneId> = local_tuner!();
|
||||
|
||||
TUNER.execute(
|
||||
&JitTuneId::new::<R>(&input.device),
|
||||
&client,
|
||||
Box::new(ConvTranspose2dOperations::<R, E, I>::new(
|
||||
input, weights, bias, options,
|
||||
)),
|
||||
)
|
||||
}
|
||||
|
||||
#[tune(operations(conv_transpose2d_direct, conv_transpose2d_col2im), create_key = create_key)]
|
||||
pub fn conv_transpose2d_operations<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
key: JitAutotuneKey,
|
||||
input: JitTensor<R, E, 4>,
|
||||
weights: JitTensor<R, E, 4>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let key = match key {
|
||||
JitAutotuneKey::ConvTranspose2d(key) => key,
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let device = &input.device;
|
||||
|
||||
let random_bounds: (E, E) = ((-1.0).elem::<E>(), (1.0).elem::<E>());
|
||||
let input_shape = Shape::new([key.batch_size, key.in_channels, key.height, key.width]);
|
||||
let input = random_uniform(input_shape, device, random_bounds.0, random_bounds.1);
|
||||
let c_per_grp = key.in_channels / key.groups;
|
||||
let [kernel_h, kernel_w] = key.kernel_size;
|
||||
let weight_shape = Shape::new([key.out_channels, c_per_grp, kernel_h, kernel_w]);
|
||||
let weights = random_uniform(weight_shape, device, random_bounds.0, random_bounds.1);
|
||||
let bias_shape = Shape::new([key.out_channels]);
|
||||
let bias = key
|
||||
.has_bias
|
||||
.then(|| random_uniform(bias_shape, device, random_bounds.0, random_bounds.1));
|
||||
tune_with!(input, weights, bias, options)
|
||||
}
|
||||
|
||||
fn create_key<R: JitRuntime, E: FloatElement>(
|
||||
input: &JitTensor<R, E, 4>,
|
||||
weights: &JitTensor<R, E, 4>,
|
||||
bias: &Option<JitTensor<R, E, 1>>,
|
||||
options: &ConvTransposeOptions<2>,
|
||||
) -> JitAutotuneKey {
|
||||
let [batch_size, in_channels, height, width] = input.shape.dims;
|
||||
let [out_channels, _, kernel_h, kernel_w] = weights.shape.dims;
|
||||
let ConvTransposeOptions {
|
||||
stride,
|
||||
padding,
|
||||
dilation,
|
||||
groups,
|
||||
padding_out,
|
||||
} = options.clone();
|
||||
JitAutotuneKey::ConvTranspose2d(ConvTranspose2dAutotuneKey::new(
|
||||
[kernel_h, kernel_w],
|
||||
stride,
|
||||
padding,
|
||||
padding_out,
|
||||
dilation,
|
||||
groups,
|
||||
in_channels,
|
||||
out_channels,
|
||||
height,
|
||||
width,
|
||||
batch_size,
|
||||
bias.is_some(),
|
||||
))
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
mod conv2d;
|
||||
mod conv_transpose2d;
|
||||
|
||||
pub use conv2d::*;
|
||||
pub use conv_transpose2d::*;
|
|
@ -1,13 +1,13 @@
|
|||
mod conv2d;
|
||||
mod conv3d;
|
||||
mod conv_transpose2d;
|
||||
mod conv_transpose3d;
|
||||
mod deform_conv2d;
|
||||
mod deform_conv_transpose2d;
|
||||
|
||||
pub(crate) use conv2d::*;
|
||||
pub(crate) use conv3d::*;
|
||||
pub(crate) use conv_transpose2d::*;
|
||||
pub(crate) use conv_transpose3d::*;
|
||||
pub(crate) use deform_conv2d::*;
|
||||
pub(crate) use deform_conv_transpose2d::*;
|
||||
|
||||
pub use conv2d::{conv2d, conv_transpose2d, Conv2dStrategy, ConvTranspose2dStrategy};
|
||||
|
|
|
@ -16,7 +16,7 @@ pub enum MatmulStrategy {
|
|||
grid_y: usize,
|
||||
},
|
||||
#[cfg(feature = "autotune")]
|
||||
/// Using autotune to chose the best kernel based on runtime information.
|
||||
/// Using autotune to choose the best kernel based on runtime information.
|
||||
Autotune,
|
||||
/// Cube implementation of matmul.
|
||||
Cube,
|
||||
|
|
|
@ -1,4 +1,10 @@
|
|||
use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
|
||||
use crate::{
|
||||
kernel::{
|
||||
self,
|
||||
conv::{Conv2dStrategy, ConvTranspose2dStrategy},
|
||||
},
|
||||
FloatElement, IntElement, JitBackend, JitRuntime,
|
||||
};
|
||||
use burn_tensor::ops::{
|
||||
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
|
||||
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
|
@ -17,7 +23,7 @@ where
|
|||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: ConvOptions<2>,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
kernel::conv::conv2d(x, weight, bias, options)
|
||||
kernel::conv::conv2d::<R, F, I>(x, weight, bias, options, Conv2dStrategy::default())
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
|
@ -58,7 +64,13 @@ where
|
|||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: ConvTransposeOptions<2>,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
kernel::conv::conv_transpose2d(x, weight, bias, options)
|
||||
kernel::conv::conv_transpose2d::<R, F, I>(
|
||||
x,
|
||||
weight,
|
||||
bias,
|
||||
options,
|
||||
ConvTranspose2dStrategy::default(),
|
||||
)
|
||||
}
|
||||
|
||||
fn conv_transpose3d(
|
||||
|
|
|
@ -6,7 +6,7 @@ mod tests {
|
|||
};
|
||||
|
||||
#[test]
|
||||
fn avg_pool2d_should_work_with_multiple_invocations() {
|
||||
fn avg_pool2d_should_match_reference_backend() {
|
||||
let tensor = Tensor::<TestBackend, 4>::random(
|
||||
[32, 32, 32, 32],
|
||||
Distribution::Default,
|
||||
|
@ -29,7 +29,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn avg_pool2d_backward_should_work_with_multiple_invocations() {
|
||||
fn avg_pool2d_backward_should_match_reference_backend() {
|
||||
TestBackend::seed(0);
|
||||
ReferenceBackend::seed(0);
|
||||
let tensor = Tensor::<TestBackend, 4>::random(
|
||||
|
|
|
@ -4,12 +4,12 @@ mod tests {
|
|||
use burn_tensor::{backend::Backend, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn cat_should_support_multiple_invocations_dim0() {
|
||||
fn cat_should_match_reference_backend_dim0() {
|
||||
test_same_as_reference([6, 256], 2, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cat_should_support_multiple_invocations_dim1() {
|
||||
fn cat_should_match_reference_backend_dim1() {
|
||||
test_same_as_reference([6, 256], 2, 1);
|
||||
}
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ mod tests {
|
|||
use burn_tensor::{module, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn conv2d_should_work_with_multiple_invocations() {
|
||||
fn conv2d_should_match_reference_backend() {
|
||||
let test_device = Default::default();
|
||||
let input =
|
||||
Tensor::<TestBackend, 4>::random([6, 16, 32, 32], Distribution::Default, &test_device);
|
||||
|
@ -26,4 +26,28 @@ mod tests {
|
|||
.into_data()
|
||||
.assert_approx_eq(&output_ref.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conv2d_should_match_reference_backend_implicit() {
|
||||
let test_device = Default::default();
|
||||
let input =
|
||||
Tensor::<TestBackend, 4>::random([4, 16, 6, 6], Distribution::Default, &test_device);
|
||||
let weight =
|
||||
Tensor::<TestBackend, 4>::random([16, 16, 3, 3], Distribution::Default, &test_device);
|
||||
let bias = Tensor::<TestBackend, 1>::random([16], Distribution::Default, &test_device);
|
||||
let ref_device = Default::default();
|
||||
|
||||
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data(), &ref_device);
|
||||
let weight_ref = Tensor::<ReferenceBackend, 4>::from_data(weight.to_data(), &ref_device);
|
||||
let bias_ref = Tensor::<ReferenceBackend, 1>::from_data(bias.to_data(), &ref_device);
|
||||
|
||||
let options = burn_tensor::ops::ConvOptions::new([1, 1], [2, 2], [1, 1], 1);
|
||||
|
||||
let output = module::conv2d(input, weight, Some(bias), options.clone());
|
||||
let output_ref = module::conv2d(input_ref, weight_ref, Some(bias_ref), options);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq(&output_ref.into_data(), 1);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,7 @@ mod tests {
|
|||
use burn_tensor::{module, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn conv3d_should_work_with_multiple_invocations() {
|
||||
fn conv3d_should_match_reference_backend() {
|
||||
let test_device = Default::default();
|
||||
let input = Tensor::<TestBackend, 5>::random(
|
||||
[6, 16, 32, 32, 32],
|
||||
|
|
|
@ -4,7 +4,7 @@ mod tests {
|
|||
use burn_tensor::{backend::Backend, module, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn conv_transpose2d_should_work_with_multiple_invocations() {
|
||||
fn conv_transpose2d_should_match_reference_backend() {
|
||||
TestBackend::seed(0);
|
||||
|
||||
let height = 8;
|
||||
|
|
|
@ -4,7 +4,7 @@ mod tests {
|
|||
use burn_tensor::{backend::Backend, module, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn conv_transpose3d_should_work_with_multiple_invocations() {
|
||||
fn conv_transpose3d_should_match_reference_backend() {
|
||||
TestBackend::seed(0);
|
||||
|
||||
let depth = 8;
|
||||
|
|
|
@ -5,7 +5,7 @@ mod tests {
|
|||
use burn_tensor::{Bool, Distribution, Tensor, TensorPrimitive};
|
||||
|
||||
#[test]
|
||||
fn mask_fill_should_work_with_multiple_invocations() {
|
||||
fn mask_fill_should_match_reference_backend() {
|
||||
let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill();
|
||||
|
||||
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_fill(
|
||||
|
@ -22,7 +22,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn mask_fill_inplace_should_work_with_multiple_invocations() {
|
||||
fn mask_fill_inplace_should_match_reference_backend() {
|
||||
let (tensor, mask, tensor_ref, mask_ref) = inputs_mask_fill();
|
||||
|
||||
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_fill(
|
||||
|
|
|
@ -5,7 +5,7 @@ mod tests {
|
|||
use burn_tensor::{backend::Backend, Bool, Distribution, Tensor, TensorPrimitive};
|
||||
|
||||
#[test]
|
||||
fn mask_where_should_work_with_multiple_invocations() {
|
||||
fn mask_where_should_match_reference_backend() {
|
||||
let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();
|
||||
|
||||
let actual = tensor.mask_where(mask, value);
|
||||
|
@ -16,7 +16,7 @@ mod tests {
|
|||
.assert_approx_eq(&actual.into_data(), 3);
|
||||
}
|
||||
#[test]
|
||||
fn mask_where_inplace_lhs_should_work_with_multiple_invocations() {
|
||||
fn mask_where_inplace_lhs_should_match_reference_backend() {
|
||||
let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();
|
||||
|
||||
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_where(
|
||||
|
@ -33,7 +33,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn mask_where_inplace_rhs_should_work_with_multiple_invocation() {
|
||||
fn mask_where_inplace_rhs_should_match_reference_backend() {
|
||||
let (tensor, value, mask, tensor_ref, value_ref, mask_ref) = inputs_mask_where();
|
||||
|
||||
let actual = Tensor::<TestBackend, 3>::from_primitive(TensorPrimitive::Float(mask_where(
|
||||
|
|
|
@ -4,7 +4,7 @@ mod tests {
|
|||
use burn_tensor::{module, Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
pub fn max_pool2d_should_work_with_multiple_invocations() {
|
||||
pub fn max_pool2d_should_match_reference_backends() {
|
||||
let tensor = Tensor::<TestBackend, 4>::random(
|
||||
[32, 32, 32, 32],
|
||||
Distribution::Default,
|
||||
|
@ -26,7 +26,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
pub fn max_pool2d_with_indices_should_work_with_multiple_invocations() {
|
||||
pub fn max_pool2d_with_indices_should_match_reference_backend() {
|
||||
let tensor = Tensor::<TestBackend, 4>::random(
|
||||
[32, 32, 32, 32],
|
||||
Distribution::Default,
|
||||
|
|
|
@ -4,7 +4,7 @@ mod tests {
|
|||
use burn_tensor::{module, ops::ModuleOps, Distribution, Tensor, TensorPrimitive};
|
||||
|
||||
#[test]
|
||||
pub fn max_pool2d_with_indices_backward_should_work_with_multiple_invocations() {
|
||||
pub fn max_pool2d_with_indices_backward_should_match_reference_backend() {
|
||||
let test_device = Default::default();
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 4>::random([32, 32, 32, 32], Distribution::Default, &test_device);
|
||||
|
|
|
@ -10,7 +10,7 @@ mod reduction {
|
|||
};
|
||||
|
||||
#[test]
|
||||
fn reduction_sum_dim_should_work_with_multiple_invocations() {
|
||||
fn reduction_sum_dim_should_match_reference_backend() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default, &Default::default());
|
||||
let tensor_ref =
|
||||
|
@ -33,7 +33,7 @@ mod reduction {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_prod_dim_should_work_with_multiple_invocations() {
|
||||
fn reduction_prod_dim_should_match_reference_backend() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default, &Default::default());
|
||||
let tensor_ref =
|
||||
|
@ -56,7 +56,7 @@ mod reduction {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_argmin_dim_should_work_with_multiple_invocations() {
|
||||
fn reduction_argmin_dim_should_match_reference_backend() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default, &Default::default());
|
||||
let tensor_ref =
|
||||
|
@ -75,7 +75,7 @@ mod reduction {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_argmax_dim_should_work_with_multiple_invocations() {
|
||||
fn reduction_argmax_dim_should_match_reference_backend() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default, &Default::default());
|
||||
let tensor_ref =
|
||||
|
@ -290,7 +290,7 @@ mod reduction {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_sum_should_work_with_multiple_invocations() {
|
||||
fn reduction_sum_should_match_reference_backend() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default, &Default::default());
|
||||
let tensor_ref =
|
||||
|
@ -306,7 +306,7 @@ mod reduction {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_prod_should_work_with_multiple_invocations() {
|
||||
fn reduction_prod_should_match_reference_backend() {
|
||||
let tensor =
|
||||
Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default, &Default::default());
|
||||
let tensor_ref =
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
use crate::kernel::{matmul::MatmulAutotuneKey, reduce::ReduceAutotuneKey};
|
||||
use crate::kernel::{
|
||||
conv::{Conv2dAutotuneKey, ConvTranspose2dAutotuneKey},
|
||||
matmul::MatmulAutotuneKey,
|
||||
reduce::ReduceAutotuneKey,
|
||||
};
|
||||
use cubecl::tune::AutotuneKey;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
|
@ -13,6 +17,10 @@ pub enum JitAutotuneKey {
|
|||
Matmul(MatmulAutotuneKey),
|
||||
/// Key for reduce dim operations
|
||||
ReduceDim(ReduceAutotuneKey),
|
||||
/// Key for convolution operations
|
||||
Conv2d(Conv2dAutotuneKey),
|
||||
/// Key for transpose convolution operations
|
||||
ConvTranspose2d(ConvTranspose2dAutotuneKey),
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
/// Key for fused element wise operations.
|
||||
FusionElemWise(FusionElemWiseAutotuneKey),
|
||||
|
@ -25,6 +33,8 @@ impl Display for JitAutotuneKey {
|
|||
JitAutotuneKey::ReduceDim(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
JitAutotuneKey::FusionElemWise(reduce_key) => std::fmt::Display::fmt(&reduce_key, f),
|
||||
JitAutotuneKey::Conv2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f),
|
||||
JitAutotuneKey::ConvTranspose2d(conv2d_key) => std::fmt::Display::fmt(&conv2d_key, f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -109,6 +109,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
let [stride_height, stride_width] = options.stride;
|
||||
let [batch_size, _in_channels, in_height, in_width] = x.shape().dims;
|
||||
let [out_channels, in_channels, kernel_height, kernel_width] = weight.shape().dims;
|
||||
let channels_per_group = out_channels / options.groups;
|
||||
|
||||
let out_height = calculate_conv_output_size(
|
||||
kernel_height,
|
||||
|
@ -141,7 +142,7 @@ pub(crate) fn conv2d<E: FloatNdArrayElement, Q: QuantElement>(
|
|||
|(k, mut output)| {
|
||||
let b = k / out_channels;
|
||||
let oc = k % out_channels;
|
||||
let g = k % options.groups;
|
||||
let g = oc / channels_per_group;
|
||||
|
||||
for ic in (in_channels * g)..(in_channels * (g + 1)) {
|
||||
let weight_ic = ic - (g * in_channels);
|
||||
|
|
|
@ -43,7 +43,7 @@ pub fn calculate_conv_transpose_output_size(
|
|||
dilation: usize,
|
||||
size_in: usize,
|
||||
) -> usize {
|
||||
(size_in - 1) * stride + dilation * (kernel_size - 1) + padding_out - 2 * padding + 1
|
||||
(size_in - 1) * stride + (dilation * (kernel_size - 1) + 1) + padding_out - 2 * padding
|
||||
}
|
||||
|
||||
/// Calculate the expected output size when doing a pooling operation.
|
||||
|
|
|
@ -69,6 +69,49 @@ mod tests {
|
|||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_groups_multiple_channels() {
|
||||
let test = Conv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 4,
|
||||
channels_out: 4,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 2,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[
|
||||
[
|
||||
[4035., 4188., 4341.],
|
||||
[4800., 4953., 5106.],
|
||||
[5565., 5718., 5871.],
|
||||
],
|
||||
[
|
||||
[10030., 10507., 10984.],
|
||||
[12415., 12892., 13369.],
|
||||
[14800., 15277., 15754.],
|
||||
],
|
||||
[
|
||||
[56075., 56876., 57677.],
|
||||
[60080., 60881., 61682.],
|
||||
[64085., 64886., 65687.],
|
||||
],
|
||||
[
|
||||
[78270., 79395., 80520.],
|
||||
[83895., 85020., 86145.],
|
||||
[89520., 90645., 91770.],
|
||||
],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv2d_complex() {
|
||||
let test = Conv2dTestCase {
|
||||
|
|
|
@ -28,6 +28,7 @@ mod tests {
|
|||
|
||||
test.assert_output(TestTensor::from([[[[5.0, 11.0], [23.0, 29.0]]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_simple_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
|
@ -71,6 +72,34 @@ mod tests {
|
|||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_simple_3() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 1,
|
||||
channels_out: 1,
|
||||
kernel_size_1: 2,
|
||||
kernel_size_2: 2,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
padding_out_1: 0,
|
||||
padding_out_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
groups: 1,
|
||||
height: 2,
|
||||
width: 2,
|
||||
};
|
||||
|
||||
test.assert_output(TestTensor::from([[[
|
||||
[0.0, 0.0, 1.0],
|
||||
[0.0, 4.0, 6.0],
|
||||
[4.0, 12.0, 9.0],
|
||||
]]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_conv_transpose2d_stride_2() {
|
||||
let test = ConvTranspose2dTestCase {
|
||||
|
|
Loading…
Reference in New Issue