Add deform_conv2d as implemented in torchvision (#2147)

This commit is contained in:
Genna Wingert 2024-09-23 21:17:23 +02:00 committed by GitHub
parent f19e0c5393
commit 2c8514ce7f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 4586 additions and 18 deletions

7
Cargo.lock generated
View File

@ -245,6 +245,12 @@ version = "1.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
[[package]]
name = "atomic_float"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a"
[[package]]
name = "atty"
version = "0.2.14"
@ -672,6 +678,7 @@ dependencies = [
name = "burn-ndarray"
version = "0.15.0"
dependencies = [
"atomic_float",
"blas-src",
"burn-autodiff",
"burn-common",

View File

@ -27,6 +27,7 @@ readme = "README.md"
version = "0.15.0"
[workspace.dependencies]
atomic_float = "1"
bytemuck = "1.18.0"
candle-core = { version = "0.6.0" }
clap = { version = "4.5.18", features = ["derive"] }

View File

@ -190,6 +190,7 @@ Burn comes with built-in modules that you can use to build your own modules.
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
| `ConvTranspose2d` | `nn.ConvTranspose2d` |
| `ConvTranspose3d` | `nn.ConvTranspose3d` |
| `DeformConv2d` | `torchvision.ops.DeformConv2d` |
### Pooling

View File

@ -441,6 +441,343 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
}
}
fn deform_conv2d(
x: AutodiffTensor<B, 4>,
offset: AutodiffTensor<B, 4>,
weight: AutodiffTensor<B, 4>,
mask: Option<AutodiffTensor<B, 4>>,
bias: Option<AutodiffTensor<B, 1>>,
options: DeformConvOptions<2>,
) -> AutodiffTensor<B, 4> {
#[derive(Debug)]
struct DeformConv2DWithMaskWithBias;
#[derive(Debug)]
struct DeformConv2DWithMaskNoBias;
#[derive(Debug)]
struct DeformConv2DNoMaskWithBias;
#[derive(Debug)]
struct DeformConv2DNoMaskNoBias;
impl<B: Backend> Backward<B, 4, 5> for DeformConv2DWithMaskWithBias {
type State = (NodeID, NodeID, NodeID, NodeID, NodeID, DeformConvOptions<2>);
fn backward(
self,
ops: Ops<Self::State, 5>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let [node_x, node_offset, node_weight, node_mask, node_bias] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, offset_state, weight_state, mask_state, bias_state, options) =
ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let offset = checkpointer.retrieve_node_output(offset_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let mask = Some(checkpointer.retrieve_node_output(mask_state));
let bias = Some(checkpointer.retrieve_node_output(bias_state));
let backward =
B::deform_conv2d_backward(x, offset, weight, mask, bias, grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad)
}
if let Some(node) = node_offset {
grads.register::<B, 4>(node.id, backward.offset_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weight_grad)
}
if let Some(node) = node_mask {
grads.register::<B, 4>(node.id, backward.mask_grad.unwrap())
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
}
}
}
impl<B: Backend> Backward<B, 4, 4> for DeformConv2DWithMaskNoBias {
type State = (NodeID, NodeID, NodeID, NodeID, DeformConvOptions<2>);
fn backward(
self,
ops: Ops<Self::State, 4>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let [node_x, node_offset, node_weight, node_mask] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, offset_state, weight_state, mask_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let offset = checkpointer.retrieve_node_output(offset_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let mask = Some(checkpointer.retrieve_node_output(mask_state));
let backward =
B::deform_conv2d_backward(x, offset, weight, mask, None, grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad)
}
if let Some(node) = node_offset {
grads.register::<B, 4>(node.id, backward.offset_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weight_grad)
}
if let Some(node) = node_mask {
grads.register::<B, 4>(node.id, backward.mask_grad.unwrap())
}
}
}
impl<B: Backend> Backward<B, 4, 4> for DeformConv2DNoMaskWithBias {
type State = (NodeID, NodeID, NodeID, NodeID, DeformConvOptions<2>);
fn backward(
self,
ops: Ops<Self::State, 4>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let [node_x, node_offset, node_weight, node_bias] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, offset_state, weight_state, bias_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let offset = checkpointer.retrieve_node_output(offset_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let bias = Some(checkpointer.retrieve_node_output(bias_state));
let backward =
B::deform_conv2d_backward(x, offset, weight, None, bias, grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad)
}
if let Some(node) = node_offset {
grads.register::<B, 4>(node.id, backward.offset_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weight_grad)
}
if let Some(node) = node_bias {
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
}
}
}
impl<B: Backend> Backward<B, 4, 3> for DeformConv2DNoMaskNoBias {
type State = (NodeID, NodeID, NodeID, DeformConvOptions<2>);
fn backward(
self,
ops: Ops<Self::State, 3>,
grads: &mut Gradients,
checkpointer: &mut Checkpointer,
) {
let [node_x, node_offset, node_weight] = ops.parents;
let grad = grads.consume::<B, 4>(&ops.node);
let (x_state, offset_state, weight_state, options) = ops.state;
let x = checkpointer.retrieve_node_output(x_state);
let offset = checkpointer.retrieve_node_output(offset_state);
let weight = checkpointer.retrieve_node_output(weight_state);
let backward =
B::deform_conv2d_backward(x, offset, weight, None, None, grad, options);
if let Some(node) = node_x {
grads.register::<B, 4>(node.id, backward.x_grad)
}
if let Some(node) = node_offset {
grads.register::<B, 4>(node.id, backward.offset_grad)
}
if let Some(node) = node_weight {
grads.register::<B, 4>(node.id, backward.weight_grad)
}
}
}
match (mask, bias) {
(Some(mask), Some(bias)) => match DeformConv2DWithMaskWithBias
.prepare::<C>([
x.node.clone(),
offset.node.clone(),
weight.node.clone(),
mask.node.clone(),
bias.node.clone(),
])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);
let offset_state = prep.checkpoint(&offset);
let weight_state = prep.checkpoint(&weight);
let mask_state = prep.checkpoint(&mask);
let bias_state = prep.checkpoint(&bias);
prep.finish(
(
x_state,
offset_state,
weight_state,
mask_state,
bias_state,
options.clone(),
),
B::deform_conv2d(
x.primitive,
offset.primitive,
weight.primitive,
Some(mask.primitive),
Some(bias.primitive),
options,
),
)
}
OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
x.primitive,
offset.primitive,
weight.primitive,
Some(mask.primitive),
Some(bias.primitive),
options,
)),
},
(Some(mask), None) => match DeformConv2DWithMaskNoBias
.prepare::<C>([
x.node.clone(),
offset.node.clone(),
weight.node.clone(),
mask.node.clone(),
])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);
let offset_state = prep.checkpoint(&offset);
let weight_state = prep.checkpoint(&weight);
let mask_state = prep.checkpoint(&mask);
prep.finish(
(
x_state,
offset_state,
weight_state,
mask_state,
options.clone(),
),
B::deform_conv2d(
x.primitive,
offset.primitive,
weight.primitive,
Some(mask.primitive),
None,
options,
),
)
}
OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
x.primitive,
offset.primitive,
weight.primitive,
Some(mask.primitive),
None,
options,
)),
},
(None, Some(bias)) => match DeformConv2DNoMaskWithBias
.prepare::<C>([
x.node.clone(),
offset.node.clone(),
weight.node.clone(),
bias.node.clone(),
])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);
let offset_state = prep.checkpoint(&offset);
let weight_state = prep.checkpoint(&weight);
let bias_state = prep.checkpoint(&bias);
prep.finish(
(
x_state,
offset_state,
weight_state,
bias_state,
options.clone(),
),
B::deform_conv2d(
x.primitive,
offset.primitive,
weight.primitive,
None,
Some(bias.primitive),
options,
),
)
}
OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
x.primitive,
offset.primitive,
weight.primitive,
None,
Some(bias.primitive),
options,
)),
},
(None, None) => match DeformConv2DNoMaskNoBias
.prepare::<C>([x.node.clone(), offset.node.clone(), weight.node.clone()])
.compute_bound()
.stateful()
{
OpsKind::Tracked(mut prep) => {
let x_state = prep.checkpoint(&x);
let offset_state = prep.checkpoint(&offset);
let weight_state = prep.checkpoint(&weight);
prep.finish(
(x_state, offset_state, weight_state, options.clone()),
B::deform_conv2d(
x.primitive,
offset.primitive,
weight.primitive,
None,
None,
options,
),
)
}
OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
x.primitive,
offset.primitive,
weight.primitive,
None,
None,
options,
)),
},
}
}
fn deform_conv2d_backward(
_x: AutodiffTensor<B, 4>,
_offset: AutodiffTensor<B, 4>,
_weight: AutodiffTensor<B, 4>,
_mask: Option<AutodiffTensor<B, 4>>,
_bias: Option<AutodiffTensor<B, 1>>,
_output_grad: AutodiffTensor<B, 4>,
_options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
panic!("Can't differentiate deform conv 2d backward.");
}
fn conv_transpose2d(
x: AutodiffTensor<B, 4>,
weight: AutodiffTensor<B, 4>,

File diff suppressed because it is too large Load Diff

View File

@ -21,6 +21,7 @@ mod conv_transpose2d;
mod conv_transpose3d;
mod cos;
mod cross_entropy;
mod deform_conv2d;
mod div;
mod erf;
mod exp;
@ -82,6 +83,8 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_conv1d!();
burn_autodiff::testgen_ad_conv2d!();
burn_autodiff::testgen_ad_conv3d!();
#[cfg(not(target_os = "macos"))] // Wgpu on MacOS currently doesn't support atomic compare exchange
burn_autodiff::testgen_ad_deform_conv2d!();
burn_autodiff::testgen_ad_conv_transpose1d!();
burn_autodiff::testgen_ad_conv_transpose2d!();
burn_autodiff::testgen_ad_conv_transpose3d!();

View File

@ -1,7 +1,8 @@
use burn_tensor::{
ops::{
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateMode,
InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, UnfoldOptions,
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
IntTensor, InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices,
ModuleOps, UnfoldOptions,
},
Shape,
};
@ -77,6 +78,29 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
})
}
fn deform_conv2d(
x: FloatTensor<Self, 4>,
offset: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
mask: Option<FloatTensor<Self, 4>>,
bias: Option<FloatTensor<Self, 1>>,
options: DeformConvOptions<2>,
) -> FloatTensor<Self, 4> {
unimplemented!("Candle does not support deformable convolutions")
}
fn deform_conv2d_backward(
x: FloatTensor<Self, 4>,
offset: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
mask: Option<FloatTensor<Self, 4>>,
bias: Option<FloatTensor<Self, 1>>,
output_grad: FloatTensor<Self, 4>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
unimplemented!("Candle does not support deformable convolutions")
}
fn conv3d(
x: FloatTensor<Self, 5>,
weight: FloatTensor<Self, 5>,

View File

@ -0,0 +1,263 @@
use alloc::format;
use burn_tensor::ops::DeformConvOptions;
use crate as burn;
use crate::config::Config;
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
use crate::nn::Initializer;
use crate::nn::PaddingConfig2d;
use crate::tensor::backend::Backend;
use crate::tensor::module::deform_conv2d;
use crate::tensor::Tensor;
use crate::nn::conv::checks;
/// Configuration to create a [deformable 2D convolution](DeformConv2d) layer, using the [init function](DeformConv2dConfig::init).
#[derive(Config, Debug)]
pub struct DeformConv2dConfig {
/// The number of channels.
pub channels: [usize; 2],
/// The size of the kernel.
pub kernel_size: [usize; 2],
/// The stride of the convolution.
#[config(default = "[1, 1]")]
pub stride: [usize; 2],
/// Spacing between kernel elements.
#[config(default = "[1, 1]")]
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
#[config(default = "1")]
pub weight_groups: usize,
/// Offset groups.
#[config(default = "1")]
pub offset_groups: usize,
/// The padding configuration.
#[config(default = "PaddingConfig2d::Valid")]
pub padding: PaddingConfig2d,
/// If bias should be added to the output.
#[config(default = true)]
pub bias: bool,
/// The type of function used to initialize neural network parameters
#[config(
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
)]
pub initializer: Initializer,
}
/// Applies a deformable 2D convolution over input tensors.
///
/// Should be created with [DeformConv2dConfig].
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct DeformConv2d<B: Backend> {
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
pub weight: Param<Tensor<B, 4>>,
/// Tensor of shape `[channels_out]`
pub bias: Option<Param<Tensor<B, 1>>>,
/// Stride of the convolution.
pub stride: [usize; 2],
/// Size of the kernel.
pub kernel_size: [usize; 2],
/// Spacing between kernel elements.
pub dilation: [usize; 2],
/// Controls the connections between input and output channels.
pub weight_groups: usize,
/// Offset groups.
pub offset_groups: usize,
/// The padding configuration.
pub padding: Ignored<PaddingConfig2d>,
}
impl DeformConv2dConfig {
/// Initialize a new [DeformConv2d](DeformConv2d) module.
pub fn init<B: Backend>(&self, device: &B::Device) -> DeformConv2d<B> {
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups);
let shape = [
self.channels[1],
self.channels[0] / self.weight_groups,
self.kernel_size[0],
self.kernel_size[1],
];
let k = self.kernel_size.iter().product::<usize>();
let fan_in = self.channels[0] / self.weight_groups * k;
let fan_out = self.channels[1] / self.weight_groups * k;
let weight = self
.initializer
.init_with(shape, Some(fan_in), Some(fan_out), device);
let mut bias = None;
if self.bias {
bias = Some(self.initializer.init_with(
[self.channels[1]],
Some(fan_in),
Some(fan_out),
device,
));
}
DeformConv2d {
weight,
bias,
stride: self.stride,
kernel_size: self.kernel_size,
dilation: self.dilation,
padding: Ignored(self.padding.clone()),
weight_groups: self.weight_groups,
offset_groups: self.weight_groups,
}
}
}
impl<B: Backend> ModuleDisplay for DeformConv2d<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
// Since padding does not implement ModuleDisplay, we need to format it manually.
let padding_formatted = format!("{}", &self.padding);
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
let stride = format!("{:?}", self.stride);
let kernel_size = format!("{:?}", self.kernel_size);
let dilation = format!("{:?}", self.dilation);
content
.add("stride", &stride)
.add("kernel_size", &kernel_size)
.add("dilation", &dilation)
.add("weight_groups", &self.weight_groups)
.add("offset_groups", &self.offset_groups)
.add("padding", &padding_formatted)
.optional()
}
}
impl<B: Backend> DeformConv2d<B> {
/// Applies the forward pass on the input tensor.
///
/// See [deform_conv2d](crate::tensor::module::deform_conv2d) for more information.
///
/// # Shapes
///
/// - input: `[batch_size, channels_in, height_in, width_in]`
/// - offset: `[batch_size, 2 * offset_groups * kernel_height * kernel_width, height_out, width_out]`
/// - mask: `[batch_size, offset_groups * kernel_height * kernel_width, height_out, width_out]`
/// - output: `[batch_size, channels_out, height_out, width_out]`
pub fn forward(
&self,
input: Tensor<B, 4>,
offset: Tensor<B, 4>,
mask: Option<Tensor<B, 4>>,
) -> Tensor<B, 4> {
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
let padding =
self.padding
.calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);
deform_conv2d(
input,
offset,
self.weight.val(),
mask,
self.bias.as_ref().map(|bias| bias.val()),
DeformConvOptions::new(
self.stride,
padding,
self.dilation,
self.weight_groups,
self.offset_groups,
),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::TensorData;
use crate::TestBackend;
#[test]
fn initializer_default() {
TestBackend::seed(0);
let config = DeformConv2dConfig::new([5, 1], [5, 5]);
let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;
let k = (config.offset_groups as f64 / k).sqrt() as f32;
let device = Default::default();
let conv = config.init::<TestBackend>(&device);
conv.weight.to_data().assert_within_range(-k..k);
}
#[test]
fn initializer_zeros() {
TestBackend::seed(0);
let config = DeformConv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
let device = Default::default();
let conv = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, Initializer::Zeros);
conv.weight
.to_data()
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
}
#[test]
fn initializer_fan_out() {
TestBackend::seed(0);
let init = Initializer::KaimingUniform {
gain: 1.0 / 3.0f64.sqrt(),
fan_out_only: true, // test that fan_out is passed to `init_with()`
};
let device = Default::default();
let config = DeformConv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
let _ = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, init);
}
#[test]
fn initializer_fan_with_groups_is_valid() {
TestBackend::seed(0);
let init = Initializer::KaimingUniform {
gain: 1.0 / 3.0f64.sqrt(),
fan_out_only: true,
};
let device = Default::default();
let config = DeformConv2dConfig::new([4, 4], [1, 1])
.with_initializer(init.clone())
.with_weight_groups(4);
let _ = config.init::<TestBackend>(&device);
assert_eq!(config.initializer, init);
}
#[test]
#[should_panic = "Both channels must be divisible by the number of groups."]
fn channels_with_groups_is_invalid() {
let device = Default::default();
let config = DeformConv2dConfig::new([1, 4], [1, 1]).with_weight_groups(4);
let _ = config.init::<TestBackend>(&device);
}
#[test]
fn display() {
let config = DeformConv2dConfig::new([5, 1], [5, 5]);
let conv = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{}", conv),
"DeformConv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], weight_groups: 1, offset_groups: 1, padding: Valid, params: 126}"
);
}
}

View File

@ -4,6 +4,7 @@ mod conv3d;
mod conv_transpose1d;
mod conv_transpose2d;
mod conv_transpose3d;
mod deform_conv2d;
pub(crate) mod checks;
@ -13,3 +14,4 @@ pub use conv3d::*;
pub use conv_transpose1d::*;
pub use conv_transpose2d::*;
pub use conv_transpose3d::*;
pub use deform_conv2d::*;

View File

@ -5,9 +5,9 @@ use burn_tensor::{
calculate_conv_output_size, calculate_conv_transpose_output_size,
calculate_pool_output_size,
},
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions,
MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices,
ModuleOps,
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
IntTensor, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
MaxPool2dWithIndices, ModuleOps,
},
repr::*,
Element,
@ -153,6 +153,211 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
out
}
fn deform_conv2d(
x: FloatTensor<Self, 4>,
offset: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
mask: Option<FloatTensor<Self, 4>>,
bias: Option<FloatTensor<Self, 1>>,
options: DeformConvOptions<2>,
) -> FloatTensor<Self, 4> {
make_ops!(
DeformConv2dOps,
DeformConv2dDescription,
|args: DeformConv2dDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let offset = handles.get_float_tensor::<B, 4>(&args.offset);
let weight = handles.get_float_tensor::<B, 4>(&args.weight);
let mask = args
.mask
.as_ref()
.map(|mask| handles.get_float_tensor::<B, 4>(mask));
let bias = args
.bias
.as_ref()
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
let output =
B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into());
handles.register_float_tensor::<B, 4>(&args.out.id, output);
}
);
let size_0 = calculate_conv_output_size(
weight.shape[2],
options.stride[0],
options.padding[0],
options.dilation[0],
x.shape[2],
);
let size_1 = calculate_conv_output_size(
weight.shape[3],
options.stride[1],
options.padding[1],
options.dilation[1],
x.shape[3],
);
let stream_1 = x.stream;
let stream_2 = offset.stream;
let stream_3 = weight.stream;
let stream_4 = mask.as_ref().map(|m| m.stream);
let stream_5 = bias.as_ref().map(|b| b.stream);
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = DeformConv2dDescription {
x: x.into_description(),
offset: offset.into_description(),
weight: weight.into_description(),
mask: mask.map(|mask| mask.into_description()),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out: out.to_description_out(),
};
let streams = match (stream_4, stream_5) {
(Some(stream_4), Some(stream_5)) => {
vec![stream_1, stream_2, stream_3, stream_4, stream_5]
}
(Some(stream_4), None) => {
vec![stream_1, stream_2, stream_3, stream_4]
}
(None, Some(stream_5)) => {
vec![stream_1, stream_2, stream_3, stream_5]
}
(None, None) => vec![stream_1, stream_2, stream_3],
};
out.client.register(
streams,
OperationDescription::Module(ModuleOperationDescription::DeformableConv2d(Box::new(
desc.clone(),
))),
DeformConv2dOps::<B>::new(desc),
);
out
}
fn deform_conv2d_backward(
x: FloatTensor<Self, 4>,
offset: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
mask: Option<FloatTensor<Self, 4>>,
bias: Option<FloatTensor<Self, 1>>,
output_grad: FloatTensor<Self, 4>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
make_ops!(
DeformConv2dBackwardOps,
DeformConv2dBackwardDescription,
|args: DeformConv2dBackwardDescription, handles: &mut HandleContainer<B::Handle>| {
let x = handles.get_float_tensor::<B, 4>(&args.x);
let offset = handles.get_float_tensor::<B, 4>(&args.offset);
let weight = handles.get_float_tensor::<B, 4>(&args.weight);
let mask = args
.mask
.as_ref()
.map(|mask| handles.get_float_tensor::<B, 4>(mask));
let bias = args
.bias
.as_ref()
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
let output_grad = handles.get_float_tensor::<B, 4>(&args.out_grad);
let output = B::deform_conv2d_backward(
x,
offset,
weight,
mask,
bias,
output_grad,
args.options.clone().into(),
);
handles.register_float_tensor::<B, 4>(&args.input_grad.id, output.x_grad);
handles.register_float_tensor::<B, 4>(&args.offset_grad.id, output.offset_grad);
handles.register_float_tensor::<B, 4>(&args.weight_grad.id, output.weight_grad);
if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) {
handles.register_float_tensor::<B, 4>(&field.id, mask_grad);
}
if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) {
handles.register_float_tensor::<B, 1>(&field.id, bias_grad);
}
}
);
let input_grad = x
.client
.tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
let offset_grad = offset
.client
.tensor_uninitialized(offset.shape.clone(), B::FloatElem::dtype());
let weight_grad = offset
.client
.tensor_uninitialized(weight.shape.clone(), B::FloatElem::dtype());
let mask_grad = mask.as_ref().map(|mask| {
offset
.client
.tensor_uninitialized(mask.shape.clone(), B::FloatElem::dtype())
});
let bias_grad = bias.as_ref().map(|bias| {
offset
.client
.tensor_uninitialized(bias.shape.clone(), B::FloatElem::dtype())
});
let stream_1 = x.stream;
let stream_2 = offset.stream;
let stream_3 = weight.stream;
let stream_4 = mask.as_ref().map(|m| m.stream);
let stream_5 = bias.as_ref().map(|b| b.stream);
let stream_6 = output_grad.stream;
let desc = DeformConv2dBackwardDescription {
x: x.into_description(),
offset: offset.into_description(),
weight: weight.into_description(),
mask: mask.map(|mask| mask.into_description()),
bias: bias.map(|bias| bias.into_description()),
options: options.into(),
out_grad: output_grad.into_description(),
input_grad: input_grad.to_description_out(),
offset_grad: offset_grad.to_description_out(),
weight_grad: weight_grad.to_description_out(),
mask_grad: mask_grad
.as_ref()
.map(|mask_grad| mask_grad.to_description_out()),
bias_grad: bias_grad
.as_ref()
.map(|bias_grad| bias_grad.to_description_out()),
};
let streams = match (stream_4, stream_5) {
(Some(stream_4), Some(stream_5)) => {
vec![stream_1, stream_2, stream_3, stream_4, stream_5, stream_6]
}
(Some(stream_4), None) => {
vec![stream_1, stream_2, stream_3, stream_4, stream_6]
}
(None, Some(stream_5)) => {
vec![stream_1, stream_2, stream_3, stream_5, stream_6]
}
(None, None) => vec![stream_1, stream_2, stream_3, stream_6],
};
input_grad.client.register(
streams,
OperationDescription::Module(ModuleOperationDescription::DeformableConv2dBackward(
Box::new(desc.clone()),
)),
DeformConv2dBackwardOps::<B>::new(desc),
);
DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
}
fn conv3d(
x: FloatTensor<Self, 5>,
weight: FloatTensor<Self, 5>,

View File

@ -180,6 +180,35 @@ impl RelativeOps for ModuleOperationDescription {
out: desc.out.to_relative(converter),
})
}
ModuleOperationDescription::DeformableConv2d(desc) => {
ModuleOperationDescription::DeformableConv2d(Box::new(DeformConv2dDescription {
x: desc.x.to_relative(converter),
offset: desc.offset.to_relative(converter),
weight: desc.weight.to_relative(converter),
mask: desc.mask.as_ref().map(|t| t.to_relative(converter)),
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
options: desc.options.clone(),
out: desc.out.to_relative(converter),
}))
}
ModuleOperationDescription::DeformableConv2dBackward(desc) => {
ModuleOperationDescription::DeformableConv2dBackward(Box::new(
DeformConv2dBackwardDescription {
x: desc.x.to_relative(converter),
offset: desc.offset.to_relative(converter),
weight: desc.weight.to_relative(converter),
mask: desc.mask.as_ref().map(|t| t.to_relative(converter)),
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
out_grad: desc.out_grad.to_relative(converter),
options: desc.options.clone(),
input_grad: desc.input_grad.to_relative(converter),
offset_grad: desc.offset_grad.to_relative(converter),
weight_grad: desc.weight_grad.to_relative(converter),
mask_grad: desc.mask_grad.as_ref().map(|t| t.to_relative(converter)),
bias_grad: desc.bias_grad.as_ref().map(|t| t.to_relative(converter)),
},
))
}
ModuleOperationDescription::ConvTranspose1d(desc) => {
ModuleOperationDescription::ConvTranspose1d(ConvTranspose1dDescription {
x: desc.x.to_relative(converter),

View File

@ -0,0 +1,316 @@
use cubecl::{calculate_cube_count_elemwise, prelude::*};
use burn_tensor::{
ops::{conv::calculate_conv_output_size, DeformConvOptions, FloatTensorOps as _},
Shape,
};
use crate::{
kernel::into_contiguous,
ops::{
numeric::{ones_device, zeros_device},
reshape, swap_dims,
},
tensor::JitTensor,
FloatElement, IntElement, JitBackend, JitRuntime,
};
#[derive(CubeLaunch)]
struct DeformConv2dArgs<F: Float> {
conv_stride_h: u32,
conv_stride_w: u32,
dilation_h: u32,
dilation_w: u32,
padding_h: F,
padding_w: F,
offset_groups: u32,
kernel_height: u32,
kernel_width: u32,
out_h: u32,
out_w: u32,
col_stride_0: u32,
}
#[cube(launch)]
fn deform_im2col_kernel<F: Float>(
input: &Tensor<F>,
offset: &Tensor<F>,
mask: &Tensor<F>,
columns: &mut Tensor<F>,
args: &DeformConv2dArgs<F>,
#[comptime] kernel_h_unroll: Option<u32>,
#[comptime] kernel_w_unroll: Option<u32>,
#[comptime] use_mask: bool,
) {
// position shape: [in_channels, batch_size, out_h, out_w]
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
let kernel_height = kernel_h_unroll.unwrap_or(args.kernel_height);
let unroll_h = kernel_h_unroll.is_some();
let kernel_width = kernel_w_unroll.unwrap_or(args.kernel_width);
let unroll_w = kernel_w_unroll.is_some();
// Keep mask in bind group
let default_mask_value = mask[0];
let out_h = args.out_h;
let out_w = args.out_w;
let batch_size = input.shape(0);
let in_channels = input.shape(1);
let height = input.shape(2);
let width = input.shape(3);
let col_stride_0 = args.col_stride_0;
let out_x = ABSOLUTE_POS % out_w;
let out_y = (ABSOLUTE_POS / out_w) % out_h;
let out_batch = (ABSOLUTE_POS / (out_w * out_h)) % batch_size;
let in_channel = ABSOLUTE_POS / (out_w * out_h * batch_size);
let out_channel = in_channel * kernel_height * kernel_width;
let channels_per_offset_group = in_channels / args.offset_groups;
let group_index = in_channel / channels_per_offset_group;
let mut col_base_idx =
out_channel * col_stride_0 + out_batch * (out_h * out_w) + out_y * out_w + out_x;
let input_base_idx = out_batch * input.stride(0) + in_channel * input.stride(1);
let offset_base_idx = out_batch * offset.stride(0)
+ group_index * kernel_height * kernel_width * 2 * out_h * out_w;
let mut mask_base_idx = 0;
if use_mask {
mask_base_idx =
out_batch * mask.stride(0) + group_index * kernel_height * kernel_width * out_h * out_w;
}
#[unroll(unroll_h)]
for kernel_y in 0..kernel_height {
#[unroll(unroll_w)]
for kernel_x in 0..kernel_width {
let mask_index = kernel_y * kernel_width + kernel_x;
let offset_index = mask_index * 2;
let mut mask_value = default_mask_value;
if use_mask {
mask_value = mask[mask_base_idx
+ mask_index * mask.stride(1)
+ out_y * mask.stride(2)
+ out_x * mask.stride(3)];
}
let offset_y = offset[offset_base_idx
+ offset_index * offset.stride(1)
+ out_y * offset.stride(2)
+ out_x * offset.stride(3)];
let offset_x = offset[offset_base_idx
+ (offset_index + 1) * offset.stride(1)
+ out_y * offset.stride(2)
+ out_x * offset.stride(3)];
let y = F::cast_from(out_y * args.conv_stride_h + kernel_y * args.dilation_h)
- args.padding_h
+ offset_y;
let x = F::cast_from(out_x * args.conv_stride_w + kernel_x * args.dilation_w)
- args.padding_w
+ offset_x;
let interpolated = bilinear_interpolate(input, height, width, y, x, input_base_idx);
columns[col_base_idx] = mask_value * interpolated;
col_base_idx += col_stride_0;
}
}
}
#[cube]
pub(crate) fn bilinear_interpolate<F: Float>(
input: &Tensor<F>,
height: u32,
width: u32,
y: F,
x: F,
offset: u32,
) -> F {
// To simplify code
let y = f32::cast_from(y);
let x = f32::cast_from(x);
let mut result = F::new(0.0);
if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {
let in_w = u32::cast_from(width);
let y_low = f32::floor(y);
let x_low = f32::floor(x);
let y_high = (y_low + 1.) as u32;
let x_high = (x_low + 1.) as u32;
let zero = F::new(0.0);
let v1: F = if y_low >= 0. && x_low >= 0. {
input[offset + y_low as u32 * in_w + x_low as u32]
} else {
zero
};
let v2: F = if y_low >= 0. && x_high < width {
input[offset + y_low as u32 * in_w + x_high]
} else {
zero
};
let v3: F = if y_high < height && x_low >= 0. {
input[offset + y_high * in_w + x_low as u32]
} else {
zero
};
let v4: F = if y_high < height && x_high < width {
input[offset + y_high * in_w + x_high]
} else {
zero
};
let l_y = y - y_low;
let l_x = x - x_low;
let h_y = 1.0 - l_y;
let h_x = 1.0 - l_x;
let w1 = F::cast_from(h_y * h_x);
let w2 = F::cast_from(h_y * l_x);
let w3 = F::cast_from(l_y * h_x);
let w4 = F::cast_from(l_y * l_x);
result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
result
}
pub(crate) fn deform_im2col<R: JitRuntime, E: FloatElement>(
input: JitTensor<R, E, 4>,
offset: JitTensor<R, E, 4>,
mask: Option<JitTensor<R, E, 4>>,
options: DeformConvOptions<2>,
out_dims: (usize, usize),
kernel_dims: (usize, usize),
) -> JitTensor<R, E, 2> {
let client = input.client.clone();
let device = input.device.clone();
let [batch_size, in_channels, _, _] = input.shape.dims;
let (out_height, out_width) = out_dims;
let (kernel_height, kernel_width) = kernel_dims;
let shape_out = Shape::new([
in_channels * kernel_height * kernel_width,
batch_size * out_height * out_width,
]);
let output = zeros_device(client.clone(), device.clone(), shape_out.clone());
let use_mask = mask.is_some();
let mask = mask.unwrap_or_else(|| {
ones_device(
client.clone(),
device.clone(),
Shape::new([
offset.shape.dims[0],
offset.shape.dims[1] / 2,
offset.shape.dims[2],
offset.shape.dims[3],
]),
)
});
let num_kernels = in_channels * batch_size * out_height * out_width;
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_kernels, cube_dim);
deform_im2col_kernel::launch::<E, R>(
&input.client,
cube_count,
cube_dim,
input.as_handle_ref().as_tensor_arg(1),
offset.as_handle_ref().as_tensor_arg(1),
mask.as_handle_ref().as_tensor_arg(1),
output.as_handle_ref().as_tensor_arg(1),
DeformConv2dArgsLaunch::new(
ScalarArg::new(options.stride[0] as u32),
ScalarArg::new(options.stride[1] as u32),
ScalarArg::new(options.dilation[0] as u32),
ScalarArg::new(options.dilation[1] as u32),
ScalarArg::new(E::from_elem(options.padding[0] as f32)),
ScalarArg::new(E::from_elem(options.padding[1] as f32)),
ScalarArg::new(options.offset_groups as u32),
ScalarArg::new(kernel_height as u32),
ScalarArg::new(kernel_width as u32),
ScalarArg::new(out_height as u32),
ScalarArg::new(out_width as u32),
ScalarArg::new(output.strides[0] as u32),
),
Some(kernel_height as u32),
Some(kernel_width as u32),
use_mask,
);
output
}
pub(crate) fn deform_conv2d<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
offset: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
mask: Option<JitTensor<R, E, 4>>,
bias: Option<JitTensor<R, E, 1>>,
options: DeformConvOptions<2>,
) -> JitTensor<R, E, 4> {
let input = into_contiguous(input);
let offset = into_contiguous(offset);
let weight = into_contiguous(weight);
let mask = mask.map(|it| into_contiguous(it));
let bias = bias.map(|it| into_contiguous(it));
let [batch_size, _, in_height, in_width] = input.shape.dims;
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
let groups = options.weight_groups;
let out_h = calculate_conv_output_size(
kernel_h,
options.stride[0],
options.padding[0],
options.dilation[0],
in_height,
);
let out_w = calculate_conv_output_size(
kernel_w,
options.stride[1],
options.padding[1],
options.dilation[1],
in_width,
);
let out_dims = (out_h, out_w);
let columns = deform_im2col(input, offset, mask, options, out_dims, (kernel_h, kernel_w));
let [col_size_0, col_size_1] = columns.shape.dims;
let col_size_0 = col_size_0 / groups;
let out_c_per_group = out_channels / groups;
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0]));
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
let out = JitBackend::<R, E, I>::float_matmul(weight, columns);
let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w]));
let out = swap_dims(out, 0, 1);
if let Some(bias) = bias {
let bias = reshape(bias, Shape::new([1, out_channels, 1, 1]));
JitBackend::<R, E, I>::float_add(out, bias)
} else {
out
}
}
pub(crate) fn index<R: JitRuntime, E: FloatElement, I: IntElement>(
tensor: JitTensor<R, E, 3>,
index: usize,
) -> JitTensor<R, E, 2> {
let [_, shape_0, shape_1] = tensor.shape.dims;
let tensor = JitBackend::<R, E, I>::float_narrow(tensor, 0, index, 1);
reshape(tensor, Shape::new([shape_0, shape_1]))
}

View File

@ -0,0 +1,591 @@
use burn_tensor::{
ops::{DeformConv2dBackward, DeformConvOptions, FloatTensorOps as _},
Shape,
};
use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch};
use crate::{
kernel::into_contiguous,
ops::{
numeric::{empty_device, ones_device, zeros_device},
reshape, swap_dims,
},
tensor::JitTensor,
FloatElement, IntElement, JitBackend, JitRuntime,
};
use super::{bilinear_interpolate, deform_im2col, index};
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
#[allow(clippy::single_range_in_vec_init)]
pub(crate) fn deform_conv2d_backward<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
offset: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
mask: Option<JitTensor<R, E, 4>>,
bias: Option<JitTensor<R, E, 1>>,
out_grad: JitTensor<R, E, 4>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<JitBackend<R, E, I>> {
let [_, _, out_h, out_w] = out_grad.shape.dims;
let [_, _, kernel_h, kernel_w] = weight.shape.dims;
let gradient_bias = bias.map(|bias| {
let grad = JitBackend::<R, E, I>::float_sum_dim(out_grad.clone(), 0);
let grad = JitBackend::<R, E, I>::float_sum_dim(grad, 2);
let grad = JitBackend::<R, E, I>::float_sum_dim(grad, 3);
reshape(grad, bias.shape)
});
let input = into_contiguous(input);
let offset = into_contiguous(offset);
let mask = mask.map(|it| into_contiguous(it));
let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs::<R, E, I>(
input.clone(),
weight.clone(),
offset.clone(),
mask.clone(),
out_grad.clone(),
&options,
(kernel_h, kernel_w),
);
let weight_grad = compute_weight_grad::<R, E, I>(
input,
offset,
mask,
out_grad,
options,
(kernel_h, kernel_w),
(out_h, out_w),
);
DeformConv2dBackward::new(
input_gradient,
offset_gradient,
weight_grad,
mask_gradient,
gradient_bias,
)
}
fn compute_weight_grad<R: JitRuntime, E: FloatElement, I: IntElement>(
input: JitTensor<R, E, 4>,
offset: JitTensor<R, E, 4>,
mask: Option<JitTensor<R, E, 4>>,
out_grad: JitTensor<R, E, 4>,
options: DeformConvOptions<2>,
kernel_dims: (usize, usize),
out_dims: (usize, usize),
) -> JitTensor<R, E, 4> {
let [_, in_channels, _, _] = input.shape.dims;
let [_, out_channels, _, _] = out_grad.shape.dims;
let (kernel_h, kernel_w) = kernel_dims;
let groups = options.weight_groups;
let in_c_per_group = in_channels / groups;
let out_c_per_group = out_channels / groups;
let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims);
let [col_size_0, col_size_1] = columns.shape.dims;
let col_size_0 = col_size_0 / groups;
let out_grad = swap_dims(out_grad, 0, 1);
let out_grad = reshape(out_grad, Shape::new([groups, out_c_per_group, col_size_1]));
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
let columns = swap_dims(columns, 1, 2);
let grad_weight = JitBackend::<R, E, I>::float_matmul(out_grad, columns);
JitBackend::<R, E, I>::float_reshape(
grad_weight,
Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]),
)
}
type InputGradients<R, E> = (
JitTensor<R, E, 4>,
JitTensor<R, E, 4>,
Option<JitTensor<R, E, 4>>,
);
fn backward_gradient_inputs<R: JitRuntime, E: FloatElement, I: IntElement>(
image: JitTensor<R, E, 4>,
weight: JitTensor<R, E, 4>,
offset: JitTensor<R, E, 4>,
mask: Option<JitTensor<R, E, 4>>,
out_grad: JitTensor<R, E, 4>,
options: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> InputGradients<R, E> {
let client = out_grad.client.clone();
let device = out_grad.device.clone();
let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape.dims;
let [batch_size, _, out_h, out_w] = out_grad.shape.dims;
let groups = options.weight_groups;
let out_c_per_group = out_channels / groups;
let col_shape_0 = in_c_per_group * kernel_h * kernel_w;
let col_shape_1 = batch_size * out_h * out_w;
let col_shape = Shape::new([groups, col_shape_0, col_shape_1]);
let mut columns = empty_device(client, device, col_shape);
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));
let out_grad = swap_dims(out_grad, 0, 1);
let out_grad_shape = Shape::new([groups, out_c_per_group, col_shape_1]);
let out_grad = reshape(out_grad, out_grad_shape);
for group in 0..groups {
let weight = swap_dims(index::<R, E, I>(weight.clone(), group), 0, 1);
let out_grad = index::<R, E, I>(out_grad.clone(), group);
let values = JitBackend::<R, E, I>::float_matmul(weight, out_grad);
let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1]));
columns = JitBackend::<R, E, I>::float_slice_assign(
columns,
[group..group + 1, 0..col_shape_0, 0..col_shape_1],
values,
);
}
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));
let input_shape = image.shape.clone();
let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient::<R, E>(
columns.clone(),
image,
offset.clone(),
mask.clone(),
options,
kernel_dims,
);
let input_gradient =
compute_input_grad::<R, E>(columns, offset, mask, options, kernel_dims, input_shape);
(input_gradient, offset_gradient, mask_gradient)
}
fn compute_offset_and_mask_gradient<R: JitRuntime, E: FloatElement>(
columns: JitTensor<R, E, 2>,
image: JitTensor<R, E, 4>,
offset: JitTensor<R, E, 4>,
mask: Option<JitTensor<R, E, 4>>,
options: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> (JitTensor<R, E, 4>, Option<JitTensor<R, E, 4>>) {
let client = offset.client.clone();
let device = offset.device.clone();
let (kernel_height, kernel_width) = kernel_dims;
let use_mask = mask.is_some();
let mask = mask.unwrap_or_else(|| {
ones_device(
client.clone(),
device.clone(),
Shape::new([
offset.shape.dims[0],
offset.shape.dims[1] / 2,
offset.shape.dims[2],
offset.shape.dims[3],
]),
)
});
let grad_offset = empty_device(client.clone(), device.clone(), offset.shape.clone());
let grad_mask = empty_device(client.clone(), device.clone(), mask.shape.clone());
let num_elements_offset = offset.shape.num_elements();
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elements_offset, cube_dim);
deform_col2img_coord_kernel::launch::<E, R>(
&image.client,
cube_count,
cube_dim,
image.as_handle_ref().as_tensor_arg(1),
offset.as_handle_ref().as_tensor_arg(1),
mask.as_handle_ref().as_tensor_arg(1),
columns.as_handle_ref().as_tensor_arg(1),
grad_offset.as_handle_ref().as_tensor_arg(1),
grad_mask.as_handle_ref().as_tensor_arg(1),
DeformConv2dCol2ImgCoordArgsLaunch::new(
ScalarArg::new(options.stride[0] as u32),
ScalarArg::new(options.stride[1] as u32),
ScalarArg::new(options.dilation[0] as u32),
ScalarArg::new(options.dilation[1] as u32),
ScalarArg::new(E::from_elem(options.padding[0] as f32)),
ScalarArg::new(E::from_elem(options.padding[1] as f32)),
ScalarArg::new(options.offset_groups as u32),
ScalarArg::new(kernel_height as u32),
ScalarArg::new(kernel_width as u32),
),
use_mask,
);
let mask_gradient = if use_mask { Some(grad_mask) } else { None };
(grad_offset, mask_gradient)
}
#[derive(CubeLaunch)]
struct DeformConv2dCol2ImgCoordArgs<F: Float> {
stride_h: u32,
stride_w: u32,
dilation_h: u32,
dilation_w: u32,
pad_h: F,
pad_w: F,
offset_groups: u32,
kernel_height: u32,
kernel_width: u32,
}
#[allow(clippy::collapsible_if)]
#[cube(launch)]
fn deform_col2img_coord_kernel<F: Float>(
image: &Tensor<F>,
offset: &Tensor<F>,
mask: &Tensor<F>,
columns: &Tensor<F>,
grad_offset: &mut Tensor<F>,
grad_mask: &mut Tensor<F>,
args: &DeformConv2dCol2ImgCoordArgs<F>,
#[comptime] use_mask: bool,
) {
// Position format: [batch, [offset_group, kernel_h, kernel_w, 2], out_h, out_w]
// Alternatively : [batch, offset_channels, out_h, out_w]
let offset_channels = offset.shape(1);
let out_h = offset.shape(2);
let out_w = offset.shape(3);
let batch_size = image.shape(0);
let in_channels = image.shape(1);
let height = image.shape(2);
let width = image.shape(3);
let kernel_w = args.kernel_width;
let kernel_h = args.kernel_height;
let n_offset_groups = args.offset_groups;
let _ = mask[0]; // Make sure mask isn't removed from bind group
let mut grad_offset_val = F::new(0.0);
let mut grad_mask_val = F::new(0.0);
let w = ABSOLUTE_POS % out_w;
let h = (ABSOLUTE_POS / out_w) % out_h;
let w_w = (ABSOLUTE_POS / (out_w * out_h * 2)) % kernel_w;
let w_h = (ABSOLUTE_POS / (out_w * out_h * 2 * kernel_w)) % kernel_h;
let c = (ABSOLUTE_POS / (out_w * out_h)) % offset_channels;
let b = ABSOLUTE_POS / (out_w * out_h * offset_channels);
let offset_group = c / (kernel_h * kernel_w * 2);
let col_step = kernel_h * kernel_w;
let channels_per_offset_group = in_channels / args.offset_groups;
let col_base_idx =
offset_group * channels_per_offset_group * kernel_h * kernel_w * batch_size * out_w * out_h;
let mut image_base_idx =
(b * n_offset_groups + offset_group) * channels_per_offset_group * height * width;
let offset_base_idx =
(b * n_offset_groups + offset_group) * 2 * kernel_h * kernel_w * out_h * out_w;
let mask_base_idx = (b * n_offset_groups + offset_group) * kernel_h * kernel_w * out_h * out_w;
let offset_c = c - offset_group * 2 * kernel_h * kernel_w;
let is_y_direction = offset_c % 2 == 0;
let c_bound = channels_per_offset_group * kernel_h * kernel_w;
for col_c in range_stepped(offset_c / 2, c_bound, col_step) {
let col_pos = (((col_c * batch_size + b) * out_h) + h) * out_w + w;
let out_x = col_pos % out_w;
let out_y = (col_pos / out_w) % out_h;
let j = (col_pos / (out_w * out_h * batch_size)) % kernel_w;
let i = (col_pos / (out_w * out_h * batch_size * kernel_w)) % kernel_h;
let mask_idx = i * kernel_w + j;
let offset_idx = mask_idx * 2;
let offset_y_idx = (offset_idx * out_h + out_y) * out_w + out_x;
let offset_x_idx = ((offset_idx + 1) * out_h + out_y) * out_w + out_x;
let offset_y = offset[offset_base_idx + offset_y_idx];
let offset_x = offset[offset_base_idx + offset_x_idx];
let mask_value = if use_mask {
mask[mask_base_idx + (mask_idx * out_h + out_y) * out_w + out_x]
} else {
F::new(1.0)
};
let y = F::cast_from(out_y * args.stride_h + i * args.dilation_h) - args.pad_h + offset_y;
let x = F::cast_from(out_x * args.stride_w + j * args.dilation_w) - args.pad_w + offset_x;
let weight = get_coordinate_weight(
image.slice(image_base_idx, image.len()),
height,
width,
y,
x,
is_y_direction,
);
grad_offset_val += mask_value * weight * columns[col_base_idx + col_pos];
if use_mask {
if is_y_direction {
grad_mask_val += columns[col_base_idx + col_pos]
* bilinear_interpolate(image, height, width, y, x, image_base_idx);
}
}
image_base_idx += height * width;
}
grad_offset[ABSOLUTE_POS] = grad_offset_val;
if use_mask {
if is_y_direction {
let idx = ((((b * n_offset_groups + offset_group) * kernel_h + w_h) * kernel_w + w_w)
* out_h
+ h)
* out_w
+ w;
grad_mask[idx] = grad_mask_val
}
}
}
#[cube]
fn get_coordinate_weight<F: Float>(
input: &Slice<'_, F>,
height: u32,
width: u32,
y: F,
x: F,
is_y_direction: bool,
) -> F {
let stride_y = width;
let y = f32::cast_from(y);
let x = f32::cast_from(x);
let y_low = f32::floor(y);
let x_low = f32::floor(x);
let y_high = y_low + 1.;
let x_high = x_low + 1.;
let valid_y_low = y_low >= 0. && y_low < height as f32;
let valid_y_high = y_high >= 0. && y_high < height as f32;
let valid_x_low = x_low >= 0. && x_low < width as f32;
let valid_x_high = x_high >= 0. && x_high < width as f32;
let bottom_left = if valid_y_low && valid_x_low {
input[y_low as u32 * stride_y + x_low as u32]
} else {
F::new(0.0)
};
let bottom_right = if valid_y_low && valid_x_high {
input[y_low as u32 * stride_y + x_high as u32]
} else {
F::new(0.0)
};
let top_left = if valid_y_high && valid_x_low {
input[y_high as u32 * stride_y + x_low as u32]
} else {
F::new(0.0)
};
let top_right = if valid_y_high && valid_x_high {
input[y_high as u32 * stride_y + x_high as u32]
} else {
F::new(0.0)
};
if is_y_direction {
let delta_x = F::cast_from(x - x_low);
delta_x * (top_right - bottom_right) + (F::new(1.0) - delta_x) * (top_left - bottom_left)
} else {
let delta_y = F::cast_from(y - y_low);
delta_y * (top_right - top_left) + (F::new(1.0) - delta_y) * (bottom_right - bottom_left)
}
}
fn compute_input_grad<R: JitRuntime, E: FloatElement>(
columns: JitTensor<R, E, 2>,
offset: JitTensor<R, E, 4>,
mask: Option<JitTensor<R, E, 4>>,
options: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
input_shape: Shape<4>,
) -> JitTensor<R, E, 4> {
let client = offset.client.clone();
let device = offset.device.clone();
let [batch_size, in_channels, height, width] = input_shape.dims;
let (kernel_height, kernel_width) = kernel_dims;
let grad_in = zeros_device::<R, E, 4>(
client.clone(),
device.clone(),
Shape::new([batch_size, in_channels, height, width]),
);
let use_mask = mask.is_some();
let mask = mask
.unwrap_or_else(|| ones_device(client.clone(), device.clone(), Shape::new([1, 1, 1, 1])));
let num_elements = columns.shape.num_elements();
let cube_dim = CubeDim::default();
let cube_count = calculate_cube_count_elemwise(num_elements, cube_dim);
deform_col2img_kernel::launch::<E, R>(
&offset.client,
cube_count,
cube_dim,
offset.as_tensor_arg(1),
mask.as_tensor_arg(1),
columns.as_tensor_arg(1),
grad_in.as_tensor_arg(1),
DeformConv2dCol2ImgArgsLaunch::new(
ScalarArg::new(options.stride[0] as u32),
ScalarArg::new(options.stride[1] as u32),
ScalarArg::new(options.dilation[0] as u32),
ScalarArg::new(options.dilation[1] as u32),
ScalarArg::new(options.padding[0] as f32),
ScalarArg::new(options.padding[1] as f32),
ScalarArg::new(options.offset_groups as u32),
ScalarArg::new(batch_size as u32),
ScalarArg::new(in_channels as u32),
ScalarArg::new(height as u32),
ScalarArg::new(width as u32),
ScalarArg::new(kernel_height as u32),
ScalarArg::new(kernel_width as u32),
),
use_mask,
);
grad_in
}
#[derive(CubeLaunch)]
struct DeformConv2dCol2ImgArgs {
stride_h: u32,
stride_w: u32,
dilation_h: u32,
dilation_w: u32,
pad_h: f32,
pad_w: f32,
offset_groups: u32,
batch_size: u32,
in_channels: u32,
height: u32,
width: u32,
kernel_height: u32,
kernel_width: u32,
}
#[cube(launch)]
fn deform_col2img_kernel<F: Float>(
offset: &Tensor<F>,
mask: &Tensor<F>,
columns: &Tensor<F>,
grad_input: &mut Tensor<AtomicU32>,
args: &DeformConv2dCol2ImgArgs,
#[comptime] use_mask: bool,
) {
// Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
let _ = mask[0]; // Keep mask in bind group
let n_in_channels = args.in_channels;
let height = args.height;
let width = args.width;
let out_h = offset.shape(2);
let out_w = offset.shape(3);
let kernel_h = args.kernel_height;
let kernel_w = args.kernel_width;
let n_offset_groups = args.offset_groups;
let batch_size = args.batch_size;
let out_x = ABSOLUTE_POS % out_w;
let out_y = (ABSOLUTE_POS / out_w) % out_h;
let batch = (ABSOLUTE_POS / (out_w * out_h)) % batch_size;
let kernel_x = (ABSOLUTE_POS / (out_w * out_h * batch_size)) % kernel_w;
let kernel_y = (ABSOLUTE_POS / (out_w * out_h * batch_size * kernel_w)) % kernel_h;
let in_channel = ABSOLUTE_POS / (out_w * out_h * batch_size * kernel_w * kernel_h);
let channels_per_offset_group = n_in_channels / n_offset_groups;
let offset_group = in_channel / channels_per_offset_group;
let offset_base_idx =
(batch * n_offset_groups + offset_group) * 2 * kernel_h * kernel_w * out_h * out_w;
let mask_idx = kernel_y * kernel_w + kernel_x;
let offset_idx = mask_idx * 2;
let offset_y_idx = (offset_idx * out_h + out_y) * out_w + out_x;
let offset_x_idx = ((offset_idx + 1) * out_h + out_y) * out_w + out_x;
let offset_y = f32::cast_from(offset[offset_base_idx + offset_y_idx]);
let offset_x = f32::cast_from(offset[offset_base_idx + offset_x_idx]);
let mask_value = if use_mask {
let mask_base_idx =
(batch * n_offset_groups + offset_group) * kernel_h * kernel_w * out_h * out_w;
mask[mask_base_idx + (mask_idx * out_h + out_y) * out_w + out_x]
} else {
F::new(1.0)
};
let y =
f32::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h) - args.pad_h + offset_y;
let x =
f32::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w) - args.pad_w + offset_x;
for dy in -1..=1 {
#[unroll]
for dx in -1..=1 {
let yp = f32::floor(y) + dy as f32;
let xp = f32::floor(x) + dx as f32;
if yp >= 0.0
&& yp < height as f32
&& xp >= 0.0
&& xp < width as f32
&& f32::abs(y - yp) < 1.0
&& f32::abs(x - xp) < 1.0
{
let gradient_pos =
((batch * n_in_channels + in_channel) * height + yp as u32) * width + xp as u32;
let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp));
let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS];
float_atomic_add::<F>(&mut grad_input[gradient_pos], value);
}
}
}
}
#[cube]
fn float_atomic_add<F: Float>(ptr: &mut AtomicU32, value: F) {
if value != F::new(0.0) {
let mut v = AtomicU32::load(ptr);
loop {
let prev = v;
let v_float = F::bitcast_from(v);
let new = u32::bitcast_from(v_float + value);
v = AtomicU32::compare_and_swap(ptr, v, new);
if prev == v {
break;
}
}
}
}

View File

@ -2,8 +2,12 @@ mod conv2d;
mod conv3d;
mod conv_transpose2d;
mod conv_transpose3d;
mod deform_conv2d;
mod deform_conv_transpose2d;
pub(crate) use conv2d::*;
pub(crate) use conv3d::*;
pub(crate) use conv_transpose2d::*;
pub(crate) use conv_transpose3d::*;
pub(crate) use deform_conv2d::*;
pub(crate) use deform_conv_transpose2d::*;

View File

@ -1,7 +1,7 @@
use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices,
ModuleOps,
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
};
use burn_tensor::ops::{FloatTensor, IntTensor};
@ -20,6 +20,29 @@ where
kernel::conv::conv2d(x, weight, bias, options)
}
fn deform_conv2d(
x: FloatTensor<Self, 4>,
offset: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
mask: Option<FloatTensor<Self, 4>>,
bias: Option<FloatTensor<Self, 1>>,
options: DeformConvOptions<2>,
) -> FloatTensor<Self, 4> {
kernel::conv::deform_conv2d::<R, F, I>(x, offset, weight, mask, bias, options)
}
fn deform_conv2d_backward(
x: FloatTensor<Self, 4>,
offset: FloatTensor<Self, 4>,
weight: FloatTensor<Self, 4>,
mask: Option<FloatTensor<Self, 4>>,
bias: Option<FloatTensor<Self, 1>>,
output_grad: FloatTensor<Self, 4>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
kernel::conv::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options)
}
fn conv3d(
x: FloatTensor<Self, 5>,
weight: FloatTensor<Self, 5>,

View File

@ -12,6 +12,7 @@ version.workspace = true
[features]
default = ["std"]
doc = ["default"]
std = [
"burn-autodiff",
"burn-common/std",
@ -24,7 +25,6 @@ std = [
"rand/std",
"num-traits/std",
]
doc = ["default"]
blas-accelerate = [
"blas-src/accelerate", # Accelerate framework (macOS only)
@ -46,10 +46,11 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", optional = true
burn-common = { path = "../burn-common", version = "0.15.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false }
matrixmultiply = { workspace = true, default-features = false }
atomic_float = { workspace = true }
blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible
derive-new = { workspace = true }
libm = { workspace = true }
matrixmultiply = { workspace = true, default-features = false }
ndarray = { workspace = true }
num-traits = { workspace = true }
openblas-src = { workspace = true, optional = true }

View File

@ -0,0 +1,656 @@
use burn_common::{iter_par, run_par};
use burn_tensor::ops::{conv::calculate_conv_output_size, DeformConvOptions};
use core::ops::AddAssign;
use ndarray::{
s, Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim,
Ix4,
};
#[cfg(not(feature = "std"))]
use num_traits::Float;
use crate::{element::QuantElement, FloatNdArrayElement, NdArrayTensor};
use super::matmul::matmul;
#[inline(always)]
#[allow(clippy::too_many_arguments)]
fn deform_im2col_kernel<F: FloatNdArrayElement>(
out_y: usize,
out_x: usize,
input: ArrayView2<F>,
offset: ArrayView3<F>,
mask: Option<ArrayView2<F>>,
mut columns: ArrayViewMut2<F>,
args: DeformConvOptions<2>,
(kernel_h, kernel_w): (usize, usize),
) {
// position shape: [in_channels, batch_size, out_h, out_w]
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
let (height, width) = input.dim();
for kernel_y in 0..kernel_h {
for kernel_x in 0..kernel_w {
let mask_value = mask
.map(|it| it[[kernel_y, kernel_x]])
.unwrap_or_else(|| F::from_elem(1.0));
let offset = offset.slice(s![kernel_y, kernel_x, ..]);
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
- F::from_elem(args.padding[0])
+ offset[0];
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
- F::from_elem(args.padding[1])
+ offset[1];
let interpolated = bilinear_interpolate(input, height, width, y, x);
columns[[kernel_y, kernel_x]] = mask_value * interpolated;
}
}
}
fn bilinear_interpolate<F: FloatNdArrayElement>(
input: ArrayView2<F>,
height: usize,
width: usize,
y: F,
x: F,
) -> F {
// To simplify code
let y = y.to_f32();
let x = x.to_f32();
let mut result = F::from_elem(0.0);
if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {
let y_low = f32::floor(y);
let x_low = f32::floor(x);
let y_high = (y_low + 1.) as usize;
let x_high = (x_low + 1.) as usize;
let zero = F::from_elem(0.0);
let v1: F = if y_low >= 0. && x_low >= 0. {
input[[y_low as usize, x_low as usize]]
} else {
zero
};
let v2: F = if y_low >= 0. && x_high < width {
input[[y_low as usize, x_high]]
} else {
zero
};
let v3: F = if y_high < height && x_low >= 0. {
input[[y_high, x_low as usize]]
} else {
zero
};
let v4: F = if y_high < height && x_high < width {
input[[y_high, x_high]]
} else {
zero
};
let l_y = y - y_low;
let l_x = x - x_low;
let h_y = 1.0 - l_y;
let h_x = 1.0 - l_x;
let w1 = F::from_elem(h_y * h_x);
let w2 = F::from_elem(h_y * l_x);
let w3 = F::from_elem(l_y * h_x);
let w4 = F::from_elem(l_y * l_x);
result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
}
result
}
pub(crate) fn deform_conv2d<F: FloatNdArrayElement>(
input: NdArrayTensor<F, 4>,
offset: NdArrayTensor<F, 4>,
weight: NdArrayTensor<F, 4>,
mask: Option<NdArrayTensor<F, 4>>,
bias: Option<NdArrayTensor<F, 1>>,
args: DeformConvOptions<2>,
) -> NdArrayTensor<F, 4> {
let [batch_size, _, in_height, in_width] = input.shape().dims;
let [out_channels, _, kernel_h, kernel_w] = weight.shape().dims;
let groups = args.weight_groups;
let weight = weight.array.as_standard_layout();
let out_h = calculate_conv_output_size(
kernel_h,
args.stride[0],
args.padding[0],
args.dilation[0],
in_height,
);
let out_w = calculate_conv_output_size(
kernel_w,
args.stride[1],
args.padding[1],
args.dilation[1],
in_width,
);
let out_dims = (out_h, out_w);
let input = input.array.into_dimensionality::<Ix4>().unwrap();
let offset = offset.array.into_dimensionality::<Ix4>().unwrap();
let mask = mask.as_ref().map(|it| {
it.array
.to_shape((
batch_size,
args.offset_groups,
kernel_h,
kernel_w,
out_h,
out_w,
))
.unwrap()
});
let columns = deform_im2col(
input.view(),
offset.view(),
mask.as_ref().map(|it| it.view()),
args,
out_dims,
(kernel_h, kernel_w),
);
let (col_size_0, col_size_1) = columns.dim();
let col_size_0 = col_size_0 / groups;
let out_c_per_group = out_channels / groups;
let weight = weight
.to_shape((groups, out_c_per_group, col_size_0))
.unwrap();
let columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();
let out = matmul(
NdArrayTensor::<_, 3>::new(weight.to_owned().into_dyn().into_shared()),
NdArrayTensor::<_, 3>::new(columns.to_owned().into_dyn().into_shared()),
);
let mut out = out
.array
.into_shape_with_order((out_channels, batch_size, out_h, out_w))
.unwrap();
out.swap_axes(0, 1);
if let Some(bias) = bias {
let bias = bias.array.to_shape((1, out_channels, 1, 1)).unwrap();
out.add_assign(&bias);
}
NdArrayTensor::new(out.into_dyn().into_shared())
}
pub(crate) fn deform_im2col<F: FloatNdArrayElement>(
input: ArrayView4<F>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
args: DeformConvOptions<2>,
out_dims: (usize, usize),
kernel_dims: (usize, usize),
) -> Array2<F> {
let (batch_size, in_channels, _, _) = input.dim();
let (kernel_h, kernel_w) = kernel_dims;
let (out_h, out_w) = out_dims;
let channels_per_offset_group = in_channels / args.offset_groups;
let mut columns = Array4::zeros(Dim([
in_channels,
kernel_h,
kernel_w,
batch_size * out_h * out_w,
]));
let groups = args.offset_groups;
run_par!(|| {
iter_par!(columns.axis_iter_mut(Axis(3)))
.enumerate()
.for_each(|(index, mut columns)| {
let out_x = index % out_w;
let out_y = (index / out_w) % out_h;
let batch = (index / (out_w * out_h)) % batch_size;
let offset = offset.slice(s![batch, .., out_y, out_x]);
let offset = offset.to_shape((groups, kernel_h, kernel_w, 2)).unwrap();
let mask = mask
.as_ref()
.map(|it| it.slice(s![batch, .., .., .., out_y, out_x]));
columns
.axis_iter_mut(Axis(0))
.enumerate()
.for_each(|(in_channel, mut columns)| {
let group_index = in_channel / channels_per_offset_group;
deform_im2col_kernel(
out_y,
out_x,
input.slice(s![batch, in_channel, .., ..]),
offset.slice(s![group_index, .., .., ..]),
mask.as_ref().map(|it| it.slice(s![group_index, .., ..])),
columns.view_mut(),
args.clone(),
kernel_dims,
);
});
});
});
columns
// Columns is created here, so we know it's contiguous
.into_shape_with_order((
in_channels * kernel_h * kernel_w,
batch_size * out_h * out_w,
))
.unwrap()
}
pub mod backward {
#[cfg(target_has_atomic = "32")]
use core::sync::atomic::Ordering;
use crate::NdArray;
use atomic_float::AtomicF32;
use burn_tensor::ops::DeformConv2dBackward;
use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4};
use super::*;
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
pub(crate) fn deform_conv2d_backward<F: FloatNdArrayElement, Q: QuantElement>(
input: NdArrayTensor<F, 4>,
offset: NdArrayTensor<F, 4>,
weight: NdArrayTensor<F, 4>,
mask: Option<NdArrayTensor<F, 4>>,
bias: Option<NdArrayTensor<F, 1>>,
out_grad: NdArrayTensor<F, 4>,
args: DeformConvOptions<2>,
) -> DeformConv2dBackward<NdArray<F, Q>> {
let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims;
let [_, _, kernel_h, kernel_w] = weight.shape().dims;
let groups = args.weight_groups;
let out_c_per_group = out_channels / groups;
let col_shape_1 = batch_size * out_h * out_w;
let mut out_grad = out_grad.array.into_dimensionality::<Ix4>().unwrap();
let gradient_bias = bias.map(|_| {
let out_grad = out_grad
.clone()
.sum_axis(Axis(0))
.sum_axis(Axis(1))
.sum_axis(Axis(1));
NdArrayTensor::new(out_grad.into_dyn().into_shared())
});
out_grad.swap_axes(0, 1);
let out_grad = out_grad
.to_shape((groups, out_c_per_group, col_shape_1))
.unwrap();
let input = input.array.into_dimensionality::<Ix4>().unwrap();
let offset = offset.array.into_dimensionality::<Ix4>().unwrap();
let mask = mask.map(|it| {
it.array
.into_shape_with_order((
batch_size,
args.offset_groups,
kernel_h,
kernel_w,
out_h,
out_w,
))
.unwrap()
});
let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs(
input.view(),
weight,
offset.view(),
mask.as_ref().map(|it| it.view()),
out_grad.view(),
&args,
(kernel_h, kernel_w),
);
let weight_grad = compute_weight_grad(
input.view(),
offset.view(),
mask.as_ref().map(|it| it.view()),
out_grad.view(),
args,
(kernel_h, kernel_w),
(out_h, out_w),
);
DeformConv2dBackward::new(
input_gradient,
offset_gradient,
weight_grad,
mask_gradient,
gradient_bias,
)
}
fn compute_weight_grad<F: FloatNdArrayElement>(
input: ArrayView4<F>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
out_grad: ArrayView3<F>,
options: DeformConvOptions<2>,
kernel_dims: (usize, usize),
out_dims: (usize, usize),
) -> NdArrayTensor<F, 4> {
let in_channels = input.dim().1;
let (groups, out_c_per_group, _) = out_grad.dim();
let (kernel_h, kernel_w) = kernel_dims;
let in_c_per_group = in_channels / groups;
let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims);
let (col_size_0, col_size_1) = columns.dim();
let col_size_0 = col_size_0 / groups;
let mut columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();
columns.swap_axes(1, 2);
let grad_weight = matmul(
NdArrayTensor::<_, 3>::new(out_grad.to_owned().into_dyn().into_shared()),
NdArrayTensor::<_, 3>::new(columns.to_owned().into_dyn().into_shared()),
);
let grad_weight = grad_weight
.array
.into_shape_with_order((out_c_per_group * groups, in_c_per_group, kernel_h, kernel_w))
.unwrap();
NdArrayTensor::new(grad_weight.into_dyn().into_shared())
}
type InputGradients<F> = (
NdArrayTensor<F, 4>,
NdArrayTensor<F, 4>,
Option<NdArrayTensor<F, 4>>,
);
fn backward_gradient_inputs<F: FloatNdArrayElement>(
image: ArrayView4<F>,
weight: NdArrayTensor<F, 4>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
out_grad: ArrayView3<F>,
args: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> InputGradients<F> {
let input_shape = image.dim();
let in_channels = input_shape.1;
let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape().dims;
let (batch_size, _, out_h, out_w) = offset.dim();
let groups = args.weight_groups;
let out_c_per_group = out_channels / groups;
let col_shape_0 = in_c_per_group * kernel_h * kernel_w;
let mut weight = weight
.array
.to_shape((groups, out_c_per_group, col_shape_0))
.unwrap();
weight.swap_axes(1, 2);
let columns = matmul(
NdArrayTensor::<_, 3>::new(weight.to_owned().into_dyn().into_shared()),
NdArrayTensor::<_, 3>::new(out_grad.to_owned().into_dyn().into_shared()),
);
let columns = columns
.array
.to_shape((in_channels, kernel_h, kernel_w, batch_size, out_h, out_w))
.unwrap();
let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient(
columns.view(),
image.view(),
offset,
mask,
args,
kernel_dims,
);
let input_gradient =
compute_input_grad(columns.view(), offset, mask, args, kernel_dims, input_shape);
(input_gradient, offset_gradient, mask_gradient)
}
fn compute_offset_and_mask_gradient<F: FloatNdArrayElement>(
columns: ArrayView6<F>,
image: ArrayView4<F>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
args: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
) -> (NdArrayTensor<F, 4>, Option<NdArrayTensor<F, 4>>) {
let (kernel_h, kernel_w) = kernel_dims;
let (_, in_channels, height, width) = image.dim();
let (batch_size, offset_channels, out_h, out_w) = offset.dim();
let offs_groups = args.offset_groups;
let channels_per_offset_group = in_channels / args.offset_groups;
let mut grad_offset = Array5::zeros((
offs_groups,
kernel_h,
kernel_w,
2,
batch_size * out_h * out_w,
));
let mut grad_mask =
Array4::zeros((offs_groups, kernel_h, kernel_w, batch_size * out_h * out_w));
grad_mask
.axis_iter_mut(Axis(3))
.zip(grad_offset.axis_iter_mut(Axis(4)))
.enumerate()
.for_each(|(index, (mut grad_mask, mut grad_offset))| {
let out_x = index % out_w;
let out_y = (index / out_w) % out_h;
let batch = index / (out_w * out_h);
let offset = offset.slice(s![batch, .., out_y, out_x]);
let offset = offset
.to_shape((offs_groups, kernel_h, kernel_w, 2))
.unwrap();
let mask: Option<ArrayView3<F>> = mask
.as_ref()
.map(|mask| mask.slice(s![batch, .., .., .., out_y, out_x]));
let columns = columns.slice(s![.., .., .., batch, out_y, out_x]);
let image = image.slice(s![batch, .., .., ..]);
for ((group, kernel_y, kernel_x), grad_mask) in grad_mask.indexed_iter_mut() {
let grad_mask: &mut F = grad_mask;
let mut grad_offset = grad_offset.slice_mut(s![group, kernel_y, kernel_x, ..]);
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
let mask = mask.map(|it| it[[group, kernel_y, kernel_x]]);
let columns = columns.slice(s![.., kernel_y, kernel_x]);
let group_offset = group * channels_per_offset_group;
let image = image.slice(s![group_offset.., .., ..]);
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
- F::from_elem(args.padding[0])
+ offset[0];
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
- F::from_elem(args.padding[1])
+ offset[1];
for (i, grad_offset) in grad_offset.iter_mut().enumerate() {
let is_y_direction = i % 2 == 0;
let use_mask = mask.is_some();
for channel in 0..channels_per_offset_group {
let mask = mask.unwrap_or_else(|| F::one());
let image = image.index_axis(Axis(0), channel);
let weight =
get_coordinate_weight(image, height, width, y, x, is_y_direction);
*grad_offset += mask * weight * columns[channel];
if use_mask && is_y_direction {
*grad_mask += columns[channel]
* bilinear_interpolate(image, height, width, y, x);
}
}
}
}
});
let mask_gradient = mask.map(|_| {
let mut grad_mask = grad_mask
.into_shape_with_order((offset_channels / 2, batch_size, out_h, out_w))
.unwrap();
grad_mask.swap_axes(0, 1);
NdArrayTensor::new(grad_mask.into_dyn().into_shared())
});
let mut grad_offset = grad_offset
.into_shape_with_order((offset_channels, batch_size, out_h, out_w))
.unwrap();
grad_offset.swap_axes(0, 1);
let offset_gradient = NdArrayTensor::new(grad_offset.into_dyn().into_shared());
(offset_gradient, mask_gradient)
}
fn get_coordinate_weight<F: FloatNdArrayElement>(
input: ArrayView2<F>,
height: usize,
width: usize,
y: F,
x: F,
is_y_direction: bool,
) -> F {
let y = y.to_f32();
let x = x.to_f32();
let y_low = f32::floor(y);
let x_low = f32::floor(x);
let y_high = y_low + 1.;
let x_high = x_low + 1.;
let valid_y_low = y_low >= 0. && y_low < height as f32;
let valid_y_high = y_high >= 0. && y_high < height as f32;
let valid_x_low = x_low >= 0. && x_low < width as f32;
let valid_x_high = x_high >= 0. && x_high < width as f32;
let bottom_left = if valid_y_low && valid_x_low {
input[[y_low as usize, x_low as usize]]
} else {
F::zero()
};
let bottom_right = if valid_y_low && valid_x_high {
input[[y_low as usize, x_high as usize]]
} else {
F::zero()
};
let top_left = if valid_y_high && valid_x_low {
input[[y_high as usize, x_low as usize]]
} else {
F::zero()
};
let top_right = if valid_y_high && valid_x_high {
input[[y_high as usize, x_high as usize]]
} else {
F::zero()
};
if is_y_direction {
let delta_x = F::from_elem(x - x_low);
delta_x * (top_right - bottom_right) + (F::one() - delta_x) * (top_left - bottom_left)
} else {
let delta_y = F::from_elem(y - y_low);
delta_y * (top_right - top_left) + (F::one() - delta_y) * (bottom_right - bottom_left)
}
}
fn compute_input_grad<F: FloatNdArrayElement>(
columns: ArrayView6<F>,
offset: ArrayView4<F>,
mask: Option<ArrayView6<F>>,
args: &DeformConvOptions<2>,
kernel_dims: (usize, usize),
input_shape: (usize, usize, usize, usize),
) -> NdArrayTensor<F, 4> {
let (batch_size, in_channels, height, width) = input_shape;
let (kernel_h, kernel_w) = kernel_dims;
let offs_groups = args.offset_groups;
let channels_per_offset_group = in_channels / offs_groups;
let grad_in =
Array4::from_shape_simple_fn((batch_size, in_channels, height, width), || {
AtomicF32::new(0.0)
});
run_par!(|| {
iter_par!(columns.indexed_iter()).for_each(
|((in_channel, kernel_y, kernel_x, batch, out_y, out_x), col)| {
let group = in_channel / channels_per_offset_group;
let offset = offset.slice(s![batch, .., out_y, out_x]);
let offset = offset
.to_shape((offs_groups, kernel_h, kernel_w, 2))
.unwrap();
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
let offset = [offset[0], offset[1]];
let mask = mask
.as_ref()
.map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32());
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
- F::from_elem(args.padding[0])
+ offset[0];
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
- F::from_elem(args.padding[1])
+ offset[1];
let grad_in = grad_in.slice(s![batch, in_channel, .., ..]);
deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in);
},
)
});
let grad_in: Array1<F> = grad_in
.into_iter()
.map(|it| F::from_elem(it.into_inner()))
.collect();
let grad_in = grad_in
.into_shape_with_order((batch_size, in_channels, height, width))
.unwrap();
NdArrayTensor::new(grad_in.into_dyn().into_shared())
}
fn deform_col2img_kernel(
y: f32,
x: f32,
mask: Option<f32>,
col: f32,
grad_input: ArrayView2<AtomicF32>,
) {
let (height, width) = grad_input.dim();
let mask_value = mask.unwrap_or(1.0);
for dy in -1..=1 {
for dx in -1..=1 {
let yp = f32::floor(y) + dy as f32;
let xp = f32::floor(x) + dx as f32;
if yp >= 0.0
&& yp < height as f32
&& xp >= 0.0
&& xp < width as f32
&& f32::abs(y - yp) < 1.0
&& f32::abs(x - xp) < 1.0
{
let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp));
#[cfg_attr(not(target_has_atomic = "32"), allow(unused))]
let value = mask_value * weight * col;
#[cfg(target_has_atomic = "32")]
grad_input[[yp as usize, xp as usize]].fetch_add(value, Ordering::AcqRel);
#[cfg(not(target_has_atomic = "32"))]
panic!("Can't use deformable convolution backwards pass without atomics");
}
}
}
}
}

View File

@ -9,6 +9,7 @@ mod tensor;
pub(crate) mod adaptive_avgpool;
pub(crate) mod avgpool;
pub(crate) mod conv;
pub(crate) mod deform_conv;
pub(crate) mod interpolate;
pub(crate) mod macros;
pub(crate) mod matmul;

View File

@ -2,6 +2,7 @@ use super::{
adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
avgpool::{avg_pool2d, avg_pool2d_backward},
conv::{conv2d, conv3d, conv_transpose2d, conv_transpose3d},
deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
};
@ -19,6 +20,29 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
conv2d::<E, Q>(x, weight, bias, options)
}
fn deform_conv2d(
x: NdArrayTensor<E, 4>,
offset: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
mask: Option<NdArrayTensor<E, 4>>,
bias: Option<NdArrayTensor<E, 1>>,
options: DeformConvOptions<2>,
) -> NdArrayTensor<E, 4> {
deform_conv2d::<E>(x, offset, weight, mask, bias, options)
}
fn deform_conv2d_backward(
x: NdArrayTensor<E, 4>,
offset: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,
mask: Option<NdArrayTensor<E, 4>>,
bias: Option<NdArrayTensor<E, 1>>,
output_grad: NdArrayTensor<E, 4>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options)
}
fn conv_transpose2d(
x: NdArrayTensor<E, 4>,
weight: NdArrayTensor<E, 4>,

View File

@ -1,7 +1,7 @@
use crate::{element::TchElement, LibTorch, QuantElement, TchTensor};
use burn_tensor::ops::{
ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices,
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateMode,
InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
};
impl<E: TchElement, Q: QuantElement> ModuleOps<Self> for LibTorch<E, Q> {
@ -86,6 +86,29 @@ impl<E: TchElement, Q: QuantElement> ModuleOps<Self> for LibTorch<E, Q> {
TchTensor::new(tensor)
}
fn deform_conv2d(
_x: TchTensor<E, 4>,
_offset: TchTensor<E, 4>,
_weight: TchTensor<E, 4>,
_mask: Option<TchTensor<E, 4>>,
_bias: Option<TchTensor<E, 1>>,
_options: DeformConvOptions<2>,
) -> TchTensor<E, 4> {
unimplemented!("Torch bindings don't support deform_conv2d");
}
fn deform_conv2d_backward(
_x: TchTensor<E, 4>,
_offset: TchTensor<E, 4>,
_weight: TchTensor<E, 4>,
_mask: Option<TchTensor<E, 4>>,
_bias: Option<TchTensor<E, 1>>,
_out_grad: TchTensor<E, 4>,
_options: DeformConvOptions<2>,
) -> DeformConv2dBackward<Self> {
unimplemented!("Torch bindings don't support deform_conv2d");
}
fn conv_transpose1d(
x: TchTensor<E, 3>,
weight: TchTensor<E, 3>,

View File

@ -2,7 +2,9 @@ use serde::{Deserialize, Serialize};
use std::ops::Range;
use crate::{
ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions},
ops::{
ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions,
},
repr::tensor::TensorDescription,
DType, Distribution, Element,
};
@ -74,6 +76,10 @@ pub enum ModuleOperationDescription {
Conv2d(Conv2dDescription),
/// Operation corresponding to [conv3d](crate::ops::ModuleOps::conv3d).
Conv3d(Conv3dDescription),
/// Operation corresponding to [deform_conv2d](crate::ops::ModuleOps::deform_conv2d)
DeformableConv2d(Box<DeformConv2dDescription>),
/// Operation corresponding to [deform_conv2d_backward](crate::ops::ModuleOps::deform_conv2d_backward)
DeformableConv2dBackward(Box<DeformConv2dBackwardDescription>),
/// Operation corresponding to [conv transpose 1d](crate::ops::ModuleOps::conv_transpose1d).
ConvTranspose1d(ConvTranspose1dDescription),
/// Operation corresponding to [conv transpose 2d](crate::ops::ModuleOps::conv_transpose2d).
@ -688,6 +694,35 @@ pub struct Conv2dDescription {
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct DeformConv2dDescription {
pub x: TensorDescription,
pub offset: TensorDescription,
pub weight: TensorDescription,
pub mask: Option<TensorDescription>,
pub bias: Option<TensorDescription>,
pub options: DeformableConv2dOptionsDescription,
pub out: TensorDescription,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct DeformConv2dBackwardDescription {
pub x: TensorDescription,
pub offset: TensorDescription,
pub weight: TensorDescription,
pub mask: Option<TensorDescription>,
pub bias: Option<TensorDescription>,
pub out_grad: TensorDescription,
pub options: DeformableConv2dOptionsDescription,
pub input_grad: TensorDescription,
pub offset_grad: TensorDescription,
pub weight_grad: TensorDescription,
pub mask_grad: Option<TensorDescription>,
pub bias_grad: Option<TensorDescription>,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv3dDescription {
@ -746,6 +781,16 @@ pub struct Conv2dOptionsDescription {
pub groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct DeformableConv2dOptionsDescription {
pub stride: [usize; 2],
pub padding: [usize; 2],
pub dilation: [usize; 2],
pub weight_groups: usize,
pub offset_groups: usize,
}
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct Conv3dOptionsDescription {
@ -818,6 +863,18 @@ impl From<ConvOptions<3>> for Conv3dOptionsDescription {
}
}
impl From<DeformConvOptions<2>> for DeformableConv2dOptionsDescription {
fn from(value: DeformConvOptions<2>) -> Self {
Self {
stride: value.stride,
padding: value.padding,
dilation: value.dilation,
weight_groups: value.weight_groups,
offset_groups: value.offset_groups,
}
}
}
impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsDescription {
fn from(value: ConvTransposeOptions<1>) -> Self {
Self {
@ -887,6 +944,18 @@ impl From<Conv3dOptionsDescription> for ConvOptions<3> {
}
}
impl From<DeformableConv2dOptionsDescription> for DeformConvOptions<2> {
fn from(value: DeformableConv2dOptionsDescription) -> Self {
DeformConvOptions {
stride: value.stride,
padding: value.padding,
dilation: value.dilation,
weight_groups: value.weight_groups,
offset_groups: value.offset_groups,
}
}
}
impl From<ConvTranspose1dOptionsDescription> for ConvTransposeOptions<1> {
fn from(val: ConvTranspose1dOptionsDescription) -> Self {
ConvTransposeOptions {
@ -1404,6 +1473,22 @@ impl ModuleOperationDescription {
vec![&desc.x, &desc.weight, &desc.out]
}
}
ModuleOperationDescription::DeformableConv2d(desc) => match (&desc.mask, &desc.bias) {
(Some(mask), Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias],
(Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask],
(None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias],
(None, None) => vec![&desc.x, &desc.offset, &desc.weight],
},
ModuleOperationDescription::DeformableConv2dBackward(desc) => {
match (&desc.mask, &desc.bias) {
(Some(mask), Some(bias)) => {
vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias]
}
(Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask],
(None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias],
(None, None) => vec![&desc.x, &desc.offset, &desc.weight],
}
}
ModuleOperationDescription::ConvTranspose1d(desc) => {
if let Some(bias) = &desc.bias {
vec![&desc.x, &desc.weight, &bias, &desc.out]

View File

@ -4,6 +4,8 @@ use crate::{
Int, Tensor, TensorPrimitive,
};
use super::ops::DeformConvOptions;
/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
where
@ -69,6 +71,28 @@ where
)))
}
/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d).
pub fn deform_conv2d<B>(
x: Tensor<B, 4>,
offset: Tensor<B, 4>,
weight: Tensor<B, 4>,
mask: Option<Tensor<B, 4>>,
bias: Option<Tensor<B, 1>>,
options: DeformConvOptions<2>,
) -> Tensor<B, 4>
where
B: Backend,
{
Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
x.primitive.tensor(),
offset.primitive.tensor(),
weight.primitive.tensor(),
mask.map(|m| m.primitive.tensor()),
bias.map(|b| b.primitive.tensor()),
options,
)))
}
/// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d).
pub fn conv_transpose1d<B>(
x: Tensor<B, 3>,

View File

@ -5,6 +5,51 @@ use crate::{
Shape,
};
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
#[derive(new)]
pub struct Conv2dBackward<B: Backend> {
/// Gradient.
pub x_grad: FloatTensor<B, 4>,
/// Weights gradient.
pub weights_grad: FloatTensor<B, 4>,
/// Bias gradient.
pub bias_grad: Option<FloatTensor<B, 1>>,
}
/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
#[derive(new)]
pub struct DeformConv2dBackward<B: Backend> {
/// Gradient.
pub x_grad: FloatTensor<B, 4>,
/// Offset gradient.
pub offset_grad: FloatTensor<B, 4>,
/// Weights gradient.
pub weight_grad: FloatTensor<B, 4>,
/// Mask gradient.
pub mask_grad: Option<FloatTensor<B, 4>>,
/// Bias gradient.
pub bias_grad: Option<FloatTensor<B, 1>>,
}
/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
#[derive(new)]
pub struct Conv3dBackward<B: Backend> {
/// Gradient.
pub x_grad: FloatTensor<B, 5>,
/// Weights gradient.
pub weights_grad: FloatTensor<B, 5>,
/// Bias gradient.
pub bias_grad: Option<FloatTensor<B, 1>>,
}
/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
#[derive(new)]
pub struct MaxPool1dBackward<B: Backend> {
@ -55,6 +100,25 @@ pub struct ConvOptions<const N: usize> {
pub groups: usize,
}
/// Convolution options.
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
pub struct DeformConvOptions<const N: usize> {
/// Stride.
pub stride: [usize; N],
/// Padding.
pub padding: [usize; N],
/// Dilation.
pub dilation: [usize; N],
/// Weight Groups.
pub weight_groups: usize,
/// Offset Groups.
pub offset_groups: usize,
}
/// Transposed convolution options.
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
pub struct ConvTransposeOptions<const N: usize> {
@ -248,6 +312,33 @@ pub trait ModuleOps<B: Backend> {
) -> FloatTensor<B, 1> {
conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
}
/// Two dimensional deformable convolution.
///
/// # Shapes
///
/// x: `[batch_size, channels_in, height, width]`,
/// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
/// bias: `[channels_out]`,
fn deform_conv2d(
x: FloatTensor<B, 4>,
offset: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
mask: Option<FloatTensor<B, 4>>,
bias: Option<FloatTensor<B, 1>>,
options: DeformConvOptions<2>,
) -> FloatTensor<B, 4>;
/// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
fn deform_conv2d_backward(
x: FloatTensor<B, 4>,
offset: FloatTensor<B, 4>,
weight: FloatTensor<B, 4>,
mask: Option<FloatTensor<B, 4>>,
bias: Option<FloatTensor<B, 1>>,
output_grad: FloatTensor<B, 4>,
options: DeformConvOptions<2>,
) -> DeformConv2dBackward<B>;
/// Three dimensional convolution.
///
/// # Shapes

View File

@ -26,6 +26,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_module_conv1d!();
burn_tensor::testgen_module_conv2d!();
burn_tensor::testgen_module_conv3d!();
burn_tensor::testgen_module_deform_conv2d!();
burn_tensor::testgen_module_conv_transpose1d!();
burn_tensor::testgen_module_conv_transpose2d!();
burn_tensor::testgen_module_conv_transpose3d!();

View File

@ -0,0 +1,439 @@
#[burn_tensor_testgen::testgen(module_deform_conv2d)]
mod tests {
use super::*;
use burn_tensor::module::deform_conv2d;
use burn_tensor::ops::{DeformConv2dBackward, DeformConvOptions, ModuleOps};
use burn_tensor::{Shape, Tensor};
#[test]
fn test_deform_conv2d_simple() {
let test = DeformConv2dTestCase {
batch_size: 1,
channels_in: 3,
channels_out: 5,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
weight_groups: 1,
offset_groups: 1,
height: 4,
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
[[0.9074, 0.6387], [0.5160, 0.4196]],
[[2.4259, 1.8008], [1.5449, 1.3112]],
[[3.9444, 2.9629], [2.5738, 2.2027]],
[[5.4629, 4.1250], [3.6027, 3.0943]],
[[6.9814, 5.2871], [4.6316, 3.9859]],
]]));
}
#[test]
fn test_deform_conv2d_batched() {
let test = DeformConv2dTestCase {
batch_size: 2,
channels_in: 3,
channels_out: 5,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
weight_groups: 1,
offset_groups: 1,
height: 4,
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([
[
[[0.2155, 0.1928], [0.1934, 0.1755]],
[[0.7251, 0.6759], [0.6877, 0.6485]],
[[1.2347, 1.1590], [1.1821, 1.1215]],
[[1.7443, 1.6421], [1.6764, 1.5945]],
[[2.2539, 2.1252], [2.1708, 2.0675]],
],
[
[[1.6530, 1.1369], [0.9840, 0.7184]],
[[4.8368, 3.4725], [3.1773, 2.4180]],
[[8.0206, 5.8080], [5.3705, 4.1176]],
[[11.2045, 8.1435], [7.5637, 5.8173]],
[[14.3883, 10.4790], [9.7570, 7.5169]],
],
]))
}
#[test]
fn test_deform_conv2d_weight_groups() {
let test = DeformConv2dTestCase {
batch_size: 1,
channels_in: 3,
channels_out: 6,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
weight_groups: 3,
offset_groups: 1,
height: 4,
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
[[0.1018, 0.0658], [0.0467, 0.0362]],
[[0.4125, 0.3367], [0.3069, 0.2824]],
[[1.3076, 1.0242], [0.9025, 0.8000]],
[[1.8405, 1.4581], [1.2994, 1.1588]],
[[3.4022, 2.6346], [2.3052, 2.0143]],
[[4.1574, 3.2315], [2.8389, 2.4857]],
]]))
}
#[test]
fn test_deform_conv2d_offset_groups() {
let test = DeformConv2dTestCase {
batch_size: 1,
channels_in: 3,
channels_out: 6,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
weight_groups: 1,
offset_groups: 3,
height: 4,
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
[[1.0794, 0.7676], [0.7209, 0.5337]],
[[2.7059, 2.0216], [1.9740, 1.5419]],
[[4.3325, 3.2755], [3.2271, 2.5501]],
[[5.9590, 4.5295], [4.4802, 3.5582]],
[[7.5855, 5.7835], [5.7333, 4.5664]],
[[9.2120, 7.0375], [6.9864, 5.5746]],
]]))
}
#[test]
fn test_deform_conv2d_different_kernel_size() {
let test = DeformConv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 3,
kernel_size_2: 4,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
weight_groups: 1,
offset_groups: 1,
height: 4,
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
[[1.0669], [0.6329]],
[[2.9741], [2.0383]],
[[4.8812], [3.4437]],
]]))
}
#[test]
fn test_deform_conv2d_different_padding_size() {
let test = DeformConv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 2,
padding_2: 3,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
weight_groups: 1,
offset_groups: 1,
height: 4,
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
[
[
0.1998, 0.3762, 0.5285, 0.6053, 0.3844, 0.1987, 0.0481, 0.0000,
],
[
0.2879, 0.5517, 0.7776, 0.8905, 0.5805, 0.3043, 0.0796, 0.0000,
],
[
0.3729, 0.7214, 1.0137, 1.1520, 0.7564, 0.3931, 0.1016, 0.0000,
],
[
0.1321, 0.3249, 0.4954, 0.5846, 0.4531, 0.2501, 0.0757, 0.0000,
],
[
0.0593, 0.1607, 0.2448, 0.2971, 0.2395, 0.1327, 0.0471, 0.0000,
],
[
0.0143, 0.0513, 0.0783, 0.0942, 0.0813, 0.0420, 0.0145, 0.0000,
],
],
[
[
0.7667, 1.1648, 1.5219, 1.7111, 1.2305, 0.8076, 0.4504, 0.3333,
],
[
0.9812, 1.6010, 2.1525, 2.4409, 1.7455, 1.0918, 0.5367, 0.3333,
],
[
1.1964, 2.0448, 2.7853, 3.1522, 2.2426, 1.3513, 0.6049, 0.3333,
],
[
0.6695, 1.1781, 1.6441, 1.9022, 1.5732, 1.0339, 0.5536, 0.3333,
],
[
0.4950, 0.7861, 1.0398, 1.2047, 1.0523, 0.7439, 0.4834, 0.3333,
],
[
0.3788, 0.4982, 0.5929, 0.6542, 0.6155, 0.4882, 0.3909, 0.3333,
],
],
[
[
1.3335, 1.9534, 2.5154, 2.8170, 2.0766, 1.4165, 0.8527, 0.6667,
],
[
1.6744, 2.6503, 3.5275, 3.9914, 2.9106, 1.8794, 0.9939, 0.6667,
],
[
2.0198, 3.3683, 4.5570, 5.1525, 3.7288, 2.3095, 1.1082, 0.6667,
],
[
1.2068, 2.0314, 2.7928, 3.2198, 2.6932, 1.8178, 1.0315, 0.6667,
],
[
0.9308, 1.4116, 1.8348, 2.1124, 1.8652, 1.3551, 0.9196, 0.6667,
],
[
0.7432, 0.9451, 1.1074, 1.2143, 1.1497, 0.9345, 0.7673, 0.6667,
],
],
]]))
}
#[test]
fn test_deform_conv2d_different_stride() {
let test = DeformConv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 2,
dilation_1: 1,
dilation_2: 1,
weight_groups: 1,
offset_groups: 1,
height: 4,
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
[[1.0647], [0.5783]],
[[2.9289], [1.8829]],
[[4.7931], [3.1875]],
]]))
}
#[test]
fn test_deform_conv2d_different_dilation() {
let test = DeformConv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 2,
weight_groups: 1,
offset_groups: 1,
height: 5,
width: 5,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
[[0.6162], [0.7611], [0.4666]],
[[1.8578], [2.2684], [1.6208]],
[[3.0994], [3.7757], [2.7749]],
]]))
}
#[test]
fn test_deform_conv2d_different_width() {
let test = DeformConv2dTestCase {
batch_size: 1,
channels_in: 2,
channels_out: 3,
kernel_size_1: 3,
kernel_size_2: 3,
padding_1: 0,
padding_2: 0,
stride_1: 1,
stride_2: 1,
dilation_1: 1,
dilation_2: 1,
weight_groups: 1,
offset_groups: 1,
height: 6,
width: 4,
};
test.assert_output(Tensor::<TestBackend, 4>::from([[
[
[0.8909, 0.6016],
[1.0697, 0.7186],
[1.2618, 0.8433],
[0.6424, 0.5032],
],
[
[2.4670, 1.8168],
[2.9529, 2.1497],
[3.4805, 2.5090],
[2.0925, 1.7411],
],
[
[4.0432, 3.0321],
[4.8362, 3.5809],
[5.6992, 4.1746],
[3.5425, 2.9790],
],
]]))
}
struct DeformConv2dTestCase {
batch_size: usize,
channels_in: usize,
channels_out: usize,
kernel_size_1: usize,
kernel_size_2: usize,
padding_1: usize,
padding_2: usize,
stride_1: usize,
stride_2: usize,
dilation_1: usize,
dilation_2: usize,
weight_groups: usize,
offset_groups: usize,
height: usize,
width: usize,
}
impl DeformConv2dTestCase {
fn assert_output(self, y: Tensor<TestBackend, 4>) {
let out_height =
(self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1)
/ self.stride_1
+ 1;
let out_width =
(self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1)
/ self.stride_2
+ 1;
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
let shape_weight = Shape::new([
self.channels_out,
self.channels_in / self.weight_groups,
self.kernel_size_1,
self.kernel_size_2,
]);
let shape_offset = Shape::new([
self.batch_size,
self.kernel_size_1 * self.kernel_size_2 * self.offset_groups * 2,
out_height,
out_width,
]);
let shape_mask = Shape::new([
self.batch_size,
self.kernel_size_1 * self.kernel_size_2 * self.offset_groups,
out_height,
out_width,
]);
let device = Default::default();
let weight = Tensor::<TestBackend, 4>::from(
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
.reshape(shape_weight.clone())
.into_data(),
)
.div_scalar(shape_weight.num_elements() as f32);
let bias = Tensor::<TestBackend, 1>::from(
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
)
.div_scalar(self.channels_out as f32);
let x = Tensor::<TestBackend, 4>::from(
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
.reshape(shape_x.clone())
.into_data(),
)
.div_scalar(shape_x.num_elements() as f32);
let offset = Tensor::<TestBackend, 4>::from(
TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device)
.reshape(shape_offset.clone())
.into_data(),
)
.div_scalar(shape_offset.num_elements() as f32);
let mask = Tensor::<TestBackend, 4>::from(
TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device)
.reshape(shape_mask.clone())
.into_data(),
)
.div_scalar(shape_mask.num_elements() as f32);
let output = deform_conv2d(
x,
offset,
weight,
Some(mask),
Some(bias),
DeformConvOptions::new(
[self.stride_1, self.stride_2],
[self.padding_1, self.padding_2],
[self.dilation_1, self.dilation_2],
self.weight_groups,
self.offset_groups,
),
);
y.to_data().assert_approx_eq(&output.into_data(), 3);
}
}
}

View File

@ -10,6 +10,7 @@ mod conv3d;
mod conv_transpose1d;
mod conv_transpose2d;
mod conv_transpose3d;
mod deform_conv2d;
mod forward;
mod maxpool1d;
mod maxpool2d;

View File

@ -11,20 +11,22 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-wgpu"
version.workspace = true
[features]
default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"]
fusion = ["burn-fusion", "burn-jit/fusion"]
autotune = ["burn-jit/autotune"]
template = ["burn-jit/template", "cubecl/template"]
default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"]
doc = ["burn-jit/doc"]
std = ["burn-jit/std", "cubecl/std"]
fusion = ["burn-fusion", "burn-jit/fusion"]
simple-memory-management = ["cubecl/simple-memory-management"]
std = ["burn-jit/std", "cubecl/std"]
template = ["burn-jit/template", "cubecl/template"]
[dependencies]
cubecl = { workspace = true, features = ["wgpu"] }
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = ["cubecl-wgpu"] }
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false }
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [
"cubecl-wgpu",
] }
[dev-dependencies]
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false, features = [