mirror of https://github.com/tracel-ai/burn.git
Add deform_conv2d as implemented in torchvision (#2147)
This commit is contained in:
parent
f19e0c5393
commit
2c8514ce7f
|
@ -245,6 +245,12 @@ version = "1.1.2"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0"
|
||||
|
||||
[[package]]
|
||||
name = "atomic_float"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a"
|
||||
|
||||
[[package]]
|
||||
name = "atty"
|
||||
version = "0.2.14"
|
||||
|
@ -672,6 +678,7 @@ dependencies = [
|
|||
name = "burn-ndarray"
|
||||
version = "0.15.0"
|
||||
dependencies = [
|
||||
"atomic_float",
|
||||
"blas-src",
|
||||
"burn-autodiff",
|
||||
"burn-common",
|
||||
|
|
|
@ -27,6 +27,7 @@ readme = "README.md"
|
|||
version = "0.15.0"
|
||||
|
||||
[workspace.dependencies]
|
||||
atomic_float = "1"
|
||||
bytemuck = "1.18.0"
|
||||
candle-core = { version = "0.6.0" }
|
||||
clap = { version = "4.5.18", features = ["derive"] }
|
||||
|
|
|
@ -190,6 +190,7 @@ Burn comes with built-in modules that you can use to build your own modules.
|
|||
| `ConvTranspose1d` | `nn.ConvTranspose1d` |
|
||||
| `ConvTranspose2d` | `nn.ConvTranspose2d` |
|
||||
| `ConvTranspose3d` | `nn.ConvTranspose3d` |
|
||||
| `DeformConv2d` | `torchvision.ops.DeformConv2d` |
|
||||
|
||||
### Pooling
|
||||
|
||||
|
|
|
@ -441,6 +441,343 @@ impl<B: Backend, C: CheckpointStrategy> ModuleOps<Autodiff<B, C>> for Autodiff<B
|
|||
}
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: AutodiffTensor<B, 4>,
|
||||
offset: AutodiffTensor<B, 4>,
|
||||
weight: AutodiffTensor<B, 4>,
|
||||
mask: Option<AutodiffTensor<B, 4>>,
|
||||
bias: Option<AutodiffTensor<B, 1>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> AutodiffTensor<B, 4> {
|
||||
#[derive(Debug)]
|
||||
struct DeformConv2DWithMaskWithBias;
|
||||
#[derive(Debug)]
|
||||
struct DeformConv2DWithMaskNoBias;
|
||||
#[derive(Debug)]
|
||||
struct DeformConv2DNoMaskWithBias;
|
||||
#[derive(Debug)]
|
||||
struct DeformConv2DNoMaskNoBias;
|
||||
|
||||
impl<B: Backend> Backward<B, 4, 5> for DeformConv2DWithMaskWithBias {
|
||||
type State = (NodeID, NodeID, NodeID, NodeID, NodeID, DeformConvOptions<2>);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 5>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let [node_x, node_offset, node_weight, node_mask, node_bias] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x_state, offset_state, weight_state, mask_state, bias_state, options) =
|
||||
ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let offset = checkpointer.retrieve_node_output(offset_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
let mask = Some(checkpointer.retrieve_node_output(mask_state));
|
||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
||||
|
||||
let backward =
|
||||
B::deform_conv2d_backward(x, offset, weight, mask, bias, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_offset {
|
||||
grads.register::<B, 4>(node.id, backward.offset_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node.id, backward.weight_grad)
|
||||
}
|
||||
if let Some(node) = node_mask {
|
||||
grads.register::<B, 4>(node.id, backward.mask_grad.unwrap())
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Backward<B, 4, 4> for DeformConv2DWithMaskNoBias {
|
||||
type State = (NodeID, NodeID, NodeID, NodeID, DeformConvOptions<2>);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 4>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let [node_x, node_offset, node_weight, node_mask] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x_state, offset_state, weight_state, mask_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let offset = checkpointer.retrieve_node_output(offset_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
let mask = Some(checkpointer.retrieve_node_output(mask_state));
|
||||
|
||||
let backward =
|
||||
B::deform_conv2d_backward(x, offset, weight, mask, None, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_offset {
|
||||
grads.register::<B, 4>(node.id, backward.offset_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node.id, backward.weight_grad)
|
||||
}
|
||||
if let Some(node) = node_mask {
|
||||
grads.register::<B, 4>(node.id, backward.mask_grad.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Backward<B, 4, 4> for DeformConv2DNoMaskWithBias {
|
||||
type State = (NodeID, NodeID, NodeID, NodeID, DeformConvOptions<2>);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 4>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let [node_x, node_offset, node_weight, node_bias] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x_state, offset_state, weight_state, bias_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let offset = checkpointer.retrieve_node_output(offset_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
let bias = Some(checkpointer.retrieve_node_output(bias_state));
|
||||
|
||||
let backward =
|
||||
B::deform_conv2d_backward(x, offset, weight, None, bias, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_offset {
|
||||
grads.register::<B, 4>(node.id, backward.offset_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node.id, backward.weight_grad)
|
||||
}
|
||||
if let Some(node) = node_bias {
|
||||
grads.register::<B, 1>(node.id, backward.bias_grad.unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Backward<B, 4, 3> for DeformConv2DNoMaskNoBias {
|
||||
type State = (NodeID, NodeID, NodeID, DeformConvOptions<2>);
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 3>,
|
||||
grads: &mut Gradients,
|
||||
checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let [node_x, node_offset, node_weight] = ops.parents;
|
||||
let grad = grads.consume::<B, 4>(&ops.node);
|
||||
|
||||
let (x_state, offset_state, weight_state, options) = ops.state;
|
||||
let x = checkpointer.retrieve_node_output(x_state);
|
||||
let offset = checkpointer.retrieve_node_output(offset_state);
|
||||
let weight = checkpointer.retrieve_node_output(weight_state);
|
||||
|
||||
let backward =
|
||||
B::deform_conv2d_backward(x, offset, weight, None, None, grad, options);
|
||||
|
||||
if let Some(node) = node_x {
|
||||
grads.register::<B, 4>(node.id, backward.x_grad)
|
||||
}
|
||||
if let Some(node) = node_offset {
|
||||
grads.register::<B, 4>(node.id, backward.offset_grad)
|
||||
}
|
||||
if let Some(node) = node_weight {
|
||||
grads.register::<B, 4>(node.id, backward.weight_grad)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
match (mask, bias) {
|
||||
(Some(mask), Some(bias)) => match DeformConv2DWithMaskWithBias
|
||||
.prepare::<C>([
|
||||
x.node.clone(),
|
||||
offset.node.clone(),
|
||||
weight.node.clone(),
|
||||
mask.node.clone(),
|
||||
bias.node.clone(),
|
||||
])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let x_state = prep.checkpoint(&x);
|
||||
let offset_state = prep.checkpoint(&offset);
|
||||
let weight_state = prep.checkpoint(&weight);
|
||||
let mask_state = prep.checkpoint(&mask);
|
||||
let bias_state = prep.checkpoint(&bias);
|
||||
prep.finish(
|
||||
(
|
||||
x_state,
|
||||
offset_state,
|
||||
weight_state,
|
||||
mask_state,
|
||||
bias_state,
|
||||
options.clone(),
|
||||
),
|
||||
B::deform_conv2d(
|
||||
x.primitive,
|
||||
offset.primitive,
|
||||
weight.primitive,
|
||||
Some(mask.primitive),
|
||||
Some(bias.primitive),
|
||||
options,
|
||||
),
|
||||
)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
|
||||
x.primitive,
|
||||
offset.primitive,
|
||||
weight.primitive,
|
||||
Some(mask.primitive),
|
||||
Some(bias.primitive),
|
||||
options,
|
||||
)),
|
||||
},
|
||||
(Some(mask), None) => match DeformConv2DWithMaskNoBias
|
||||
.prepare::<C>([
|
||||
x.node.clone(),
|
||||
offset.node.clone(),
|
||||
weight.node.clone(),
|
||||
mask.node.clone(),
|
||||
])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let x_state = prep.checkpoint(&x);
|
||||
let offset_state = prep.checkpoint(&offset);
|
||||
let weight_state = prep.checkpoint(&weight);
|
||||
let mask_state = prep.checkpoint(&mask);
|
||||
prep.finish(
|
||||
(
|
||||
x_state,
|
||||
offset_state,
|
||||
weight_state,
|
||||
mask_state,
|
||||
options.clone(),
|
||||
),
|
||||
B::deform_conv2d(
|
||||
x.primitive,
|
||||
offset.primitive,
|
||||
weight.primitive,
|
||||
Some(mask.primitive),
|
||||
None,
|
||||
options,
|
||||
),
|
||||
)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
|
||||
x.primitive,
|
||||
offset.primitive,
|
||||
weight.primitive,
|
||||
Some(mask.primitive),
|
||||
None,
|
||||
options,
|
||||
)),
|
||||
},
|
||||
(None, Some(bias)) => match DeformConv2DNoMaskWithBias
|
||||
.prepare::<C>([
|
||||
x.node.clone(),
|
||||
offset.node.clone(),
|
||||
weight.node.clone(),
|
||||
bias.node.clone(),
|
||||
])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let x_state = prep.checkpoint(&x);
|
||||
let offset_state = prep.checkpoint(&offset);
|
||||
let weight_state = prep.checkpoint(&weight);
|
||||
let bias_state = prep.checkpoint(&bias);
|
||||
prep.finish(
|
||||
(
|
||||
x_state,
|
||||
offset_state,
|
||||
weight_state,
|
||||
bias_state,
|
||||
options.clone(),
|
||||
),
|
||||
B::deform_conv2d(
|
||||
x.primitive,
|
||||
offset.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
Some(bias.primitive),
|
||||
options,
|
||||
),
|
||||
)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
|
||||
x.primitive,
|
||||
offset.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
Some(bias.primitive),
|
||||
options,
|
||||
)),
|
||||
},
|
||||
(None, None) => match DeformConv2DNoMaskNoBias
|
||||
.prepare::<C>([x.node.clone(), offset.node.clone(), weight.node.clone()])
|
||||
.compute_bound()
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(mut prep) => {
|
||||
let x_state = prep.checkpoint(&x);
|
||||
let offset_state = prep.checkpoint(&offset);
|
||||
let weight_state = prep.checkpoint(&weight);
|
||||
prep.finish(
|
||||
(x_state, offset_state, weight_state, options.clone()),
|
||||
B::deform_conv2d(
|
||||
x.primitive,
|
||||
offset.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
None,
|
||||
options,
|
||||
),
|
||||
)
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d(
|
||||
x.primitive,
|
||||
offset.primitive,
|
||||
weight.primitive,
|
||||
None,
|
||||
None,
|
||||
options,
|
||||
)),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
_x: AutodiffTensor<B, 4>,
|
||||
_offset: AutodiffTensor<B, 4>,
|
||||
_weight: AutodiffTensor<B, 4>,
|
||||
_mask: Option<AutodiffTensor<B, 4>>,
|
||||
_bias: Option<AutodiffTensor<B, 1>>,
|
||||
_output_grad: AutodiffTensor<B, 4>,
|
||||
_options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
panic!("Can't differentiate deform conv 2d backward.");
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
x: AutodiffTensor<B, 4>,
|
||||
weight: AutodiffTensor<B, 4>,
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -21,6 +21,7 @@ mod conv_transpose2d;
|
|||
mod conv_transpose3d;
|
||||
mod cos;
|
||||
mod cross_entropy;
|
||||
mod deform_conv2d;
|
||||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
|
@ -82,6 +83,8 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_conv1d!();
|
||||
burn_autodiff::testgen_ad_conv2d!();
|
||||
burn_autodiff::testgen_ad_conv3d!();
|
||||
#[cfg(not(target_os = "macos"))] // Wgpu on MacOS currently doesn't support atomic compare exchange
|
||||
burn_autodiff::testgen_ad_deform_conv2d!();
|
||||
burn_autodiff::testgen_ad_conv_transpose1d!();
|
||||
burn_autodiff::testgen_ad_conv_transpose2d!();
|
||||
burn_autodiff::testgen_ad_conv_transpose3d!();
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
use burn_tensor::{
|
||||
ops::{
|
||||
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateMode,
|
||||
InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, UnfoldOptions,
|
||||
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
|
||||
IntTensor, InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices,
|
||||
ModuleOps, UnfoldOptions,
|
||||
},
|
||||
Shape,
|
||||
};
|
||||
|
@ -77,6 +78,29 @@ impl<F: FloatCandleElement, I: IntCandleElement> ModuleOps<Self> for Candle<F, I
|
|||
})
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: FloatTensor<Self, 4>,
|
||||
offset: FloatTensor<Self, 4>,
|
||||
weight: FloatTensor<Self, 4>,
|
||||
mask: Option<FloatTensor<Self, 4>>,
|
||||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
unimplemented!("Candle does not support deformable convolutions")
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
x: FloatTensor<Self, 4>,
|
||||
offset: FloatTensor<Self, 4>,
|
||||
weight: FloatTensor<Self, 4>,
|
||||
mask: Option<FloatTensor<Self, 4>>,
|
||||
bias: Option<FloatTensor<Self, 1>>,
|
||||
output_grad: FloatTensor<Self, 4>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
unimplemented!("Candle does not support deformable convolutions")
|
||||
}
|
||||
|
||||
fn conv3d(
|
||||
x: FloatTensor<Self, 5>,
|
||||
weight: FloatTensor<Self, 5>,
|
||||
|
|
|
@ -0,0 +1,263 @@
|
|||
use alloc::format;
|
||||
use burn_tensor::ops::DeformConvOptions;
|
||||
|
||||
use crate as burn;
|
||||
|
||||
use crate::config::Config;
|
||||
use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param};
|
||||
use crate::nn::Initializer;
|
||||
use crate::nn::PaddingConfig2d;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::module::deform_conv2d;
|
||||
use crate::tensor::Tensor;
|
||||
|
||||
use crate::nn::conv::checks;
|
||||
|
||||
/// Configuration to create a [deformable 2D convolution](DeformConv2d) layer, using the [init function](DeformConv2dConfig::init).
|
||||
#[derive(Config, Debug)]
|
||||
pub struct DeformConv2dConfig {
|
||||
/// The number of channels.
|
||||
pub channels: [usize; 2],
|
||||
/// The size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// The stride of the convolution.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub stride: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
#[config(default = "[1, 1]")]
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
#[config(default = "1")]
|
||||
pub weight_groups: usize,
|
||||
/// Offset groups.
|
||||
#[config(default = "1")]
|
||||
pub offset_groups: usize,
|
||||
/// The padding configuration.
|
||||
#[config(default = "PaddingConfig2d::Valid")]
|
||||
pub padding: PaddingConfig2d,
|
||||
/// If bias should be added to the output.
|
||||
#[config(default = true)]
|
||||
pub bias: bool,
|
||||
/// The type of function used to initialize neural network parameters
|
||||
#[config(
|
||||
default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}"
|
||||
)]
|
||||
pub initializer: Initializer,
|
||||
}
|
||||
|
||||
/// Applies a deformable 2D convolution over input tensors.
|
||||
///
|
||||
/// Should be created with [DeformConv2dConfig].
|
||||
#[derive(Module, Debug)]
|
||||
#[module(custom_display)]
|
||||
pub struct DeformConv2d<B: Backend> {
|
||||
/// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]`
|
||||
pub weight: Param<Tensor<B, 4>>,
|
||||
/// Tensor of shape `[channels_out]`
|
||||
pub bias: Option<Param<Tensor<B, 1>>>,
|
||||
/// Stride of the convolution.
|
||||
pub stride: [usize; 2],
|
||||
/// Size of the kernel.
|
||||
pub kernel_size: [usize; 2],
|
||||
/// Spacing between kernel elements.
|
||||
pub dilation: [usize; 2],
|
||||
/// Controls the connections between input and output channels.
|
||||
pub weight_groups: usize,
|
||||
/// Offset groups.
|
||||
pub offset_groups: usize,
|
||||
/// The padding configuration.
|
||||
pub padding: Ignored<PaddingConfig2d>,
|
||||
}
|
||||
|
||||
impl DeformConv2dConfig {
|
||||
/// Initialize a new [DeformConv2d](DeformConv2d) module.
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> DeformConv2d<B> {
|
||||
checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups);
|
||||
|
||||
let shape = [
|
||||
self.channels[1],
|
||||
self.channels[0] / self.weight_groups,
|
||||
self.kernel_size[0],
|
||||
self.kernel_size[1],
|
||||
];
|
||||
|
||||
let k = self.kernel_size.iter().product::<usize>();
|
||||
let fan_in = self.channels[0] / self.weight_groups * k;
|
||||
let fan_out = self.channels[1] / self.weight_groups * k;
|
||||
|
||||
let weight = self
|
||||
.initializer
|
||||
.init_with(shape, Some(fan_in), Some(fan_out), device);
|
||||
let mut bias = None;
|
||||
|
||||
if self.bias {
|
||||
bias = Some(self.initializer.init_with(
|
||||
[self.channels[1]],
|
||||
Some(fan_in),
|
||||
Some(fan_out),
|
||||
device,
|
||||
));
|
||||
}
|
||||
|
||||
DeformConv2d {
|
||||
weight,
|
||||
bias,
|
||||
stride: self.stride,
|
||||
kernel_size: self.kernel_size,
|
||||
dilation: self.dilation,
|
||||
padding: Ignored(self.padding.clone()),
|
||||
weight_groups: self.weight_groups,
|
||||
offset_groups: self.weight_groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ModuleDisplay for DeformConv2d<B> {
|
||||
fn custom_settings(&self) -> Option<DisplaySettings> {
|
||||
DisplaySettings::new()
|
||||
.with_new_line_after_attribute(false)
|
||||
.optional()
|
||||
}
|
||||
|
||||
fn custom_content(&self, content: Content) -> Option<Content> {
|
||||
// Since padding does not implement ModuleDisplay, we need to format it manually.
|
||||
let padding_formatted = format!("{}", &self.padding);
|
||||
|
||||
// Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed.
|
||||
let stride = format!("{:?}", self.stride);
|
||||
let kernel_size = format!("{:?}", self.kernel_size);
|
||||
let dilation = format!("{:?}", self.dilation);
|
||||
|
||||
content
|
||||
.add("stride", &stride)
|
||||
.add("kernel_size", &kernel_size)
|
||||
.add("dilation", &dilation)
|
||||
.add("weight_groups", &self.weight_groups)
|
||||
.add("offset_groups", &self.offset_groups)
|
||||
.add("padding", &padding_formatted)
|
||||
.optional()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> DeformConv2d<B> {
|
||||
/// Applies the forward pass on the input tensor.
|
||||
///
|
||||
/// See [deform_conv2d](crate::tensor::module::deform_conv2d) for more information.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// - input: `[batch_size, channels_in, height_in, width_in]`
|
||||
/// - offset: `[batch_size, 2 * offset_groups * kernel_height * kernel_width, height_out, width_out]`
|
||||
/// - mask: `[batch_size, offset_groups * kernel_height * kernel_width, height_out, width_out]`
|
||||
/// - output: `[batch_size, channels_out, height_out, width_out]`
|
||||
pub fn forward(
|
||||
&self,
|
||||
input: Tensor<B, 4>,
|
||||
offset: Tensor<B, 4>,
|
||||
mask: Option<Tensor<B, 4>>,
|
||||
) -> Tensor<B, 4> {
|
||||
let [_batch_size, _channels_in, height_in, width_in] = input.dims();
|
||||
let padding =
|
||||
self.padding
|
||||
.calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride);
|
||||
deform_conv2d(
|
||||
input,
|
||||
offset,
|
||||
self.weight.val(),
|
||||
mask,
|
||||
self.bias.as_ref().map(|bias| bias.val()),
|
||||
DeformConvOptions::new(
|
||||
self.stride,
|
||||
padding,
|
||||
self.dilation,
|
||||
self.weight_groups,
|
||||
self.offset_groups,
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tensor::TensorData;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn initializer_default() {
|
||||
TestBackend::seed(0);
|
||||
|
||||
let config = DeformConv2dConfig::new([5, 1], [5, 5]);
|
||||
let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64;
|
||||
let k = (config.offset_groups as f64 / k).sqrt() as f32;
|
||||
let device = Default::default();
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
conv.weight.to_data().assert_within_range(-k..k);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_zeros() {
|
||||
TestBackend::seed(0);
|
||||
|
||||
let config = DeformConv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros);
|
||||
let device = Default::default();
|
||||
let conv = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, Initializer::Zeros);
|
||||
conv.weight
|
||||
.to_data()
|
||||
.assert_approx_eq(&TensorData::zeros::<f32, _>(conv.weight.shape()), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_fan_out() {
|
||||
TestBackend::seed(0);
|
||||
|
||||
let init = Initializer::KaimingUniform {
|
||||
gain: 1.0 / 3.0f64.sqrt(),
|
||||
fan_out_only: true, // test that fan_out is passed to `init_with()`
|
||||
};
|
||||
let device = Default::default();
|
||||
let config = DeformConv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone());
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn initializer_fan_with_groups_is_valid() {
|
||||
TestBackend::seed(0);
|
||||
|
||||
let init = Initializer::KaimingUniform {
|
||||
gain: 1.0 / 3.0f64.sqrt(),
|
||||
fan_out_only: true,
|
||||
};
|
||||
let device = Default::default();
|
||||
let config = DeformConv2dConfig::new([4, 4], [1, 1])
|
||||
.with_initializer(init.clone())
|
||||
.with_weight_groups(4);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
|
||||
assert_eq!(config.initializer, init);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic = "Both channels must be divisible by the number of groups."]
|
||||
fn channels_with_groups_is_invalid() {
|
||||
let device = Default::default();
|
||||
let config = DeformConv2dConfig::new([1, 4], [1, 1]).with_weight_groups(4);
|
||||
let _ = config.init::<TestBackend>(&device);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn display() {
|
||||
let config = DeformConv2dConfig::new([5, 1], [5, 5]);
|
||||
let conv = config.init::<TestBackend>(&Default::default());
|
||||
|
||||
assert_eq!(
|
||||
alloc::format!("{}", conv),
|
||||
"DeformConv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], weight_groups: 1, offset_groups: 1, padding: Valid, params: 126}"
|
||||
);
|
||||
}
|
||||
}
|
|
@ -4,6 +4,7 @@ mod conv3d;
|
|||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod conv_transpose3d;
|
||||
mod deform_conv2d;
|
||||
|
||||
pub(crate) mod checks;
|
||||
|
||||
|
@ -13,3 +14,4 @@ pub use conv3d::*;
|
|||
pub use conv_transpose1d::*;
|
||||
pub use conv_transpose2d::*;
|
||||
pub use conv_transpose3d::*;
|
||||
pub use deform_conv2d::*;
|
||||
|
|
|
@ -5,9 +5,9 @@ use burn_tensor::{
|
|||
calculate_conv_output_size, calculate_conv_transpose_output_size,
|
||||
calculate_pool_output_size,
|
||||
},
|
||||
ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions,
|
||||
MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices,
|
||||
ModuleOps,
|
||||
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor,
|
||||
IntTensor, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward,
|
||||
MaxPool2dWithIndices, ModuleOps,
|
||||
},
|
||||
repr::*,
|
||||
Element,
|
||||
|
@ -153,6 +153,211 @@ impl<B: FusionBackend> ModuleOps<Fusion<B>> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: FloatTensor<Self, 4>,
|
||||
offset: FloatTensor<Self, 4>,
|
||||
weight: FloatTensor<Self, 4>,
|
||||
mask: Option<FloatTensor<Self, 4>>,
|
||||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
make_ops!(
|
||||
DeformConv2dOps,
|
||||
DeformConv2dDescription,
|
||||
|args: DeformConv2dDescription, handles: &mut HandleContainer<B::Handle>| {
|
||||
let x = handles.get_float_tensor::<B, 4>(&args.x);
|
||||
let offset = handles.get_float_tensor::<B, 4>(&args.offset);
|
||||
let weight = handles.get_float_tensor::<B, 4>(&args.weight);
|
||||
let mask = args
|
||||
.mask
|
||||
.as_ref()
|
||||
.map(|mask| handles.get_float_tensor::<B, 4>(mask));
|
||||
let bias = args
|
||||
.bias
|
||||
.as_ref()
|
||||
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
|
||||
|
||||
let output =
|
||||
B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into());
|
||||
|
||||
handles.register_float_tensor::<B, 4>(&args.out.id, output);
|
||||
}
|
||||
);
|
||||
|
||||
let size_0 = calculate_conv_output_size(
|
||||
weight.shape[2],
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
x.shape[2],
|
||||
);
|
||||
let size_1 = calculate_conv_output_size(
|
||||
weight.shape[3],
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
x.shape[3],
|
||||
);
|
||||
|
||||
let stream_1 = x.stream;
|
||||
let stream_2 = offset.stream;
|
||||
let stream_3 = weight.stream;
|
||||
let stream_4 = mask.as_ref().map(|m| m.stream);
|
||||
let stream_5 = bias.as_ref().map(|b| b.stream);
|
||||
let shape = vec![x.shape[0], weight.shape[0], size_0, size_1];
|
||||
let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype());
|
||||
|
||||
let desc = DeformConv2dDescription {
|
||||
x: x.into_description(),
|
||||
offset: offset.into_description(),
|
||||
weight: weight.into_description(),
|
||||
mask: mask.map(|mask| mask.into_description()),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
let streams = match (stream_4, stream_5) {
|
||||
(Some(stream_4), Some(stream_5)) => {
|
||||
vec![stream_1, stream_2, stream_3, stream_4, stream_5]
|
||||
}
|
||||
(Some(stream_4), None) => {
|
||||
vec![stream_1, stream_2, stream_3, stream_4]
|
||||
}
|
||||
(None, Some(stream_5)) => {
|
||||
vec![stream_1, stream_2, stream_3, stream_5]
|
||||
}
|
||||
(None, None) => vec![stream_1, stream_2, stream_3],
|
||||
};
|
||||
out.client.register(
|
||||
streams,
|
||||
OperationDescription::Module(ModuleOperationDescription::DeformableConv2d(Box::new(
|
||||
desc.clone(),
|
||||
))),
|
||||
DeformConv2dOps::<B>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
x: FloatTensor<Self, 4>,
|
||||
offset: FloatTensor<Self, 4>,
|
||||
weight: FloatTensor<Self, 4>,
|
||||
mask: Option<FloatTensor<Self, 4>>,
|
||||
bias: Option<FloatTensor<Self, 1>>,
|
||||
output_grad: FloatTensor<Self, 4>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
make_ops!(
|
||||
DeformConv2dBackwardOps,
|
||||
DeformConv2dBackwardDescription,
|
||||
|args: DeformConv2dBackwardDescription, handles: &mut HandleContainer<B::Handle>| {
|
||||
let x = handles.get_float_tensor::<B, 4>(&args.x);
|
||||
let offset = handles.get_float_tensor::<B, 4>(&args.offset);
|
||||
let weight = handles.get_float_tensor::<B, 4>(&args.weight);
|
||||
let mask = args
|
||||
.mask
|
||||
.as_ref()
|
||||
.map(|mask| handles.get_float_tensor::<B, 4>(mask));
|
||||
let bias = args
|
||||
.bias
|
||||
.as_ref()
|
||||
.map(|bias| handles.get_float_tensor::<B, 1>(bias));
|
||||
let output_grad = handles.get_float_tensor::<B, 4>(&args.out_grad);
|
||||
|
||||
let output = B::deform_conv2d_backward(
|
||||
x,
|
||||
offset,
|
||||
weight,
|
||||
mask,
|
||||
bias,
|
||||
output_grad,
|
||||
args.options.clone().into(),
|
||||
);
|
||||
|
||||
handles.register_float_tensor::<B, 4>(&args.input_grad.id, output.x_grad);
|
||||
handles.register_float_tensor::<B, 4>(&args.offset_grad.id, output.offset_grad);
|
||||
handles.register_float_tensor::<B, 4>(&args.weight_grad.id, output.weight_grad);
|
||||
if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) {
|
||||
handles.register_float_tensor::<B, 4>(&field.id, mask_grad);
|
||||
}
|
||||
if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) {
|
||||
handles.register_float_tensor::<B, 1>(&field.id, bias_grad);
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
let input_grad = x
|
||||
.client
|
||||
.tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype());
|
||||
let offset_grad = offset
|
||||
.client
|
||||
.tensor_uninitialized(offset.shape.clone(), B::FloatElem::dtype());
|
||||
let weight_grad = offset
|
||||
.client
|
||||
.tensor_uninitialized(weight.shape.clone(), B::FloatElem::dtype());
|
||||
let mask_grad = mask.as_ref().map(|mask| {
|
||||
offset
|
||||
.client
|
||||
.tensor_uninitialized(mask.shape.clone(), B::FloatElem::dtype())
|
||||
});
|
||||
let bias_grad = bias.as_ref().map(|bias| {
|
||||
offset
|
||||
.client
|
||||
.tensor_uninitialized(bias.shape.clone(), B::FloatElem::dtype())
|
||||
});
|
||||
|
||||
let stream_1 = x.stream;
|
||||
let stream_2 = offset.stream;
|
||||
let stream_3 = weight.stream;
|
||||
let stream_4 = mask.as_ref().map(|m| m.stream);
|
||||
let stream_5 = bias.as_ref().map(|b| b.stream);
|
||||
let stream_6 = output_grad.stream;
|
||||
|
||||
let desc = DeformConv2dBackwardDescription {
|
||||
x: x.into_description(),
|
||||
offset: offset.into_description(),
|
||||
weight: weight.into_description(),
|
||||
mask: mask.map(|mask| mask.into_description()),
|
||||
bias: bias.map(|bias| bias.into_description()),
|
||||
options: options.into(),
|
||||
out_grad: output_grad.into_description(),
|
||||
input_grad: input_grad.to_description_out(),
|
||||
offset_grad: offset_grad.to_description_out(),
|
||||
weight_grad: weight_grad.to_description_out(),
|
||||
mask_grad: mask_grad
|
||||
.as_ref()
|
||||
.map(|mask_grad| mask_grad.to_description_out()),
|
||||
bias_grad: bias_grad
|
||||
.as_ref()
|
||||
.map(|bias_grad| bias_grad.to_description_out()),
|
||||
};
|
||||
|
||||
let streams = match (stream_4, stream_5) {
|
||||
(Some(stream_4), Some(stream_5)) => {
|
||||
vec![stream_1, stream_2, stream_3, stream_4, stream_5, stream_6]
|
||||
}
|
||||
(Some(stream_4), None) => {
|
||||
vec![stream_1, stream_2, stream_3, stream_4, stream_6]
|
||||
}
|
||||
(None, Some(stream_5)) => {
|
||||
vec![stream_1, stream_2, stream_3, stream_5, stream_6]
|
||||
}
|
||||
(None, None) => vec![stream_1, stream_2, stream_3, stream_6],
|
||||
};
|
||||
|
||||
input_grad.client.register(
|
||||
streams,
|
||||
OperationDescription::Module(ModuleOperationDescription::DeformableConv2dBackward(
|
||||
Box::new(desc.clone()),
|
||||
)),
|
||||
DeformConv2dBackwardOps::<B>::new(desc),
|
||||
);
|
||||
|
||||
DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad)
|
||||
}
|
||||
|
||||
fn conv3d(
|
||||
x: FloatTensor<Self, 5>,
|
||||
weight: FloatTensor<Self, 5>,
|
||||
|
|
|
@ -180,6 +180,35 @@ impl RelativeOps for ModuleOperationDescription {
|
|||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
ModuleOperationDescription::DeformableConv2d(desc) => {
|
||||
ModuleOperationDescription::DeformableConv2d(Box::new(DeformConv2dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
offset: desc.offset.to_relative(converter),
|
||||
weight: desc.weight.to_relative(converter),
|
||||
mask: desc.mask.as_ref().map(|t| t.to_relative(converter)),
|
||||
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
|
||||
options: desc.options.clone(),
|
||||
out: desc.out.to_relative(converter),
|
||||
}))
|
||||
}
|
||||
ModuleOperationDescription::DeformableConv2dBackward(desc) => {
|
||||
ModuleOperationDescription::DeformableConv2dBackward(Box::new(
|
||||
DeformConv2dBackwardDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
offset: desc.offset.to_relative(converter),
|
||||
weight: desc.weight.to_relative(converter),
|
||||
mask: desc.mask.as_ref().map(|t| t.to_relative(converter)),
|
||||
bias: desc.bias.as_ref().map(|t| t.to_relative(converter)),
|
||||
out_grad: desc.out_grad.to_relative(converter),
|
||||
options: desc.options.clone(),
|
||||
input_grad: desc.input_grad.to_relative(converter),
|
||||
offset_grad: desc.offset_grad.to_relative(converter),
|
||||
weight_grad: desc.weight_grad.to_relative(converter),
|
||||
mask_grad: desc.mask_grad.as_ref().map(|t| t.to_relative(converter)),
|
||||
bias_grad: desc.bias_grad.as_ref().map(|t| t.to_relative(converter)),
|
||||
},
|
||||
))
|
||||
}
|
||||
ModuleOperationDescription::ConvTranspose1d(desc) => {
|
||||
ModuleOperationDescription::ConvTranspose1d(ConvTranspose1dDescription {
|
||||
x: desc.x.to_relative(converter),
|
||||
|
|
|
@ -0,0 +1,316 @@
|
|||
use cubecl::{calculate_cube_count_elemwise, prelude::*};
|
||||
|
||||
use burn_tensor::{
|
||||
ops::{conv::calculate_conv_output_size, DeformConvOptions, FloatTensorOps as _},
|
||||
Shape,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
kernel::into_contiguous,
|
||||
ops::{
|
||||
numeric::{ones_device, zeros_device},
|
||||
reshape, swap_dims,
|
||||
},
|
||||
tensor::JitTensor,
|
||||
FloatElement, IntElement, JitBackend, JitRuntime,
|
||||
};
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
struct DeformConv2dArgs<F: Float> {
|
||||
conv_stride_h: u32,
|
||||
conv_stride_w: u32,
|
||||
dilation_h: u32,
|
||||
dilation_w: u32,
|
||||
padding_h: F,
|
||||
padding_w: F,
|
||||
offset_groups: u32,
|
||||
|
||||
kernel_height: u32,
|
||||
kernel_width: u32,
|
||||
out_h: u32,
|
||||
out_w: u32,
|
||||
|
||||
col_stride_0: u32,
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn deform_im2col_kernel<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
offset: &Tensor<F>,
|
||||
mask: &Tensor<F>,
|
||||
columns: &mut Tensor<F>,
|
||||
args: &DeformConv2dArgs<F>,
|
||||
#[comptime] kernel_h_unroll: Option<u32>,
|
||||
#[comptime] kernel_w_unroll: Option<u32>,
|
||||
#[comptime] use_mask: bool,
|
||||
) {
|
||||
// position shape: [in_channels, batch_size, out_h, out_w]
|
||||
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
|
||||
|
||||
let kernel_height = kernel_h_unroll.unwrap_or(args.kernel_height);
|
||||
let unroll_h = kernel_h_unroll.is_some();
|
||||
let kernel_width = kernel_w_unroll.unwrap_or(args.kernel_width);
|
||||
let unroll_w = kernel_w_unroll.is_some();
|
||||
|
||||
// Keep mask in bind group
|
||||
let default_mask_value = mask[0];
|
||||
|
||||
let out_h = args.out_h;
|
||||
let out_w = args.out_w;
|
||||
let batch_size = input.shape(0);
|
||||
let in_channels = input.shape(1);
|
||||
let height = input.shape(2);
|
||||
let width = input.shape(3);
|
||||
let col_stride_0 = args.col_stride_0;
|
||||
|
||||
let out_x = ABSOLUTE_POS % out_w;
|
||||
let out_y = (ABSOLUTE_POS / out_w) % out_h;
|
||||
let out_batch = (ABSOLUTE_POS / (out_w * out_h)) % batch_size;
|
||||
let in_channel = ABSOLUTE_POS / (out_w * out_h * batch_size);
|
||||
let out_channel = in_channel * kernel_height * kernel_width;
|
||||
|
||||
let channels_per_offset_group = in_channels / args.offset_groups;
|
||||
let group_index = in_channel / channels_per_offset_group;
|
||||
|
||||
let mut col_base_idx =
|
||||
out_channel * col_stride_0 + out_batch * (out_h * out_w) + out_y * out_w + out_x;
|
||||
|
||||
let input_base_idx = out_batch * input.stride(0) + in_channel * input.stride(1);
|
||||
|
||||
let offset_base_idx = out_batch * offset.stride(0)
|
||||
+ group_index * kernel_height * kernel_width * 2 * out_h * out_w;
|
||||
let mut mask_base_idx = 0;
|
||||
if use_mask {
|
||||
mask_base_idx =
|
||||
out_batch * mask.stride(0) + group_index * kernel_height * kernel_width * out_h * out_w;
|
||||
}
|
||||
|
||||
#[unroll(unroll_h)]
|
||||
for kernel_y in 0..kernel_height {
|
||||
#[unroll(unroll_w)]
|
||||
for kernel_x in 0..kernel_width {
|
||||
let mask_index = kernel_y * kernel_width + kernel_x;
|
||||
let offset_index = mask_index * 2;
|
||||
|
||||
let mut mask_value = default_mask_value;
|
||||
if use_mask {
|
||||
mask_value = mask[mask_base_idx
|
||||
+ mask_index * mask.stride(1)
|
||||
+ out_y * mask.stride(2)
|
||||
+ out_x * mask.stride(3)];
|
||||
}
|
||||
|
||||
let offset_y = offset[offset_base_idx
|
||||
+ offset_index * offset.stride(1)
|
||||
+ out_y * offset.stride(2)
|
||||
+ out_x * offset.stride(3)];
|
||||
let offset_x = offset[offset_base_idx
|
||||
+ (offset_index + 1) * offset.stride(1)
|
||||
+ out_y * offset.stride(2)
|
||||
+ out_x * offset.stride(3)];
|
||||
let y = F::cast_from(out_y * args.conv_stride_h + kernel_y * args.dilation_h)
|
||||
- args.padding_h
|
||||
+ offset_y;
|
||||
let x = F::cast_from(out_x * args.conv_stride_w + kernel_x * args.dilation_w)
|
||||
- args.padding_w
|
||||
+ offset_x;
|
||||
|
||||
let interpolated = bilinear_interpolate(input, height, width, y, x, input_base_idx);
|
||||
|
||||
columns[col_base_idx] = mask_value * interpolated;
|
||||
col_base_idx += col_stride_0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
pub(crate) fn bilinear_interpolate<F: Float>(
|
||||
input: &Tensor<F>,
|
||||
height: u32,
|
||||
width: u32,
|
||||
y: F,
|
||||
x: F,
|
||||
offset: u32,
|
||||
) -> F {
|
||||
// To simplify code
|
||||
let y = f32::cast_from(y);
|
||||
let x = f32::cast_from(x);
|
||||
|
||||
let mut result = F::new(0.0);
|
||||
if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {
|
||||
let in_w = u32::cast_from(width);
|
||||
|
||||
let y_low = f32::floor(y);
|
||||
let x_low = f32::floor(x);
|
||||
let y_high = (y_low + 1.) as u32;
|
||||
let x_high = (x_low + 1.) as u32;
|
||||
|
||||
let zero = F::new(0.0);
|
||||
let v1: F = if y_low >= 0. && x_low >= 0. {
|
||||
input[offset + y_low as u32 * in_w + x_low as u32]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v2: F = if y_low >= 0. && x_high < width {
|
||||
input[offset + y_low as u32 * in_w + x_high]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v3: F = if y_high < height && x_low >= 0. {
|
||||
input[offset + y_high * in_w + x_low as u32]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v4: F = if y_high < height && x_high < width {
|
||||
input[offset + y_high * in_w + x_high]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
|
||||
let l_y = y - y_low;
|
||||
let l_x = x - x_low;
|
||||
let h_y = 1.0 - l_y;
|
||||
let h_x = 1.0 - l_x;
|
||||
|
||||
let w1 = F::cast_from(h_y * h_x);
|
||||
let w2 = F::cast_from(h_y * l_x);
|
||||
let w3 = F::cast_from(l_y * h_x);
|
||||
let w4 = F::cast_from(l_y * l_x);
|
||||
|
||||
result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub(crate) fn deform_im2col<R: JitRuntime, E: FloatElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
offset: JitTensor<R, E, 4>,
|
||||
mask: Option<JitTensor<R, E, 4>>,
|
||||
options: DeformConvOptions<2>,
|
||||
out_dims: (usize, usize),
|
||||
kernel_dims: (usize, usize),
|
||||
) -> JitTensor<R, E, 2> {
|
||||
let client = input.client.clone();
|
||||
let device = input.device.clone();
|
||||
|
||||
let [batch_size, in_channels, _, _] = input.shape.dims;
|
||||
let (out_height, out_width) = out_dims;
|
||||
let (kernel_height, kernel_width) = kernel_dims;
|
||||
|
||||
let shape_out = Shape::new([
|
||||
in_channels * kernel_height * kernel_width,
|
||||
batch_size * out_height * out_width,
|
||||
]);
|
||||
|
||||
let output = zeros_device(client.clone(), device.clone(), shape_out.clone());
|
||||
let use_mask = mask.is_some();
|
||||
let mask = mask.unwrap_or_else(|| {
|
||||
ones_device(
|
||||
client.clone(),
|
||||
device.clone(),
|
||||
Shape::new([
|
||||
offset.shape.dims[0],
|
||||
offset.shape.dims[1] / 2,
|
||||
offset.shape.dims[2],
|
||||
offset.shape.dims[3],
|
||||
]),
|
||||
)
|
||||
});
|
||||
|
||||
let num_kernels = in_channels * batch_size * out_height * out_width;
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count = calculate_cube_count_elemwise(num_kernels, cube_dim);
|
||||
|
||||
deform_im2col_kernel::launch::<E, R>(
|
||||
&input.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
input.as_handle_ref().as_tensor_arg(1),
|
||||
offset.as_handle_ref().as_tensor_arg(1),
|
||||
mask.as_handle_ref().as_tensor_arg(1),
|
||||
output.as_handle_ref().as_tensor_arg(1),
|
||||
DeformConv2dArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0] as u32),
|
||||
ScalarArg::new(options.stride[1] as u32),
|
||||
ScalarArg::new(options.dilation[0] as u32),
|
||||
ScalarArg::new(options.dilation[1] as u32),
|
||||
ScalarArg::new(E::from_elem(options.padding[0] as f32)),
|
||||
ScalarArg::new(E::from_elem(options.padding[1] as f32)),
|
||||
ScalarArg::new(options.offset_groups as u32),
|
||||
ScalarArg::new(kernel_height as u32),
|
||||
ScalarArg::new(kernel_width as u32),
|
||||
ScalarArg::new(out_height as u32),
|
||||
ScalarArg::new(out_width as u32),
|
||||
ScalarArg::new(output.strides[0] as u32),
|
||||
),
|
||||
Some(kernel_height as u32),
|
||||
Some(kernel_width as u32),
|
||||
use_mask,
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub(crate) fn deform_conv2d<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
offset: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
mask: Option<JitTensor<R, E, 4>>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let input = into_contiguous(input);
|
||||
let offset = into_contiguous(offset);
|
||||
let weight = into_contiguous(weight);
|
||||
let mask = mask.map(|it| into_contiguous(it));
|
||||
let bias = bias.map(|it| into_contiguous(it));
|
||||
|
||||
let [batch_size, _, in_height, in_width] = input.shape.dims;
|
||||
let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims;
|
||||
let groups = options.weight_groups;
|
||||
|
||||
let out_h = calculate_conv_output_size(
|
||||
kernel_h,
|
||||
options.stride[0],
|
||||
options.padding[0],
|
||||
options.dilation[0],
|
||||
in_height,
|
||||
);
|
||||
let out_w = calculate_conv_output_size(
|
||||
kernel_w,
|
||||
options.stride[1],
|
||||
options.padding[1],
|
||||
options.dilation[1],
|
||||
in_width,
|
||||
);
|
||||
let out_dims = (out_h, out_w);
|
||||
|
||||
let columns = deform_im2col(input, offset, mask, options, out_dims, (kernel_h, kernel_w));
|
||||
|
||||
let [col_size_0, col_size_1] = columns.shape.dims;
|
||||
let col_size_0 = col_size_0 / groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0]));
|
||||
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
|
||||
let out = JitBackend::<R, E, I>::float_matmul(weight, columns);
|
||||
|
||||
let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w]));
|
||||
let out = swap_dims(out, 0, 1);
|
||||
|
||||
if let Some(bias) = bias {
|
||||
let bias = reshape(bias, Shape::new([1, out_channels, 1, 1]));
|
||||
JitBackend::<R, E, I>::float_add(out, bias)
|
||||
} else {
|
||||
out
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn index<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
tensor: JitTensor<R, E, 3>,
|
||||
index: usize,
|
||||
) -> JitTensor<R, E, 2> {
|
||||
let [_, shape_0, shape_1] = tensor.shape.dims;
|
||||
let tensor = JitBackend::<R, E, I>::float_narrow(tensor, 0, index, 1);
|
||||
reshape(tensor, Shape::new([shape_0, shape_1]))
|
||||
}
|
|
@ -0,0 +1,591 @@
|
|||
use burn_tensor::{
|
||||
ops::{DeformConv2dBackward, DeformConvOptions, FloatTensorOps as _},
|
||||
Shape,
|
||||
};
|
||||
use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch};
|
||||
|
||||
use crate::{
|
||||
kernel::into_contiguous,
|
||||
ops::{
|
||||
numeric::{empty_device, ones_device, zeros_device},
|
||||
reshape, swap_dims,
|
||||
},
|
||||
tensor::JitTensor,
|
||||
FloatElement, IntElement, JitBackend, JitRuntime,
|
||||
};
|
||||
|
||||
use super::{bilinear_interpolate, deform_im2col, index};
|
||||
|
||||
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
|
||||
#[allow(clippy::single_range_in_vec_init)]
|
||||
pub(crate) fn deform_conv2d_backward<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
offset: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
mask: Option<JitTensor<R, E, 4>>,
|
||||
bias: Option<JitTensor<R, E, 1>>,
|
||||
out_grad: JitTensor<R, E, 4>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<JitBackend<R, E, I>> {
|
||||
let [_, _, out_h, out_w] = out_grad.shape.dims;
|
||||
let [_, _, kernel_h, kernel_w] = weight.shape.dims;
|
||||
|
||||
let gradient_bias = bias.map(|bias| {
|
||||
let grad = JitBackend::<R, E, I>::float_sum_dim(out_grad.clone(), 0);
|
||||
let grad = JitBackend::<R, E, I>::float_sum_dim(grad, 2);
|
||||
let grad = JitBackend::<R, E, I>::float_sum_dim(grad, 3);
|
||||
|
||||
reshape(grad, bias.shape)
|
||||
});
|
||||
|
||||
let input = into_contiguous(input);
|
||||
let offset = into_contiguous(offset);
|
||||
let mask = mask.map(|it| into_contiguous(it));
|
||||
|
||||
let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs::<R, E, I>(
|
||||
input.clone(),
|
||||
weight.clone(),
|
||||
offset.clone(),
|
||||
mask.clone(),
|
||||
out_grad.clone(),
|
||||
&options,
|
||||
(kernel_h, kernel_w),
|
||||
);
|
||||
|
||||
let weight_grad = compute_weight_grad::<R, E, I>(
|
||||
input,
|
||||
offset,
|
||||
mask,
|
||||
out_grad,
|
||||
options,
|
||||
(kernel_h, kernel_w),
|
||||
(out_h, out_w),
|
||||
);
|
||||
|
||||
DeformConv2dBackward::new(
|
||||
input_gradient,
|
||||
offset_gradient,
|
||||
weight_grad,
|
||||
mask_gradient,
|
||||
gradient_bias,
|
||||
)
|
||||
}
|
||||
|
||||
fn compute_weight_grad<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
input: JitTensor<R, E, 4>,
|
||||
offset: JitTensor<R, E, 4>,
|
||||
mask: Option<JitTensor<R, E, 4>>,
|
||||
out_grad: JitTensor<R, E, 4>,
|
||||
options: DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
out_dims: (usize, usize),
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let [_, in_channels, _, _] = input.shape.dims;
|
||||
let [_, out_channels, _, _] = out_grad.shape.dims;
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
let groups = options.weight_groups;
|
||||
|
||||
let in_c_per_group = in_channels / groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims);
|
||||
let [col_size_0, col_size_1] = columns.shape.dims;
|
||||
let col_size_0 = col_size_0 / groups;
|
||||
|
||||
let out_grad = swap_dims(out_grad, 0, 1);
|
||||
let out_grad = reshape(out_grad, Shape::new([groups, out_c_per_group, col_size_1]));
|
||||
|
||||
let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1]));
|
||||
let columns = swap_dims(columns, 1, 2);
|
||||
|
||||
let grad_weight = JitBackend::<R, E, I>::float_matmul(out_grad, columns);
|
||||
|
||||
JitBackend::<R, E, I>::float_reshape(
|
||||
grad_weight,
|
||||
Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]),
|
||||
)
|
||||
}
|
||||
|
||||
type InputGradients<R, E> = (
|
||||
JitTensor<R, E, 4>,
|
||||
JitTensor<R, E, 4>,
|
||||
Option<JitTensor<R, E, 4>>,
|
||||
);
|
||||
|
||||
fn backward_gradient_inputs<R: JitRuntime, E: FloatElement, I: IntElement>(
|
||||
image: JitTensor<R, E, 4>,
|
||||
weight: JitTensor<R, E, 4>,
|
||||
offset: JitTensor<R, E, 4>,
|
||||
mask: Option<JitTensor<R, E, 4>>,
|
||||
out_grad: JitTensor<R, E, 4>,
|
||||
options: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
) -> InputGradients<R, E> {
|
||||
let client = out_grad.client.clone();
|
||||
let device = out_grad.device.clone();
|
||||
|
||||
let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape.dims;
|
||||
let [batch_size, _, out_h, out_w] = out_grad.shape.dims;
|
||||
|
||||
let groups = options.weight_groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let col_shape_0 = in_c_per_group * kernel_h * kernel_w;
|
||||
let col_shape_1 = batch_size * out_h * out_w;
|
||||
let col_shape = Shape::new([groups, col_shape_0, col_shape_1]);
|
||||
let mut columns = empty_device(client, device, col_shape);
|
||||
|
||||
let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0]));
|
||||
|
||||
let out_grad = swap_dims(out_grad, 0, 1);
|
||||
let out_grad_shape = Shape::new([groups, out_c_per_group, col_shape_1]);
|
||||
let out_grad = reshape(out_grad, out_grad_shape);
|
||||
|
||||
for group in 0..groups {
|
||||
let weight = swap_dims(index::<R, E, I>(weight.clone(), group), 0, 1);
|
||||
let out_grad = index::<R, E, I>(out_grad.clone(), group);
|
||||
let values = JitBackend::<R, E, I>::float_matmul(weight, out_grad);
|
||||
let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1]));
|
||||
columns = JitBackend::<R, E, I>::float_slice_assign(
|
||||
columns,
|
||||
[group..group + 1, 0..col_shape_0, 0..col_shape_1],
|
||||
values,
|
||||
);
|
||||
}
|
||||
|
||||
let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1]));
|
||||
|
||||
let input_shape = image.shape.clone();
|
||||
let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient::<R, E>(
|
||||
columns.clone(),
|
||||
image,
|
||||
offset.clone(),
|
||||
mask.clone(),
|
||||
options,
|
||||
kernel_dims,
|
||||
);
|
||||
|
||||
let input_gradient =
|
||||
compute_input_grad::<R, E>(columns, offset, mask, options, kernel_dims, input_shape);
|
||||
|
||||
(input_gradient, offset_gradient, mask_gradient)
|
||||
}
|
||||
|
||||
fn compute_offset_and_mask_gradient<R: JitRuntime, E: FloatElement>(
|
||||
columns: JitTensor<R, E, 2>,
|
||||
image: JitTensor<R, E, 4>,
|
||||
offset: JitTensor<R, E, 4>,
|
||||
mask: Option<JitTensor<R, E, 4>>,
|
||||
options: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
) -> (JitTensor<R, E, 4>, Option<JitTensor<R, E, 4>>) {
|
||||
let client = offset.client.clone();
|
||||
let device = offset.device.clone();
|
||||
let (kernel_height, kernel_width) = kernel_dims;
|
||||
|
||||
let use_mask = mask.is_some();
|
||||
|
||||
let mask = mask.unwrap_or_else(|| {
|
||||
ones_device(
|
||||
client.clone(),
|
||||
device.clone(),
|
||||
Shape::new([
|
||||
offset.shape.dims[0],
|
||||
offset.shape.dims[1] / 2,
|
||||
offset.shape.dims[2],
|
||||
offset.shape.dims[3],
|
||||
]),
|
||||
)
|
||||
});
|
||||
|
||||
let grad_offset = empty_device(client.clone(), device.clone(), offset.shape.clone());
|
||||
let grad_mask = empty_device(client.clone(), device.clone(), mask.shape.clone());
|
||||
|
||||
let num_elements_offset = offset.shape.num_elements();
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count = calculate_cube_count_elemwise(num_elements_offset, cube_dim);
|
||||
|
||||
deform_col2img_coord_kernel::launch::<E, R>(
|
||||
&image.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
image.as_handle_ref().as_tensor_arg(1),
|
||||
offset.as_handle_ref().as_tensor_arg(1),
|
||||
mask.as_handle_ref().as_tensor_arg(1),
|
||||
columns.as_handle_ref().as_tensor_arg(1),
|
||||
grad_offset.as_handle_ref().as_tensor_arg(1),
|
||||
grad_mask.as_handle_ref().as_tensor_arg(1),
|
||||
DeformConv2dCol2ImgCoordArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0] as u32),
|
||||
ScalarArg::new(options.stride[1] as u32),
|
||||
ScalarArg::new(options.dilation[0] as u32),
|
||||
ScalarArg::new(options.dilation[1] as u32),
|
||||
ScalarArg::new(E::from_elem(options.padding[0] as f32)),
|
||||
ScalarArg::new(E::from_elem(options.padding[1] as f32)),
|
||||
ScalarArg::new(options.offset_groups as u32),
|
||||
ScalarArg::new(kernel_height as u32),
|
||||
ScalarArg::new(kernel_width as u32),
|
||||
),
|
||||
use_mask,
|
||||
);
|
||||
|
||||
let mask_gradient = if use_mask { Some(grad_mask) } else { None };
|
||||
(grad_offset, mask_gradient)
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
struct DeformConv2dCol2ImgCoordArgs<F: Float> {
|
||||
stride_h: u32,
|
||||
stride_w: u32,
|
||||
dilation_h: u32,
|
||||
dilation_w: u32,
|
||||
pad_h: F,
|
||||
pad_w: F,
|
||||
offset_groups: u32,
|
||||
kernel_height: u32,
|
||||
kernel_width: u32,
|
||||
}
|
||||
|
||||
#[allow(clippy::collapsible_if)]
|
||||
#[cube(launch)]
|
||||
fn deform_col2img_coord_kernel<F: Float>(
|
||||
image: &Tensor<F>,
|
||||
offset: &Tensor<F>,
|
||||
mask: &Tensor<F>,
|
||||
columns: &Tensor<F>,
|
||||
grad_offset: &mut Tensor<F>,
|
||||
grad_mask: &mut Tensor<F>,
|
||||
args: &DeformConv2dCol2ImgCoordArgs<F>,
|
||||
#[comptime] use_mask: bool,
|
||||
) {
|
||||
// Position format: [batch, [offset_group, kernel_h, kernel_w, 2], out_h, out_w]
|
||||
// Alternatively : [batch, offset_channels, out_h, out_w]
|
||||
|
||||
let offset_channels = offset.shape(1);
|
||||
let out_h = offset.shape(2);
|
||||
let out_w = offset.shape(3);
|
||||
let batch_size = image.shape(0);
|
||||
let in_channels = image.shape(1);
|
||||
let height = image.shape(2);
|
||||
let width = image.shape(3);
|
||||
let kernel_w = args.kernel_width;
|
||||
let kernel_h = args.kernel_height;
|
||||
let n_offset_groups = args.offset_groups;
|
||||
let _ = mask[0]; // Make sure mask isn't removed from bind group
|
||||
|
||||
let mut grad_offset_val = F::new(0.0);
|
||||
let mut grad_mask_val = F::new(0.0);
|
||||
|
||||
let w = ABSOLUTE_POS % out_w;
|
||||
let h = (ABSOLUTE_POS / out_w) % out_h;
|
||||
let w_w = (ABSOLUTE_POS / (out_w * out_h * 2)) % kernel_w;
|
||||
let w_h = (ABSOLUTE_POS / (out_w * out_h * 2 * kernel_w)) % kernel_h;
|
||||
let c = (ABSOLUTE_POS / (out_w * out_h)) % offset_channels;
|
||||
let b = ABSOLUTE_POS / (out_w * out_h * offset_channels);
|
||||
|
||||
let offset_group = c / (kernel_h * kernel_w * 2);
|
||||
let col_step = kernel_h * kernel_w;
|
||||
|
||||
let channels_per_offset_group = in_channels / args.offset_groups;
|
||||
|
||||
let col_base_idx =
|
||||
offset_group * channels_per_offset_group * kernel_h * kernel_w * batch_size * out_w * out_h;
|
||||
let mut image_base_idx =
|
||||
(b * n_offset_groups + offset_group) * channels_per_offset_group * height * width;
|
||||
let offset_base_idx =
|
||||
(b * n_offset_groups + offset_group) * 2 * kernel_h * kernel_w * out_h * out_w;
|
||||
let mask_base_idx = (b * n_offset_groups + offset_group) * kernel_h * kernel_w * out_h * out_w;
|
||||
|
||||
let offset_c = c - offset_group * 2 * kernel_h * kernel_w;
|
||||
let is_y_direction = offset_c % 2 == 0;
|
||||
|
||||
let c_bound = channels_per_offset_group * kernel_h * kernel_w;
|
||||
|
||||
for col_c in range_stepped(offset_c / 2, c_bound, col_step) {
|
||||
let col_pos = (((col_c * batch_size + b) * out_h) + h) * out_w + w;
|
||||
|
||||
let out_x = col_pos % out_w;
|
||||
let out_y = (col_pos / out_w) % out_h;
|
||||
let j = (col_pos / (out_w * out_h * batch_size)) % kernel_w;
|
||||
let i = (col_pos / (out_w * out_h * batch_size * kernel_w)) % kernel_h;
|
||||
|
||||
let mask_idx = i * kernel_w + j;
|
||||
let offset_idx = mask_idx * 2;
|
||||
|
||||
let offset_y_idx = (offset_idx * out_h + out_y) * out_w + out_x;
|
||||
let offset_x_idx = ((offset_idx + 1) * out_h + out_y) * out_w + out_x;
|
||||
|
||||
let offset_y = offset[offset_base_idx + offset_y_idx];
|
||||
let offset_x = offset[offset_base_idx + offset_x_idx];
|
||||
|
||||
let mask_value = if use_mask {
|
||||
mask[mask_base_idx + (mask_idx * out_h + out_y) * out_w + out_x]
|
||||
} else {
|
||||
F::new(1.0)
|
||||
};
|
||||
|
||||
let y = F::cast_from(out_y * args.stride_h + i * args.dilation_h) - args.pad_h + offset_y;
|
||||
let x = F::cast_from(out_x * args.stride_w + j * args.dilation_w) - args.pad_w + offset_x;
|
||||
|
||||
let weight = get_coordinate_weight(
|
||||
image.slice(image_base_idx, image.len()),
|
||||
height,
|
||||
width,
|
||||
y,
|
||||
x,
|
||||
is_y_direction,
|
||||
);
|
||||
|
||||
grad_offset_val += mask_value * weight * columns[col_base_idx + col_pos];
|
||||
|
||||
if use_mask {
|
||||
if is_y_direction {
|
||||
grad_mask_val += columns[col_base_idx + col_pos]
|
||||
* bilinear_interpolate(image, height, width, y, x, image_base_idx);
|
||||
}
|
||||
}
|
||||
|
||||
image_base_idx += height * width;
|
||||
}
|
||||
|
||||
grad_offset[ABSOLUTE_POS] = grad_offset_val;
|
||||
|
||||
if use_mask {
|
||||
if is_y_direction {
|
||||
let idx = ((((b * n_offset_groups + offset_group) * kernel_h + w_h) * kernel_w + w_w)
|
||||
* out_h
|
||||
+ h)
|
||||
* out_w
|
||||
+ w;
|
||||
grad_mask[idx] = grad_mask_val
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn get_coordinate_weight<F: Float>(
|
||||
input: &Slice<'_, F>,
|
||||
height: u32,
|
||||
width: u32,
|
||||
y: F,
|
||||
x: F,
|
||||
is_y_direction: bool,
|
||||
) -> F {
|
||||
let stride_y = width;
|
||||
|
||||
let y = f32::cast_from(y);
|
||||
let x = f32::cast_from(x);
|
||||
|
||||
let y_low = f32::floor(y);
|
||||
let x_low = f32::floor(x);
|
||||
let y_high = y_low + 1.;
|
||||
let x_high = x_low + 1.;
|
||||
|
||||
let valid_y_low = y_low >= 0. && y_low < height as f32;
|
||||
let valid_y_high = y_high >= 0. && y_high < height as f32;
|
||||
let valid_x_low = x_low >= 0. && x_low < width as f32;
|
||||
let valid_x_high = x_high >= 0. && x_high < width as f32;
|
||||
|
||||
let bottom_left = if valid_y_low && valid_x_low {
|
||||
input[y_low as u32 * stride_y + x_low as u32]
|
||||
} else {
|
||||
F::new(0.0)
|
||||
};
|
||||
let bottom_right = if valid_y_low && valid_x_high {
|
||||
input[y_low as u32 * stride_y + x_high as u32]
|
||||
} else {
|
||||
F::new(0.0)
|
||||
};
|
||||
let top_left = if valid_y_high && valid_x_low {
|
||||
input[y_high as u32 * stride_y + x_low as u32]
|
||||
} else {
|
||||
F::new(0.0)
|
||||
};
|
||||
let top_right = if valid_y_high && valid_x_high {
|
||||
input[y_high as u32 * stride_y + x_high as u32]
|
||||
} else {
|
||||
F::new(0.0)
|
||||
};
|
||||
|
||||
if is_y_direction {
|
||||
let delta_x = F::cast_from(x - x_low);
|
||||
delta_x * (top_right - bottom_right) + (F::new(1.0) - delta_x) * (top_left - bottom_left)
|
||||
} else {
|
||||
let delta_y = F::cast_from(y - y_low);
|
||||
delta_y * (top_right - top_left) + (F::new(1.0) - delta_y) * (bottom_right - bottom_left)
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_input_grad<R: JitRuntime, E: FloatElement>(
|
||||
columns: JitTensor<R, E, 2>,
|
||||
offset: JitTensor<R, E, 4>,
|
||||
mask: Option<JitTensor<R, E, 4>>,
|
||||
options: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
input_shape: Shape<4>,
|
||||
) -> JitTensor<R, E, 4> {
|
||||
let client = offset.client.clone();
|
||||
let device = offset.device.clone();
|
||||
|
||||
let [batch_size, in_channels, height, width] = input_shape.dims;
|
||||
let (kernel_height, kernel_width) = kernel_dims;
|
||||
|
||||
let grad_in = zeros_device::<R, E, 4>(
|
||||
client.clone(),
|
||||
device.clone(),
|
||||
Shape::new([batch_size, in_channels, height, width]),
|
||||
);
|
||||
|
||||
let use_mask = mask.is_some();
|
||||
let mask = mask
|
||||
.unwrap_or_else(|| ones_device(client.clone(), device.clone(), Shape::new([1, 1, 1, 1])));
|
||||
|
||||
let num_elements = columns.shape.num_elements();
|
||||
let cube_dim = CubeDim::default();
|
||||
let cube_count = calculate_cube_count_elemwise(num_elements, cube_dim);
|
||||
|
||||
deform_col2img_kernel::launch::<E, R>(
|
||||
&offset.client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
offset.as_tensor_arg(1),
|
||||
mask.as_tensor_arg(1),
|
||||
columns.as_tensor_arg(1),
|
||||
grad_in.as_tensor_arg(1),
|
||||
DeformConv2dCol2ImgArgsLaunch::new(
|
||||
ScalarArg::new(options.stride[0] as u32),
|
||||
ScalarArg::new(options.stride[1] as u32),
|
||||
ScalarArg::new(options.dilation[0] as u32),
|
||||
ScalarArg::new(options.dilation[1] as u32),
|
||||
ScalarArg::new(options.padding[0] as f32),
|
||||
ScalarArg::new(options.padding[1] as f32),
|
||||
ScalarArg::new(options.offset_groups as u32),
|
||||
ScalarArg::new(batch_size as u32),
|
||||
ScalarArg::new(in_channels as u32),
|
||||
ScalarArg::new(height as u32),
|
||||
ScalarArg::new(width as u32),
|
||||
ScalarArg::new(kernel_height as u32),
|
||||
ScalarArg::new(kernel_width as u32),
|
||||
),
|
||||
use_mask,
|
||||
);
|
||||
|
||||
grad_in
|
||||
}
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
struct DeformConv2dCol2ImgArgs {
|
||||
stride_h: u32,
|
||||
stride_w: u32,
|
||||
dilation_h: u32,
|
||||
dilation_w: u32,
|
||||
pad_h: f32,
|
||||
pad_w: f32,
|
||||
offset_groups: u32,
|
||||
batch_size: u32,
|
||||
in_channels: u32,
|
||||
height: u32,
|
||||
width: u32,
|
||||
kernel_height: u32,
|
||||
kernel_width: u32,
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
fn deform_col2img_kernel<F: Float>(
|
||||
offset: &Tensor<F>,
|
||||
mask: &Tensor<F>,
|
||||
columns: &Tensor<F>,
|
||||
grad_input: &mut Tensor<AtomicU32>,
|
||||
args: &DeformConv2dCol2ImgArgs,
|
||||
#[comptime] use_mask: bool,
|
||||
) {
|
||||
// Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
|
||||
let _ = mask[0]; // Keep mask in bind group
|
||||
|
||||
let n_in_channels = args.in_channels;
|
||||
let height = args.height;
|
||||
let width = args.width;
|
||||
let out_h = offset.shape(2);
|
||||
let out_w = offset.shape(3);
|
||||
let kernel_h = args.kernel_height;
|
||||
let kernel_w = args.kernel_width;
|
||||
let n_offset_groups = args.offset_groups;
|
||||
let batch_size = args.batch_size;
|
||||
|
||||
let out_x = ABSOLUTE_POS % out_w;
|
||||
let out_y = (ABSOLUTE_POS / out_w) % out_h;
|
||||
let batch = (ABSOLUTE_POS / (out_w * out_h)) % batch_size;
|
||||
let kernel_x = (ABSOLUTE_POS / (out_w * out_h * batch_size)) % kernel_w;
|
||||
let kernel_y = (ABSOLUTE_POS / (out_w * out_h * batch_size * kernel_w)) % kernel_h;
|
||||
let in_channel = ABSOLUTE_POS / (out_w * out_h * batch_size * kernel_w * kernel_h);
|
||||
|
||||
let channels_per_offset_group = n_in_channels / n_offset_groups;
|
||||
let offset_group = in_channel / channels_per_offset_group;
|
||||
|
||||
let offset_base_idx =
|
||||
(batch * n_offset_groups + offset_group) * 2 * kernel_h * kernel_w * out_h * out_w;
|
||||
|
||||
let mask_idx = kernel_y * kernel_w + kernel_x;
|
||||
let offset_idx = mask_idx * 2;
|
||||
|
||||
let offset_y_idx = (offset_idx * out_h + out_y) * out_w + out_x;
|
||||
let offset_x_idx = ((offset_idx + 1) * out_h + out_y) * out_w + out_x;
|
||||
|
||||
let offset_y = f32::cast_from(offset[offset_base_idx + offset_y_idx]);
|
||||
let offset_x = f32::cast_from(offset[offset_base_idx + offset_x_idx]);
|
||||
|
||||
let mask_value = if use_mask {
|
||||
let mask_base_idx =
|
||||
(batch * n_offset_groups + offset_group) * kernel_h * kernel_w * out_h * out_w;
|
||||
|
||||
mask[mask_base_idx + (mask_idx * out_h + out_y) * out_w + out_x]
|
||||
} else {
|
||||
F::new(1.0)
|
||||
};
|
||||
|
||||
let y =
|
||||
f32::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h) - args.pad_h + offset_y;
|
||||
let x =
|
||||
f32::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w) - args.pad_w + offset_x;
|
||||
|
||||
for dy in -1..=1 {
|
||||
#[unroll]
|
||||
for dx in -1..=1 {
|
||||
let yp = f32::floor(y) + dy as f32;
|
||||
let xp = f32::floor(x) + dx as f32;
|
||||
|
||||
if yp >= 0.0
|
||||
&& yp < height as f32
|
||||
&& xp >= 0.0
|
||||
&& xp < width as f32
|
||||
&& f32::abs(y - yp) < 1.0
|
||||
&& f32::abs(x - xp) < 1.0
|
||||
{
|
||||
let gradient_pos =
|
||||
((batch * n_in_channels + in_channel) * height + yp as u32) * width + xp as u32;
|
||||
|
||||
let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp));
|
||||
|
||||
let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS];
|
||||
|
||||
float_atomic_add::<F>(&mut grad_input[gradient_pos], value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn float_atomic_add<F: Float>(ptr: &mut AtomicU32, value: F) {
|
||||
if value != F::new(0.0) {
|
||||
let mut v = AtomicU32::load(ptr);
|
||||
loop {
|
||||
let prev = v;
|
||||
let v_float = F::bitcast_from(v);
|
||||
let new = u32::bitcast_from(v_float + value);
|
||||
v = AtomicU32::compare_and_swap(ptr, v, new);
|
||||
if prev == v {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -2,8 +2,12 @@ mod conv2d;
|
|||
mod conv3d;
|
||||
mod conv_transpose2d;
|
||||
mod conv_transpose3d;
|
||||
mod deform_conv2d;
|
||||
mod deform_conv_transpose2d;
|
||||
|
||||
pub(crate) use conv2d::*;
|
||||
pub(crate) use conv3d::*;
|
||||
pub(crate) use conv_transpose2d::*;
|
||||
pub(crate) use conv_transpose3d::*;
|
||||
pub(crate) use deform_conv2d::*;
|
||||
pub(crate) use deform_conv_transpose2d::*;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
|
||||
use burn_tensor::ops::{
|
||||
ConvOptions, ConvTransposeOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices,
|
||||
ModuleOps,
|
||||
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions,
|
||||
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
};
|
||||
use burn_tensor::ops::{FloatTensor, IntTensor};
|
||||
|
||||
|
@ -20,6 +20,29 @@ where
|
|||
kernel::conv::conv2d(x, weight, bias, options)
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: FloatTensor<Self, 4>,
|
||||
offset: FloatTensor<Self, 4>,
|
||||
weight: FloatTensor<Self, 4>,
|
||||
mask: Option<FloatTensor<Self, 4>>,
|
||||
bias: Option<FloatTensor<Self, 1>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> FloatTensor<Self, 4> {
|
||||
kernel::conv::deform_conv2d::<R, F, I>(x, offset, weight, mask, bias, options)
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
x: FloatTensor<Self, 4>,
|
||||
offset: FloatTensor<Self, 4>,
|
||||
weight: FloatTensor<Self, 4>,
|
||||
mask: Option<FloatTensor<Self, 4>>,
|
||||
bias: Option<FloatTensor<Self, 1>>,
|
||||
output_grad: FloatTensor<Self, 4>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
kernel::conv::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options)
|
||||
}
|
||||
|
||||
fn conv3d(
|
||||
x: FloatTensor<Self, 5>,
|
||||
weight: FloatTensor<Self, 5>,
|
||||
|
|
|
@ -12,6 +12,7 @@ version.workspace = true
|
|||
|
||||
[features]
|
||||
default = ["std"]
|
||||
doc = ["default"]
|
||||
std = [
|
||||
"burn-autodiff",
|
||||
"burn-common/std",
|
||||
|
@ -24,7 +25,6 @@ std = [
|
|||
"rand/std",
|
||||
"num-traits/std",
|
||||
]
|
||||
doc = ["default"]
|
||||
|
||||
blas-accelerate = [
|
||||
"blas-src/accelerate", # Accelerate framework (macOS only)
|
||||
|
@ -46,10 +46,11 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", optional = true
|
|||
burn-common = { path = "../burn-common", version = "0.15.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false }
|
||||
|
||||
matrixmultiply = { workspace = true, default-features = false }
|
||||
atomic_float = { workspace = true }
|
||||
blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible
|
||||
derive-new = { workspace = true }
|
||||
libm = { workspace = true }
|
||||
matrixmultiply = { workspace = true, default-features = false }
|
||||
ndarray = { workspace = true }
|
||||
num-traits = { workspace = true }
|
||||
openblas-src = { workspace = true, optional = true }
|
||||
|
|
|
@ -0,0 +1,656 @@
|
|||
use burn_common::{iter_par, run_par};
|
||||
use burn_tensor::ops::{conv::calculate_conv_output_size, DeformConvOptions};
|
||||
use core::ops::AddAssign;
|
||||
use ndarray::{
|
||||
s, Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim,
|
||||
Ix4,
|
||||
};
|
||||
#[cfg(not(feature = "std"))]
|
||||
use num_traits::Float;
|
||||
|
||||
use crate::{element::QuantElement, FloatNdArrayElement, NdArrayTensor};
|
||||
|
||||
use super::matmul::matmul;
|
||||
|
||||
#[inline(always)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn deform_im2col_kernel<F: FloatNdArrayElement>(
|
||||
out_y: usize,
|
||||
out_x: usize,
|
||||
input: ArrayView2<F>,
|
||||
offset: ArrayView3<F>,
|
||||
mask: Option<ArrayView2<F>>,
|
||||
mut columns: ArrayViewMut2<F>,
|
||||
args: DeformConvOptions<2>,
|
||||
(kernel_h, kernel_w): (usize, usize),
|
||||
) {
|
||||
// position shape: [in_channels, batch_size, out_h, out_w]
|
||||
// columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]]
|
||||
|
||||
let (height, width) = input.dim();
|
||||
|
||||
for kernel_y in 0..kernel_h {
|
||||
for kernel_x in 0..kernel_w {
|
||||
let mask_value = mask
|
||||
.map(|it| it[[kernel_y, kernel_x]])
|
||||
.unwrap_or_else(|| F::from_elem(1.0));
|
||||
|
||||
let offset = offset.slice(s![kernel_y, kernel_x, ..]);
|
||||
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
|
||||
- F::from_elem(args.padding[0])
|
||||
+ offset[0];
|
||||
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
|
||||
- F::from_elem(args.padding[1])
|
||||
+ offset[1];
|
||||
|
||||
let interpolated = bilinear_interpolate(input, height, width, y, x);
|
||||
|
||||
columns[[kernel_y, kernel_x]] = mask_value * interpolated;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn bilinear_interpolate<F: FloatNdArrayElement>(
|
||||
input: ArrayView2<F>,
|
||||
height: usize,
|
||||
width: usize,
|
||||
y: F,
|
||||
x: F,
|
||||
) -> F {
|
||||
// To simplify code
|
||||
let y = y.to_f32();
|
||||
let x = x.to_f32();
|
||||
|
||||
let mut result = F::from_elem(0.0);
|
||||
if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x {
|
||||
let y_low = f32::floor(y);
|
||||
let x_low = f32::floor(x);
|
||||
let y_high = (y_low + 1.) as usize;
|
||||
let x_high = (x_low + 1.) as usize;
|
||||
|
||||
let zero = F::from_elem(0.0);
|
||||
let v1: F = if y_low >= 0. && x_low >= 0. {
|
||||
input[[y_low as usize, x_low as usize]]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v2: F = if y_low >= 0. && x_high < width {
|
||||
input[[y_low as usize, x_high]]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v3: F = if y_high < height && x_low >= 0. {
|
||||
input[[y_high, x_low as usize]]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
let v4: F = if y_high < height && x_high < width {
|
||||
input[[y_high, x_high]]
|
||||
} else {
|
||||
zero
|
||||
};
|
||||
|
||||
let l_y = y - y_low;
|
||||
let l_x = x - x_low;
|
||||
let h_y = 1.0 - l_y;
|
||||
let h_x = 1.0 - l_x;
|
||||
|
||||
let w1 = F::from_elem(h_y * h_x);
|
||||
let w2 = F::from_elem(h_y * l_x);
|
||||
let w3 = F::from_elem(l_y * h_x);
|
||||
let w4 = F::from_elem(l_y * l_x);
|
||||
|
||||
result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub(crate) fn deform_conv2d<F: FloatNdArrayElement>(
|
||||
input: NdArrayTensor<F, 4>,
|
||||
offset: NdArrayTensor<F, 4>,
|
||||
weight: NdArrayTensor<F, 4>,
|
||||
mask: Option<NdArrayTensor<F, 4>>,
|
||||
bias: Option<NdArrayTensor<F, 1>>,
|
||||
args: DeformConvOptions<2>,
|
||||
) -> NdArrayTensor<F, 4> {
|
||||
let [batch_size, _, in_height, in_width] = input.shape().dims;
|
||||
let [out_channels, _, kernel_h, kernel_w] = weight.shape().dims;
|
||||
let groups = args.weight_groups;
|
||||
|
||||
let weight = weight.array.as_standard_layout();
|
||||
|
||||
let out_h = calculate_conv_output_size(
|
||||
kernel_h,
|
||||
args.stride[0],
|
||||
args.padding[0],
|
||||
args.dilation[0],
|
||||
in_height,
|
||||
);
|
||||
let out_w = calculate_conv_output_size(
|
||||
kernel_w,
|
||||
args.stride[1],
|
||||
args.padding[1],
|
||||
args.dilation[1],
|
||||
in_width,
|
||||
);
|
||||
let out_dims = (out_h, out_w);
|
||||
|
||||
let input = input.array.into_dimensionality::<Ix4>().unwrap();
|
||||
let offset = offset.array.into_dimensionality::<Ix4>().unwrap();
|
||||
let mask = mask.as_ref().map(|it| {
|
||||
it.array
|
||||
.to_shape((
|
||||
batch_size,
|
||||
args.offset_groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
out_h,
|
||||
out_w,
|
||||
))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let columns = deform_im2col(
|
||||
input.view(),
|
||||
offset.view(),
|
||||
mask.as_ref().map(|it| it.view()),
|
||||
args,
|
||||
out_dims,
|
||||
(kernel_h, kernel_w),
|
||||
);
|
||||
|
||||
let (col_size_0, col_size_1) = columns.dim();
|
||||
let col_size_0 = col_size_0 / groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let weight = weight
|
||||
.to_shape((groups, out_c_per_group, col_size_0))
|
||||
.unwrap();
|
||||
let columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();
|
||||
let out = matmul(
|
||||
NdArrayTensor::<_, 3>::new(weight.to_owned().into_dyn().into_shared()),
|
||||
NdArrayTensor::<_, 3>::new(columns.to_owned().into_dyn().into_shared()),
|
||||
);
|
||||
|
||||
let mut out = out
|
||||
.array
|
||||
.into_shape_with_order((out_channels, batch_size, out_h, out_w))
|
||||
.unwrap();
|
||||
out.swap_axes(0, 1);
|
||||
|
||||
if let Some(bias) = bias {
|
||||
let bias = bias.array.to_shape((1, out_channels, 1, 1)).unwrap();
|
||||
out.add_assign(&bias);
|
||||
}
|
||||
|
||||
NdArrayTensor::new(out.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
pub(crate) fn deform_im2col<F: FloatNdArrayElement>(
|
||||
input: ArrayView4<F>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
args: DeformConvOptions<2>,
|
||||
out_dims: (usize, usize),
|
||||
kernel_dims: (usize, usize),
|
||||
) -> Array2<F> {
|
||||
let (batch_size, in_channels, _, _) = input.dim();
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
let (out_h, out_w) = out_dims;
|
||||
let channels_per_offset_group = in_channels / args.offset_groups;
|
||||
|
||||
let mut columns = Array4::zeros(Dim([
|
||||
in_channels,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
batch_size * out_h * out_w,
|
||||
]));
|
||||
|
||||
let groups = args.offset_groups;
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(columns.axis_iter_mut(Axis(3)))
|
||||
.enumerate()
|
||||
.for_each(|(index, mut columns)| {
|
||||
let out_x = index % out_w;
|
||||
let out_y = (index / out_w) % out_h;
|
||||
let batch = (index / (out_w * out_h)) % batch_size;
|
||||
let offset = offset.slice(s![batch, .., out_y, out_x]);
|
||||
let offset = offset.to_shape((groups, kernel_h, kernel_w, 2)).unwrap();
|
||||
let mask = mask
|
||||
.as_ref()
|
||||
.map(|it| it.slice(s![batch, .., .., .., out_y, out_x]));
|
||||
columns
|
||||
.axis_iter_mut(Axis(0))
|
||||
.enumerate()
|
||||
.for_each(|(in_channel, mut columns)| {
|
||||
let group_index = in_channel / channels_per_offset_group;
|
||||
deform_im2col_kernel(
|
||||
out_y,
|
||||
out_x,
|
||||
input.slice(s![batch, in_channel, .., ..]),
|
||||
offset.slice(s![group_index, .., .., ..]),
|
||||
mask.as_ref().map(|it| it.slice(s![group_index, .., ..])),
|
||||
columns.view_mut(),
|
||||
args.clone(),
|
||||
kernel_dims,
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
columns
|
||||
// Columns is created here, so we know it's contiguous
|
||||
.into_shape_with_order((
|
||||
in_channels * kernel_h * kernel_w,
|
||||
batch_size * out_h * out_w,
|
||||
))
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
pub mod backward {
|
||||
#[cfg(target_has_atomic = "32")]
|
||||
use core::sync::atomic::Ordering;
|
||||
|
||||
use crate::NdArray;
|
||||
use atomic_float::AtomicF32;
|
||||
use burn_tensor::ops::DeformConv2dBackward;
|
||||
use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4};
|
||||
|
||||
use super::*;
|
||||
|
||||
/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions.
|
||||
pub(crate) fn deform_conv2d_backward<F: FloatNdArrayElement, Q: QuantElement>(
|
||||
input: NdArrayTensor<F, 4>,
|
||||
offset: NdArrayTensor<F, 4>,
|
||||
weight: NdArrayTensor<F, 4>,
|
||||
mask: Option<NdArrayTensor<F, 4>>,
|
||||
bias: Option<NdArrayTensor<F, 1>>,
|
||||
out_grad: NdArrayTensor<F, 4>,
|
||||
args: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<NdArray<F, Q>> {
|
||||
let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims;
|
||||
let [_, _, kernel_h, kernel_w] = weight.shape().dims;
|
||||
let groups = args.weight_groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
let col_shape_1 = batch_size * out_h * out_w;
|
||||
let mut out_grad = out_grad.array.into_dimensionality::<Ix4>().unwrap();
|
||||
|
||||
let gradient_bias = bias.map(|_| {
|
||||
let out_grad = out_grad
|
||||
.clone()
|
||||
.sum_axis(Axis(0))
|
||||
.sum_axis(Axis(1))
|
||||
.sum_axis(Axis(1));
|
||||
|
||||
NdArrayTensor::new(out_grad.into_dyn().into_shared())
|
||||
});
|
||||
|
||||
out_grad.swap_axes(0, 1);
|
||||
let out_grad = out_grad
|
||||
.to_shape((groups, out_c_per_group, col_shape_1))
|
||||
.unwrap();
|
||||
|
||||
let input = input.array.into_dimensionality::<Ix4>().unwrap();
|
||||
let offset = offset.array.into_dimensionality::<Ix4>().unwrap();
|
||||
let mask = mask.map(|it| {
|
||||
it.array
|
||||
.into_shape_with_order((
|
||||
batch_size,
|
||||
args.offset_groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
out_h,
|
||||
out_w,
|
||||
))
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs(
|
||||
input.view(),
|
||||
weight,
|
||||
offset.view(),
|
||||
mask.as_ref().map(|it| it.view()),
|
||||
out_grad.view(),
|
||||
&args,
|
||||
(kernel_h, kernel_w),
|
||||
);
|
||||
|
||||
let weight_grad = compute_weight_grad(
|
||||
input.view(),
|
||||
offset.view(),
|
||||
mask.as_ref().map(|it| it.view()),
|
||||
out_grad.view(),
|
||||
args,
|
||||
(kernel_h, kernel_w),
|
||||
(out_h, out_w),
|
||||
);
|
||||
|
||||
DeformConv2dBackward::new(
|
||||
input_gradient,
|
||||
offset_gradient,
|
||||
weight_grad,
|
||||
mask_gradient,
|
||||
gradient_bias,
|
||||
)
|
||||
}
|
||||
|
||||
fn compute_weight_grad<F: FloatNdArrayElement>(
|
||||
input: ArrayView4<F>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
out_grad: ArrayView3<F>,
|
||||
options: DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
out_dims: (usize, usize),
|
||||
) -> NdArrayTensor<F, 4> {
|
||||
let in_channels = input.dim().1;
|
||||
let (groups, out_c_per_group, _) = out_grad.dim();
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
|
||||
let in_c_per_group = in_channels / groups;
|
||||
|
||||
let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims);
|
||||
let (col_size_0, col_size_1) = columns.dim();
|
||||
let col_size_0 = col_size_0 / groups;
|
||||
|
||||
let mut columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap();
|
||||
columns.swap_axes(1, 2);
|
||||
|
||||
let grad_weight = matmul(
|
||||
NdArrayTensor::<_, 3>::new(out_grad.to_owned().into_dyn().into_shared()),
|
||||
NdArrayTensor::<_, 3>::new(columns.to_owned().into_dyn().into_shared()),
|
||||
);
|
||||
|
||||
let grad_weight = grad_weight
|
||||
.array
|
||||
.into_shape_with_order((out_c_per_group * groups, in_c_per_group, kernel_h, kernel_w))
|
||||
.unwrap();
|
||||
NdArrayTensor::new(grad_weight.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
type InputGradients<F> = (
|
||||
NdArrayTensor<F, 4>,
|
||||
NdArrayTensor<F, 4>,
|
||||
Option<NdArrayTensor<F, 4>>,
|
||||
);
|
||||
|
||||
fn backward_gradient_inputs<F: FloatNdArrayElement>(
|
||||
image: ArrayView4<F>,
|
||||
weight: NdArrayTensor<F, 4>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
out_grad: ArrayView3<F>,
|
||||
args: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
) -> InputGradients<F> {
|
||||
let input_shape = image.dim();
|
||||
let in_channels = input_shape.1;
|
||||
let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape().dims;
|
||||
let (batch_size, _, out_h, out_w) = offset.dim();
|
||||
|
||||
let groups = args.weight_groups;
|
||||
let out_c_per_group = out_channels / groups;
|
||||
|
||||
let col_shape_0 = in_c_per_group * kernel_h * kernel_w;
|
||||
|
||||
let mut weight = weight
|
||||
.array
|
||||
.to_shape((groups, out_c_per_group, col_shape_0))
|
||||
.unwrap();
|
||||
weight.swap_axes(1, 2);
|
||||
let columns = matmul(
|
||||
NdArrayTensor::<_, 3>::new(weight.to_owned().into_dyn().into_shared()),
|
||||
NdArrayTensor::<_, 3>::new(out_grad.to_owned().into_dyn().into_shared()),
|
||||
);
|
||||
|
||||
let columns = columns
|
||||
.array
|
||||
.to_shape((in_channels, kernel_h, kernel_w, batch_size, out_h, out_w))
|
||||
.unwrap();
|
||||
|
||||
let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient(
|
||||
columns.view(),
|
||||
image.view(),
|
||||
offset,
|
||||
mask,
|
||||
args,
|
||||
kernel_dims,
|
||||
);
|
||||
|
||||
let input_gradient =
|
||||
compute_input_grad(columns.view(), offset, mask, args, kernel_dims, input_shape);
|
||||
|
||||
(input_gradient, offset_gradient, mask_gradient)
|
||||
}
|
||||
|
||||
fn compute_offset_and_mask_gradient<F: FloatNdArrayElement>(
|
||||
columns: ArrayView6<F>,
|
||||
image: ArrayView4<F>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
args: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
) -> (NdArrayTensor<F, 4>, Option<NdArrayTensor<F, 4>>) {
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
let (_, in_channels, height, width) = image.dim();
|
||||
let (batch_size, offset_channels, out_h, out_w) = offset.dim();
|
||||
let offs_groups = args.offset_groups;
|
||||
let channels_per_offset_group = in_channels / args.offset_groups;
|
||||
|
||||
let mut grad_offset = Array5::zeros((
|
||||
offs_groups,
|
||||
kernel_h,
|
||||
kernel_w,
|
||||
2,
|
||||
batch_size * out_h * out_w,
|
||||
));
|
||||
let mut grad_mask =
|
||||
Array4::zeros((offs_groups, kernel_h, kernel_w, batch_size * out_h * out_w));
|
||||
|
||||
grad_mask
|
||||
.axis_iter_mut(Axis(3))
|
||||
.zip(grad_offset.axis_iter_mut(Axis(4)))
|
||||
.enumerate()
|
||||
.for_each(|(index, (mut grad_mask, mut grad_offset))| {
|
||||
let out_x = index % out_w;
|
||||
let out_y = (index / out_w) % out_h;
|
||||
let batch = index / (out_w * out_h);
|
||||
let offset = offset.slice(s![batch, .., out_y, out_x]);
|
||||
let offset = offset
|
||||
.to_shape((offs_groups, kernel_h, kernel_w, 2))
|
||||
.unwrap();
|
||||
let mask: Option<ArrayView3<F>> = mask
|
||||
.as_ref()
|
||||
.map(|mask| mask.slice(s![batch, .., .., .., out_y, out_x]));
|
||||
let columns = columns.slice(s![.., .., .., batch, out_y, out_x]);
|
||||
let image = image.slice(s![batch, .., .., ..]);
|
||||
|
||||
for ((group, kernel_y, kernel_x), grad_mask) in grad_mask.indexed_iter_mut() {
|
||||
let grad_mask: &mut F = grad_mask;
|
||||
let mut grad_offset = grad_offset.slice_mut(s![group, kernel_y, kernel_x, ..]);
|
||||
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
|
||||
let mask = mask.map(|it| it[[group, kernel_y, kernel_x]]);
|
||||
let columns = columns.slice(s![.., kernel_y, kernel_x]);
|
||||
let group_offset = group * channels_per_offset_group;
|
||||
let image = image.slice(s![group_offset.., .., ..]);
|
||||
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
|
||||
- F::from_elem(args.padding[0])
|
||||
+ offset[0];
|
||||
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
|
||||
- F::from_elem(args.padding[1])
|
||||
+ offset[1];
|
||||
for (i, grad_offset) in grad_offset.iter_mut().enumerate() {
|
||||
let is_y_direction = i % 2 == 0;
|
||||
let use_mask = mask.is_some();
|
||||
|
||||
for channel in 0..channels_per_offset_group {
|
||||
let mask = mask.unwrap_or_else(|| F::one());
|
||||
let image = image.index_axis(Axis(0), channel);
|
||||
let weight =
|
||||
get_coordinate_weight(image, height, width, y, x, is_y_direction);
|
||||
*grad_offset += mask * weight * columns[channel];
|
||||
if use_mask && is_y_direction {
|
||||
*grad_mask += columns[channel]
|
||||
* bilinear_interpolate(image, height, width, y, x);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let mask_gradient = mask.map(|_| {
|
||||
let mut grad_mask = grad_mask
|
||||
.into_shape_with_order((offset_channels / 2, batch_size, out_h, out_w))
|
||||
.unwrap();
|
||||
grad_mask.swap_axes(0, 1);
|
||||
NdArrayTensor::new(grad_mask.into_dyn().into_shared())
|
||||
});
|
||||
let mut grad_offset = grad_offset
|
||||
.into_shape_with_order((offset_channels, batch_size, out_h, out_w))
|
||||
.unwrap();
|
||||
grad_offset.swap_axes(0, 1);
|
||||
let offset_gradient = NdArrayTensor::new(grad_offset.into_dyn().into_shared());
|
||||
(offset_gradient, mask_gradient)
|
||||
}
|
||||
|
||||
fn get_coordinate_weight<F: FloatNdArrayElement>(
|
||||
input: ArrayView2<F>,
|
||||
height: usize,
|
||||
width: usize,
|
||||
y: F,
|
||||
x: F,
|
||||
is_y_direction: bool,
|
||||
) -> F {
|
||||
let y = y.to_f32();
|
||||
let x = x.to_f32();
|
||||
|
||||
let y_low = f32::floor(y);
|
||||
let x_low = f32::floor(x);
|
||||
let y_high = y_low + 1.;
|
||||
let x_high = x_low + 1.;
|
||||
|
||||
let valid_y_low = y_low >= 0. && y_low < height as f32;
|
||||
let valid_y_high = y_high >= 0. && y_high < height as f32;
|
||||
let valid_x_low = x_low >= 0. && x_low < width as f32;
|
||||
let valid_x_high = x_high >= 0. && x_high < width as f32;
|
||||
|
||||
let bottom_left = if valid_y_low && valid_x_low {
|
||||
input[[y_low as usize, x_low as usize]]
|
||||
} else {
|
||||
F::zero()
|
||||
};
|
||||
let bottom_right = if valid_y_low && valid_x_high {
|
||||
input[[y_low as usize, x_high as usize]]
|
||||
} else {
|
||||
F::zero()
|
||||
};
|
||||
let top_left = if valid_y_high && valid_x_low {
|
||||
input[[y_high as usize, x_low as usize]]
|
||||
} else {
|
||||
F::zero()
|
||||
};
|
||||
let top_right = if valid_y_high && valid_x_high {
|
||||
input[[y_high as usize, x_high as usize]]
|
||||
} else {
|
||||
F::zero()
|
||||
};
|
||||
|
||||
if is_y_direction {
|
||||
let delta_x = F::from_elem(x - x_low);
|
||||
delta_x * (top_right - bottom_right) + (F::one() - delta_x) * (top_left - bottom_left)
|
||||
} else {
|
||||
let delta_y = F::from_elem(y - y_low);
|
||||
delta_y * (top_right - top_left) + (F::one() - delta_y) * (bottom_right - bottom_left)
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_input_grad<F: FloatNdArrayElement>(
|
||||
columns: ArrayView6<F>,
|
||||
offset: ArrayView4<F>,
|
||||
mask: Option<ArrayView6<F>>,
|
||||
args: &DeformConvOptions<2>,
|
||||
kernel_dims: (usize, usize),
|
||||
input_shape: (usize, usize, usize, usize),
|
||||
) -> NdArrayTensor<F, 4> {
|
||||
let (batch_size, in_channels, height, width) = input_shape;
|
||||
let (kernel_h, kernel_w) = kernel_dims;
|
||||
let offs_groups = args.offset_groups;
|
||||
let channels_per_offset_group = in_channels / offs_groups;
|
||||
|
||||
let grad_in =
|
||||
Array4::from_shape_simple_fn((batch_size, in_channels, height, width), || {
|
||||
AtomicF32::new(0.0)
|
||||
});
|
||||
|
||||
run_par!(|| {
|
||||
iter_par!(columns.indexed_iter()).for_each(
|
||||
|((in_channel, kernel_y, kernel_x, batch, out_y, out_x), col)| {
|
||||
let group = in_channel / channels_per_offset_group;
|
||||
let offset = offset.slice(s![batch, .., out_y, out_x]);
|
||||
let offset = offset
|
||||
.to_shape((offs_groups, kernel_h, kernel_w, 2))
|
||||
.unwrap();
|
||||
let offset = offset.slice(s![group, kernel_y, kernel_x, ..]);
|
||||
let offset = [offset[0], offset[1]];
|
||||
let mask = mask
|
||||
.as_ref()
|
||||
.map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32());
|
||||
let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0])
|
||||
- F::from_elem(args.padding[0])
|
||||
+ offset[0];
|
||||
let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1])
|
||||
- F::from_elem(args.padding[1])
|
||||
+ offset[1];
|
||||
let grad_in = grad_in.slice(s![batch, in_channel, .., ..]);
|
||||
deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in);
|
||||
},
|
||||
)
|
||||
});
|
||||
|
||||
let grad_in: Array1<F> = grad_in
|
||||
.into_iter()
|
||||
.map(|it| F::from_elem(it.into_inner()))
|
||||
.collect();
|
||||
let grad_in = grad_in
|
||||
.into_shape_with_order((batch_size, in_channels, height, width))
|
||||
.unwrap();
|
||||
NdArrayTensor::new(grad_in.into_dyn().into_shared())
|
||||
}
|
||||
|
||||
fn deform_col2img_kernel(
|
||||
y: f32,
|
||||
x: f32,
|
||||
mask: Option<f32>,
|
||||
col: f32,
|
||||
grad_input: ArrayView2<AtomicF32>,
|
||||
) {
|
||||
let (height, width) = grad_input.dim();
|
||||
let mask_value = mask.unwrap_or(1.0);
|
||||
|
||||
for dy in -1..=1 {
|
||||
for dx in -1..=1 {
|
||||
let yp = f32::floor(y) + dy as f32;
|
||||
let xp = f32::floor(x) + dx as f32;
|
||||
|
||||
if yp >= 0.0
|
||||
&& yp < height as f32
|
||||
&& xp >= 0.0
|
||||
&& xp < width as f32
|
||||
&& f32::abs(y - yp) < 1.0
|
||||
&& f32::abs(x - xp) < 1.0
|
||||
{
|
||||
let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp));
|
||||
|
||||
#[cfg_attr(not(target_has_atomic = "32"), allow(unused))]
|
||||
let value = mask_value * weight * col;
|
||||
|
||||
#[cfg(target_has_atomic = "32")]
|
||||
grad_input[[yp as usize, xp as usize]].fetch_add(value, Ordering::AcqRel);
|
||||
#[cfg(not(target_has_atomic = "32"))]
|
||||
panic!("Can't use deformable convolution backwards pass without atomics");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ mod tensor;
|
|||
pub(crate) mod adaptive_avgpool;
|
||||
pub(crate) mod avgpool;
|
||||
pub(crate) mod conv;
|
||||
pub(crate) mod deform_conv;
|
||||
pub(crate) mod interpolate;
|
||||
pub(crate) mod macros;
|
||||
pub(crate) mod matmul;
|
||||
|
|
|
@ -2,6 +2,7 @@ use super::{
|
|||
adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward},
|
||||
avgpool::{avg_pool2d, avg_pool2d_backward},
|
||||
conv::{conv2d, conv3d, conv_transpose2d, conv_transpose3d},
|
||||
deform_conv::{backward::deform_conv2d_backward, deform_conv2d},
|
||||
interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate},
|
||||
maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices},
|
||||
};
|
||||
|
@ -19,6 +20,29 @@ impl<E: FloatNdArrayElement, Q: QuantElement> ModuleOps<Self> for NdArray<E, Q>
|
|||
conv2d::<E, Q>(x, weight, bias, options)
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
offset: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
mask: Option<NdArrayTensor<E, 4>>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> NdArrayTensor<E, 4> {
|
||||
deform_conv2d::<E>(x, offset, weight, mask, bias, options)
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
offset: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
mask: Option<NdArrayTensor<E, 4>>,
|
||||
bias: Option<NdArrayTensor<E, 1>>,
|
||||
output_grad: NdArrayTensor<E, 4>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options)
|
||||
}
|
||||
|
||||
fn conv_transpose2d(
|
||||
x: NdArrayTensor<E, 4>,
|
||||
weight: NdArrayTensor<E, 4>,
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{element::TchElement, LibTorch, QuantElement, TchTensor};
|
||||
use burn_tensor::ops::{
|
||||
ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices,
|
||||
MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateMode,
|
||||
InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps,
|
||||
};
|
||||
|
||||
impl<E: TchElement, Q: QuantElement> ModuleOps<Self> for LibTorch<E, Q> {
|
||||
|
@ -86,6 +86,29 @@ impl<E: TchElement, Q: QuantElement> ModuleOps<Self> for LibTorch<E, Q> {
|
|||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
fn deform_conv2d(
|
||||
_x: TchTensor<E, 4>,
|
||||
_offset: TchTensor<E, 4>,
|
||||
_weight: TchTensor<E, 4>,
|
||||
_mask: Option<TchTensor<E, 4>>,
|
||||
_bias: Option<TchTensor<E, 1>>,
|
||||
_options: DeformConvOptions<2>,
|
||||
) -> TchTensor<E, 4> {
|
||||
unimplemented!("Torch bindings don't support deform_conv2d");
|
||||
}
|
||||
|
||||
fn deform_conv2d_backward(
|
||||
_x: TchTensor<E, 4>,
|
||||
_offset: TchTensor<E, 4>,
|
||||
_weight: TchTensor<E, 4>,
|
||||
_mask: Option<TchTensor<E, 4>>,
|
||||
_bias: Option<TchTensor<E, 1>>,
|
||||
_out_grad: TchTensor<E, 4>,
|
||||
_options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<Self> {
|
||||
unimplemented!("Torch bindings don't support deform_conv2d");
|
||||
}
|
||||
|
||||
fn conv_transpose1d(
|
||||
x: TchTensor<E, 3>,
|
||||
weight: TchTensor<E, 3>,
|
||||
|
|
|
@ -2,7 +2,9 @@ use serde::{Deserialize, Serialize};
|
|||
use std::ops::Range;
|
||||
|
||||
use crate::{
|
||||
ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions},
|
||||
ops::{
|
||||
ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions,
|
||||
},
|
||||
repr::tensor::TensorDescription,
|
||||
DType, Distribution, Element,
|
||||
};
|
||||
|
@ -74,6 +76,10 @@ pub enum ModuleOperationDescription {
|
|||
Conv2d(Conv2dDescription),
|
||||
/// Operation corresponding to [conv3d](crate::ops::ModuleOps::conv3d).
|
||||
Conv3d(Conv3dDescription),
|
||||
/// Operation corresponding to [deform_conv2d](crate::ops::ModuleOps::deform_conv2d)
|
||||
DeformableConv2d(Box<DeformConv2dDescription>),
|
||||
/// Operation corresponding to [deform_conv2d_backward](crate::ops::ModuleOps::deform_conv2d_backward)
|
||||
DeformableConv2dBackward(Box<DeformConv2dBackwardDescription>),
|
||||
/// Operation corresponding to [conv transpose 1d](crate::ops::ModuleOps::conv_transpose1d).
|
||||
ConvTranspose1d(ConvTranspose1dDescription),
|
||||
/// Operation corresponding to [conv transpose 2d](crate::ops::ModuleOps::conv_transpose2d).
|
||||
|
@ -688,6 +694,35 @@ pub struct Conv2dDescription {
|
|||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct DeformConv2dDescription {
|
||||
pub x: TensorDescription,
|
||||
pub offset: TensorDescription,
|
||||
pub weight: TensorDescription,
|
||||
pub mask: Option<TensorDescription>,
|
||||
pub bias: Option<TensorDescription>,
|
||||
pub options: DeformableConv2dOptionsDescription,
|
||||
pub out: TensorDescription,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct DeformConv2dBackwardDescription {
|
||||
pub x: TensorDescription,
|
||||
pub offset: TensorDescription,
|
||||
pub weight: TensorDescription,
|
||||
pub mask: Option<TensorDescription>,
|
||||
pub bias: Option<TensorDescription>,
|
||||
pub out_grad: TensorDescription,
|
||||
pub options: DeformableConv2dOptionsDescription,
|
||||
pub input_grad: TensorDescription,
|
||||
pub offset_grad: TensorDescription,
|
||||
pub weight_grad: TensorDescription,
|
||||
pub mask_grad: Option<TensorDescription>,
|
||||
pub bias_grad: Option<TensorDescription>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct Conv3dDescription {
|
||||
|
@ -746,6 +781,16 @@ pub struct Conv2dOptionsDescription {
|
|||
pub groups: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct DeformableConv2dOptionsDescription {
|
||||
pub stride: [usize; 2],
|
||||
pub padding: [usize; 2],
|
||||
pub dilation: [usize; 2],
|
||||
pub weight_groups: usize,
|
||||
pub offset_groups: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct Conv3dOptionsDescription {
|
||||
|
@ -818,6 +863,18 @@ impl From<ConvOptions<3>> for Conv3dOptionsDescription {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<DeformConvOptions<2>> for DeformableConv2dOptionsDescription {
|
||||
fn from(value: DeformConvOptions<2>) -> Self {
|
||||
Self {
|
||||
stride: value.stride,
|
||||
padding: value.padding,
|
||||
dilation: value.dilation,
|
||||
weight_groups: value.weight_groups,
|
||||
offset_groups: value.offset_groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ConvTransposeOptions<1>> for ConvTranspose1dOptionsDescription {
|
||||
fn from(value: ConvTransposeOptions<1>) -> Self {
|
||||
Self {
|
||||
|
@ -887,6 +944,18 @@ impl From<Conv3dOptionsDescription> for ConvOptions<3> {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<DeformableConv2dOptionsDescription> for DeformConvOptions<2> {
|
||||
fn from(value: DeformableConv2dOptionsDescription) -> Self {
|
||||
DeformConvOptions {
|
||||
stride: value.stride,
|
||||
padding: value.padding,
|
||||
dilation: value.dilation,
|
||||
weight_groups: value.weight_groups,
|
||||
offset_groups: value.offset_groups,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<ConvTranspose1dOptionsDescription> for ConvTransposeOptions<1> {
|
||||
fn from(val: ConvTranspose1dOptionsDescription) -> Self {
|
||||
ConvTransposeOptions {
|
||||
|
@ -1404,6 +1473,22 @@ impl ModuleOperationDescription {
|
|||
vec![&desc.x, &desc.weight, &desc.out]
|
||||
}
|
||||
}
|
||||
ModuleOperationDescription::DeformableConv2d(desc) => match (&desc.mask, &desc.bias) {
|
||||
(Some(mask), Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias],
|
||||
(Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask],
|
||||
(None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias],
|
||||
(None, None) => vec![&desc.x, &desc.offset, &desc.weight],
|
||||
},
|
||||
ModuleOperationDescription::DeformableConv2dBackward(desc) => {
|
||||
match (&desc.mask, &desc.bias) {
|
||||
(Some(mask), Some(bias)) => {
|
||||
vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias]
|
||||
}
|
||||
(Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask],
|
||||
(None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias],
|
||||
(None, None) => vec![&desc.x, &desc.offset, &desc.weight],
|
||||
}
|
||||
}
|
||||
ModuleOperationDescription::ConvTranspose1d(desc) => {
|
||||
if let Some(bias) = &desc.bias {
|
||||
vec![&desc.x, &desc.weight, &bias, &desc.out]
|
||||
|
|
|
@ -4,6 +4,8 @@ use crate::{
|
|||
Int, Tensor, TensorPrimitive,
|
||||
};
|
||||
|
||||
use super::ops::DeformConvOptions;
|
||||
|
||||
/// Applies the [embedding module](crate::ops::ModuleOps::embedding).
|
||||
pub fn embedding<B>(weights: Tensor<B, 2>, indices: Tensor<B, 2, Int>) -> Tensor<B, 3>
|
||||
where
|
||||
|
@ -69,6 +71,28 @@ where
|
|||
)))
|
||||
}
|
||||
|
||||
/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d).
|
||||
pub fn deform_conv2d<B>(
|
||||
x: Tensor<B, 4>,
|
||||
offset: Tensor<B, 4>,
|
||||
weight: Tensor<B, 4>,
|
||||
mask: Option<Tensor<B, 4>>,
|
||||
bias: Option<Tensor<B, 1>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> Tensor<B, 4>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
Tensor::new(TensorPrimitive::Float(B::deform_conv2d(
|
||||
x.primitive.tensor(),
|
||||
offset.primitive.tensor(),
|
||||
weight.primitive.tensor(),
|
||||
mask.map(|m| m.primitive.tensor()),
|
||||
bias.map(|b| b.primitive.tensor()),
|
||||
options,
|
||||
)))
|
||||
}
|
||||
|
||||
/// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d).
|
||||
pub fn conv_transpose1d<B>(
|
||||
x: Tensor<B, 3>,
|
||||
|
|
|
@ -5,6 +5,51 @@ use crate::{
|
|||
Shape,
|
||||
};
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
|
||||
#[derive(new)]
|
||||
pub struct Conv2dBackward<B: Backend> {
|
||||
/// Gradient.
|
||||
pub x_grad: FloatTensor<B, 4>,
|
||||
|
||||
/// Weights gradient.
|
||||
pub weights_grad: FloatTensor<B, 4>,
|
||||
|
||||
/// Bias gradient.
|
||||
pub bias_grad: Option<FloatTensor<B, 1>>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d).
|
||||
#[derive(new)]
|
||||
pub struct DeformConv2dBackward<B: Backend> {
|
||||
/// Gradient.
|
||||
pub x_grad: FloatTensor<B, 4>,
|
||||
|
||||
/// Offset gradient.
|
||||
pub offset_grad: FloatTensor<B, 4>,
|
||||
|
||||
/// Weights gradient.
|
||||
pub weight_grad: FloatTensor<B, 4>,
|
||||
|
||||
/// Mask gradient.
|
||||
pub mask_grad: Option<FloatTensor<B, 4>>,
|
||||
|
||||
/// Bias gradient.
|
||||
pub bias_grad: Option<FloatTensor<B, 1>>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d).
|
||||
#[derive(new)]
|
||||
pub struct Conv3dBackward<B: Backend> {
|
||||
/// Gradient.
|
||||
pub x_grad: FloatTensor<B, 5>,
|
||||
|
||||
/// Weights gradient.
|
||||
pub weights_grad: FloatTensor<B, 5>,
|
||||
|
||||
/// Bias gradient.
|
||||
pub bias_grad: Option<FloatTensor<B, 1>>,
|
||||
}
|
||||
|
||||
/// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d).
|
||||
#[derive(new)]
|
||||
pub struct MaxPool1dBackward<B: Backend> {
|
||||
|
@ -55,6 +100,25 @@ pub struct ConvOptions<const N: usize> {
|
|||
pub groups: usize,
|
||||
}
|
||||
|
||||
/// Convolution options.
|
||||
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct DeformConvOptions<const N: usize> {
|
||||
/// Stride.
|
||||
pub stride: [usize; N],
|
||||
|
||||
/// Padding.
|
||||
pub padding: [usize; N],
|
||||
|
||||
/// Dilation.
|
||||
pub dilation: [usize; N],
|
||||
|
||||
/// Weight Groups.
|
||||
pub weight_groups: usize,
|
||||
|
||||
/// Offset Groups.
|
||||
pub offset_groups: usize,
|
||||
}
|
||||
|
||||
/// Transposed convolution options.
|
||||
#[derive(new, Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct ConvTransposeOptions<const N: usize> {
|
||||
|
@ -248,6 +312,33 @@ pub trait ModuleOps<B: Backend> {
|
|||
) -> FloatTensor<B, 1> {
|
||||
conv::conv2d_bias_backward::<B>(x, weight, bias, output_grad)
|
||||
}
|
||||
|
||||
/// Two dimensional deformable convolution.
|
||||
///
|
||||
/// # Shapes
|
||||
///
|
||||
/// x: `[batch_size, channels_in, height, width]`,
|
||||
/// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`,
|
||||
/// bias: `[channels_out]`,
|
||||
fn deform_conv2d(
|
||||
x: FloatTensor<B, 4>,
|
||||
offset: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
mask: Option<FloatTensor<B, 4>>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> FloatTensor<B, 4>;
|
||||
/// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation.
|
||||
fn deform_conv2d_backward(
|
||||
x: FloatTensor<B, 4>,
|
||||
offset: FloatTensor<B, 4>,
|
||||
weight: FloatTensor<B, 4>,
|
||||
mask: Option<FloatTensor<B, 4>>,
|
||||
bias: Option<FloatTensor<B, 1>>,
|
||||
output_grad: FloatTensor<B, 4>,
|
||||
options: DeformConvOptions<2>,
|
||||
) -> DeformConv2dBackward<B>;
|
||||
|
||||
/// Three dimensional convolution.
|
||||
///
|
||||
/// # Shapes
|
||||
|
|
|
@ -26,6 +26,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_module_conv1d!();
|
||||
burn_tensor::testgen_module_conv2d!();
|
||||
burn_tensor::testgen_module_conv3d!();
|
||||
burn_tensor::testgen_module_deform_conv2d!();
|
||||
burn_tensor::testgen_module_conv_transpose1d!();
|
||||
burn_tensor::testgen_module_conv_transpose2d!();
|
||||
burn_tensor::testgen_module_conv_transpose3d!();
|
||||
|
|
|
@ -0,0 +1,439 @@
|
|||
#[burn_tensor_testgen::testgen(module_deform_conv2d)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use burn_tensor::module::deform_conv2d;
|
||||
use burn_tensor::ops::{DeformConv2dBackward, DeformConvOptions, ModuleOps};
|
||||
use burn_tensor::{Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_simple() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 5,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
[[0.9074, 0.6387], [0.5160, 0.4196]],
|
||||
[[2.4259, 1.8008], [1.5449, 1.3112]],
|
||||
[[3.9444, 2.9629], [2.5738, 2.2027]],
|
||||
[[5.4629, 4.1250], [3.6027, 3.0943]],
|
||||
[[6.9814, 5.2871], [4.6316, 3.9859]],
|
||||
]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_batched() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 2,
|
||||
channels_in: 3,
|
||||
channels_out: 5,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([
|
||||
[
|
||||
[[0.2155, 0.1928], [0.1934, 0.1755]],
|
||||
[[0.7251, 0.6759], [0.6877, 0.6485]],
|
||||
[[1.2347, 1.1590], [1.1821, 1.1215]],
|
||||
[[1.7443, 1.6421], [1.6764, 1.5945]],
|
||||
[[2.2539, 2.1252], [2.1708, 2.0675]],
|
||||
],
|
||||
[
|
||||
[[1.6530, 1.1369], [0.9840, 0.7184]],
|
||||
[[4.8368, 3.4725], [3.1773, 2.4180]],
|
||||
[[8.0206, 5.8080], [5.3705, 4.1176]],
|
||||
[[11.2045, 8.1435], [7.5637, 5.8173]],
|
||||
[[14.3883, 10.4790], [9.7570, 7.5169]],
|
||||
],
|
||||
]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_weight_groups() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 6,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 3,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
[[0.1018, 0.0658], [0.0467, 0.0362]],
|
||||
[[0.4125, 0.3367], [0.3069, 0.2824]],
|
||||
[[1.3076, 1.0242], [0.9025, 0.8000]],
|
||||
[[1.8405, 1.4581], [1.2994, 1.1588]],
|
||||
[[3.4022, 2.6346], [2.3052, 2.0143]],
|
||||
[[4.1574, 3.2315], [2.8389, 2.4857]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_offset_groups() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 3,
|
||||
channels_out: 6,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 3,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
[[1.0794, 0.7676], [0.7209, 0.5337]],
|
||||
[[2.7059, 2.0216], [1.9740, 1.5419]],
|
||||
[[4.3325, 3.2755], [3.2271, 2.5501]],
|
||||
[[5.9590, 4.5295], [4.4802, 3.5582]],
|
||||
[[7.5855, 5.7835], [5.7333, 4.5664]],
|
||||
[[9.2120, 7.0375], [6.9864, 5.5746]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_kernel_size() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 4,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
[[1.0669], [0.6329]],
|
||||
[[2.9741], [2.0383]],
|
||||
[[4.8812], [3.4437]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_padding_size() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 2,
|
||||
padding_2: 3,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
[
|
||||
[
|
||||
0.1998, 0.3762, 0.5285, 0.6053, 0.3844, 0.1987, 0.0481, 0.0000,
|
||||
],
|
||||
[
|
||||
0.2879, 0.5517, 0.7776, 0.8905, 0.5805, 0.3043, 0.0796, 0.0000,
|
||||
],
|
||||
[
|
||||
0.3729, 0.7214, 1.0137, 1.1520, 0.7564, 0.3931, 0.1016, 0.0000,
|
||||
],
|
||||
[
|
||||
0.1321, 0.3249, 0.4954, 0.5846, 0.4531, 0.2501, 0.0757, 0.0000,
|
||||
],
|
||||
[
|
||||
0.0593, 0.1607, 0.2448, 0.2971, 0.2395, 0.1327, 0.0471, 0.0000,
|
||||
],
|
||||
[
|
||||
0.0143, 0.0513, 0.0783, 0.0942, 0.0813, 0.0420, 0.0145, 0.0000,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
0.7667, 1.1648, 1.5219, 1.7111, 1.2305, 0.8076, 0.4504, 0.3333,
|
||||
],
|
||||
[
|
||||
0.9812, 1.6010, 2.1525, 2.4409, 1.7455, 1.0918, 0.5367, 0.3333,
|
||||
],
|
||||
[
|
||||
1.1964, 2.0448, 2.7853, 3.1522, 2.2426, 1.3513, 0.6049, 0.3333,
|
||||
],
|
||||
[
|
||||
0.6695, 1.1781, 1.6441, 1.9022, 1.5732, 1.0339, 0.5536, 0.3333,
|
||||
],
|
||||
[
|
||||
0.4950, 0.7861, 1.0398, 1.2047, 1.0523, 0.7439, 0.4834, 0.3333,
|
||||
],
|
||||
[
|
||||
0.3788, 0.4982, 0.5929, 0.6542, 0.6155, 0.4882, 0.3909, 0.3333,
|
||||
],
|
||||
],
|
||||
[
|
||||
[
|
||||
1.3335, 1.9534, 2.5154, 2.8170, 2.0766, 1.4165, 0.8527, 0.6667,
|
||||
],
|
||||
[
|
||||
1.6744, 2.6503, 3.5275, 3.9914, 2.9106, 1.8794, 0.9939, 0.6667,
|
||||
],
|
||||
[
|
||||
2.0198, 3.3683, 4.5570, 5.1525, 3.7288, 2.3095, 1.1082, 0.6667,
|
||||
],
|
||||
[
|
||||
1.2068, 2.0314, 2.7928, 3.2198, 2.6932, 1.8178, 1.0315, 0.6667,
|
||||
],
|
||||
[
|
||||
0.9308, 1.4116, 1.8348, 2.1124, 1.8652, 1.3551, 0.9196, 0.6667,
|
||||
],
|
||||
[
|
||||
0.7432, 0.9451, 1.1074, 1.2143, 1.1497, 0.9345, 0.7673, 0.6667,
|
||||
],
|
||||
],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_stride() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 2,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 4,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
[[1.0647], [0.5783]],
|
||||
[[2.9289], [1.8829]],
|
||||
[[4.7931], [3.1875]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_dilation() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 2,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 5,
|
||||
width: 5,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
[[0.6162], [0.7611], [0.4666]],
|
||||
[[1.8578], [2.2684], [1.6208]],
|
||||
[[3.0994], [3.7757], [2.7749]],
|
||||
]]))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deform_conv2d_different_width() {
|
||||
let test = DeformConv2dTestCase {
|
||||
batch_size: 1,
|
||||
channels_in: 2,
|
||||
channels_out: 3,
|
||||
kernel_size_1: 3,
|
||||
kernel_size_2: 3,
|
||||
padding_1: 0,
|
||||
padding_2: 0,
|
||||
stride_1: 1,
|
||||
stride_2: 1,
|
||||
dilation_1: 1,
|
||||
dilation_2: 1,
|
||||
weight_groups: 1,
|
||||
offset_groups: 1,
|
||||
height: 6,
|
||||
width: 4,
|
||||
};
|
||||
|
||||
test.assert_output(Tensor::<TestBackend, 4>::from([[
|
||||
[
|
||||
[0.8909, 0.6016],
|
||||
[1.0697, 0.7186],
|
||||
[1.2618, 0.8433],
|
||||
[0.6424, 0.5032],
|
||||
],
|
||||
[
|
||||
[2.4670, 1.8168],
|
||||
[2.9529, 2.1497],
|
||||
[3.4805, 2.5090],
|
||||
[2.0925, 1.7411],
|
||||
],
|
||||
[
|
||||
[4.0432, 3.0321],
|
||||
[4.8362, 3.5809],
|
||||
[5.6992, 4.1746],
|
||||
[3.5425, 2.9790],
|
||||
],
|
||||
]]))
|
||||
}
|
||||
|
||||
struct DeformConv2dTestCase {
|
||||
batch_size: usize,
|
||||
channels_in: usize,
|
||||
channels_out: usize,
|
||||
kernel_size_1: usize,
|
||||
kernel_size_2: usize,
|
||||
padding_1: usize,
|
||||
padding_2: usize,
|
||||
stride_1: usize,
|
||||
stride_2: usize,
|
||||
dilation_1: usize,
|
||||
dilation_2: usize,
|
||||
weight_groups: usize,
|
||||
offset_groups: usize,
|
||||
height: usize,
|
||||
width: usize,
|
||||
}
|
||||
|
||||
impl DeformConv2dTestCase {
|
||||
fn assert_output(self, y: Tensor<TestBackend, 4>) {
|
||||
let out_height =
|
||||
(self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1)
|
||||
/ self.stride_1
|
||||
+ 1;
|
||||
let out_width =
|
||||
(self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1)
|
||||
/ self.stride_2
|
||||
+ 1;
|
||||
|
||||
let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]);
|
||||
let shape_weight = Shape::new([
|
||||
self.channels_out,
|
||||
self.channels_in / self.weight_groups,
|
||||
self.kernel_size_1,
|
||||
self.kernel_size_2,
|
||||
]);
|
||||
let shape_offset = Shape::new([
|
||||
self.batch_size,
|
||||
self.kernel_size_1 * self.kernel_size_2 * self.offset_groups * 2,
|
||||
out_height,
|
||||
out_width,
|
||||
]);
|
||||
let shape_mask = Shape::new([
|
||||
self.batch_size,
|
||||
self.kernel_size_1 * self.kernel_size_2 * self.offset_groups,
|
||||
out_height,
|
||||
out_width,
|
||||
]);
|
||||
let device = Default::default();
|
||||
let weight = Tensor::<TestBackend, 4>::from(
|
||||
TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device)
|
||||
.reshape(shape_weight.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_weight.num_elements() as f32);
|
||||
let bias = Tensor::<TestBackend, 1>::from(
|
||||
TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(),
|
||||
)
|
||||
.div_scalar(self.channels_out as f32);
|
||||
let x = Tensor::<TestBackend, 4>::from(
|
||||
TestTensorInt::arange(0..shape_x.num_elements() as i64, &device)
|
||||
.reshape(shape_x.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_x.num_elements() as f32);
|
||||
let offset = Tensor::<TestBackend, 4>::from(
|
||||
TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device)
|
||||
.reshape(shape_offset.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_offset.num_elements() as f32);
|
||||
let mask = Tensor::<TestBackend, 4>::from(
|
||||
TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device)
|
||||
.reshape(shape_mask.clone())
|
||||
.into_data(),
|
||||
)
|
||||
.div_scalar(shape_mask.num_elements() as f32);
|
||||
|
||||
let output = deform_conv2d(
|
||||
x,
|
||||
offset,
|
||||
weight,
|
||||
Some(mask),
|
||||
Some(bias),
|
||||
DeformConvOptions::new(
|
||||
[self.stride_1, self.stride_2],
|
||||
[self.padding_1, self.padding_2],
|
||||
[self.dilation_1, self.dilation_2],
|
||||
self.weight_groups,
|
||||
self.offset_groups,
|
||||
),
|
||||
);
|
||||
|
||||
y.to_data().assert_approx_eq(&output.into_data(), 3);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -10,6 +10,7 @@ mod conv3d;
|
|||
mod conv_transpose1d;
|
||||
mod conv_transpose2d;
|
||||
mod conv_transpose3d;
|
||||
mod deform_conv2d;
|
||||
mod forward;
|
||||
mod maxpool1d;
|
||||
mod maxpool2d;
|
||||
|
|
|
@ -11,20 +11,22 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-wgpu"
|
|||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"]
|
||||
fusion = ["burn-fusion", "burn-jit/fusion"]
|
||||
autotune = ["burn-jit/autotune"]
|
||||
template = ["burn-jit/template", "cubecl/template"]
|
||||
default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"]
|
||||
doc = ["burn-jit/doc"]
|
||||
std = ["burn-jit/std", "cubecl/std"]
|
||||
fusion = ["burn-fusion", "burn-jit/fusion"]
|
||||
simple-memory-management = ["cubecl/simple-memory-management"]
|
||||
std = ["burn-jit/std", "cubecl/std"]
|
||||
template = ["burn-jit/template", "cubecl/template"]
|
||||
|
||||
[dependencies]
|
||||
cubecl = { workspace = true, features = ["wgpu"] }
|
||||
|
||||
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = ["cubecl-wgpu"] }
|
||||
burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true }
|
||||
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false }
|
||||
burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [
|
||||
"cubecl-wgpu",
|
||||
] }
|
||||
|
||||
[dev-dependencies]
|
||||
burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false, features = [
|
||||
|
|
Loading…
Reference in New Issue