From 2c8514ce7f6aa257c9702b312c000e740714887a Mon Sep 17 00:00:00 2001 From: Genna Wingert Date: Mon, 23 Sep 2024 21:17:23 +0200 Subject: [PATCH] Add deform_conv2d as implemented in torchvision (#2147) --- Cargo.lock | 7 + Cargo.toml | 1 + burn-book/src/building-blocks/module.md | 1 + crates/burn-autodiff/src/ops/module.rs | 337 ++++ .../burn-autodiff/src/tests/deform_conv2d.rs | 1414 +++++++++++++++++ crates/burn-autodiff/src/tests/mod.rs | 3 + crates/burn-candle/src/ops/module.rs | 28 +- crates/burn-core/src/nn/conv/deform_conv2d.rs | 263 +++ crates/burn-core/src/nn/conv/mod.rs | 2 + crates/burn-fusion/src/ops/module.rs | 211 ++- crates/burn-fusion/src/stream/context.rs | 29 + .../burn-jit/src/kernel/conv/deform_conv2d.rs | 316 ++++ .../kernel/conv/deform_conv_transpose2d.rs | 591 +++++++ crates/burn-jit/src/kernel/conv/mod.rs | 4 + crates/burn-jit/src/ops/module_ops.rs | 27 +- crates/burn-ndarray/Cargo.toml | 5 +- crates/burn-ndarray/src/ops/deform_conv.rs | 656 ++++++++ crates/burn-ndarray/src/ops/mod.rs | 1 + crates/burn-ndarray/src/ops/module.rs | 24 + crates/burn-tch/src/ops/module.rs | 27 +- crates/burn-tensor/src/repr/operation.rs | 87 +- crates/burn-tensor/src/tensor/module.rs | 24 + .../src/tensor/ops/modules/base.rs | 91 ++ crates/burn-tensor/src/tests/mod.rs | 1 + .../src/tests/module/deform_conv2d.rs | 439 +++++ crates/burn-tensor/src/tests/module/mod.rs | 1 + crates/burn-wgpu/Cargo.toml | 14 +- 27 files changed, 4586 insertions(+), 18 deletions(-) create mode 100644 crates/burn-autodiff/src/tests/deform_conv2d.rs create mode 100644 crates/burn-core/src/nn/conv/deform_conv2d.rs create mode 100644 crates/burn-jit/src/kernel/conv/deform_conv2d.rs create mode 100644 crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs create mode 100644 crates/burn-ndarray/src/ops/deform_conv.rs create mode 100644 crates/burn-tensor/src/tests/module/deform_conv2d.rs diff --git a/Cargo.lock b/Cargo.lock index 474788f85..9c9172817 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -245,6 +245,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "atomic_float" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "628d228f918ac3b82fe590352cc719d30664a0c13ca3a60266fe02c7132d480a" + [[package]] name = "atty" version = "0.2.14" @@ -672,6 +678,7 @@ dependencies = [ name = "burn-ndarray" version = "0.15.0" dependencies = [ + "atomic_float", "blas-src", "burn-autodiff", "burn-common", diff --git a/Cargo.toml b/Cargo.toml index 69c62721e..e38df0400 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ readme = "README.md" version = "0.15.0" [workspace.dependencies] +atomic_float = "1" bytemuck = "1.18.0" candle-core = { version = "0.6.0" } clap = { version = "4.5.18", features = ["derive"] } diff --git a/burn-book/src/building-blocks/module.md b/burn-book/src/building-blocks/module.md index 48f4bf9fa..4653b3ee1 100644 --- a/burn-book/src/building-blocks/module.md +++ b/burn-book/src/building-blocks/module.md @@ -190,6 +190,7 @@ Burn comes with built-in modules that you can use to build your own modules. | `ConvTranspose1d` | `nn.ConvTranspose1d` | | `ConvTranspose2d` | `nn.ConvTranspose2d` | | `ConvTranspose3d` | `nn.ConvTranspose3d` | +| `DeformConv2d` | `torchvision.ops.DeformConv2d` | ### Pooling diff --git a/crates/burn-autodiff/src/ops/module.rs b/crates/burn-autodiff/src/ops/module.rs index a0a70788b..315beba6e 100644 --- a/crates/burn-autodiff/src/ops/module.rs +++ b/crates/burn-autodiff/src/ops/module.rs @@ -441,6 +441,343 @@ impl ModuleOps> for Autodiff, + offset: AutodiffTensor, + weight: AutodiffTensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, + ) -> AutodiffTensor { + #[derive(Debug)] + struct DeformConv2DWithMaskWithBias; + #[derive(Debug)] + struct DeformConv2DWithMaskNoBias; + #[derive(Debug)] + struct DeformConv2DNoMaskWithBias; + #[derive(Debug)] + struct DeformConv2DNoMaskNoBias; + + impl Backward for DeformConv2DWithMaskWithBias { + type State = (NodeID, NodeID, NodeID, NodeID, NodeID, DeformConvOptions<2>); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + checkpointer: &mut Checkpointer, + ) { + let [node_x, node_offset, node_weight, node_mask, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x_state, offset_state, weight_state, mask_state, bias_state, options) = + ops.state; + let x = checkpointer.retrieve_node_output(x_state); + let offset = checkpointer.retrieve_node_output(offset_state); + let weight = checkpointer.retrieve_node_output(weight_state); + let mask = Some(checkpointer.retrieve_node_output(mask_state)); + let bias = Some(checkpointer.retrieve_node_output(bias_state)); + + let backward = + B::deform_conv2d_backward(x, offset, weight, mask, bias, grad, options); + + if let Some(node) = node_x { + grads.register::(node.id, backward.x_grad) + } + if let Some(node) = node_offset { + grads.register::(node.id, backward.offset_grad) + } + if let Some(node) = node_weight { + grads.register::(node.id, backward.weight_grad) + } + if let Some(node) = node_mask { + grads.register::(node.id, backward.mask_grad.unwrap()) + } + if let Some(node) = node_bias { + grads.register::(node.id, backward.bias_grad.unwrap()) + } + } + } + + impl Backward for DeformConv2DWithMaskNoBias { + type State = (NodeID, NodeID, NodeID, NodeID, DeformConvOptions<2>); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + checkpointer: &mut Checkpointer, + ) { + let [node_x, node_offset, node_weight, node_mask] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x_state, offset_state, weight_state, mask_state, options) = ops.state; + let x = checkpointer.retrieve_node_output(x_state); + let offset = checkpointer.retrieve_node_output(offset_state); + let weight = checkpointer.retrieve_node_output(weight_state); + let mask = Some(checkpointer.retrieve_node_output(mask_state)); + + let backward = + B::deform_conv2d_backward(x, offset, weight, mask, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node.id, backward.x_grad) + } + if let Some(node) = node_offset { + grads.register::(node.id, backward.offset_grad) + } + if let Some(node) = node_weight { + grads.register::(node.id, backward.weight_grad) + } + if let Some(node) = node_mask { + grads.register::(node.id, backward.mask_grad.unwrap()) + } + } + } + + impl Backward for DeformConv2DNoMaskWithBias { + type State = (NodeID, NodeID, NodeID, NodeID, DeformConvOptions<2>); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + checkpointer: &mut Checkpointer, + ) { + let [node_x, node_offset, node_weight, node_bias] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x_state, offset_state, weight_state, bias_state, options) = ops.state; + let x = checkpointer.retrieve_node_output(x_state); + let offset = checkpointer.retrieve_node_output(offset_state); + let weight = checkpointer.retrieve_node_output(weight_state); + let bias = Some(checkpointer.retrieve_node_output(bias_state)); + + let backward = + B::deform_conv2d_backward(x, offset, weight, None, bias, grad, options); + + if let Some(node) = node_x { + grads.register::(node.id, backward.x_grad) + } + if let Some(node) = node_offset { + grads.register::(node.id, backward.offset_grad) + } + if let Some(node) = node_weight { + grads.register::(node.id, backward.weight_grad) + } + if let Some(node) = node_bias { + grads.register::(node.id, backward.bias_grad.unwrap()) + } + } + } + + impl Backward for DeformConv2DNoMaskNoBias { + type State = (NodeID, NodeID, NodeID, DeformConvOptions<2>); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + checkpointer: &mut Checkpointer, + ) { + let [node_x, node_offset, node_weight] = ops.parents; + let grad = grads.consume::(&ops.node); + + let (x_state, offset_state, weight_state, options) = ops.state; + let x = checkpointer.retrieve_node_output(x_state); + let offset = checkpointer.retrieve_node_output(offset_state); + let weight = checkpointer.retrieve_node_output(weight_state); + + let backward = + B::deform_conv2d_backward(x, offset, weight, None, None, grad, options); + + if let Some(node) = node_x { + grads.register::(node.id, backward.x_grad) + } + if let Some(node) = node_offset { + grads.register::(node.id, backward.offset_grad) + } + if let Some(node) = node_weight { + grads.register::(node.id, backward.weight_grad) + } + } + } + + match (mask, bias) { + (Some(mask), Some(bias)) => match DeformConv2DWithMaskWithBias + .prepare::([ + x.node.clone(), + offset.node.clone(), + weight.node.clone(), + mask.node.clone(), + bias.node.clone(), + ]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(mut prep) => { + let x_state = prep.checkpoint(&x); + let offset_state = prep.checkpoint(&offset); + let weight_state = prep.checkpoint(&weight); + let mask_state = prep.checkpoint(&mask); + let bias_state = prep.checkpoint(&bias); + prep.finish( + ( + x_state, + offset_state, + weight_state, + mask_state, + bias_state, + options.clone(), + ), + B::deform_conv2d( + x.primitive, + offset.primitive, + weight.primitive, + Some(mask.primitive), + Some(bias.primitive), + options, + ), + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d( + x.primitive, + offset.primitive, + weight.primitive, + Some(mask.primitive), + Some(bias.primitive), + options, + )), + }, + (Some(mask), None) => match DeformConv2DWithMaskNoBias + .prepare::([ + x.node.clone(), + offset.node.clone(), + weight.node.clone(), + mask.node.clone(), + ]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(mut prep) => { + let x_state = prep.checkpoint(&x); + let offset_state = prep.checkpoint(&offset); + let weight_state = prep.checkpoint(&weight); + let mask_state = prep.checkpoint(&mask); + prep.finish( + ( + x_state, + offset_state, + weight_state, + mask_state, + options.clone(), + ), + B::deform_conv2d( + x.primitive, + offset.primitive, + weight.primitive, + Some(mask.primitive), + None, + options, + ), + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d( + x.primitive, + offset.primitive, + weight.primitive, + Some(mask.primitive), + None, + options, + )), + }, + (None, Some(bias)) => match DeformConv2DNoMaskWithBias + .prepare::([ + x.node.clone(), + offset.node.clone(), + weight.node.clone(), + bias.node.clone(), + ]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(mut prep) => { + let x_state = prep.checkpoint(&x); + let offset_state = prep.checkpoint(&offset); + let weight_state = prep.checkpoint(&weight); + let bias_state = prep.checkpoint(&bias); + prep.finish( + ( + x_state, + offset_state, + weight_state, + bias_state, + options.clone(), + ), + B::deform_conv2d( + x.primitive, + offset.primitive, + weight.primitive, + None, + Some(bias.primitive), + options, + ), + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d( + x.primitive, + offset.primitive, + weight.primitive, + None, + Some(bias.primitive), + options, + )), + }, + (None, None) => match DeformConv2DNoMaskNoBias + .prepare::([x.node.clone(), offset.node.clone(), weight.node.clone()]) + .compute_bound() + .stateful() + { + OpsKind::Tracked(mut prep) => { + let x_state = prep.checkpoint(&x); + let offset_state = prep.checkpoint(&offset); + let weight_state = prep.checkpoint(&weight); + prep.finish( + (x_state, offset_state, weight_state, options.clone()), + B::deform_conv2d( + x.primitive, + offset.primitive, + weight.primitive, + None, + None, + options, + ), + ) + } + OpsKind::UnTracked(prep) => prep.finish(B::deform_conv2d( + x.primitive, + offset.primitive, + weight.primitive, + None, + None, + options, + )), + }, + } + } + + fn deform_conv2d_backward( + _x: AutodiffTensor, + _offset: AutodiffTensor, + _weight: AutodiffTensor, + _mask: Option>, + _bias: Option>, + _output_grad: AutodiffTensor, + _options: DeformConvOptions<2>, + ) -> DeformConv2dBackward { + panic!("Can't differentiate deform conv 2d backward."); + } + fn conv_transpose2d( x: AutodiffTensor, weight: AutodiffTensor, diff --git a/crates/burn-autodiff/src/tests/deform_conv2d.rs b/crates/burn-autodiff/src/tests/deform_conv2d.rs new file mode 100644 index 000000000..13fe37e72 --- /dev/null +++ b/crates/burn-autodiff/src/tests/deform_conv2d.rs @@ -0,0 +1,1414 @@ +#[burn_tensor_testgen::testgen(ad_deform_conv2d)] +mod tests { + use super::*; + use burn_tensor::{module::deform_conv2d, ops::DeformConvOptions, Shape}; + + #[test] + fn test_deform_conv2d_basic() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + offset_groups: 1, + height: 4, + width: 4, + }; + let device = Default::default(); + let grads = Grads { + x: TestTensor::from_floats( + [[ + [ + [0.0000, 6.0678, 14.2071, 12.2477], + [11.2292, 33.7937, 50.1555, 44.0561], + [17.9294, 57.2174, 85.1505, 79.1840], + [18.0220, 73.6263, 126.8184, 151.6910], + ], + [ + [0.0000, 8.9783, 20.7620, 17.7888], + [16.2326, 48.7386, 71.7961, 62.5845], + [25.3808, 80.5195, 119.0949, 110.0938], + [25.0567, 101.8461, 174.3329, 206.6013], + ], + ]], + &device, + ), + offset: TestTensor::from_floats( + [[ + [[0.0000, 15.0000], [30.0000, 45.0000]], + [[0.0000, 3.7500], [7.5000, 11.2500]], + [[62.6667, 78.3333], [94.0000, 109.6667]], + [[15.6667, 19.5833], [23.5000, 27.4167]], + [[130.6667, 104.1250], [163.3333, 122.2732]], + [[32.6667, -492.9583], [40.8333, -787.1620]], + [[204.0000, 221.0000], [238.0000, 255.0000]], + [[51.0000, 55.2500], [59.5000, 63.7500]], + [[282.6667, 300.3333], [318.0000, 335.6667]], + [[70.6667, 75.0833], [79.5000, 83.9167]], + [[366.6667, 144.3750], [403.3333, 146.4121]], + [[91.6667, -1788.9860], [100.8333, -2392.7456]], + [[456.0000, 475.0000], [-2718.6250, -2953.2188]], + [[114.0000, 118.7500], [37.7361, 37.4063]], + [[550.6667, 570.3334], [-3404.5139, -3672.5312]], + [[137.6667, 142.5833], [28.6806, 27.5197]], + [[650.6667, 27.9584], [-4174.3657, -59.7509]], + [[162.6667, -3991.0139], [14.4028, -298.7557]], + ]], + &device, + ), + weight: TestTensor::from_floats( + [ + [ + [ + [0.7029, 2.8356, 5.1067], + [12.7492, 19.4745, 17.8345], + [22.0687, 25.9156, 14.6394], + ], + [ + [3.3696, 12.6134, 19.2671], + [36.7492, 50.5856, 43.5506], + [50.8774, 56.3292, 30.7470], + ], + ], + [ + [ + [0.7029, 2.8356, 5.1067], + [12.7492, 19.4745, 17.8345], + [22.0687, 25.9156, 14.6394], + ], + [ + [3.3696, 12.6134, 19.2671], + [36.7492, 50.5856, 43.5506], + [50.8774, 56.3292, 30.7470], + ], + ], + [ + [ + [0.7029, 2.8356, 5.1067], + [12.7492, 19.4745, 17.8345], + [22.0687, 25.9156, 14.6394], + ], + [ + [3.3696, 12.6134, 19.2671], + [36.7492, 50.5856, 43.5506], + [50.8774, 56.3292, 30.7470], + ], + ], + ], + &device, + ), + mask: TestTensor::from_floats( + [[ + [[1303.5000, 1447.8750], [1862.2500, 2006.6250]], + [[1571.1666, 1721.9581], [2154.7500, 2305.5417]], + [[1857.4999, 1396.7151], [2465.9167, 1753.2246]], + [[2315.5000, 2479.1250], [2948.7502, 3112.3750]], + [[2645.1665, 2815.2085], [3303.2500, 3473.2917]], + [[2993.5000, 1150.0625], [3676.4165, 1300.4055]], + [[3531.5000, 3714.3752], [1150.1876, 1148.4744]], + [[3923.1665, 4112.4585], [794.3865, 770.0470]], + [[4333.5000, 181.4101], [368.3260, 4.2679]], + ]], + &device, + ), + bias: TestTensor::from_floats([4., 4., 4.], &device), + }; + test.assert_grads(grads); + } + + #[test] + fn test_deform_conv2d_batched() { + let test = Conv2dTestCase { + batch_size: 2, + channels_in: 2, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + offset_groups: 1, + height: 4, + width: 4, + }; + let device = Default::default(); + let grads = Grads { + x: TestTensor::from_floats( + [ + [ + [ + [0.0000, 3.4604, 8.7539, 6.8080], + [8.4661, 24.0784, 35.4610, 26.4276], + [19.5988, 51.0406, 68.4389, 53.4993], + [17.4698, 47.9106, 67.3808, 56.6063], + ], + [ + [0.0000, 5.1185, 12.7803, 9.8796], + [12.1957, 34.5728, 50.4616, 37.3777], + [27.4521, 71.1227, 94.5778, 73.4724], + [24.1147, 65.8443, 91.8995, 76.7475], + ], + ], + [ + [ + [6.3750, 19.3553, 26.4935, 22.5650], + [17.0026, 57.8088, 85.5580, 78.0746], + [20.7334, 86.5793, 139.4667, 136.4133], + [16.8126, 103.0225, 186.4502, 206.9613], + ], + [ + [9.5625, 28.8786, 39.1137, 32.9178], + [25.1984, 85.0747, 124.6941, 112.5691], + [30.0242, 124.2863, 198.6056, 192.4489], + [23.5826, 143.4660, 257.8752, 283.2587], + ], + ], + ], + &device, + ), + offset: TestTensor::from_floats( + [ + [ + [[0.0000, 7.5000], [15.0000, 22.5000]], + [[0.0000, 1.8750], [3.7500, 5.6250]], + [[31.3333, 39.1667], [47.0000, 54.8333]], + [[7.8333, 9.7917], [11.7500, 13.7083]], + [[65.3333, 62.7813], [81.6667, 75.4849]], + [[16.3333, -237.8021], [20.4167, -381.7280]], + [[102.0000, 110.5000], [119.0000, 127.5000]], + [[25.5000, 27.6250], [29.7500, 31.8750]], + [[141.3333, 150.1667], [159.0000, 167.8333]], + [[35.3333, 37.5417], [39.7500, 41.9583]], + [[183.3333, 132.3438], [201.6667, 142.0197]], + [[45.8333, -839.6840], [50.4167, -1133.4155]], + [[228.0000, 237.5000], [-1336.1562, -1452.1173]], + [[57.0000, 59.3750], [40.3090, 41.4141]], + [[275.3333, 285.1667], [-1670.5034, -1802.9244]], + [[68.8333, 71.2917], [44.0451, 44.9841]], + [[325.3333, 174.7396], [-2045.1747, -1090.4585]], + [[81.3333, -1844.0659], [46.8090, -1150.2101]], + ], + [ + [[270.0000, 277.5000], [285.0000, 292.5000]], + [[67.5000, 69.3750], [71.2500, 73.1250]], + [[313.3333, 321.1667], [329.0000, 336.8333]], + [[78.3333, 80.2917], [82.2500, 84.2083]], + [[359.3333, 130.1563], [375.6667, 130.6099]], + [[89.8333, -4312.7603], [93.9167, -4893.6035]], + [[408.0000, 416.5000], [425.0000, 433.5000]], + [[102.0000, 104.1250], [106.2500, 108.3750]], + [[459.3333, 468.1667], [477.0000, 485.8333]], + [[114.8333, 117.0417], [119.2500, 121.4583]], + [[513.3334, 97.9688], [531.6667, 93.8947]], + [[128.3333, -6720.3926], [132.9167, -7504.5405]], + [[570.0000, 579.5000], [-7971.8438, -8251.0850]], + [[142.5000, 144.8750], [22.4965, 21.8203]], + [[629.3333, 639.1667], [-8948.2334, -9249.6641]], + [[157.3333, 159.7917], [15.7743, 14.8695]], + [[691.3333, 14.6145], [-9992.9453, -70.4040]], + [[172.8333, -9818.5234], [7.4132, -352.0222]], + ], + ], + &device, + ), + weight: TestTensor::from_floats( + [ + [ + [ + [77.7195, 89.8692, 69.0213], + [121.0760, 137.0775, 92.2989], + [100.0212, 106.5561, 61.1851], + ], + [ + [112.3862, 131.6470, 103.8793], + [177.0760, 200.1887, 138.2681], + [149.5922, 158.7074, 94.3991], + ], + ], + [ + [ + [77.7195, 89.8692, 69.0213], + [121.0760, 137.0775, 92.2989], + [100.0212, 106.5561, 61.1851], + ], + [ + [112.3862, 131.6470, 103.8793], + [177.0760, 200.1887, 138.2681], + [149.5922, 158.7074, 94.3991], + ], + ], + [ + [ + [77.7195, 89.8692, 69.0213], + [121.0760, 137.0775, 92.2989], + [100.0212, 106.5561, 61.1851], + ], + [ + [112.3862, 131.6470, 103.8793], + [177.0760, 200.1887, 138.2681], + [149.5922, 158.7074, 94.3991], + ], + ], + ], + &device, + ), + mask: TestTensor::from_floats( + [ + [ + [[1299.7499, 1439.4375], [1849.1249, 1988.8125]], + [[1528.0834, 1673.9791], [2101.8750, 2247.7708]], + [[1771.7500, 1624.9811], [2369.9583, 2099.5039]], + [[2183.7500, 2342.0625], [2806.3750, 2964.6875]], + [[2464.0833, 2628.6042], [3111.1250, 3275.6458]], + [[2759.7500, 1979.2551], [3431.2085, 2390.0286]], + [[3241.7498, 3418.6873], [2415.3589, 2500.8682]], + [[3574.0835, 3757.2292], [2394.3889, 2471.7510]], + [[3921.7500, 2095.5293], [2345.9363, 1199.5048]], + ], + [ + [[5957.2500, 6096.9375], [6506.6250, 6646.3125]], + [[6392.5835, 6538.4790], [6966.3750, 7112.2705]], + [[6843.2500, 2443.8982], [7441.4585, 2550.9199]], + [[7462.2505, 7620.5625], [8084.8745, 8243.1875]], + [[7949.5835, 8114.1045], [8596.6250, 8761.1465]], + [[8452.2500, 1591.6719], [9123.7080, 1589.9454]], + [[9141.2500, 9318.1875], [1414.3584, 1375.1803]], + [[9680.5840, 9863.7285], [949.0560, 897.3544]], + [[10235.2500, 213.4454], [428.2699, 2.4790]], + ], + ], + &device, + ), + bias: TestTensor::from_floats([8., 8., 8.], &device), + }; + test.assert_grads(grads); + } + + #[test] + fn test_deform_conv2d_different_kernel_size() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 1, + padding_2: 1, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + offset_groups: 1, + height: 4, + width: 4, + }; + let device = Default::default(); + let grads = Grads { + x: TestTensor::from_floats( + [[ + [ + [14.5585, 27.2496, 37.3820, 36.0394], + [33.1519, 60.4807, 81.2647, 78.6182], + [57.5201, 108.6233, 153.4136, 170.0730], + [54.7062, 102.5967, 144.3672, 162.6436], + ], + [ + [25.8364, 48.0884, 65.2492, 62.1033], + [56.8052, 102.9956, 136.9831, 131.1209], + [96.1054, 179.7902, 250.5509, 272.6688], + [90.2110, 167.5679, 232.8473, 257.9347], + ], + ]], + &device, + ), + offset: TestTensor::from_floats( + [[ + [ + [0.0000, 5.3559, 11.7153], + [0.3125, 8.0000, 10.0000], + [0.7500, 14.0000, 16.0000], + [1.3125, 20.0000, 22.0000], + ], + [ + [0.0000, 0.0017, 0.0069], + [16.0625, 2.0000, 2.5000], + [44.2500, 3.5000, 4.0000], + [84.5625, 5.0000, 5.5000], + ], + [ + [67.4583, 79.9648, 93.5305], + [31.6667, 33.7778, 35.8889], + [38.0000, 40.1111, 42.2222], + [44.3333, 46.4444, 48.5556], + ], + [ + [0.5278, 0.5956, 0.6671], + [7.9167, 8.4444, 8.9722], + [9.5000, 10.0278, 10.5556], + [11.0833, 11.6111, 12.1389], + ], + [ + [154.7778, 175.1640, 151.8874], + [60.0000, 62.2222, 49.8997], + [66.6667, 68.8889, 54.3210], + [73.3333, 75.5555, 58.6034], + ], + [ + [2.2222, 2.3630, -33.6034], + [15.0000, 15.5556, -227.7485], + [16.6667, 17.2222, -323.1605], + [18.3333, 18.8889, -432.0448], + ], + [ + [264.1250, 202.1189, 0.0000], + [91.0000, 64.8148, 0.0000], + [98.0000, 68.6308, 0.0000], + [105.0000, 72.3009, 0.0000], + ], + [ + [5.2500, -72.6832, 0.0000], + [22.7500, -334.6296, 0.0000], + [24.5000, -461.1053, 0.0000], + [26.2500, -601.7269, 0.0000], + ], + [ + [44.0000, 119.7778, 122.2222], + [48.0486, 127.1111, 129.5556], + [52.2500, 134.4444, 136.8889], + [-313.8958, -800.7446, -850.7313], + ], + [ + [337.7778, 29.9444, 30.5556], + [484.8542, 31.7778, 32.3889], + [646.7500, 33.6111, 34.2222], + [490.9653, 22.3989, 22.6599], + ], + [ + [153.3333, 155.8889, 158.4444], + [161.0000, 163.5556, 166.1111], + [168.6667, 171.2222, 173.7778], + [-995.2491, -1054.5505, -1115.1342], + ], + [ + [38.3333, 38.9722, 39.6111], + [40.2500, 40.8889, 41.5278], + [42.1667, 42.8056, 43.4444], + [24.3377, 24.5351, 24.7281], + ], + [ + [192.0000, 194.6667, 89.0741], + [200.0000, 202.6667, 90.5463], + [208.0000, 210.6667, 91.8519], + [-1272.9375, -1343.5092, -581.1921], + ], + [ + [48.0000, 48.6667, -741.3703], + [50.0000, 50.6667, -978.8981], + [52.0000, 52.6667, -1232.5927], + [25.3125, 25.4352, -638.8311], + ], + [ + [233.3333, 87.7218, 0.0000], + [241.6667, 88.2716, 0.0000], + [250.0000, 88.6478, 0.0000], + [-1587.2161, -553.5372, 0.0000], + ], + [ + [58.3333, -901.1902, 0.0000], + [60.4167, -1179.9877, 0.0000], + [62.5000, -1475.6252, 0.0000], + [24.8915, -621.3175, 0.0000], + ], + [ + [196.4444, 280.2222, 283.1111], + [205.5625, 288.8889, 291.7778], + [-1173.4723, -1679.6113, -1771.2903], + [0.0000, 0.0000, 0.0000], + ], + [ + [1144.8890, 70.0556, 70.7778], + [1469.6459, 72.2222, 72.9444], + [502.9167, 22.9882, 22.9506], + [0.0000, 0.0000, 0.0000], + ], + [ + [324.0000, 327.0000, 330.0000], + [333.0000, 336.0000, 339.0000], + [-1931.4688, -2034.9608, -2139.9585], + [0.0000, 0.0000, 0.0000], + ], + [ + [81.0000, 81.7500, 82.5000], + [83.2500, 84.0000, 84.7500], + [19.5938, 19.4661, 19.3333], + [0.0000, 0.0000, 0.0000], + ], + [ + [373.3333, 376.4445, 44.8087], + [382.6667, 385.7778, 41.8596], + [-2313.7917, -2431.2759, -239.2101], + [0.0000, 0.0000, 0.0000], + ], + [ + [93.3333, 94.1111, -1904.9321], + [95.6667, 96.4444, -2344.7146], + [14.2917, 14.0621, -341.7283], + [0.0000, 0.0000, 0.0000], + ], + [ + [425.3333, 16.3684, 0.0000], + [435.0000, 12.1728, 0.0000], + [-2738.5173, -47.9289, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + [ + [106.3333, -2178.7473, 0.0000], + [108.7500, -2670.6790, 0.0000], + [6.9479, -162.9574, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + ]], + &device, + ), + weight: TestTensor::from_floats( + [ + [ + [ + [1.8560, 7.2034, 12.8334, 11.9694], + [24.2368, 40.1255, 41.3964, 27.6420], + [43.6131, 57.5089, 46.0933, 25.1744], + ], + [ + [6.9899, 26.5803, 42.6186, 37.5014], + [75.6232, 116.9257, 113.2884, 72.5678], + [112.7249, 139.8264, 107.6534, 56.7994], + ], + ], + [ + [ + [1.8560, 7.2034, 12.8334, 11.9694], + [24.2368, 40.1255, 41.3964, 27.6420], + [43.6131, 57.5089, 46.0933, 25.1744], + ], + [ + [6.9899, 26.5803, 42.6186, 37.5014], + [75.6232, 116.9257, 113.2884, 72.5678], + [112.7249, 139.8264, 107.6534, 56.7994], + ], + ], + ], + &device, + ), + mask: TestTensor::from_floats( + [[ + [ + [0.0000, 2.6779, 5.8576], + [40.1562, 775.9999, 849.2499], + [66.3750, 1067.7499, 1140.9999], + [98.6563, 1359.5000, 1432.7499], + ], + [ + [67.4583, 76.8892, 86.8497], + [838.7916, 916.1111, 993.4306], + [1146.7500, 1224.0695, 1301.3889], + [1454.7083, 1532.0278, 1609.3472], + ], + [ + [154.7778, 171.6607, 146.0455], + [986.1667, 1067.5555, 875.6536], + [1310.3333, 1391.7222, 1110.8640], + [1634.5001, 1715.8888, 1339.3390], + ], + [ + [264.1250, 199.3876, 0.0000], + [1144.8751, 836.5740, 0.0000], + [1485.2499, 1056.2528, 0.0000], + [1825.6250, 1268.8589, 0.0000], + ], + [ + [380.0000, 1047.8611, 1137.3889], + [527.6354, 1404.4445, 1493.9722], + [682.6807, 1761.0276, 1850.5554], + [503.8855, 1256.3406, 1304.9355], + ], + [ + [1123.5000, 1217.0972, 1310.6943], + [1496.2917, 1589.8889, 1683.4861], + [1869.0834, 1962.6805, 2056.2778], + [1146.6998, 1190.1357, 1232.9299], + ], + [ + [1300.0001, 1397.6667, 651.2036], + [1689.0000, 1786.6667, 807.2734], + [2078.0000, 2175.6667, 955.2593], + [1060.7812, 1097.7451, 465.6539], + ], + [ + [1487.8334, 567.2195, 0.0000], + [1893.0416, 697.2655, 0.0000], + [2298.2500, 818.8910, 0.0000], + [947.2098, 323.8781, 0.0000], + ], + [ + [1216.4445, 1792.8055, 1898.6112], + [1536.4478, 2214.2222, 2320.0278], + [517.7084, 725.6571, 749.3920], + [0.0000, 0.0000, 0.0000], + ], + [ + [1897.5000, 2007.3749, 2117.2500], + [2335.1250, 2445.0000, 2554.8750], + [559.1096, 575.0975, 590.3336], + [0.0000, 0.0000, 0.0000], + ], + [ + [2119.3333, 2233.2776, 265.4414], + [2573.1667, 2687.1111, 290.7444], + [385.6317, 392.4502, 37.3766], + [0.0000, 0.0000, 0.0000], + ], + [ + [2352.5000, 90.0985, 0.0000], + [2822.5415, 78.5491, 0.0000], + [178.5990, 2.9309, 0.0000], + [0.0000, 0.0000, 0.0000], + ], + ]], + &device, + ), + bias: TestTensor::from_floats([12., 12.], &device), + }; + test.assert_grads(grads); + } + + #[test] + fn test_deform_conv2d_different_padding() { + let test = Conv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 2, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 2, + padding_2: 3, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + groups: 1, + offset_groups: 1, + height: 4, + width: 4, + }; + let device = Default::default(); + let grads = Grads { + x: TestTensor::from_floats( + [[ + [ + [60.6330, 60.9065, 61.1795, 61.4520], + [122.5578, 123.0882, 123.6186, 124.1490], + [126.8011, 127.3315, 127.8619, 128.3924], + [131.0444, 131.5749, 132.1053, 132.6357], + ], + [ + [102.0006, 102.4976, 102.9938, 103.4893], + [198.9330, 199.8306, 200.7282, 201.6259], + [206.1140, 207.0117, 207.9092, 208.8069], + [213.2949, 214.1926, 215.0903, 215.9879], + ], + ]], + &device, + ), + offset: TestTensor::from_floats( + [[ + [ + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.8951, 14.7606, 17.6042, 20.6981, 22.2004, 0.0000, + ], + [ + 0.0000, 0.0000, 0.6875, 9.5000, 10.0000, 10.5000, 10.1088, 0.0000, + ], + [ + 0.0000, 0.0000, 1.1134, 13.5000, 14.0000, 14.5000, 13.6458, 0.0000, + ], + [ + 0.0000, 0.0000, 1.6134, 17.5000, 18.0000, 18.5000, 17.1088, 0.0000, + ], + [ + 0.0000, 0.0000, -12.3958, -122.3994, -130.7523, -139.3555, -131.5268, + 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.1543, 0.0175, 0.0208, 0.0245, -0.3875, 0.0000, + ], + [ + 0.0000, 0.0000, 24.1875, 2.3750, 2.5000, 2.6250, -37.8634, 0.0000, + ], + [ + 0.0000, 0.0000, 48.0579, 3.3750, 3.5000, 3.6250, -66.7708, 0.0000, + ], + [ + 0.0000, 0.0000, 80.0023, 4.3750, 4.5000, 4.6250, -103.7523, 0.0000, + ], + [ + 0.0000, 0.0000, 113.2153, 5.1075, 5.2199, 5.3320, -139.7259, 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 14.2060, 83.0176, 92.3794, 102.0100, 90.3563, 0.0000, 0.0000, + ], + [ + 0.0000, 6.5047, 35.4444, 35.9815, 36.5185, 29.9790, 0.0000, 0.0000, + ], + [ + 0.0000, 7.6683, 39.7407, 40.2778, 40.8148, 33.0719, 0.0000, 0.0000, + ], + [ + 0.0000, 8.9115, 44.0370, 44.5741, 45.1111, 36.0853, 0.0000, 0.0000, + ], + [ + 0.0000, -57.5230, -274.2679, -289.5471, -305.0951, -248.5786, 0.0000, + 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 9.7492, 0.9554, 0.9810, 1.0069, -13.9305, 0.0000, 0.0000, + ], + [ + 0.0000, 96.0469, 8.8611, 8.9954, 9.1296, -129.9207, 0.0000, 0.0000, + ], + [ + 0.0000, 147.4348, 9.9352, 10.0694, 10.2037, -186.7187, 0.0000, 0.0000, + ], + [ + 0.0000, 207.4948, 11.0093, 11.1435, 11.2778, -252.1889, 0.0000, 0.0000, + ], + [ + 0.0000, 226.0500, 10.1534, 10.2520, 10.3504, -266.2553, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 44.2250, 159.8985, 176.6519, 193.6927, 146.2708, 0.0000, 0.0000, 0.0000, + ], + [ + 19.0508, 64.8704, 65.4444, 66.0185, 46.5532, 0.0000, 0.0000, 0.0000, + ], + [ + 21.0494, 69.4630, 70.0370, 70.6111, 49.1046, 0.0000, 0.0000, 0.0000, + ], + [ + 23.1331, 74.0556, 74.6296, 75.2037, 51.5710, 0.0000, 0.0000, 0.0000, + ], + [ + -141.2003, -445.3022, -468.3810, -491.7472, -341.5531, 0.0000, 0.0000, + 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 35.6653, 3.5057, 3.5567, 3.6081, -48.7569, 0.0000, 0.0000, 0.0000, + ], + [ + 181.4047, 16.2176, 16.3611, 16.5046, -238.1361, 0.0000, 0.0000, 0.0000, + ], + [ + 263.8889, 17.3657, 17.5093, 17.6528, -326.4037, 0.0000, 0.0000, 0.0000, + ], + [ + 355.6433, 18.5139, 18.6574, 18.8009, -423.9413, 0.0000, 0.0000, 0.0000, + ], + [ + 318.7092, 14.3597, 14.4416, 14.5231, -369.8195, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 88.8467, 237.4784, 261.7312, 286.2899, 182.5087, 0.0000, + ], + [ + 0.0000, 0.0000, 37.6880, 94.7222, 95.3333, 95.9445, 57.4416, 0.0000, + ], + [ + 0.0000, 0.0000, 40.5625, 99.6111, 100.2222, 100.8333, 59.4107, 0.0000, + ], + [ + 0.0000, 0.0000, 43.5275, 104.5000, 105.1111, 105.7222, 61.2893, 0.0000, + ], + [ + 0.0000, 0.0000, -258.3244, -618.3539, -649.3403, -680.6325, -397.1010, + 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 76.2294, 7.5641, 7.6417, 7.7197, -102.7922, 0.0000, + ], + [ + 0.0000, 0.0000, 272.0152, 23.6806, 23.8333, 23.9861, -351.9442, 0.0000, + ], + [ + 0.0000, 0.0000, 386.0625, 24.9028, 25.0556, 25.2083, -472.1479, 0.0000, + ], + [ + 0.0000, 0.0000, 509.9781, 26.1250, 26.2778, 26.4306, -602.2200, 0.0000, + ], + [ + 0.0000, 0.0000, 378.4102, 17.1237, 17.1875, 17.2510, -436.0007, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 157.6233, 331.9385, 365.2834, 398.9526, 205.9885, 0.0000, + 0.0000, + ], + [ + 0.0000, 66.4959, 130.9259, 131.5741, 132.2222, 64.4360, 0.0000, 0.0000, + ], + [ + 0.0000, 70.3968, 136.1111, 136.7593, 137.4074, 65.6723, 0.0000, 0.0000, + ], + [ + 0.0000, 74.3938, 141.2963, 141.9444, 142.5926, 66.8125, 0.0000, 0.0000, + ], + [ + 0.0000, -432.7980, -827.4919, -867.9785, -908.7894, -425.0742, 0.0000, + 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 140.1500, 14.0440, 14.1529, 14.2623, -187.6569, 0.0000, 0.0000, + ], + [ + 0.0000, 386.8139, 32.7315, 32.8935, 33.0556, -494.7796, 0.0000, 0.0000, + ], + [ + 0.0000, 538.9267, 34.0278, 34.1898, 34.3519, -653.4219, 0.0000, 0.0000, + ], + [ + 0.0000, 701.5059, 35.3241, 35.4861, 35.6482, -822.5306, 0.0000, 0.0000, + ], + [ + 0.0000, 416.0446, 18.9036, 18.9446, 18.9853, -476.7288, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 249.8765, 435.8685, 479.1788, 522.8320, 207.9198, 0.0000, 0.0000, + 0.0000, + ], + [ + 105.4170, 170.6111, 171.2963, 171.9815, 64.7500, 0.0000, 0.0000, 0.0000, + ], + [ + 110.4417, 176.0926, 176.7778, 177.4630, 65.1560, 0.0000, 0.0000, 0.0000, + ], + [ + 115.5679, 181.5741, 182.2592, 182.9445, 65.4606, 0.0000, 0.0000, 0.0000, + ], + [ + -662.7435, -1056.6418, -1107.5020, -1158.7047, -409.5102, 0.0000, + 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 227.1605, 22.9825, 23.1258, 23.2695, -303.1120, 0.0000, 0.0000, 0.0000, + ], + [ + 518.4952, 42.6528, 42.8241, 42.9954, -657.1573, 0.0000, 0.0000, 0.0000, + ], + [ + 712.2524, 44.0231, 44.1944, 44.3657, -857.8172, 0.0000, 0.0000, 0.0000, + ], + [ + 917.0740, 45.3935, 45.5648, 45.7361, -1069.5416, 0.0000, 0.0000, 0.0000, + ], + [ + 416.5815, 18.9978, 19.0131, 19.0280, -475.0315, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 151.7503, 210.1667, 210.8889, 211.6111, 57.5069, 0.0000, + ], + [ + 0.0000, 0.0000, 157.9293, 215.9444, 216.6667, 217.3889, 57.0522, 0.0000, + ], + [ + 0.0000, 0.0000, 164.2153, 221.7222, 222.4445, 223.1667, 56.4905, 0.0000, + ], + [ + 0.0000, 0.0000, -931.7838, -1285.3538, -1346.5559, -1408.1194, + -346.7390, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 655.6700, 52.5417, 52.7222, 52.9028, -824.9468, 0.0000, + ], + [ + 0.0000, 0.0000, 890.9725, 53.9861, 54.1667, 54.3472, -1067.5250, 0.0000, + ], + [ + 0.0000, 0.0000, 1137.9375, 55.4306, 55.6111, 55.7917, -1321.7656, + 0.0000, + ], + [ + 0.0000, 0.0000, 375.5806, 17.1810, 17.1695, 17.1576, -425.9937, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 213.5215, 256.6296, 257.3889, 258.1481, 41.6529, 0.0000, 0.0000, + ], + [ + 0.0000, 221.0156, 262.7037, 263.4630, 264.2222, 40.1766, 0.0000, 0.0000, + ], + [ + 0.0000, 228.6223, 268.7778, 269.5370, 270.2963, 38.5878, 0.0000, 0.0000, + ], + [ + 0.0000, -1285.4667, -1554.2545, -1627.5306, -1701.1866, -228.2914, + 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 823.3806, 64.1574, 64.3472, 64.5370, -1028.5327, 0.0000, 0.0000, + ], + [ + 0.0000, 1107.2965, 65.6759, 65.8657, 66.0556, -1320.0975, 0.0000, + 0.0000, + ], + [ + 0.0000, 1403.4730, 67.1944, 67.3843, 67.5741, -1623.9230, 0.0000, + 0.0000, + ], + [ + 0.0000, 288.1514, 13.2018, 13.1585, 13.1148, -323.5778, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 288.7901, 306.5741, 307.3704, 308.1667, 15.7342, 0.0000, 0.0000, 0.0000, + ], + [ + 297.6968, 312.9444, 313.7407, 314.5370, 13.1389, 0.0000, 0.0000, 0.0000, + ], + [ + 306.7215, 319.3148, 320.1111, 320.9074, 10.4255, 0.0000, 0.0000, 0.0000, + ], + [ + -1711.5431, -1844.0131, -1930.2366, -2016.8586, -46.8461, 0.0000, + 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 1011.3581, 76.6435, 76.8426, 77.0417, -1255.0455, 0.0000, 0.0000, + 0.0000, + ], + [ + 1347.4663, 78.2361, 78.4352, 78.6343, -1599.1759, 0.0000, 0.0000, + 0.0000, + ], + [ + 1696.4333, 79.8287, 80.0278, 80.2269, -1956.1649, 0.0000, 0.0000, + 0.0000, + ], + [ + 146.7036, 6.6909, 6.6128, 6.5342, -159.2772, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + ]], + &device, + ), + weight: TestTensor::from_floats( + [ + [ + [ + [10.3420, 22.9881, 35.6342], + [46.9202, 59.5663, 72.2124], + [80.8816, 92.5915, 104.1585], + ], + [ + [29.2134, 68.8378, 108.4622], + [143.8251, 183.4495, 223.0739], + [228.0294, 256.7517, 283.8071], + ], + ], + [ + [ + [10.3420, 22.9881, 35.6342], + [46.9202, 59.5663, 72.2124], + [80.8816, 92.5915, 104.1585], + ], + [ + [29.2134, 68.8378, 108.4622], + [143.8251, 183.4495, 223.0739], + [228.0294, 256.7517, 283.8071], + ], + ], + ], + &device, + ), + mask: TestTensor::from_floats( + [[ + [ + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.4475, 7.3803, 8.8021, 10.3490, 11.1002, 0.0000, + ], + [ + 0.0000, 0.0000, 44.3438, 584.9374, 639.2500, 693.5624, 683.2628, 0.0000, + ], + [ + 0.0000, 0.0000, 68.3901, 803.4375, 857.7500, 912.0625, 874.6981, 0.0000, + ], + [ + 0.0000, 0.0000, 96.4734, 1021.9375, 1076.2500, 1130.5625, 1062.0959, + 0.0000, + ], + [ + 0.0000, 0.0000, 121.3021, 1168.4879, 1218.3738, 1268.1349, 1169.4447, + 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 13.0845, 75.8609, 83.7678, 91.8090, 80.7282, 0.0000, 0.0000, + ], + [ + 0.0000, 118.9504, 649.4861, 707.8218, 766.1574, 658.0766, 0.0000, + 0.0000, + ], + [ + 0.0000, 170.6608, 884.1713, 942.5070, 1000.8427, 837.8093, 0.0000, + 0.0000, + ], + [ + 0.0000, 226.7073, 1118.8564, 1177.1923, 1235.5277, 1013.2059, 0.0000, + 0.0000, + ], + [ + 0.0000, 234.9397, 1106.2139, 1153.4156, 1200.4827, 966.2489, 0.0000, + 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 42.5240, 153.0457, 168.3193, 183.7365, 138.1447, 0.0000, 0.0000, 0.0000, + ], + [ + 207.3196, 718.4328, 780.7916, 843.1504, 619.9750, 0.0000, 0.0000, + 0.0000, + ], + [ + 290.2778, 969.3032, 1031.6620, 1094.0208, 784.4216, 0.0000, 0.0000, + 0.0000, + ], + [ + 377.8711, 1220.1736, 1282.5325, 1344.8912, 944.2330, 0.0000, 0.0000, + 0.0000, + ], + [ + 328.0830, 1025.4950, 1069.1306, 1112.6222, 766.0549, 0.0000, 0.0000, + 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 88.2382, 235.0552, 258.1943, 281.4864, 178.8585, 0.0000, + ], + [ + 0.0000, 0.0000, 305.5755, 789.8680, 856.2500, 922.6319, 572.4661, + 0.0000, + ], + [ + 0.0000, 0.0000, 421.8090, 1056.9236, 1123.3055, 1189.6875, 719.5988, + 0.0000, + ], + [ + 0.0000, 0.0000, 542.9767, 1323.9792, 1390.3612, 1456.7430, 861.7973, + 0.0000, + ], + [ + 0.0000, 0.0000, 393.2916, 934.4397, 974.0104, 1013.4281, 586.9240, + 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 157.2149, 330.2274, 362.4734, 394.8816, 203.3744, 0.0000, + 0.0000, + ], + [ + 0.0000, 424.3406, 867.4954, 937.9005, 1008.3055, 505.6405, 0.0000, + 0.0000, + ], + [ + 0.0000, 578.8949, 1150.7361, 1221.1412, 1291.5463, 630.4140, 0.0000, + 0.0000, + ], + [ + 0.0000, 738.6825, 1433.9769, 1504.3820, 1574.7871, 749.9543, 0.0000, + 0.0000, + ], + [ + 0.0000, 429.9127, 816.5075, 850.7720, 884.8738, 411.1526, 0.0000, + 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 249.8765, 434.9642, 477.1987, 519.6047, 206.2156, 0.0000, 0.0000, + 0.0000, + ], + [ + 560.3093, 949.5208, 1023.9491, 1098.3773, 422.4583, 0.0000, 0.0000, + 0.0000, + ], + [ + 756.7681, 1248.9468, 1323.3750, 1397.8032, 521.2890, 0.0000, 0.0000, + 0.0000, + ], + [ + 958.7592, 1548.3728, 1622.8009, 1697.2292, 614.5874, 0.0000, 0.0000, + 0.0000, + ], + [ + 428.8339, 679.2698, 707.3463, 735.2509, 258.1694, 0.0000, 0.0000, + 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 0.0000, 707.6714, 1033.6875, 1112.1389, 1190.5902, 328.2950, + 0.0000, + ], + [ + 0.0000, 0.0000, 947.7794, 1349.2986, 1427.7500, 1506.2014, 399.4381, + 0.0000, + ], + [ + 0.0000, 0.0000, 1193.7189, 1664.9097, 1743.3611, 1821.8125, 464.7498, + 0.0000, + ], + [ + 0.0000, 0.0000, 388.7379, 532.5035, 553.9629, 575.2411, 140.6583, + 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 0.0000, 880.7972, 1124.3936, 1206.8680, 1289.3427, 209.6276, 0.0000, + 0.0000, + ], + [ + 0.0000, 1169.8828, 1456.1898, 1538.6644, 1621.1389, 247.7547, 0.0000, + 0.0000, + ], + [ + 0.0000, 1465.0988, 1787.9861, 1870.4606, 1952.9352, 279.7515, 0.0000, + 0.0000, + ], + [ + 0.0000, 297.3307, 356.3622, 369.8935, 383.2343, 50.9746, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + [ + [ + 1074.5679, 1219.4977, 1305.9954, 1392.4930, 71.1624, 0.0000, 0.0000, + 0.0000, + ], + [ + 1416.2147, 1567.4791, 1653.9769, 1740.4746, 72.6899, 0.0000, 0.0000, + 0.0000, + ], + [ + 1764.2908, 1915.4606, 2001.9585, 2088.4561, 67.7876, 0.0000, 0.0000, + 0.0000, + ], + [ + 151.0184, 160.0550, 164.7761, 169.2984, 3.8659, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + [ + 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, + ], + ], + ]], + &device, + ), + bias: TestTensor::from_floats([48., 48.], &device), + }; + test.assert_grads(grads); + } + + struct Conv2dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + dilation_1: usize, + dilation_2: usize, + groups: usize, + offset_groups: usize, + height: usize, + width: usize, + } + + struct Grads { + x: TestTensor<4>, + offset: TestTensor<4>, + weight: TestTensor<4>, + mask: TestTensor<4>, + bias: TestTensor<1>, + } + + impl Conv2dTestCase { + fn assert_grads(self, expected_grads: Grads) { + let out_height = + (self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1) + / self.stride_1 + + 1; + let out_width = + (self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1) + / self.stride_2 + + 1; + + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let shape_offset = Shape::new([ + self.batch_size, + 2 * self.offset_groups * self.kernel_size_1 * self.kernel_size_2, + out_height, + out_width, + ]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.groups, + self.kernel_size_1, + self.kernel_size_2, + ]); + let shape_mask = Shape::new([ + self.batch_size, + self.offset_groups * self.kernel_size_1 * self.kernel_size_2, + out_height, + out_width, + ]); + let device = Default::default(); + let weight = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) + .reshape(shape_weight) + .into_data(), + &device, + ) + .require_grad(); + let bias = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), + &device, + ) + .require_grad(); + let x = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) + .reshape(shape_x) + .into_data(), + &device, + ) + .require_grad(); + let offset = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device) + .reshape(shape_offset.clone()) + .into_data(), + &device, + ) + .div_scalar(shape_offset.num_elements() as f32) + .require_grad(); + + let mask = TestAutodiffTensor::from_data( + TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device) + .reshape(shape_mask.clone()) + .into_data(), + &device, + ) + .div_scalar(shape_mask.num_elements() as f32) + .require_grad(); + + let output = deform_conv2d( + x.clone(), + offset.clone(), + weight.clone(), + Some(mask.clone()), + Some(bias.clone()), + DeformConvOptions::new( + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + [self.dilation_1, self.dilation_2], + self.groups, + self.offset_groups, + ), + ); + let grads = output.backward(); + + // Assert + let x_grad_actual = x.grad(&grads).unwrap(); + let offset_grad_actual = offset.grad(&grads).unwrap(); + let weight_grad_actual = weight.grad(&grads).unwrap(); + let mask_grad_actual = mask.grad(&grads).unwrap(); + let bias_grad_actual = bias.grad(&grads).unwrap(); + + println!("Testing bias"); + expected_grads + .bias + .to_data() + .assert_approx_eq(&bias_grad_actual.to_data(), 3); + println!("Testing input"); + expected_grads + .x + .to_data() + .assert_approx_eq(&x_grad_actual.to_data(), 3); + println!("Testing offset"); + expected_grads + .offset + .to_data() + .assert_approx_eq(&offset_grad_actual.to_data(), 3); + println!("Testing mask"); + expected_grads + .mask + .to_data() + .assert_approx_eq(&mask_grad_actual.to_data(), 3); + println!("Testing weight"); + expected_grads + .weight + .to_data() + .assert_approx_eq(&weight_grad_actual.to_data(), 3); + } + } +} diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index af583d4ae..84a0eee27 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -21,6 +21,7 @@ mod conv_transpose2d; mod conv_transpose3d; mod cos; mod cross_entropy; +mod deform_conv2d; mod div; mod erf; mod exp; @@ -82,6 +83,8 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_conv1d!(); burn_autodiff::testgen_ad_conv2d!(); burn_autodiff::testgen_ad_conv3d!(); + #[cfg(not(target_os = "macos"))] // Wgpu on MacOS currently doesn't support atomic compare exchange + burn_autodiff::testgen_ad_deform_conv2d!(); burn_autodiff::testgen_ad_conv_transpose1d!(); burn_autodiff::testgen_ad_conv_transpose2d!(); burn_autodiff::testgen_ad_conv_transpose3d!(); diff --git a/crates/burn-candle/src/ops/module.rs b/crates/burn-candle/src/ops/module.rs index db9d4b059..553b8674f 100644 --- a/crates/burn-candle/src/ops/module.rs +++ b/crates/burn-candle/src/ops/module.rs @@ -1,7 +1,8 @@ use burn_tensor::{ ops::{ - ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateMode, - InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, UnfoldOptions, + ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor, + IntTensor, InterpolateMode, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, + ModuleOps, UnfoldOptions, }, Shape, }; @@ -77,6 +78,29 @@ impl ModuleOps for Candle, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, + ) -> FloatTensor { + unimplemented!("Candle does not support deformable convolutions") + } + + fn deform_conv2d_backward( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + output_grad: FloatTensor, + options: DeformConvOptions<2>, + ) -> DeformConv2dBackward { + unimplemented!("Candle does not support deformable convolutions") + } + fn conv3d( x: FloatTensor, weight: FloatTensor, diff --git a/crates/burn-core/src/nn/conv/deform_conv2d.rs b/crates/burn-core/src/nn/conv/deform_conv2d.rs new file mode 100644 index 000000000..03becd9d4 --- /dev/null +++ b/crates/burn-core/src/nn/conv/deform_conv2d.rs @@ -0,0 +1,263 @@ +use alloc::format; +use burn_tensor::ops::DeformConvOptions; + +use crate as burn; + +use crate::config::Config; +use crate::module::{Content, DisplaySettings, Ignored, Module, ModuleDisplay, Param}; +use crate::nn::Initializer; +use crate::nn::PaddingConfig2d; +use crate::tensor::backend::Backend; +use crate::tensor::module::deform_conv2d; +use crate::tensor::Tensor; + +use crate::nn::conv::checks; + +/// Configuration to create a [deformable 2D convolution](DeformConv2d) layer, using the [init function](DeformConv2dConfig::init). +#[derive(Config, Debug)] +pub struct DeformConv2dConfig { + /// The number of channels. + pub channels: [usize; 2], + /// The size of the kernel. + pub kernel_size: [usize; 2], + /// The stride of the convolution. + #[config(default = "[1, 1]")] + pub stride: [usize; 2], + /// Spacing between kernel elements. + #[config(default = "[1, 1]")] + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + #[config(default = "1")] + pub weight_groups: usize, + /// Offset groups. + #[config(default = "1")] + pub offset_groups: usize, + /// The padding configuration. + #[config(default = "PaddingConfig2d::Valid")] + pub padding: PaddingConfig2d, + /// If bias should be added to the output. + #[config(default = true)] + pub bias: bool, + /// The type of function used to initialize neural network parameters + #[config( + default = "Initializer::KaimingUniform{gain:1.0/num_traits::Float::sqrt(3.0),fan_out_only:false}" + )] + pub initializer: Initializer, +} + +/// Applies a deformable 2D convolution over input tensors. +/// +/// Should be created with [DeformConv2dConfig]. +#[derive(Module, Debug)] +#[module(custom_display)] +pub struct DeformConv2d { + /// Tensor of shape `[channels_out, channels_in / groups, kernel_size_1, kernel_size_2]` + pub weight: Param>, + /// Tensor of shape `[channels_out]` + pub bias: Option>>, + /// Stride of the convolution. + pub stride: [usize; 2], + /// Size of the kernel. + pub kernel_size: [usize; 2], + /// Spacing between kernel elements. + pub dilation: [usize; 2], + /// Controls the connections between input and output channels. + pub weight_groups: usize, + /// Offset groups. + pub offset_groups: usize, + /// The padding configuration. + pub padding: Ignored, +} + +impl DeformConv2dConfig { + /// Initialize a new [DeformConv2d](DeformConv2d) module. + pub fn init(&self, device: &B::Device) -> DeformConv2d { + checks::checks_channels_div_groups(self.channels[0], self.channels[1], self.weight_groups); + + let shape = [ + self.channels[1], + self.channels[0] / self.weight_groups, + self.kernel_size[0], + self.kernel_size[1], + ]; + + let k = self.kernel_size.iter().product::(); + let fan_in = self.channels[0] / self.weight_groups * k; + let fan_out = self.channels[1] / self.weight_groups * k; + + let weight = self + .initializer + .init_with(shape, Some(fan_in), Some(fan_out), device); + let mut bias = None; + + if self.bias { + bias = Some(self.initializer.init_with( + [self.channels[1]], + Some(fan_in), + Some(fan_out), + device, + )); + } + + DeformConv2d { + weight, + bias, + stride: self.stride, + kernel_size: self.kernel_size, + dilation: self.dilation, + padding: Ignored(self.padding.clone()), + weight_groups: self.weight_groups, + offset_groups: self.weight_groups, + } + } +} + +impl ModuleDisplay for DeformConv2d { + fn custom_settings(&self) -> Option { + DisplaySettings::new() + .with_new_line_after_attribute(false) + .optional() + } + + fn custom_content(&self, content: Content) -> Option { + // Since padding does not implement ModuleDisplay, we need to format it manually. + let padding_formatted = format!("{}", &self.padding); + + // Format the stride, kernel_size and dilation as strings, formatted as arrays instead of indexed. + let stride = format!("{:?}", self.stride); + let kernel_size = format!("{:?}", self.kernel_size); + let dilation = format!("{:?}", self.dilation); + + content + .add("stride", &stride) + .add("kernel_size", &kernel_size) + .add("dilation", &dilation) + .add("weight_groups", &self.weight_groups) + .add("offset_groups", &self.offset_groups) + .add("padding", &padding_formatted) + .optional() + } +} + +impl DeformConv2d { + /// Applies the forward pass on the input tensor. + /// + /// See [deform_conv2d](crate::tensor::module::deform_conv2d) for more information. + /// + /// # Shapes + /// + /// - input: `[batch_size, channels_in, height_in, width_in]` + /// - offset: `[batch_size, 2 * offset_groups * kernel_height * kernel_width, height_out, width_out]` + /// - mask: `[batch_size, offset_groups * kernel_height * kernel_width, height_out, width_out]` + /// - output: `[batch_size, channels_out, height_out, width_out]` + pub fn forward( + &self, + input: Tensor, + offset: Tensor, + mask: Option>, + ) -> Tensor { + let [_batch_size, _channels_in, height_in, width_in] = input.dims(); + let padding = + self.padding + .calculate_padding_2d(height_in, width_in, &self.kernel_size, &self.stride); + deform_conv2d( + input, + offset, + self.weight.val(), + mask, + self.bias.as_ref().map(|bias| bias.val()), + DeformConvOptions::new( + self.stride, + padding, + self.dilation, + self.weight_groups, + self.offset_groups, + ), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tensor::TensorData; + use crate::TestBackend; + + #[test] + fn initializer_default() { + TestBackend::seed(0); + + let config = DeformConv2dConfig::new([5, 1], [5, 5]); + let k = (config.channels[0] * config.kernel_size[0] * config.kernel_size[1]) as f64; + let k = (config.offset_groups as f64 / k).sqrt() as f32; + let device = Default::default(); + let conv = config.init::(&device); + + conv.weight.to_data().assert_within_range(-k..k); + } + + #[test] + fn initializer_zeros() { + TestBackend::seed(0); + + let config = DeformConv2dConfig::new([5, 2], [5, 5]).with_initializer(Initializer::Zeros); + let device = Default::default(); + let conv = config.init::(&device); + + assert_eq!(config.initializer, Initializer::Zeros); + conv.weight + .to_data() + .assert_approx_eq(&TensorData::zeros::(conv.weight.shape()), 3); + } + + #[test] + fn initializer_fan_out() { + TestBackend::seed(0); + + let init = Initializer::KaimingUniform { + gain: 1.0 / 3.0f64.sqrt(), + fan_out_only: true, // test that fan_out is passed to `init_with()` + }; + let device = Default::default(); + let config = DeformConv2dConfig::new([5, 1], [5, 5]).with_initializer(init.clone()); + let _ = config.init::(&device); + + assert_eq!(config.initializer, init); + } + + #[test] + fn initializer_fan_with_groups_is_valid() { + TestBackend::seed(0); + + let init = Initializer::KaimingUniform { + gain: 1.0 / 3.0f64.sqrt(), + fan_out_only: true, + }; + let device = Default::default(); + let config = DeformConv2dConfig::new([4, 4], [1, 1]) + .with_initializer(init.clone()) + .with_weight_groups(4); + let _ = config.init::(&device); + + assert_eq!(config.initializer, init); + } + + #[test] + #[should_panic = "Both channels must be divisible by the number of groups."] + fn channels_with_groups_is_invalid() { + let device = Default::default(); + let config = DeformConv2dConfig::new([1, 4], [1, 1]).with_weight_groups(4); + let _ = config.init::(&device); + } + + #[test] + fn display() { + let config = DeformConv2dConfig::new([5, 1], [5, 5]); + let conv = config.init::(&Default::default()); + + assert_eq!( + alloc::format!("{}", conv), + "DeformConv2d {stride: [1, 1], kernel_size: [5, 5], dilation: [1, 1], weight_groups: 1, offset_groups: 1, padding: Valid, params: 126}" + ); + } +} diff --git a/crates/burn-core/src/nn/conv/mod.rs b/crates/burn-core/src/nn/conv/mod.rs index 81990e809..d233b6570 100644 --- a/crates/burn-core/src/nn/conv/mod.rs +++ b/crates/burn-core/src/nn/conv/mod.rs @@ -4,6 +4,7 @@ mod conv3d; mod conv_transpose1d; mod conv_transpose2d; mod conv_transpose3d; +mod deform_conv2d; pub(crate) mod checks; @@ -13,3 +14,4 @@ pub use conv3d::*; pub use conv_transpose1d::*; pub use conv_transpose2d::*; pub use conv_transpose3d::*; +pub use deform_conv2d::*; diff --git a/crates/burn-fusion/src/ops/module.rs b/crates/burn-fusion/src/ops/module.rs index 4cb98d87a..0d4b52980 100644 --- a/crates/burn-fusion/src/ops/module.rs +++ b/crates/burn-fusion/src/ops/module.rs @@ -5,9 +5,9 @@ use burn_tensor::{ calculate_conv_output_size, calculate_conv_transpose_output_size, calculate_pool_output_size, }, - ConvOptions, ConvTransposeOptions, FloatTensor, IntTensor, InterpolateOptions, - MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, - ModuleOps, + ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, FloatTensor, + IntTensor, InterpolateOptions, MaxPool1dBackward, MaxPool1dWithIndices, MaxPool2dBackward, + MaxPool2dWithIndices, ModuleOps, }, repr::*, Element, @@ -153,6 +153,211 @@ impl ModuleOps> for Fusion { out } + fn deform_conv2d( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, + ) -> FloatTensor { + make_ops!( + DeformConv2dOps, + DeformConv2dDescription, + |args: DeformConv2dDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let offset = handles.get_float_tensor::(&args.offset); + let weight = handles.get_float_tensor::(&args.weight); + let mask = args + .mask + .as_ref() + .map(|mask| handles.get_float_tensor::(mask)); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + + let output = + B::deform_conv2d(x, offset, weight, mask, bias, args.options.clone().into()); + + handles.register_float_tensor::(&args.out.id, output); + } + ); + + let size_0 = calculate_conv_output_size( + weight.shape[2], + options.stride[0], + options.padding[0], + options.dilation[0], + x.shape[2], + ); + let size_1 = calculate_conv_output_size( + weight.shape[3], + options.stride[1], + options.padding[1], + options.dilation[1], + x.shape[3], + ); + + let stream_1 = x.stream; + let stream_2 = offset.stream; + let stream_3 = weight.stream; + let stream_4 = mask.as_ref().map(|m| m.stream); + let stream_5 = bias.as_ref().map(|b| b.stream); + let shape = vec![x.shape[0], weight.shape[0], size_0, size_1]; + let out = x.client.tensor_uninitialized(shape, B::FloatElem::dtype()); + + let desc = DeformConv2dDescription { + x: x.into_description(), + offset: offset.into_description(), + weight: weight.into_description(), + mask: mask.map(|mask| mask.into_description()), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out: out.to_description_out(), + }; + + let streams = match (stream_4, stream_5) { + (Some(stream_4), Some(stream_5)) => { + vec![stream_1, stream_2, stream_3, stream_4, stream_5] + } + (Some(stream_4), None) => { + vec![stream_1, stream_2, stream_3, stream_4] + } + (None, Some(stream_5)) => { + vec![stream_1, stream_2, stream_3, stream_5] + } + (None, None) => vec![stream_1, stream_2, stream_3], + }; + out.client.register( + streams, + OperationDescription::Module(ModuleOperationDescription::DeformableConv2d(Box::new( + desc.clone(), + ))), + DeformConv2dOps::::new(desc), + ); + + out + } + + fn deform_conv2d_backward( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + output_grad: FloatTensor, + options: DeformConvOptions<2>, + ) -> DeformConv2dBackward { + make_ops!( + DeformConv2dBackwardOps, + DeformConv2dBackwardDescription, + |args: DeformConv2dBackwardDescription, handles: &mut HandleContainer| { + let x = handles.get_float_tensor::(&args.x); + let offset = handles.get_float_tensor::(&args.offset); + let weight = handles.get_float_tensor::(&args.weight); + let mask = args + .mask + .as_ref() + .map(|mask| handles.get_float_tensor::(mask)); + let bias = args + .bias + .as_ref() + .map(|bias| handles.get_float_tensor::(bias)); + let output_grad = handles.get_float_tensor::(&args.out_grad); + + let output = B::deform_conv2d_backward( + x, + offset, + weight, + mask, + bias, + output_grad, + args.options.clone().into(), + ); + + handles.register_float_tensor::(&args.input_grad.id, output.x_grad); + handles.register_float_tensor::(&args.offset_grad.id, output.offset_grad); + handles.register_float_tensor::(&args.weight_grad.id, output.weight_grad); + if let Some((mask_grad, field)) = output.mask_grad.zip(args.mask_grad.as_ref()) { + handles.register_float_tensor::(&field.id, mask_grad); + } + if let Some((bias_grad, field)) = output.bias_grad.zip(args.bias_grad.as_ref()) { + handles.register_float_tensor::(&field.id, bias_grad); + } + } + ); + + let input_grad = x + .client + .tensor_uninitialized(x.shape.clone(), B::FloatElem::dtype()); + let offset_grad = offset + .client + .tensor_uninitialized(offset.shape.clone(), B::FloatElem::dtype()); + let weight_grad = offset + .client + .tensor_uninitialized(weight.shape.clone(), B::FloatElem::dtype()); + let mask_grad = mask.as_ref().map(|mask| { + offset + .client + .tensor_uninitialized(mask.shape.clone(), B::FloatElem::dtype()) + }); + let bias_grad = bias.as_ref().map(|bias| { + offset + .client + .tensor_uninitialized(bias.shape.clone(), B::FloatElem::dtype()) + }); + + let stream_1 = x.stream; + let stream_2 = offset.stream; + let stream_3 = weight.stream; + let stream_4 = mask.as_ref().map(|m| m.stream); + let stream_5 = bias.as_ref().map(|b| b.stream); + let stream_6 = output_grad.stream; + + let desc = DeformConv2dBackwardDescription { + x: x.into_description(), + offset: offset.into_description(), + weight: weight.into_description(), + mask: mask.map(|mask| mask.into_description()), + bias: bias.map(|bias| bias.into_description()), + options: options.into(), + out_grad: output_grad.into_description(), + input_grad: input_grad.to_description_out(), + offset_grad: offset_grad.to_description_out(), + weight_grad: weight_grad.to_description_out(), + mask_grad: mask_grad + .as_ref() + .map(|mask_grad| mask_grad.to_description_out()), + bias_grad: bias_grad + .as_ref() + .map(|bias_grad| bias_grad.to_description_out()), + }; + + let streams = match (stream_4, stream_5) { + (Some(stream_4), Some(stream_5)) => { + vec![stream_1, stream_2, stream_3, stream_4, stream_5, stream_6] + } + (Some(stream_4), None) => { + vec![stream_1, stream_2, stream_3, stream_4, stream_6] + } + (None, Some(stream_5)) => { + vec![stream_1, stream_2, stream_3, stream_5, stream_6] + } + (None, None) => vec![stream_1, stream_2, stream_3, stream_6], + }; + + input_grad.client.register( + streams, + OperationDescription::Module(ModuleOperationDescription::DeformableConv2dBackward( + Box::new(desc.clone()), + )), + DeformConv2dBackwardOps::::new(desc), + ); + + DeformConv2dBackward::new(input_grad, offset_grad, weight_grad, mask_grad, bias_grad) + } + fn conv3d( x: FloatTensor, weight: FloatTensor, diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index a1b6a0db0..6eee274bd 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -180,6 +180,35 @@ impl RelativeOps for ModuleOperationDescription { out: desc.out.to_relative(converter), }) } + ModuleOperationDescription::DeformableConv2d(desc) => { + ModuleOperationDescription::DeformableConv2d(Box::new(DeformConv2dDescription { + x: desc.x.to_relative(converter), + offset: desc.offset.to_relative(converter), + weight: desc.weight.to_relative(converter), + mask: desc.mask.as_ref().map(|t| t.to_relative(converter)), + bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), + options: desc.options.clone(), + out: desc.out.to_relative(converter), + })) + } + ModuleOperationDescription::DeformableConv2dBackward(desc) => { + ModuleOperationDescription::DeformableConv2dBackward(Box::new( + DeformConv2dBackwardDescription { + x: desc.x.to_relative(converter), + offset: desc.offset.to_relative(converter), + weight: desc.weight.to_relative(converter), + mask: desc.mask.as_ref().map(|t| t.to_relative(converter)), + bias: desc.bias.as_ref().map(|t| t.to_relative(converter)), + out_grad: desc.out_grad.to_relative(converter), + options: desc.options.clone(), + input_grad: desc.input_grad.to_relative(converter), + offset_grad: desc.offset_grad.to_relative(converter), + weight_grad: desc.weight_grad.to_relative(converter), + mask_grad: desc.mask_grad.as_ref().map(|t| t.to_relative(converter)), + bias_grad: desc.bias_grad.as_ref().map(|t| t.to_relative(converter)), + }, + )) + } ModuleOperationDescription::ConvTranspose1d(desc) => { ModuleOperationDescription::ConvTranspose1d(ConvTranspose1dDescription { x: desc.x.to_relative(converter), diff --git a/crates/burn-jit/src/kernel/conv/deform_conv2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs new file mode 100644 index 000000000..c7e5c0aea --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/deform_conv2d.rs @@ -0,0 +1,316 @@ +use cubecl::{calculate_cube_count_elemwise, prelude::*}; + +use burn_tensor::{ + ops::{conv::calculate_conv_output_size, DeformConvOptions, FloatTensorOps as _}, + Shape, +}; + +use crate::{ + kernel::into_contiguous, + ops::{ + numeric::{ones_device, zeros_device}, + reshape, swap_dims, + }, + tensor::JitTensor, + FloatElement, IntElement, JitBackend, JitRuntime, +}; + +#[derive(CubeLaunch)] +struct DeformConv2dArgs { + conv_stride_h: u32, + conv_stride_w: u32, + dilation_h: u32, + dilation_w: u32, + padding_h: F, + padding_w: F, + offset_groups: u32, + + kernel_height: u32, + kernel_width: u32, + out_h: u32, + out_w: u32, + + col_stride_0: u32, +} + +#[cube(launch)] +fn deform_im2col_kernel( + input: &Tensor, + offset: &Tensor, + mask: &Tensor, + columns: &mut Tensor, + args: &DeformConv2dArgs, + #[comptime] kernel_h_unroll: Option, + #[comptime] kernel_w_unroll: Option, + #[comptime] use_mask: bool, +) { + // position shape: [in_channels, batch_size, out_h, out_w] + // columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] + + let kernel_height = kernel_h_unroll.unwrap_or(args.kernel_height); + let unroll_h = kernel_h_unroll.is_some(); + let kernel_width = kernel_w_unroll.unwrap_or(args.kernel_width); + let unroll_w = kernel_w_unroll.is_some(); + + // Keep mask in bind group + let default_mask_value = mask[0]; + + let out_h = args.out_h; + let out_w = args.out_w; + let batch_size = input.shape(0); + let in_channels = input.shape(1); + let height = input.shape(2); + let width = input.shape(3); + let col_stride_0 = args.col_stride_0; + + let out_x = ABSOLUTE_POS % out_w; + let out_y = (ABSOLUTE_POS / out_w) % out_h; + let out_batch = (ABSOLUTE_POS / (out_w * out_h)) % batch_size; + let in_channel = ABSOLUTE_POS / (out_w * out_h * batch_size); + let out_channel = in_channel * kernel_height * kernel_width; + + let channels_per_offset_group = in_channels / args.offset_groups; + let group_index = in_channel / channels_per_offset_group; + + let mut col_base_idx = + out_channel * col_stride_0 + out_batch * (out_h * out_w) + out_y * out_w + out_x; + + let input_base_idx = out_batch * input.stride(0) + in_channel * input.stride(1); + + let offset_base_idx = out_batch * offset.stride(0) + + group_index * kernel_height * kernel_width * 2 * out_h * out_w; + let mut mask_base_idx = 0; + if use_mask { + mask_base_idx = + out_batch * mask.stride(0) + group_index * kernel_height * kernel_width * out_h * out_w; + } + + #[unroll(unroll_h)] + for kernel_y in 0..kernel_height { + #[unroll(unroll_w)] + for kernel_x in 0..kernel_width { + let mask_index = kernel_y * kernel_width + kernel_x; + let offset_index = mask_index * 2; + + let mut mask_value = default_mask_value; + if use_mask { + mask_value = mask[mask_base_idx + + mask_index * mask.stride(1) + + out_y * mask.stride(2) + + out_x * mask.stride(3)]; + } + + let offset_y = offset[offset_base_idx + + offset_index * offset.stride(1) + + out_y * offset.stride(2) + + out_x * offset.stride(3)]; + let offset_x = offset[offset_base_idx + + (offset_index + 1) * offset.stride(1) + + out_y * offset.stride(2) + + out_x * offset.stride(3)]; + let y = F::cast_from(out_y * args.conv_stride_h + kernel_y * args.dilation_h) + - args.padding_h + + offset_y; + let x = F::cast_from(out_x * args.conv_stride_w + kernel_x * args.dilation_w) + - args.padding_w + + offset_x; + + let interpolated = bilinear_interpolate(input, height, width, y, x, input_base_idx); + + columns[col_base_idx] = mask_value * interpolated; + col_base_idx += col_stride_0; + } + } +} + +#[cube] +pub(crate) fn bilinear_interpolate( + input: &Tensor, + height: u32, + width: u32, + y: F, + x: F, + offset: u32, +) -> F { + // To simplify code + let y = f32::cast_from(y); + let x = f32::cast_from(x); + + let mut result = F::new(0.0); + if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x { + let in_w = u32::cast_from(width); + + let y_low = f32::floor(y); + let x_low = f32::floor(x); + let y_high = (y_low + 1.) as u32; + let x_high = (x_low + 1.) as u32; + + let zero = F::new(0.0); + let v1: F = if y_low >= 0. && x_low >= 0. { + input[offset + y_low as u32 * in_w + x_low as u32] + } else { + zero + }; + let v2: F = if y_low >= 0. && x_high < width { + input[offset + y_low as u32 * in_w + x_high] + } else { + zero + }; + let v3: F = if y_high < height && x_low >= 0. { + input[offset + y_high * in_w + x_low as u32] + } else { + zero + }; + let v4: F = if y_high < height && x_high < width { + input[offset + y_high * in_w + x_high] + } else { + zero + }; + + let l_y = y - y_low; + let l_x = x - x_low; + let h_y = 1.0 - l_y; + let h_x = 1.0 - l_x; + + let w1 = F::cast_from(h_y * h_x); + let w2 = F::cast_from(h_y * l_x); + let w3 = F::cast_from(l_y * h_x); + let w4 = F::cast_from(l_y * l_x); + + result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + } + result +} + +pub(crate) fn deform_im2col( + input: JitTensor, + offset: JitTensor, + mask: Option>, + options: DeformConvOptions<2>, + out_dims: (usize, usize), + kernel_dims: (usize, usize), +) -> JitTensor { + let client = input.client.clone(); + let device = input.device.clone(); + + let [batch_size, in_channels, _, _] = input.shape.dims; + let (out_height, out_width) = out_dims; + let (kernel_height, kernel_width) = kernel_dims; + + let shape_out = Shape::new([ + in_channels * kernel_height * kernel_width, + batch_size * out_height * out_width, + ]); + + let output = zeros_device(client.clone(), device.clone(), shape_out.clone()); + let use_mask = mask.is_some(); + let mask = mask.unwrap_or_else(|| { + ones_device( + client.clone(), + device.clone(), + Shape::new([ + offset.shape.dims[0], + offset.shape.dims[1] / 2, + offset.shape.dims[2], + offset.shape.dims[3], + ]), + ) + }); + + let num_kernels = in_channels * batch_size * out_height * out_width; + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_kernels, cube_dim); + + deform_im2col_kernel::launch::( + &input.client, + cube_count, + cube_dim, + input.as_handle_ref().as_tensor_arg(1), + offset.as_handle_ref().as_tensor_arg(1), + mask.as_handle_ref().as_tensor_arg(1), + output.as_handle_ref().as_tensor_arg(1), + DeformConv2dArgsLaunch::new( + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(E::from_elem(options.padding[0] as f32)), + ScalarArg::new(E::from_elem(options.padding[1] as f32)), + ScalarArg::new(options.offset_groups as u32), + ScalarArg::new(kernel_height as u32), + ScalarArg::new(kernel_width as u32), + ScalarArg::new(out_height as u32), + ScalarArg::new(out_width as u32), + ScalarArg::new(output.strides[0] as u32), + ), + Some(kernel_height as u32), + Some(kernel_width as u32), + use_mask, + ); + + output +} + +pub(crate) fn deform_conv2d( + input: JitTensor, + offset: JitTensor, + weight: JitTensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, +) -> JitTensor { + let input = into_contiguous(input); + let offset = into_contiguous(offset); + let weight = into_contiguous(weight); + let mask = mask.map(|it| into_contiguous(it)); + let bias = bias.map(|it| into_contiguous(it)); + + let [batch_size, _, in_height, in_width] = input.shape.dims; + let [out_channels, _, kernel_h, kernel_w] = weight.shape.dims; + let groups = options.weight_groups; + + let out_h = calculate_conv_output_size( + kernel_h, + options.stride[0], + options.padding[0], + options.dilation[0], + in_height, + ); + let out_w = calculate_conv_output_size( + kernel_w, + options.stride[1], + options.padding[1], + options.dilation[1], + in_width, + ); + let out_dims = (out_h, out_w); + + let columns = deform_im2col(input, offset, mask, options, out_dims, (kernel_h, kernel_w)); + + let [col_size_0, col_size_1] = columns.shape.dims; + let col_size_0 = col_size_0 / groups; + let out_c_per_group = out_channels / groups; + + let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_size_0])); + let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); + let out = JitBackend::::float_matmul(weight, columns); + + let out = reshape(out, Shape::new([out_channels, batch_size, out_h, out_w])); + let out = swap_dims(out, 0, 1); + + if let Some(bias) = bias { + let bias = reshape(bias, Shape::new([1, out_channels, 1, 1])); + JitBackend::::float_add(out, bias) + } else { + out + } +} + +pub(crate) fn index( + tensor: JitTensor, + index: usize, +) -> JitTensor { + let [_, shape_0, shape_1] = tensor.shape.dims; + let tensor = JitBackend::::float_narrow(tensor, 0, index, 1); + reshape(tensor, Shape::new([shape_0, shape_1])) +} diff --git a/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs new file mode 100644 index 000000000..996f3e87f --- /dev/null +++ b/crates/burn-jit/src/kernel/conv/deform_conv_transpose2d.rs @@ -0,0 +1,591 @@ +use burn_tensor::{ + ops::{DeformConv2dBackward, DeformConvOptions, FloatTensorOps as _}, + Shape, +}; +use cubecl::{calculate_cube_count_elemwise, cube, prelude::*, CubeDim, CubeLaunch}; + +use crate::{ + kernel::into_contiguous, + ops::{ + numeric::{empty_device, ones_device, zeros_device}, + reshape, swap_dims, + }, + tensor::JitTensor, + FloatElement, IntElement, JitBackend, JitRuntime, +}; + +use super::{bilinear_interpolate, deform_im2col, index}; + +/// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. +#[allow(clippy::single_range_in_vec_init)] +pub(crate) fn deform_conv2d_backward( + input: JitTensor, + offset: JitTensor, + weight: JitTensor, + mask: Option>, + bias: Option>, + out_grad: JitTensor, + options: DeformConvOptions<2>, +) -> DeformConv2dBackward> { + let [_, _, out_h, out_w] = out_grad.shape.dims; + let [_, _, kernel_h, kernel_w] = weight.shape.dims; + + let gradient_bias = bias.map(|bias| { + let grad = JitBackend::::float_sum_dim(out_grad.clone(), 0); + let grad = JitBackend::::float_sum_dim(grad, 2); + let grad = JitBackend::::float_sum_dim(grad, 3); + + reshape(grad, bias.shape) + }); + + let input = into_contiguous(input); + let offset = into_contiguous(offset); + let mask = mask.map(|it| into_contiguous(it)); + + let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs::( + input.clone(), + weight.clone(), + offset.clone(), + mask.clone(), + out_grad.clone(), + &options, + (kernel_h, kernel_w), + ); + + let weight_grad = compute_weight_grad::( + input, + offset, + mask, + out_grad, + options, + (kernel_h, kernel_w), + (out_h, out_w), + ); + + DeformConv2dBackward::new( + input_gradient, + offset_gradient, + weight_grad, + mask_gradient, + gradient_bias, + ) +} + +fn compute_weight_grad( + input: JitTensor, + offset: JitTensor, + mask: Option>, + out_grad: JitTensor, + options: DeformConvOptions<2>, + kernel_dims: (usize, usize), + out_dims: (usize, usize), +) -> JitTensor { + let [_, in_channels, _, _] = input.shape.dims; + let [_, out_channels, _, _] = out_grad.shape.dims; + let (kernel_h, kernel_w) = kernel_dims; + let groups = options.weight_groups; + + let in_c_per_group = in_channels / groups; + let out_c_per_group = out_channels / groups; + + let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims); + let [col_size_0, col_size_1] = columns.shape.dims; + let col_size_0 = col_size_0 / groups; + + let out_grad = swap_dims(out_grad, 0, 1); + let out_grad = reshape(out_grad, Shape::new([groups, out_c_per_group, col_size_1])); + + let columns = reshape(columns, Shape::new([groups, col_size_0, col_size_1])); + let columns = swap_dims(columns, 1, 2); + + let grad_weight = JitBackend::::float_matmul(out_grad, columns); + + JitBackend::::float_reshape( + grad_weight, + Shape::new([out_channels, in_c_per_group, kernel_h, kernel_w]), + ) +} + +type InputGradients = ( + JitTensor, + JitTensor, + Option>, +); + +fn backward_gradient_inputs( + image: JitTensor, + weight: JitTensor, + offset: JitTensor, + mask: Option>, + out_grad: JitTensor, + options: &DeformConvOptions<2>, + kernel_dims: (usize, usize), +) -> InputGradients { + let client = out_grad.client.clone(); + let device = out_grad.device.clone(); + + let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape.dims; + let [batch_size, _, out_h, out_w] = out_grad.shape.dims; + + let groups = options.weight_groups; + let out_c_per_group = out_channels / groups; + + let col_shape_0 = in_c_per_group * kernel_h * kernel_w; + let col_shape_1 = batch_size * out_h * out_w; + let col_shape = Shape::new([groups, col_shape_0, col_shape_1]); + let mut columns = empty_device(client, device, col_shape); + + let weight = reshape(weight, Shape::new([groups, out_c_per_group, col_shape_0])); + + let out_grad = swap_dims(out_grad, 0, 1); + let out_grad_shape = Shape::new([groups, out_c_per_group, col_shape_1]); + let out_grad = reshape(out_grad, out_grad_shape); + + for group in 0..groups { + let weight = swap_dims(index::(weight.clone(), group), 0, 1); + let out_grad = index::(out_grad.clone(), group); + let values = JitBackend::::float_matmul(weight, out_grad); + let values = reshape(values, Shape::new([1, col_shape_0, col_shape_1])); + columns = JitBackend::::float_slice_assign( + columns, + [group..group + 1, 0..col_shape_0, 0..col_shape_1], + values, + ); + } + + let columns = reshape(columns, Shape::new([col_shape_0 * groups, col_shape_1])); + + let input_shape = image.shape.clone(); + let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient::( + columns.clone(), + image, + offset.clone(), + mask.clone(), + options, + kernel_dims, + ); + + let input_gradient = + compute_input_grad::(columns, offset, mask, options, kernel_dims, input_shape); + + (input_gradient, offset_gradient, mask_gradient) +} + +fn compute_offset_and_mask_gradient( + columns: JitTensor, + image: JitTensor, + offset: JitTensor, + mask: Option>, + options: &DeformConvOptions<2>, + kernel_dims: (usize, usize), +) -> (JitTensor, Option>) { + let client = offset.client.clone(); + let device = offset.device.clone(); + let (kernel_height, kernel_width) = kernel_dims; + + let use_mask = mask.is_some(); + + let mask = mask.unwrap_or_else(|| { + ones_device( + client.clone(), + device.clone(), + Shape::new([ + offset.shape.dims[0], + offset.shape.dims[1] / 2, + offset.shape.dims[2], + offset.shape.dims[3], + ]), + ) + }); + + let grad_offset = empty_device(client.clone(), device.clone(), offset.shape.clone()); + let grad_mask = empty_device(client.clone(), device.clone(), mask.shape.clone()); + + let num_elements_offset = offset.shape.num_elements(); + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elements_offset, cube_dim); + + deform_col2img_coord_kernel::launch::( + &image.client, + cube_count, + cube_dim, + image.as_handle_ref().as_tensor_arg(1), + offset.as_handle_ref().as_tensor_arg(1), + mask.as_handle_ref().as_tensor_arg(1), + columns.as_handle_ref().as_tensor_arg(1), + grad_offset.as_handle_ref().as_tensor_arg(1), + grad_mask.as_handle_ref().as_tensor_arg(1), + DeformConv2dCol2ImgCoordArgsLaunch::new( + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(E::from_elem(options.padding[0] as f32)), + ScalarArg::new(E::from_elem(options.padding[1] as f32)), + ScalarArg::new(options.offset_groups as u32), + ScalarArg::new(kernel_height as u32), + ScalarArg::new(kernel_width as u32), + ), + use_mask, + ); + + let mask_gradient = if use_mask { Some(grad_mask) } else { None }; + (grad_offset, mask_gradient) +} + +#[derive(CubeLaunch)] +struct DeformConv2dCol2ImgCoordArgs { + stride_h: u32, + stride_w: u32, + dilation_h: u32, + dilation_w: u32, + pad_h: F, + pad_w: F, + offset_groups: u32, + kernel_height: u32, + kernel_width: u32, +} + +#[allow(clippy::collapsible_if)] +#[cube(launch)] +fn deform_col2img_coord_kernel( + image: &Tensor, + offset: &Tensor, + mask: &Tensor, + columns: &Tensor, + grad_offset: &mut Tensor, + grad_mask: &mut Tensor, + args: &DeformConv2dCol2ImgCoordArgs, + #[comptime] use_mask: bool, +) { + // Position format: [batch, [offset_group, kernel_h, kernel_w, 2], out_h, out_w] + // Alternatively : [batch, offset_channels, out_h, out_w] + + let offset_channels = offset.shape(1); + let out_h = offset.shape(2); + let out_w = offset.shape(3); + let batch_size = image.shape(0); + let in_channels = image.shape(1); + let height = image.shape(2); + let width = image.shape(3); + let kernel_w = args.kernel_width; + let kernel_h = args.kernel_height; + let n_offset_groups = args.offset_groups; + let _ = mask[0]; // Make sure mask isn't removed from bind group + + let mut grad_offset_val = F::new(0.0); + let mut grad_mask_val = F::new(0.0); + + let w = ABSOLUTE_POS % out_w; + let h = (ABSOLUTE_POS / out_w) % out_h; + let w_w = (ABSOLUTE_POS / (out_w * out_h * 2)) % kernel_w; + let w_h = (ABSOLUTE_POS / (out_w * out_h * 2 * kernel_w)) % kernel_h; + let c = (ABSOLUTE_POS / (out_w * out_h)) % offset_channels; + let b = ABSOLUTE_POS / (out_w * out_h * offset_channels); + + let offset_group = c / (kernel_h * kernel_w * 2); + let col_step = kernel_h * kernel_w; + + let channels_per_offset_group = in_channels / args.offset_groups; + + let col_base_idx = + offset_group * channels_per_offset_group * kernel_h * kernel_w * batch_size * out_w * out_h; + let mut image_base_idx = + (b * n_offset_groups + offset_group) * channels_per_offset_group * height * width; + let offset_base_idx = + (b * n_offset_groups + offset_group) * 2 * kernel_h * kernel_w * out_h * out_w; + let mask_base_idx = (b * n_offset_groups + offset_group) * kernel_h * kernel_w * out_h * out_w; + + let offset_c = c - offset_group * 2 * kernel_h * kernel_w; + let is_y_direction = offset_c % 2 == 0; + + let c_bound = channels_per_offset_group * kernel_h * kernel_w; + + for col_c in range_stepped(offset_c / 2, c_bound, col_step) { + let col_pos = (((col_c * batch_size + b) * out_h) + h) * out_w + w; + + let out_x = col_pos % out_w; + let out_y = (col_pos / out_w) % out_h; + let j = (col_pos / (out_w * out_h * batch_size)) % kernel_w; + let i = (col_pos / (out_w * out_h * batch_size * kernel_w)) % kernel_h; + + let mask_idx = i * kernel_w + j; + let offset_idx = mask_idx * 2; + + let offset_y_idx = (offset_idx * out_h + out_y) * out_w + out_x; + let offset_x_idx = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + + let offset_y = offset[offset_base_idx + offset_y_idx]; + let offset_x = offset[offset_base_idx + offset_x_idx]; + + let mask_value = if use_mask { + mask[mask_base_idx + (mask_idx * out_h + out_y) * out_w + out_x] + } else { + F::new(1.0) + }; + + let y = F::cast_from(out_y * args.stride_h + i * args.dilation_h) - args.pad_h + offset_y; + let x = F::cast_from(out_x * args.stride_w + j * args.dilation_w) - args.pad_w + offset_x; + + let weight = get_coordinate_weight( + image.slice(image_base_idx, image.len()), + height, + width, + y, + x, + is_y_direction, + ); + + grad_offset_val += mask_value * weight * columns[col_base_idx + col_pos]; + + if use_mask { + if is_y_direction { + grad_mask_val += columns[col_base_idx + col_pos] + * bilinear_interpolate(image, height, width, y, x, image_base_idx); + } + } + + image_base_idx += height * width; + } + + grad_offset[ABSOLUTE_POS] = grad_offset_val; + + if use_mask { + if is_y_direction { + let idx = ((((b * n_offset_groups + offset_group) * kernel_h + w_h) * kernel_w + w_w) + * out_h + + h) + * out_w + + w; + grad_mask[idx] = grad_mask_val + } + } +} + +#[cube] +fn get_coordinate_weight( + input: &Slice<'_, F>, + height: u32, + width: u32, + y: F, + x: F, + is_y_direction: bool, +) -> F { + let stride_y = width; + + let y = f32::cast_from(y); + let x = f32::cast_from(x); + + let y_low = f32::floor(y); + let x_low = f32::floor(x); + let y_high = y_low + 1.; + let x_high = x_low + 1.; + + let valid_y_low = y_low >= 0. && y_low < height as f32; + let valid_y_high = y_high >= 0. && y_high < height as f32; + let valid_x_low = x_low >= 0. && x_low < width as f32; + let valid_x_high = x_high >= 0. && x_high < width as f32; + + let bottom_left = if valid_y_low && valid_x_low { + input[y_low as u32 * stride_y + x_low as u32] + } else { + F::new(0.0) + }; + let bottom_right = if valid_y_low && valid_x_high { + input[y_low as u32 * stride_y + x_high as u32] + } else { + F::new(0.0) + }; + let top_left = if valid_y_high && valid_x_low { + input[y_high as u32 * stride_y + x_low as u32] + } else { + F::new(0.0) + }; + let top_right = if valid_y_high && valid_x_high { + input[y_high as u32 * stride_y + x_high as u32] + } else { + F::new(0.0) + }; + + if is_y_direction { + let delta_x = F::cast_from(x - x_low); + delta_x * (top_right - bottom_right) + (F::new(1.0) - delta_x) * (top_left - bottom_left) + } else { + let delta_y = F::cast_from(y - y_low); + delta_y * (top_right - top_left) + (F::new(1.0) - delta_y) * (bottom_right - bottom_left) + } +} + +fn compute_input_grad( + columns: JitTensor, + offset: JitTensor, + mask: Option>, + options: &DeformConvOptions<2>, + kernel_dims: (usize, usize), + input_shape: Shape<4>, +) -> JitTensor { + let client = offset.client.clone(); + let device = offset.device.clone(); + + let [batch_size, in_channels, height, width] = input_shape.dims; + let (kernel_height, kernel_width) = kernel_dims; + + let grad_in = zeros_device::( + client.clone(), + device.clone(), + Shape::new([batch_size, in_channels, height, width]), + ); + + let use_mask = mask.is_some(); + let mask = mask + .unwrap_or_else(|| ones_device(client.clone(), device.clone(), Shape::new([1, 1, 1, 1]))); + + let num_elements = columns.shape.num_elements(); + let cube_dim = CubeDim::default(); + let cube_count = calculate_cube_count_elemwise(num_elements, cube_dim); + + deform_col2img_kernel::launch::( + &offset.client, + cube_count, + cube_dim, + offset.as_tensor_arg(1), + mask.as_tensor_arg(1), + columns.as_tensor_arg(1), + grad_in.as_tensor_arg(1), + DeformConv2dCol2ImgArgsLaunch::new( + ScalarArg::new(options.stride[0] as u32), + ScalarArg::new(options.stride[1] as u32), + ScalarArg::new(options.dilation[0] as u32), + ScalarArg::new(options.dilation[1] as u32), + ScalarArg::new(options.padding[0] as f32), + ScalarArg::new(options.padding[1] as f32), + ScalarArg::new(options.offset_groups as u32), + ScalarArg::new(batch_size as u32), + ScalarArg::new(in_channels as u32), + ScalarArg::new(height as u32), + ScalarArg::new(width as u32), + ScalarArg::new(kernel_height as u32), + ScalarArg::new(kernel_width as u32), + ), + use_mask, + ); + + grad_in +} + +#[derive(CubeLaunch)] +struct DeformConv2dCol2ImgArgs { + stride_h: u32, + stride_w: u32, + dilation_h: u32, + dilation_w: u32, + pad_h: f32, + pad_w: f32, + offset_groups: u32, + batch_size: u32, + in_channels: u32, + height: u32, + width: u32, + kernel_height: u32, + kernel_width: u32, +} + +#[cube(launch)] +fn deform_col2img_kernel( + offset: &Tensor, + mask: &Tensor, + columns: &Tensor, + grad_input: &mut Tensor, + args: &DeformConv2dCol2ImgArgs, + #[comptime] use_mask: bool, +) { + // Position format: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] + let _ = mask[0]; // Keep mask in bind group + + let n_in_channels = args.in_channels; + let height = args.height; + let width = args.width; + let out_h = offset.shape(2); + let out_w = offset.shape(3); + let kernel_h = args.kernel_height; + let kernel_w = args.kernel_width; + let n_offset_groups = args.offset_groups; + let batch_size = args.batch_size; + + let out_x = ABSOLUTE_POS % out_w; + let out_y = (ABSOLUTE_POS / out_w) % out_h; + let batch = (ABSOLUTE_POS / (out_w * out_h)) % batch_size; + let kernel_x = (ABSOLUTE_POS / (out_w * out_h * batch_size)) % kernel_w; + let kernel_y = (ABSOLUTE_POS / (out_w * out_h * batch_size * kernel_w)) % kernel_h; + let in_channel = ABSOLUTE_POS / (out_w * out_h * batch_size * kernel_w * kernel_h); + + let channels_per_offset_group = n_in_channels / n_offset_groups; + let offset_group = in_channel / channels_per_offset_group; + + let offset_base_idx = + (batch * n_offset_groups + offset_group) * 2 * kernel_h * kernel_w * out_h * out_w; + + let mask_idx = kernel_y * kernel_w + kernel_x; + let offset_idx = mask_idx * 2; + + let offset_y_idx = (offset_idx * out_h + out_y) * out_w + out_x; + let offset_x_idx = ((offset_idx + 1) * out_h + out_y) * out_w + out_x; + + let offset_y = f32::cast_from(offset[offset_base_idx + offset_y_idx]); + let offset_x = f32::cast_from(offset[offset_base_idx + offset_x_idx]); + + let mask_value = if use_mask { + let mask_base_idx = + (batch * n_offset_groups + offset_group) * kernel_h * kernel_w * out_h * out_w; + + mask[mask_base_idx + (mask_idx * out_h + out_y) * out_w + out_x] + } else { + F::new(1.0) + }; + + let y = + f32::cast_from(out_y * args.stride_h + kernel_y * args.dilation_h) - args.pad_h + offset_y; + let x = + f32::cast_from(out_x * args.stride_w + kernel_x * args.dilation_w) - args.pad_w + offset_x; + + for dy in -1..=1 { + #[unroll] + for dx in -1..=1 { + let yp = f32::floor(y) + dy as f32; + let xp = f32::floor(x) + dx as f32; + + if yp >= 0.0 + && yp < height as f32 + && xp >= 0.0 + && xp < width as f32 + && f32::abs(y - yp) < 1.0 + && f32::abs(x - xp) < 1.0 + { + let gradient_pos = + ((batch * n_in_channels + in_channel) * height + yp as u32) * width + xp as u32; + + let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp)); + + let value = mask_value * F::cast_from(weight) * columns[ABSOLUTE_POS]; + + float_atomic_add::(&mut grad_input[gradient_pos], value); + } + } + } +} + +#[cube] +fn float_atomic_add(ptr: &mut AtomicU32, value: F) { + if value != F::new(0.0) { + let mut v = AtomicU32::load(ptr); + loop { + let prev = v; + let v_float = F::bitcast_from(v); + let new = u32::bitcast_from(v_float + value); + v = AtomicU32::compare_and_swap(ptr, v, new); + if prev == v { + break; + } + } + } +} diff --git a/crates/burn-jit/src/kernel/conv/mod.rs b/crates/burn-jit/src/kernel/conv/mod.rs index 781e921d2..2c60c07bf 100644 --- a/crates/burn-jit/src/kernel/conv/mod.rs +++ b/crates/burn-jit/src/kernel/conv/mod.rs @@ -2,8 +2,12 @@ mod conv2d; mod conv3d; mod conv_transpose2d; mod conv_transpose3d; +mod deform_conv2d; +mod deform_conv_transpose2d; pub(crate) use conv2d::*; pub(crate) use conv3d::*; pub(crate) use conv_transpose2d::*; pub(crate) use conv_transpose3d::*; +pub(crate) use deform_conv2d::*; +pub(crate) use deform_conv_transpose2d::*; diff --git a/crates/burn-jit/src/ops/module_ops.rs b/crates/burn-jit/src/ops/module_ops.rs index 12c4044c0..48fc1db49 100644 --- a/crates/burn-jit/src/ops/module_ops.rs +++ b/crates/burn-jit/src/ops/module_ops.rs @@ -1,7 +1,7 @@ use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime}; use burn_tensor::ops::{ - ConvOptions, ConvTransposeOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, - ModuleOps, + ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateOptions, + MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }; use burn_tensor::ops::{FloatTensor, IntTensor}; @@ -20,6 +20,29 @@ where kernel::conv::conv2d(x, weight, bias, options) } + fn deform_conv2d( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, + ) -> FloatTensor { + kernel::conv::deform_conv2d::(x, offset, weight, mask, bias, options) + } + + fn deform_conv2d_backward( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + output_grad: FloatTensor, + options: DeformConvOptions<2>, + ) -> DeformConv2dBackward { + kernel::conv::deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options) + } + fn conv3d( x: FloatTensor, weight: FloatTensor, diff --git a/crates/burn-ndarray/Cargo.toml b/crates/burn-ndarray/Cargo.toml index da5b47b00..53a5ebfdb 100644 --- a/crates/burn-ndarray/Cargo.toml +++ b/crates/burn-ndarray/Cargo.toml @@ -12,6 +12,7 @@ version.workspace = true [features] default = ["std"] +doc = ["default"] std = [ "burn-autodiff", "burn-common/std", @@ -24,7 +25,6 @@ std = [ "rand/std", "num-traits/std", ] -doc = ["default"] blas-accelerate = [ "blas-src/accelerate", # Accelerate framework (macOS only) @@ -46,10 +46,11 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.15.0", optional = true burn-common = { path = "../burn-common", version = "0.15.0", default-features = false } burn-tensor = { path = "../burn-tensor", version = "0.15.0", default-features = false } -matrixmultiply = { workspace = true, default-features = false } +atomic_float = { workspace = true } blas-src = { workspace = true, default-features = false, optional = true } # no-std compatible derive-new = { workspace = true } libm = { workspace = true } +matrixmultiply = { workspace = true, default-features = false } ndarray = { workspace = true } num-traits = { workspace = true } openblas-src = { workspace = true, optional = true } diff --git a/crates/burn-ndarray/src/ops/deform_conv.rs b/crates/burn-ndarray/src/ops/deform_conv.rs new file mode 100644 index 000000000..6e96bbcd8 --- /dev/null +++ b/crates/burn-ndarray/src/ops/deform_conv.rs @@ -0,0 +1,656 @@ +use burn_common::{iter_par, run_par}; +use burn_tensor::ops::{conv::calculate_conv_output_size, DeformConvOptions}; +use core::ops::AddAssign; +use ndarray::{ + s, Array2, Array4, ArrayView2, ArrayView3, ArrayView4, ArrayView6, ArrayViewMut2, Axis, Dim, + Ix4, +}; +#[cfg(not(feature = "std"))] +use num_traits::Float; + +use crate::{element::QuantElement, FloatNdArrayElement, NdArrayTensor}; + +use super::matmul::matmul; + +#[inline(always)] +#[allow(clippy::too_many_arguments)] +fn deform_im2col_kernel( + out_y: usize, + out_x: usize, + input: ArrayView2, + offset: ArrayView3, + mask: Option>, + mut columns: ArrayViewMut2, + args: DeformConvOptions<2>, + (kernel_h, kernel_w): (usize, usize), +) { + // position shape: [in_channels, batch_size, out_h, out_w] + // columns shape: [[in_channels, kernel_h, kernel_w], [batch_size, out_h, out_w]] + + let (height, width) = input.dim(); + + for kernel_y in 0..kernel_h { + for kernel_x in 0..kernel_w { + let mask_value = mask + .map(|it| it[[kernel_y, kernel_x]]) + .unwrap_or_else(|| F::from_elem(1.0)); + + let offset = offset.slice(s![kernel_y, kernel_x, ..]); + let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) + - F::from_elem(args.padding[0]) + + offset[0]; + let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) + - F::from_elem(args.padding[1]) + + offset[1]; + + let interpolated = bilinear_interpolate(input, height, width, y, x); + + columns[[kernel_y, kernel_x]] = mask_value * interpolated; + } + } +} + +fn bilinear_interpolate( + input: ArrayView2, + height: usize, + width: usize, + y: F, + x: F, +) -> F { + // To simplify code + let y = y.to_f32(); + let x = x.to_f32(); + + let mut result = F::from_elem(0.0); + if y > -1.0 && height as f32 > y && x > -1.0 && width as f32 > x { + let y_low = f32::floor(y); + let x_low = f32::floor(x); + let y_high = (y_low + 1.) as usize; + let x_high = (x_low + 1.) as usize; + + let zero = F::from_elem(0.0); + let v1: F = if y_low >= 0. && x_low >= 0. { + input[[y_low as usize, x_low as usize]] + } else { + zero + }; + let v2: F = if y_low >= 0. && x_high < width { + input[[y_low as usize, x_high]] + } else { + zero + }; + let v3: F = if y_high < height && x_low >= 0. { + input[[y_high, x_low as usize]] + } else { + zero + }; + let v4: F = if y_high < height && x_high < width { + input[[y_high, x_high]] + } else { + zero + }; + + let l_y = y - y_low; + let l_x = x - x_low; + let h_y = 1.0 - l_y; + let h_x = 1.0 - l_x; + + let w1 = F::from_elem(h_y * h_x); + let w2 = F::from_elem(h_y * l_x); + let w3 = F::from_elem(l_y * h_x); + let w4 = F::from_elem(l_y * l_x); + + result = w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4; + } + result +} + +pub(crate) fn deform_conv2d( + input: NdArrayTensor, + offset: NdArrayTensor, + weight: NdArrayTensor, + mask: Option>, + bias: Option>, + args: DeformConvOptions<2>, +) -> NdArrayTensor { + let [batch_size, _, in_height, in_width] = input.shape().dims; + let [out_channels, _, kernel_h, kernel_w] = weight.shape().dims; + let groups = args.weight_groups; + + let weight = weight.array.as_standard_layout(); + + let out_h = calculate_conv_output_size( + kernel_h, + args.stride[0], + args.padding[0], + args.dilation[0], + in_height, + ); + let out_w = calculate_conv_output_size( + kernel_w, + args.stride[1], + args.padding[1], + args.dilation[1], + in_width, + ); + let out_dims = (out_h, out_w); + + let input = input.array.into_dimensionality::().unwrap(); + let offset = offset.array.into_dimensionality::().unwrap(); + let mask = mask.as_ref().map(|it| { + it.array + .to_shape(( + batch_size, + args.offset_groups, + kernel_h, + kernel_w, + out_h, + out_w, + )) + .unwrap() + }); + + let columns = deform_im2col( + input.view(), + offset.view(), + mask.as_ref().map(|it| it.view()), + args, + out_dims, + (kernel_h, kernel_w), + ); + + let (col_size_0, col_size_1) = columns.dim(); + let col_size_0 = col_size_0 / groups; + let out_c_per_group = out_channels / groups; + + let weight = weight + .to_shape((groups, out_c_per_group, col_size_0)) + .unwrap(); + let columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap(); + let out = matmul( + NdArrayTensor::<_, 3>::new(weight.to_owned().into_dyn().into_shared()), + NdArrayTensor::<_, 3>::new(columns.to_owned().into_dyn().into_shared()), + ); + + let mut out = out + .array + .into_shape_with_order((out_channels, batch_size, out_h, out_w)) + .unwrap(); + out.swap_axes(0, 1); + + if let Some(bias) = bias { + let bias = bias.array.to_shape((1, out_channels, 1, 1)).unwrap(); + out.add_assign(&bias); + } + + NdArrayTensor::new(out.into_dyn().into_shared()) +} + +pub(crate) fn deform_im2col( + input: ArrayView4, + offset: ArrayView4, + mask: Option>, + args: DeformConvOptions<2>, + out_dims: (usize, usize), + kernel_dims: (usize, usize), +) -> Array2 { + let (batch_size, in_channels, _, _) = input.dim(); + let (kernel_h, kernel_w) = kernel_dims; + let (out_h, out_w) = out_dims; + let channels_per_offset_group = in_channels / args.offset_groups; + + let mut columns = Array4::zeros(Dim([ + in_channels, + kernel_h, + kernel_w, + batch_size * out_h * out_w, + ])); + + let groups = args.offset_groups; + + run_par!(|| { + iter_par!(columns.axis_iter_mut(Axis(3))) + .enumerate() + .for_each(|(index, mut columns)| { + let out_x = index % out_w; + let out_y = (index / out_w) % out_h; + let batch = (index / (out_w * out_h)) % batch_size; + let offset = offset.slice(s![batch, .., out_y, out_x]); + let offset = offset.to_shape((groups, kernel_h, kernel_w, 2)).unwrap(); + let mask = mask + .as_ref() + .map(|it| it.slice(s![batch, .., .., .., out_y, out_x])); + columns + .axis_iter_mut(Axis(0)) + .enumerate() + .for_each(|(in_channel, mut columns)| { + let group_index = in_channel / channels_per_offset_group; + deform_im2col_kernel( + out_y, + out_x, + input.slice(s![batch, in_channel, .., ..]), + offset.slice(s![group_index, .., .., ..]), + mask.as_ref().map(|it| it.slice(s![group_index, .., ..])), + columns.view_mut(), + args.clone(), + kernel_dims, + ); + }); + }); + }); + + columns + // Columns is created here, so we know it's contiguous + .into_shape_with_order(( + in_channels * kernel_h * kernel_w, + batch_size * out_h * out_w, + )) + .unwrap() +} + +pub mod backward { + #[cfg(target_has_atomic = "32")] + use core::sync::atomic::Ordering; + + use crate::NdArray; + use atomic_float::AtomicF32; + use burn_tensor::ops::DeformConv2dBackward; + use ndarray::{Array1, Array5, ArrayView4, ArrayView6, Ix4}; + + use super::*; + + /// Calculate the [deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d) backward pass using convolutions. + pub(crate) fn deform_conv2d_backward( + input: NdArrayTensor, + offset: NdArrayTensor, + weight: NdArrayTensor, + mask: Option>, + bias: Option>, + out_grad: NdArrayTensor, + args: DeformConvOptions<2>, + ) -> DeformConv2dBackward> { + let [batch_size, out_channels, out_h, out_w] = out_grad.shape().dims; + let [_, _, kernel_h, kernel_w] = weight.shape().dims; + let groups = args.weight_groups; + let out_c_per_group = out_channels / groups; + let col_shape_1 = batch_size * out_h * out_w; + let mut out_grad = out_grad.array.into_dimensionality::().unwrap(); + + let gradient_bias = bias.map(|_| { + let out_grad = out_grad + .clone() + .sum_axis(Axis(0)) + .sum_axis(Axis(1)) + .sum_axis(Axis(1)); + + NdArrayTensor::new(out_grad.into_dyn().into_shared()) + }); + + out_grad.swap_axes(0, 1); + let out_grad = out_grad + .to_shape((groups, out_c_per_group, col_shape_1)) + .unwrap(); + + let input = input.array.into_dimensionality::().unwrap(); + let offset = offset.array.into_dimensionality::().unwrap(); + let mask = mask.map(|it| { + it.array + .into_shape_with_order(( + batch_size, + args.offset_groups, + kernel_h, + kernel_w, + out_h, + out_w, + )) + .unwrap() + }); + + let (input_gradient, offset_gradient, mask_gradient) = backward_gradient_inputs( + input.view(), + weight, + offset.view(), + mask.as_ref().map(|it| it.view()), + out_grad.view(), + &args, + (kernel_h, kernel_w), + ); + + let weight_grad = compute_weight_grad( + input.view(), + offset.view(), + mask.as_ref().map(|it| it.view()), + out_grad.view(), + args, + (kernel_h, kernel_w), + (out_h, out_w), + ); + + DeformConv2dBackward::new( + input_gradient, + offset_gradient, + weight_grad, + mask_gradient, + gradient_bias, + ) + } + + fn compute_weight_grad( + input: ArrayView4, + offset: ArrayView4, + mask: Option>, + out_grad: ArrayView3, + options: DeformConvOptions<2>, + kernel_dims: (usize, usize), + out_dims: (usize, usize), + ) -> NdArrayTensor { + let in_channels = input.dim().1; + let (groups, out_c_per_group, _) = out_grad.dim(); + let (kernel_h, kernel_w) = kernel_dims; + + let in_c_per_group = in_channels / groups; + + let columns = deform_im2col(input, offset, mask, options, out_dims, kernel_dims); + let (col_size_0, col_size_1) = columns.dim(); + let col_size_0 = col_size_0 / groups; + + let mut columns = columns.to_shape((groups, col_size_0, col_size_1)).unwrap(); + columns.swap_axes(1, 2); + + let grad_weight = matmul( + NdArrayTensor::<_, 3>::new(out_grad.to_owned().into_dyn().into_shared()), + NdArrayTensor::<_, 3>::new(columns.to_owned().into_dyn().into_shared()), + ); + + let grad_weight = grad_weight + .array + .into_shape_with_order((out_c_per_group * groups, in_c_per_group, kernel_h, kernel_w)) + .unwrap(); + NdArrayTensor::new(grad_weight.into_dyn().into_shared()) + } + + type InputGradients = ( + NdArrayTensor, + NdArrayTensor, + Option>, + ); + + fn backward_gradient_inputs( + image: ArrayView4, + weight: NdArrayTensor, + offset: ArrayView4, + mask: Option>, + out_grad: ArrayView3, + args: &DeformConvOptions<2>, + kernel_dims: (usize, usize), + ) -> InputGradients { + let input_shape = image.dim(); + let in_channels = input_shape.1; + let [out_channels, in_c_per_group, kernel_h, kernel_w] = weight.shape().dims; + let (batch_size, _, out_h, out_w) = offset.dim(); + + let groups = args.weight_groups; + let out_c_per_group = out_channels / groups; + + let col_shape_0 = in_c_per_group * kernel_h * kernel_w; + + let mut weight = weight + .array + .to_shape((groups, out_c_per_group, col_shape_0)) + .unwrap(); + weight.swap_axes(1, 2); + let columns = matmul( + NdArrayTensor::<_, 3>::new(weight.to_owned().into_dyn().into_shared()), + NdArrayTensor::<_, 3>::new(out_grad.to_owned().into_dyn().into_shared()), + ); + + let columns = columns + .array + .to_shape((in_channels, kernel_h, kernel_w, batch_size, out_h, out_w)) + .unwrap(); + + let (offset_gradient, mask_gradient) = compute_offset_and_mask_gradient( + columns.view(), + image.view(), + offset, + mask, + args, + kernel_dims, + ); + + let input_gradient = + compute_input_grad(columns.view(), offset, mask, args, kernel_dims, input_shape); + + (input_gradient, offset_gradient, mask_gradient) + } + + fn compute_offset_and_mask_gradient( + columns: ArrayView6, + image: ArrayView4, + offset: ArrayView4, + mask: Option>, + args: &DeformConvOptions<2>, + kernel_dims: (usize, usize), + ) -> (NdArrayTensor, Option>) { + let (kernel_h, kernel_w) = kernel_dims; + let (_, in_channels, height, width) = image.dim(); + let (batch_size, offset_channels, out_h, out_w) = offset.dim(); + let offs_groups = args.offset_groups; + let channels_per_offset_group = in_channels / args.offset_groups; + + let mut grad_offset = Array5::zeros(( + offs_groups, + kernel_h, + kernel_w, + 2, + batch_size * out_h * out_w, + )); + let mut grad_mask = + Array4::zeros((offs_groups, kernel_h, kernel_w, batch_size * out_h * out_w)); + + grad_mask + .axis_iter_mut(Axis(3)) + .zip(grad_offset.axis_iter_mut(Axis(4))) + .enumerate() + .for_each(|(index, (mut grad_mask, mut grad_offset))| { + let out_x = index % out_w; + let out_y = (index / out_w) % out_h; + let batch = index / (out_w * out_h); + let offset = offset.slice(s![batch, .., out_y, out_x]); + let offset = offset + .to_shape((offs_groups, kernel_h, kernel_w, 2)) + .unwrap(); + let mask: Option> = mask + .as_ref() + .map(|mask| mask.slice(s![batch, .., .., .., out_y, out_x])); + let columns = columns.slice(s![.., .., .., batch, out_y, out_x]); + let image = image.slice(s![batch, .., .., ..]); + + for ((group, kernel_y, kernel_x), grad_mask) in grad_mask.indexed_iter_mut() { + let grad_mask: &mut F = grad_mask; + let mut grad_offset = grad_offset.slice_mut(s![group, kernel_y, kernel_x, ..]); + let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); + let mask = mask.map(|it| it[[group, kernel_y, kernel_x]]); + let columns = columns.slice(s![.., kernel_y, kernel_x]); + let group_offset = group * channels_per_offset_group; + let image = image.slice(s![group_offset.., .., ..]); + let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) + - F::from_elem(args.padding[0]) + + offset[0]; + let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) + - F::from_elem(args.padding[1]) + + offset[1]; + for (i, grad_offset) in grad_offset.iter_mut().enumerate() { + let is_y_direction = i % 2 == 0; + let use_mask = mask.is_some(); + + for channel in 0..channels_per_offset_group { + let mask = mask.unwrap_or_else(|| F::one()); + let image = image.index_axis(Axis(0), channel); + let weight = + get_coordinate_weight(image, height, width, y, x, is_y_direction); + *grad_offset += mask * weight * columns[channel]; + if use_mask && is_y_direction { + *grad_mask += columns[channel] + * bilinear_interpolate(image, height, width, y, x); + } + } + } + } + }); + + let mask_gradient = mask.map(|_| { + let mut grad_mask = grad_mask + .into_shape_with_order((offset_channels / 2, batch_size, out_h, out_w)) + .unwrap(); + grad_mask.swap_axes(0, 1); + NdArrayTensor::new(grad_mask.into_dyn().into_shared()) + }); + let mut grad_offset = grad_offset + .into_shape_with_order((offset_channels, batch_size, out_h, out_w)) + .unwrap(); + grad_offset.swap_axes(0, 1); + let offset_gradient = NdArrayTensor::new(grad_offset.into_dyn().into_shared()); + (offset_gradient, mask_gradient) + } + + fn get_coordinate_weight( + input: ArrayView2, + height: usize, + width: usize, + y: F, + x: F, + is_y_direction: bool, + ) -> F { + let y = y.to_f32(); + let x = x.to_f32(); + + let y_low = f32::floor(y); + let x_low = f32::floor(x); + let y_high = y_low + 1.; + let x_high = x_low + 1.; + + let valid_y_low = y_low >= 0. && y_low < height as f32; + let valid_y_high = y_high >= 0. && y_high < height as f32; + let valid_x_low = x_low >= 0. && x_low < width as f32; + let valid_x_high = x_high >= 0. && x_high < width as f32; + + let bottom_left = if valid_y_low && valid_x_low { + input[[y_low as usize, x_low as usize]] + } else { + F::zero() + }; + let bottom_right = if valid_y_low && valid_x_high { + input[[y_low as usize, x_high as usize]] + } else { + F::zero() + }; + let top_left = if valid_y_high && valid_x_low { + input[[y_high as usize, x_low as usize]] + } else { + F::zero() + }; + let top_right = if valid_y_high && valid_x_high { + input[[y_high as usize, x_high as usize]] + } else { + F::zero() + }; + + if is_y_direction { + let delta_x = F::from_elem(x - x_low); + delta_x * (top_right - bottom_right) + (F::one() - delta_x) * (top_left - bottom_left) + } else { + let delta_y = F::from_elem(y - y_low); + delta_y * (top_right - top_left) + (F::one() - delta_y) * (bottom_right - bottom_left) + } + } + + fn compute_input_grad( + columns: ArrayView6, + offset: ArrayView4, + mask: Option>, + args: &DeformConvOptions<2>, + kernel_dims: (usize, usize), + input_shape: (usize, usize, usize, usize), + ) -> NdArrayTensor { + let (batch_size, in_channels, height, width) = input_shape; + let (kernel_h, kernel_w) = kernel_dims; + let offs_groups = args.offset_groups; + let channels_per_offset_group = in_channels / offs_groups; + + let grad_in = + Array4::from_shape_simple_fn((batch_size, in_channels, height, width), || { + AtomicF32::new(0.0) + }); + + run_par!(|| { + iter_par!(columns.indexed_iter()).for_each( + |((in_channel, kernel_y, kernel_x, batch, out_y, out_x), col)| { + let group = in_channel / channels_per_offset_group; + let offset = offset.slice(s![batch, .., out_y, out_x]); + let offset = offset + .to_shape((offs_groups, kernel_h, kernel_w, 2)) + .unwrap(); + let offset = offset.slice(s![group, kernel_y, kernel_x, ..]); + let offset = [offset[0], offset[1]]; + let mask = mask + .as_ref() + .map(|it| it[[batch, group, kernel_y, kernel_x, out_y, out_x]].to_f32()); + let y = F::from_elem(out_y * args.stride[0] + kernel_y * args.dilation[0]) + - F::from_elem(args.padding[0]) + + offset[0]; + let x = F::from_elem(out_x * args.stride[1] + kernel_x * args.dilation[1]) + - F::from_elem(args.padding[1]) + + offset[1]; + let grad_in = grad_in.slice(s![batch, in_channel, .., ..]); + deform_col2img_kernel(y.to_f32(), x.to_f32(), mask, col.to_f32(), grad_in); + }, + ) + }); + + let grad_in: Array1 = grad_in + .into_iter() + .map(|it| F::from_elem(it.into_inner())) + .collect(); + let grad_in = grad_in + .into_shape_with_order((batch_size, in_channels, height, width)) + .unwrap(); + NdArrayTensor::new(grad_in.into_dyn().into_shared()) + } + + fn deform_col2img_kernel( + y: f32, + x: f32, + mask: Option, + col: f32, + grad_input: ArrayView2, + ) { + let (height, width) = grad_input.dim(); + let mask_value = mask.unwrap_or(1.0); + + for dy in -1..=1 { + for dx in -1..=1 { + let yp = f32::floor(y) + dy as f32; + let xp = f32::floor(x) + dx as f32; + + if yp >= 0.0 + && yp < height as f32 + && xp >= 0.0 + && xp < width as f32 + && f32::abs(y - yp) < 1.0 + && f32::abs(x - xp) < 1.0 + { + let weight = (1.0 - f32::abs(y - yp)) * (1.0 - f32::abs(x - xp)); + + #[cfg_attr(not(target_has_atomic = "32"), allow(unused))] + let value = mask_value * weight * col; + + #[cfg(target_has_atomic = "32")] + grad_input[[yp as usize, xp as usize]].fetch_add(value, Ordering::AcqRel); + #[cfg(not(target_has_atomic = "32"))] + panic!("Can't use deformable convolution backwards pass without atomics"); + } + } + } + } +} diff --git a/crates/burn-ndarray/src/ops/mod.rs b/crates/burn-ndarray/src/ops/mod.rs index 13a9cc242..56eb62eed 100644 --- a/crates/burn-ndarray/src/ops/mod.rs +++ b/crates/burn-ndarray/src/ops/mod.rs @@ -9,6 +9,7 @@ mod tensor; pub(crate) mod adaptive_avgpool; pub(crate) mod avgpool; pub(crate) mod conv; +pub(crate) mod deform_conv; pub(crate) mod interpolate; pub(crate) mod macros; pub(crate) mod matmul; diff --git a/crates/burn-ndarray/src/ops/module.rs b/crates/burn-ndarray/src/ops/module.rs index 7058149f8..69951bb8d 100644 --- a/crates/burn-ndarray/src/ops/module.rs +++ b/crates/burn-ndarray/src/ops/module.rs @@ -2,6 +2,7 @@ use super::{ adaptive_avgpool::{adaptive_avg_pool2d, adaptive_avg_pool2d_backward}, avgpool::{avg_pool2d, avg_pool2d_backward}, conv::{conv2d, conv3d, conv_transpose2d, conv_transpose3d}, + deform_conv::{backward::deform_conv2d_backward, deform_conv2d}, interpolate::{bicubic_interpolate, bilinear_interpolate, nearest_interpolate}, maxpool::{max_pool2d, max_pool2d_backward, max_pool2d_with_indices}, }; @@ -19,6 +20,29 @@ impl ModuleOps for NdArray conv2d::(x, weight, bias, options) } + fn deform_conv2d( + x: NdArrayTensor, + offset: NdArrayTensor, + weight: NdArrayTensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, + ) -> NdArrayTensor { + deform_conv2d::(x, offset, weight, mask, bias, options) + } + + fn deform_conv2d_backward( + x: NdArrayTensor, + offset: NdArrayTensor, + weight: NdArrayTensor, + mask: Option>, + bias: Option>, + output_grad: NdArrayTensor, + options: DeformConvOptions<2>, + ) -> DeformConv2dBackward { + deform_conv2d_backward(x, offset, weight, mask, bias, output_grad, options) + } + fn conv_transpose2d( x: NdArrayTensor, weight: NdArrayTensor, diff --git a/crates/burn-tch/src/ops/module.rs b/crates/burn-tch/src/ops/module.rs index 60fde6a33..f16b83d0c 100644 --- a/crates/burn-tch/src/ops/module.rs +++ b/crates/burn-tch/src/ops/module.rs @@ -1,7 +1,7 @@ use crate::{element::TchElement, LibTorch, QuantElement, TchTensor}; use burn_tensor::ops::{ - ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions, MaxPool1dWithIndices, - MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, + ConvOptions, ConvTransposeOptions, DeformConv2dBackward, DeformConvOptions, InterpolateMode, + InterpolateOptions, MaxPool1dWithIndices, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }; impl ModuleOps for LibTorch { @@ -86,6 +86,29 @@ impl ModuleOps for LibTorch { TchTensor::new(tensor) } + fn deform_conv2d( + _x: TchTensor, + _offset: TchTensor, + _weight: TchTensor, + _mask: Option>, + _bias: Option>, + _options: DeformConvOptions<2>, + ) -> TchTensor { + unimplemented!("Torch bindings don't support deform_conv2d"); + } + + fn deform_conv2d_backward( + _x: TchTensor, + _offset: TchTensor, + _weight: TchTensor, + _mask: Option>, + _bias: Option>, + _out_grad: TchTensor, + _options: DeformConvOptions<2>, + ) -> DeformConv2dBackward { + unimplemented!("Torch bindings don't support deform_conv2d"); + } + fn conv_transpose1d( x: TchTensor, weight: TchTensor, diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index a314da1de..634867875 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -2,7 +2,9 @@ use serde::{Deserialize, Serialize}; use std::ops::Range; use crate::{ - ops::{ConvOptions, ConvTransposeOptions, InterpolateMode, InterpolateOptions}, + ops::{ + ConvOptions, ConvTransposeOptions, DeformConvOptions, InterpolateMode, InterpolateOptions, + }, repr::tensor::TensorDescription, DType, Distribution, Element, }; @@ -74,6 +76,10 @@ pub enum ModuleOperationDescription { Conv2d(Conv2dDescription), /// Operation corresponding to [conv3d](crate::ops::ModuleOps::conv3d). Conv3d(Conv3dDescription), + /// Operation corresponding to [deform_conv2d](crate::ops::ModuleOps::deform_conv2d) + DeformableConv2d(Box), + /// Operation corresponding to [deform_conv2d_backward](crate::ops::ModuleOps::deform_conv2d_backward) + DeformableConv2dBackward(Box), /// Operation corresponding to [conv transpose 1d](crate::ops::ModuleOps::conv_transpose1d). ConvTranspose1d(ConvTranspose1dDescription), /// Operation corresponding to [conv transpose 2d](crate::ops::ModuleOps::conv_transpose2d). @@ -688,6 +694,35 @@ pub struct Conv2dDescription { pub out: TensorDescription, } +#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct DeformConv2dDescription { + pub x: TensorDescription, + pub offset: TensorDescription, + pub weight: TensorDescription, + pub mask: Option, + pub bias: Option, + pub options: DeformableConv2dOptionsDescription, + pub out: TensorDescription, +} + +#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct DeformConv2dBackwardDescription { + pub x: TensorDescription, + pub offset: TensorDescription, + pub weight: TensorDescription, + pub mask: Option, + pub bias: Option, + pub out_grad: TensorDescription, + pub options: DeformableConv2dOptionsDescription, + pub input_grad: TensorDescription, + pub offset_grad: TensorDescription, + pub weight_grad: TensorDescription, + pub mask_grad: Option, + pub bias_grad: Option, +} + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv3dDescription { @@ -746,6 +781,16 @@ pub struct Conv2dOptionsDescription { pub groups: usize, } +#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] +#[allow(missing_docs)] +pub struct DeformableConv2dOptionsDescription { + pub stride: [usize; 2], + pub padding: [usize; 2], + pub dilation: [usize; 2], + pub weight_groups: usize, + pub offset_groups: usize, +} + #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[allow(missing_docs)] pub struct Conv3dOptionsDescription { @@ -818,6 +863,18 @@ impl From> for Conv3dOptionsDescription { } } +impl From> for DeformableConv2dOptionsDescription { + fn from(value: DeformConvOptions<2>) -> Self { + Self { + stride: value.stride, + padding: value.padding, + dilation: value.dilation, + weight_groups: value.weight_groups, + offset_groups: value.offset_groups, + } + } +} + impl From> for ConvTranspose1dOptionsDescription { fn from(value: ConvTransposeOptions<1>) -> Self { Self { @@ -887,6 +944,18 @@ impl From for ConvOptions<3> { } } +impl From for DeformConvOptions<2> { + fn from(value: DeformableConv2dOptionsDescription) -> Self { + DeformConvOptions { + stride: value.stride, + padding: value.padding, + dilation: value.dilation, + weight_groups: value.weight_groups, + offset_groups: value.offset_groups, + } + } +} + impl From for ConvTransposeOptions<1> { fn from(val: ConvTranspose1dOptionsDescription) -> Self { ConvTransposeOptions { @@ -1404,6 +1473,22 @@ impl ModuleOperationDescription { vec![&desc.x, &desc.weight, &desc.out] } } + ModuleOperationDescription::DeformableConv2d(desc) => match (&desc.mask, &desc.bias) { + (Some(mask), Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias], + (Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask], + (None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias], + (None, None) => vec![&desc.x, &desc.offset, &desc.weight], + }, + ModuleOperationDescription::DeformableConv2dBackward(desc) => { + match (&desc.mask, &desc.bias) { + (Some(mask), Some(bias)) => { + vec![&desc.x, &desc.offset, &desc.weight, &mask, &bias] + } + (Some(mask), None) => vec![&desc.x, &desc.offset, &desc.weight, &mask], + (None, Some(bias)) => vec![&desc.x, &desc.offset, &desc.weight, &bias], + (None, None) => vec![&desc.x, &desc.offset, &desc.weight], + } + } ModuleOperationDescription::ConvTranspose1d(desc) => { if let Some(bias) = &desc.bias { vec![&desc.x, &desc.weight, &bias, &desc.out] diff --git a/crates/burn-tensor/src/tensor/module.rs b/crates/burn-tensor/src/tensor/module.rs index 9c66093e1..0980ec3e3 100644 --- a/crates/burn-tensor/src/tensor/module.rs +++ b/crates/burn-tensor/src/tensor/module.rs @@ -4,6 +4,8 @@ use crate::{ Int, Tensor, TensorPrimitive, }; +use super::ops::DeformConvOptions; + /// Applies the [embedding module](crate::ops::ModuleOps::embedding). pub fn embedding(weights: Tensor, indices: Tensor) -> Tensor where @@ -69,6 +71,28 @@ where ))) } +/// Applies a [Deformable 2D convolution](crate::ops::ModuleOps::deform_conv2d). +pub fn deform_conv2d( + x: Tensor, + offset: Tensor, + weight: Tensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, +) -> Tensor +where + B: Backend, +{ + Tensor::new(TensorPrimitive::Float(B::deform_conv2d( + x.primitive.tensor(), + offset.primitive.tensor(), + weight.primitive.tensor(), + mask.map(|m| m.primitive.tensor()), + bias.map(|b| b.primitive.tensor()), + options, + ))) +} + /// Applies a [1D transposed convolution](crate::ops::ModuleOps::conv_transpose1d). pub fn conv_transpose1d( x: Tensor, diff --git a/crates/burn-tensor/src/tensor/ops/modules/base.rs b/crates/burn-tensor/src/tensor/ops/modules/base.rs index 58ed8682c..66cdc460b 100644 --- a/crates/burn-tensor/src/tensor/ops/modules/base.rs +++ b/crates/burn-tensor/src/tensor/ops/modules/base.rs @@ -5,6 +5,51 @@ use crate::{ Shape, }; +/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d). +#[derive(new)] +pub struct Conv2dBackward { + /// Gradient. + pub x_grad: FloatTensor, + + /// Weights gradient. + pub weights_grad: FloatTensor, + + /// Bias gradient. + pub bias_grad: Option>, +} + +/// Gradient computed during the backward pass for each tensor used by [deform_conv2d](ModuleOps::deform_conv2d). +#[derive(new)] +pub struct DeformConv2dBackward { + /// Gradient. + pub x_grad: FloatTensor, + + /// Offset gradient. + pub offset_grad: FloatTensor, + + /// Weights gradient. + pub weight_grad: FloatTensor, + + /// Mask gradient. + pub mask_grad: Option>, + + /// Bias gradient. + pub bias_grad: Option>, +} + +/// Gradient computed during the backward pass for each tensor used by [conv3d](ModuleOps::conv3d). +#[derive(new)] +pub struct Conv3dBackward { + /// Gradient. + pub x_grad: FloatTensor, + + /// Weights gradient. + pub weights_grad: FloatTensor, + + /// Bias gradient. + pub bias_grad: Option>, +} + /// Gradient computed during the backward pass for each tensor used by [max_pool1d](ModuleOps::max_pool1d). #[derive(new)] pub struct MaxPool1dBackward { @@ -55,6 +100,25 @@ pub struct ConvOptions { pub groups: usize, } +/// Convolution options. +#[derive(new, Debug, Clone, Hash, PartialEq, Eq)] +pub struct DeformConvOptions { + /// Stride. + pub stride: [usize; N], + + /// Padding. + pub padding: [usize; N], + + /// Dilation. + pub dilation: [usize; N], + + /// Weight Groups. + pub weight_groups: usize, + + /// Offset Groups. + pub offset_groups: usize, +} + /// Transposed convolution options. #[derive(new, Debug, Clone, Hash, PartialEq, Eq)] pub struct ConvTransposeOptions { @@ -248,6 +312,33 @@ pub trait ModuleOps { ) -> FloatTensor { conv::conv2d_bias_backward::(x, weight, bias, output_grad) } + + /// Two dimensional deformable convolution. + /// + /// # Shapes + /// + /// x: `[batch_size, channels_in, height, width]`, + /// weight: `[channels_out, channels_in, kernel_size_1, kernel_size_2]`, + /// bias: `[channels_out]`, + fn deform_conv2d( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + options: DeformConvOptions<2>, + ) -> FloatTensor; + /// Backward pass for the [deform_conv2d](ModuleOps::deform_conv2d) operation. + fn deform_conv2d_backward( + x: FloatTensor, + offset: FloatTensor, + weight: FloatTensor, + mask: Option>, + bias: Option>, + output_grad: FloatTensor, + options: DeformConvOptions<2>, + ) -> DeformConv2dBackward; + /// Three dimensional convolution. /// /// # Shapes diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 86966b3c0..e9ee65c3c 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -26,6 +26,7 @@ macro_rules! testgen_all { burn_tensor::testgen_module_conv1d!(); burn_tensor::testgen_module_conv2d!(); burn_tensor::testgen_module_conv3d!(); + burn_tensor::testgen_module_deform_conv2d!(); burn_tensor::testgen_module_conv_transpose1d!(); burn_tensor::testgen_module_conv_transpose2d!(); burn_tensor::testgen_module_conv_transpose3d!(); diff --git a/crates/burn-tensor/src/tests/module/deform_conv2d.rs b/crates/burn-tensor/src/tests/module/deform_conv2d.rs new file mode 100644 index 000000000..ecb4231e0 --- /dev/null +++ b/crates/burn-tensor/src/tests/module/deform_conv2d.rs @@ -0,0 +1,439 @@ +#[burn_tensor_testgen::testgen(module_deform_conv2d)] +mod tests { + + use super::*; + use burn_tensor::module::deform_conv2d; + use burn_tensor::ops::{DeformConv2dBackward, DeformConvOptions, ModuleOps}; + use burn_tensor::{Shape, Tensor}; + + #[test] + fn test_deform_conv2d_simple() { + let test = DeformConv2dTestCase { + batch_size: 1, + channels_in: 3, + channels_out: 5, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + weight_groups: 1, + offset_groups: 1, + height: 4, + width: 4, + }; + + test.assert_output(Tensor::::from([[ + [[0.9074, 0.6387], [0.5160, 0.4196]], + [[2.4259, 1.8008], [1.5449, 1.3112]], + [[3.9444, 2.9629], [2.5738, 2.2027]], + [[5.4629, 4.1250], [3.6027, 3.0943]], + [[6.9814, 5.2871], [4.6316, 3.9859]], + ]])); + } + + #[test] + fn test_deform_conv2d_batched() { + let test = DeformConv2dTestCase { + batch_size: 2, + channels_in: 3, + channels_out: 5, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + weight_groups: 1, + offset_groups: 1, + height: 4, + width: 4, + }; + + test.assert_output(Tensor::::from([ + [ + [[0.2155, 0.1928], [0.1934, 0.1755]], + [[0.7251, 0.6759], [0.6877, 0.6485]], + [[1.2347, 1.1590], [1.1821, 1.1215]], + [[1.7443, 1.6421], [1.6764, 1.5945]], + [[2.2539, 2.1252], [2.1708, 2.0675]], + ], + [ + [[1.6530, 1.1369], [0.9840, 0.7184]], + [[4.8368, 3.4725], [3.1773, 2.4180]], + [[8.0206, 5.8080], [5.3705, 4.1176]], + [[11.2045, 8.1435], [7.5637, 5.8173]], + [[14.3883, 10.4790], [9.7570, 7.5169]], + ], + ])) + } + + #[test] + fn test_deform_conv2d_weight_groups() { + let test = DeformConv2dTestCase { + batch_size: 1, + channels_in: 3, + channels_out: 6, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + weight_groups: 3, + offset_groups: 1, + height: 4, + width: 4, + }; + + test.assert_output(Tensor::::from([[ + [[0.1018, 0.0658], [0.0467, 0.0362]], + [[0.4125, 0.3367], [0.3069, 0.2824]], + [[1.3076, 1.0242], [0.9025, 0.8000]], + [[1.8405, 1.4581], [1.2994, 1.1588]], + [[3.4022, 2.6346], [2.3052, 2.0143]], + [[4.1574, 3.2315], [2.8389, 2.4857]], + ]])) + } + + #[test] + fn test_deform_conv2d_offset_groups() { + let test = DeformConv2dTestCase { + batch_size: 1, + channels_in: 3, + channels_out: 6, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + weight_groups: 1, + offset_groups: 3, + height: 4, + width: 4, + }; + + test.assert_output(Tensor::::from([[ + [[1.0794, 0.7676], [0.7209, 0.5337]], + [[2.7059, 2.0216], [1.9740, 1.5419]], + [[4.3325, 3.2755], [3.2271, 2.5501]], + [[5.9590, 4.5295], [4.4802, 3.5582]], + [[7.5855, 5.7835], [5.7333, 4.5664]], + [[9.2120, 7.0375], [6.9864, 5.5746]], + ]])) + } + + #[test] + fn test_deform_conv2d_different_kernel_size() { + let test = DeformConv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 4, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + weight_groups: 1, + offset_groups: 1, + height: 4, + width: 4, + }; + + test.assert_output(Tensor::::from([[ + [[1.0669], [0.6329]], + [[2.9741], [2.0383]], + [[4.8812], [3.4437]], + ]])) + } + + #[test] + fn test_deform_conv2d_different_padding_size() { + let test = DeformConv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 2, + padding_2: 3, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + weight_groups: 1, + offset_groups: 1, + height: 4, + width: 4, + }; + + test.assert_output(Tensor::::from([[ + [ + [ + 0.1998, 0.3762, 0.5285, 0.6053, 0.3844, 0.1987, 0.0481, 0.0000, + ], + [ + 0.2879, 0.5517, 0.7776, 0.8905, 0.5805, 0.3043, 0.0796, 0.0000, + ], + [ + 0.3729, 0.7214, 1.0137, 1.1520, 0.7564, 0.3931, 0.1016, 0.0000, + ], + [ + 0.1321, 0.3249, 0.4954, 0.5846, 0.4531, 0.2501, 0.0757, 0.0000, + ], + [ + 0.0593, 0.1607, 0.2448, 0.2971, 0.2395, 0.1327, 0.0471, 0.0000, + ], + [ + 0.0143, 0.0513, 0.0783, 0.0942, 0.0813, 0.0420, 0.0145, 0.0000, + ], + ], + [ + [ + 0.7667, 1.1648, 1.5219, 1.7111, 1.2305, 0.8076, 0.4504, 0.3333, + ], + [ + 0.9812, 1.6010, 2.1525, 2.4409, 1.7455, 1.0918, 0.5367, 0.3333, + ], + [ + 1.1964, 2.0448, 2.7853, 3.1522, 2.2426, 1.3513, 0.6049, 0.3333, + ], + [ + 0.6695, 1.1781, 1.6441, 1.9022, 1.5732, 1.0339, 0.5536, 0.3333, + ], + [ + 0.4950, 0.7861, 1.0398, 1.2047, 1.0523, 0.7439, 0.4834, 0.3333, + ], + [ + 0.3788, 0.4982, 0.5929, 0.6542, 0.6155, 0.4882, 0.3909, 0.3333, + ], + ], + [ + [ + 1.3335, 1.9534, 2.5154, 2.8170, 2.0766, 1.4165, 0.8527, 0.6667, + ], + [ + 1.6744, 2.6503, 3.5275, 3.9914, 2.9106, 1.8794, 0.9939, 0.6667, + ], + [ + 2.0198, 3.3683, 4.5570, 5.1525, 3.7288, 2.3095, 1.1082, 0.6667, + ], + [ + 1.2068, 2.0314, 2.7928, 3.2198, 2.6932, 1.8178, 1.0315, 0.6667, + ], + [ + 0.9308, 1.4116, 1.8348, 2.1124, 1.8652, 1.3551, 0.9196, 0.6667, + ], + [ + 0.7432, 0.9451, 1.1074, 1.2143, 1.1497, 0.9345, 0.7673, 0.6667, + ], + ], + ]])) + } + + #[test] + fn test_deform_conv2d_different_stride() { + let test = DeformConv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 2, + dilation_1: 1, + dilation_2: 1, + weight_groups: 1, + offset_groups: 1, + height: 4, + width: 4, + }; + + test.assert_output(Tensor::::from([[ + [[1.0647], [0.5783]], + [[2.9289], [1.8829]], + [[4.7931], [3.1875]], + ]])) + } + + #[test] + fn test_deform_conv2d_different_dilation() { + let test = DeformConv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 2, + weight_groups: 1, + offset_groups: 1, + height: 5, + width: 5, + }; + + test.assert_output(Tensor::::from([[ + [[0.6162], [0.7611], [0.4666]], + [[1.8578], [2.2684], [1.6208]], + [[3.0994], [3.7757], [2.7749]], + ]])) + } + + #[test] + fn test_deform_conv2d_different_width() { + let test = DeformConv2dTestCase { + batch_size: 1, + channels_in: 2, + channels_out: 3, + kernel_size_1: 3, + kernel_size_2: 3, + padding_1: 0, + padding_2: 0, + stride_1: 1, + stride_2: 1, + dilation_1: 1, + dilation_2: 1, + weight_groups: 1, + offset_groups: 1, + height: 6, + width: 4, + }; + + test.assert_output(Tensor::::from([[ + [ + [0.8909, 0.6016], + [1.0697, 0.7186], + [1.2618, 0.8433], + [0.6424, 0.5032], + ], + [ + [2.4670, 1.8168], + [2.9529, 2.1497], + [3.4805, 2.5090], + [2.0925, 1.7411], + ], + [ + [4.0432, 3.0321], + [4.8362, 3.5809], + [5.6992, 4.1746], + [3.5425, 2.9790], + ], + ]])) + } + + struct DeformConv2dTestCase { + batch_size: usize, + channels_in: usize, + channels_out: usize, + kernel_size_1: usize, + kernel_size_2: usize, + padding_1: usize, + padding_2: usize, + stride_1: usize, + stride_2: usize, + dilation_1: usize, + dilation_2: usize, + weight_groups: usize, + offset_groups: usize, + height: usize, + width: usize, + } + + impl DeformConv2dTestCase { + fn assert_output(self, y: Tensor) { + let out_height = + (self.height + 2 * self.padding_1 - self.dilation_1 * (self.kernel_size_1 - 1) - 1) + / self.stride_1 + + 1; + let out_width = + (self.width + 2 * self.padding_2 - self.dilation_2 * (self.kernel_size_2 - 1) - 1) + / self.stride_2 + + 1; + + let shape_x = Shape::new([self.batch_size, self.channels_in, self.height, self.width]); + let shape_weight = Shape::new([ + self.channels_out, + self.channels_in / self.weight_groups, + self.kernel_size_1, + self.kernel_size_2, + ]); + let shape_offset = Shape::new([ + self.batch_size, + self.kernel_size_1 * self.kernel_size_2 * self.offset_groups * 2, + out_height, + out_width, + ]); + let shape_mask = Shape::new([ + self.batch_size, + self.kernel_size_1 * self.kernel_size_2 * self.offset_groups, + out_height, + out_width, + ]); + let device = Default::default(); + let weight = Tensor::::from( + TestTensorInt::arange(0..shape_weight.num_elements() as i64, &device) + .reshape(shape_weight.clone()) + .into_data(), + ) + .div_scalar(shape_weight.num_elements() as f32); + let bias = Tensor::::from( + TestTensorInt::arange(0..self.channels_out as i64, &device).into_data(), + ) + .div_scalar(self.channels_out as f32); + let x = Tensor::::from( + TestTensorInt::arange(0..shape_x.num_elements() as i64, &device) + .reshape(shape_x.clone()) + .into_data(), + ) + .div_scalar(shape_x.num_elements() as f32); + let offset = Tensor::::from( + TestTensorInt::arange(0..shape_offset.num_elements() as i64, &device) + .reshape(shape_offset.clone()) + .into_data(), + ) + .div_scalar(shape_offset.num_elements() as f32); + let mask = Tensor::::from( + TestTensorInt::arange(0..shape_mask.num_elements() as i64, &device) + .reshape(shape_mask.clone()) + .into_data(), + ) + .div_scalar(shape_mask.num_elements() as f32); + + let output = deform_conv2d( + x, + offset, + weight, + Some(mask), + Some(bias), + DeformConvOptions::new( + [self.stride_1, self.stride_2], + [self.padding_1, self.padding_2], + [self.dilation_1, self.dilation_2], + self.weight_groups, + self.offset_groups, + ), + ); + + y.to_data().assert_approx_eq(&output.into_data(), 3); + } + } +} diff --git a/crates/burn-tensor/src/tests/module/mod.rs b/crates/burn-tensor/src/tests/module/mod.rs index ab755d0ce..c4c982f70 100644 --- a/crates/burn-tensor/src/tests/module/mod.rs +++ b/crates/burn-tensor/src/tests/module/mod.rs @@ -10,6 +10,7 @@ mod conv3d; mod conv_transpose1d; mod conv_transpose2d; mod conv_transpose3d; +mod deform_conv2d; mod forward; mod maxpool1d; mod maxpool2d; diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index b617027fd..4b5f9223a 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -11,20 +11,22 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-wgpu" version.workspace = true [features] -default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"] -fusion = ["burn-fusion", "burn-jit/fusion"] autotune = ["burn-jit/autotune"] -template = ["burn-jit/template", "cubecl/template"] +default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"] doc = ["burn-jit/doc"] -std = ["burn-jit/std", "cubecl/std"] +fusion = ["burn-fusion", "burn-jit/fusion"] simple-memory-management = ["cubecl/simple-memory-management"] +std = ["burn-jit/std", "cubecl/std"] +template = ["burn-jit/template", "cubecl/template"] [dependencies] cubecl = { workspace = true, features = ["wgpu"] } -burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false } -burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = ["cubecl-wgpu"] } burn-fusion = { path = "../burn-fusion", version = "0.15.0", optional = true } +burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false } +burn-tensor = { path = "../burn-tensor", version = "0.15.0", features = [ + "cubecl-wgpu", +] } [dev-dependencies] burn-jit = { path = "../burn-jit", version = "0.15.0", default-features = false, features = [