Provide Tensor Padding Helpers #960 (#1097)

* Initial padding approach

Create padding implementation for the last two dimensions of Float and Int Tensors.

Create PadMode Enum, allowing Constant padding.

Create Padding Struct with Uniform, Asymmetric, height, and width implementations.

Create tests for the padding implementation.

* Update padding.rs

remove unneeded import

* Update from Merge

Use crate Element

Swap from old from_data() to new from_data_devauto()

* Formatting Changes

Formatting changes from cargo fmt --all

* Additional Format Change

One more format change that cargo fmt didn't get the first time.

* Changes to Example

Modify Example to ensure it works.

* modify naming

better names for impl / input variables.

* Modify API

- Change Padding to PadSize.
- integrate padding value into PadMode.
- update tests and examples.

* Comments and print

Improve comments+naming and remove println

* Pad Fixes

Moved pad to numeric

Simplified PadMode Element

updated tensor creations

fixed doc example

* Fix test location

* Simplified pad API

* Fix for failed unit tests

* Remove bool_full

* Rename `pads` to `padding`

---------

Co-authored-by: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com>
This commit is contained in:
jcmullwh 2024-03-27 13:46:55 -04:00 committed by GitHub
parent f6f6b5c0fa
commit 626457e1c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 165 additions and 11 deletions

View File

@ -144,11 +144,11 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
@ -185,6 +185,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.all_close(other, atol, rtol)` | `torch.allclose(tensor, other, atol, rtol)` |
| `tensor.argmax(dim)` | `tensor.argmax(dim)` |
| `tensor.argmin(dim)` | `tensor.argmin(dim)` |
| `tensor.argsort(dim)` | `tensor.argsort(dim)` |
| `tensor.argsort_descending(dim)` | `tensor.argsort(dim, descending=True)` |
| `tensor.bool()` | `tensor.bool()` |
| `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` |
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
@ -218,6 +220,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.mul_scalar(scalar)` or `tensor * scalar` | `tensor * scalar` |
| `tensor.neg()` or `-tensor` | `-tensor` |
| `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` |
| `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` |
| `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` |
| `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` |
| `tensor.prod()` | `tensor.prod()` |
@ -226,20 +229,18 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
| `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` |
| `tensor.select_assign(dim, indices, values)` | N/A |
| `tensor.sign()` | `tensor.sign()` |
| `tensor.sort(dim)` | `tensor.sort(dim).values` |
| `tensor.sort_descending(dim)` | `tensor.sort(dim, descending=True).values` |
| `tensor.sort_descending_with_indices(dim)` | `tensor.sort(dim, descending=True)` |
| `tensor.sort_with_indices(dim)` | `tensor.sort(dim)` |
| `tensor.sub(other)` or `tensor - other` | `tensor - other` |
| `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` |
| `tensor.sum()` | `tensor.sum()` |
| `tensor.sum_dim(dim)` | `tensor.sum(dim, keepdim=True)` |
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
| `tensor.sort(dim)` | `tensor.sort(dim).values` |
| `tensor.sort_descending(dim)` | `tensor.sort(dim, descending=True).values` |
| `tensor.sort_with_indices(dim)` | `tensor.sort(dim)` |
| `tensor.sort_descending_with_indices(dim)` | `tensor.sort(dim, descending=True)` |
| `tensor.argsort(dim)` | `tensor.argsort(dim)` |
| `tensor.argsort_descending(dim)` | `tensor.argsort(dim, descending=True)` |
| `tensor.topk(k, dim)` | `tensor.topk(k, dim).values` |
| `tensor.topk_with_indices(k, dim)` | `tensor.topk(k, dim)` |
| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` |
| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` |
### Float Operations

View File

@ -150,7 +150,7 @@ mod tests {
loss.into_data().assert_approx_eq(&Data::from([0.284]), 7);
loss_sum
.into_data()
.assert_approx_eq(&Data::from([1.42]), 7);
.assert_approx_eq(&Data::from([1.42]), 5);
}
#[cfg(feature = "std")]

View File

@ -1,3 +1,5 @@
use alloc::vec::Vec;
use crate::alloc::borrow::ToOwned;
use crate::{
@ -736,6 +738,49 @@ where
indices.select(dim, k_indices),
)
}
/// Pad the tensor with the given value on the last two dimensions.
///
/// # Arguments
///
/// * `padding` - A tuple of four integers representing the padding on the left, right, top, and bottom.
/// * `value` - The value to pad the tensor with.
///
/// # Returns
///
/// A new tensor with the given padding.
pub fn pad(self, padding: (usize, usize, usize, usize), value: K::Elem) -> Tensor<B, D, K> {
let (left, right, top, bottom) = padding;
let mut padded_dims: [usize; D] = self.dims();
// Update the last two dimensions with padding
padded_dims[D - 2] += top + bottom;
padded_dims[D - 1] += left + right;
// Create the ranges for the padded tensor
let ranges: [core::ops::Range<usize>; D] = padded_dims
.iter()
.enumerate()
.map(|(i, &dim)| {
if i == D - 2 {
top..dim - bottom
} else if i == D - 1 {
left..dim - right
} else {
0..dim
}
})
.collect::<Vec<core::ops::Range<usize>>>()
.try_into()
.unwrap();
// Create the padded tensor
let padded_tensor = Tensor::full(padded_dims, value, &self.device());
// Assign the original tensor data to the appropriate slice of the padded tensor
padded_tensor.slice_assign(ranges, self)
}
}
impl<B, K> Tensor<B, 2, K>

View File

@ -103,5 +103,8 @@ macro_rules! testgen_all {
// test clone invariance
burn_tensor::testgen_clone_invariance!();
// test padding
burn_tensor::testgen_padding!();
};
}

View File

@ -1,7 +1,7 @@
#[burn_tensor_testgen::testgen(full)]
mod tests {
use super::*;
use burn_tensor::{Data, Int, Shape, Tensor};
use burn_tensor::{Bool, Data, Int, Shape, Tensor};
#[test]
fn test_data_full() {
@ -22,5 +22,15 @@ mod tests {
let int_tensor = Tensor::<TestBackend, 2, Int>::full([2, 2], 2, &device);
let data_expected = Data::from([[2, 2], [2, 2]]);
assert_eq!(data_expected, int_tensor.into_data());
// TODO enable after adding support for bool
// // Test full with bool
// let bool_tensor = Tensor::<TestBackend, 2, Bool>::full([2, 2], true, &device);
// let data_expected = Data::from([[true, true], [true, true]]);
// assert_eq!(data_expected, bool_tensor.into_data());
// let bool_tensor = Tensor::<TestBackend, 2, Bool>::full([2, 2], false, &device);
// let data_expected = Data::from([[false, false], [false, false]]);
// assert_eq!(data_expected, bool_tensor.into_data());
}
}

View File

@ -35,6 +35,7 @@ mod mul;
mod narrow;
mod neg;
mod one_hot;
mod padding;
mod permute;
mod powf;
mod powf_scalar;

View File

@ -0,0 +1,94 @@
#[burn_tensor_testgen::testgen(padding)]
mod tests {
use super::*;
use burn_tensor::{Data, Int, Numeric, Shape, Tensor};
#[test]
fn padding_2d_test() {
let unpadded_floats: [[f32; 3]; 2] = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]];
let tensor = TestTensor::from(unpadded_floats);
let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1);
let padded_primitive_data_expected = [
[1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 0.0, 1.0, 2.0, 1.1, 1.1],
[1.1, 1.1, 3.0, 4.0, 5.0, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1],
];
let padded_data_expected = Data::from(padded_primitive_data_expected);
let padded_data_actual = padded_tensor.into_data();
assert_eq!(padded_data_expected, padded_data_actual);
}
#[test]
fn padding_4d_test() {
let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]];
let tensor = TestTensor::from(unpadded_floats);
let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1);
let padded_primitive_data_expected = [[[
[1.1, 1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 0.0, 1.0, 1.1, 1.1],
[1.1, 1.1, 2.0, 3.0, 1.1, 1.1],
[1.1, 1.1, 4.0, 5.0, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1, 1.1],
]]];
let padded_data_expected = Data::from(padded_primitive_data_expected);
let padded_data_actual = padded_tensor.into_data();
assert_eq!(padded_data_expected, padded_data_actual);
}
#[test]
fn padding_asymmetric_test() {
let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]];
let tensor = TestTensor::from(unpadded_floats);
let padded_tensor = tensor.pad((2, 1, 4, 3), 1.1);
let padded_primitive_data_expected = [[[
[1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 0.0, 1.0, 1.1],
[1.1, 1.1, 2.0, 3.0, 1.1],
[1.1, 1.1, 4.0, 5.0, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1],
[1.1, 1.1, 1.1, 1.1, 1.1],
]]];
}
#[test]
fn padding_asymmetric_integer_test() {
let unpadded_ints = [[[[0, 1], [2, 3], [4, 5]]]];
let tensor = TestTensorInt::from(unpadded_ints);
let padded_tensor = tensor.pad((2, 1, 4, 3), 6);
let padded_primitive_data_expected = [[[
[6, 6, 6, 6, 6],
[6, 6, 6, 6, 6],
[6, 6, 6, 6, 6],
[6, 6, 6, 6, 6],
[6, 6, 0, 1, 6],
[6, 6, 2, 3, 6],
[6, 6, 4, 5, 6],
[6, 6, 6, 6, 6],
[6, 6, 6, 6, 6],
[6, 6, 6, 6, 6],
]]];
let padded_data_expected = Data::from(padded_primitive_data_expected);
let padded_data_actual = padded_tensor.into_data();
assert_eq!(padded_data_expected, padded_data_actual);
}
}