diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index ea34c5c65..b46287e8d 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -144,11 +144,11 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. | `tensor.all_dim(dim)` | `tensor.all(dim)` | | `tensor.any()` | `tensor.any()` | | `tensor.any_dim(dim)` | `tensor.any(dim)` | -| `tensor.expand(shape)` | `tensor.expand(shape)` | | `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | | `tensor.device()` | `tensor.device` | | `tensor.dims()` | `tensor.size()` | | `tensor.equal(other)` | `x == y` | +| `tensor.expand(shape)` | `tensor.expand(shape)` | | `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | | `tensor.flip(axes)` | `tensor.flip(axes)` | | `tensor.into_data()` | N/A | @@ -185,6 +185,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.all_close(other, atol, rtol)` | `torch.allclose(tensor, other, atol, rtol)` | | `tensor.argmax(dim)` | `tensor.argmax(dim)` | | `tensor.argmin(dim)` | `tensor.argmin(dim)` | +| `tensor.argsort(dim)` | `tensor.argsort(dim)` | +| `tensor.argsort_descending(dim)` | `tensor.argsort(dim, descending=True)` | | `tensor.bool()` | `tensor.bool()` | | `tensor.clamp(min, max)` | `torch.clamp(tensor, min=min, max=max)` | | `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` | @@ -218,6 +220,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.mul_scalar(scalar)` or `tensor * scalar` | `tensor * scalar` | | `tensor.neg()` or `-tensor` | `-tensor` | | `tensor.not_equal_elem(scalar)` | `tensor.ne(scalar)` | +| `tensor.pad(pads, value)` | `torch.nn.functional.pad(input, pad, value)` | | `tensor.powf(other)` or `tensor.powi(intother)` | `tensor.pow(other)` | | `tensor.powf_scalar(scalar)` or `tensor.powi_scalar(intscalar)` | `tensor.pow(scalar)` | | `tensor.prod()` | `tensor.prod()` | @@ -226,20 +229,18 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`. | `tensor.select(dim, indices)` | `tensor.index_select(dim, indices)` | | `tensor.select_assign(dim, indices, values)` | N/A | | `tensor.sign()` | `tensor.sign()` | +| `tensor.sort(dim)` | `tensor.sort(dim).values` | +| `tensor.sort_descending(dim)` | `tensor.sort(dim, descending=True).values` | +| `tensor.sort_descending_with_indices(dim)` | `tensor.sort(dim, descending=True)` | +| `tensor.sort_with_indices(dim)` | `tensor.sort(dim)` | | `tensor.sub(other)` or `tensor - other` | `tensor - other` | | `tensor.sub_scalar(scalar)` or `tensor - scalar` | `tensor - scalar` | | `tensor.sum()` | `tensor.sum()` | | `tensor.sum_dim(dim)` | `tensor.sum(dim, keepdim=True)` | -| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` | -| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` | -| `tensor.sort(dim)` | `tensor.sort(dim).values` | -| `tensor.sort_descending(dim)` | `tensor.sort(dim, descending=True).values` | -| `tensor.sort_with_indices(dim)` | `tensor.sort(dim)` | -| `tensor.sort_descending_with_indices(dim)` | `tensor.sort(dim, descending=True)` | -| `tensor.argsort(dim)` | `tensor.argsort(dim)` | -| `tensor.argsort_descending(dim)` | `tensor.argsort(dim, descending=True)` | | `tensor.topk(k, dim)` | `tensor.topk(k, dim).values` | | `tensor.topk_with_indices(k, dim)` | `tensor.topk(k, dim)` | +| `tensor.tril(diagonal)` | `torch.tril(tensor, diagonal)` | +| `tensor.triu(diagonal)` | `torch.triu(tensor, diagonal)` | ### Float Operations diff --git a/crates/burn-core/src/nn/loss/huber.rs b/crates/burn-core/src/nn/loss/huber.rs index 055910ed4..a2ae13dd7 100644 --- a/crates/burn-core/src/nn/loss/huber.rs +++ b/crates/burn-core/src/nn/loss/huber.rs @@ -150,7 +150,7 @@ mod tests { loss.into_data().assert_approx_eq(&Data::from([0.284]), 7); loss_sum .into_data() - .assert_approx_eq(&Data::from([1.42]), 7); + .assert_approx_eq(&Data::from([1.42]), 5); } #[cfg(feature = "std")] diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 11b1c0754..87c2cc10d 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -1,3 +1,5 @@ +use alloc::vec::Vec; + use crate::alloc::borrow::ToOwned; use crate::{ @@ -736,6 +738,49 @@ where indices.select(dim, k_indices), ) } + + /// Pad the tensor with the given value on the last two dimensions. + /// + /// # Arguments + /// + /// * `padding` - A tuple of four integers representing the padding on the left, right, top, and bottom. + /// * `value` - The value to pad the tensor with. + /// + /// # Returns + /// + /// A new tensor with the given padding. + pub fn pad(self, padding: (usize, usize, usize, usize), value: K::Elem) -> Tensor { + let (left, right, top, bottom) = padding; + + let mut padded_dims: [usize; D] = self.dims(); + + // Update the last two dimensions with padding + padded_dims[D - 2] += top + bottom; + padded_dims[D - 1] += left + right; + + // Create the ranges for the padded tensor + let ranges: [core::ops::Range; D] = padded_dims + .iter() + .enumerate() + .map(|(i, &dim)| { + if i == D - 2 { + top..dim - bottom + } else if i == D - 1 { + left..dim - right + } else { + 0..dim + } + }) + .collect::>>() + .try_into() + .unwrap(); + + // Create the padded tensor + let padded_tensor = Tensor::full(padded_dims, value, &self.device()); + + // Assign the original tensor data to the appropriate slice of the padded tensor + padded_tensor.slice_assign(ranges, self) + } } impl Tensor diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index 69580aa6b..feed0f75b 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -103,5 +103,8 @@ macro_rules! testgen_all { // test clone invariance burn_tensor::testgen_clone_invariance!(); + + // test padding + burn_tensor::testgen_padding!(); }; } diff --git a/crates/burn-tensor/src/tests/ops/full.rs b/crates/burn-tensor/src/tests/ops/full.rs index 562dcd03b..3f8245f67 100644 --- a/crates/burn-tensor/src/tests/ops/full.rs +++ b/crates/burn-tensor/src/tests/ops/full.rs @@ -1,7 +1,7 @@ #[burn_tensor_testgen::testgen(full)] mod tests { use super::*; - use burn_tensor::{Data, Int, Shape, Tensor}; + use burn_tensor::{Bool, Data, Int, Shape, Tensor}; #[test] fn test_data_full() { @@ -22,5 +22,15 @@ mod tests { let int_tensor = Tensor::::full([2, 2], 2, &device); let data_expected = Data::from([[2, 2], [2, 2]]); assert_eq!(data_expected, int_tensor.into_data()); + + // TODO enable after adding support for bool + // // Test full with bool + // let bool_tensor = Tensor::::full([2, 2], true, &device); + // let data_expected = Data::from([[true, true], [true, true]]); + // assert_eq!(data_expected, bool_tensor.into_data()); + + // let bool_tensor = Tensor::::full([2, 2], false, &device); + // let data_expected = Data::from([[false, false], [false, false]]); + // assert_eq!(data_expected, bool_tensor.into_data()); } } diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index 5a475b786..f91423b0a 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -35,6 +35,7 @@ mod mul; mod narrow; mod neg; mod one_hot; +mod padding; mod permute; mod powf; mod powf_scalar; diff --git a/crates/burn-tensor/src/tests/ops/padding.rs b/crates/burn-tensor/src/tests/ops/padding.rs new file mode 100644 index 000000000..dbeba7d0e --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/padding.rs @@ -0,0 +1,94 @@ +#[burn_tensor_testgen::testgen(padding)] +mod tests { + use super::*; + use burn_tensor::{Data, Int, Numeric, Shape, Tensor}; + + #[test] + fn padding_2d_test() { + let unpadded_floats: [[f32; 3]; 2] = [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]; + let tensor = TestTensor::from(unpadded_floats); + + let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1); + + let padded_primitive_data_expected = [ + [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 0.0, 1.0, 2.0, 1.1, 1.1], + [1.1, 1.1, 3.0, 4.0, 5.0, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1, 1.1, 1.1], + ]; + + let padded_data_expected = Data::from(padded_primitive_data_expected); + let padded_data_actual = padded_tensor.into_data(); + assert_eq!(padded_data_expected, padded_data_actual); + } + + #[test] + fn padding_4d_test() { + let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]]; + let tensor = TestTensor::from(unpadded_floats); + + let padded_tensor = tensor.pad((2, 2, 2, 2), 1.1); + + let padded_primitive_data_expected = [[[ + [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 0.0, 1.0, 1.1, 1.1], + [1.1, 1.1, 2.0, 3.0, 1.1, 1.1], + [1.1, 1.1, 4.0, 5.0, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1, 1.1], + ]]]; + + let padded_data_expected = Data::from(padded_primitive_data_expected); + let padded_data_actual = padded_tensor.into_data(); + assert_eq!(padded_data_expected, padded_data_actual); + } + + #[test] + fn padding_asymmetric_test() { + let unpadded_floats = [[[[0.0, 1.0], [2.0, 3.0], [4.0, 5.0]]]]; + let tensor = TestTensor::from(unpadded_floats); + + let padded_tensor = tensor.pad((2, 1, 4, 3), 1.1); + + let padded_primitive_data_expected = [[[ + [1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 0.0, 1.0, 1.1], + [1.1, 1.1, 2.0, 3.0, 1.1], + [1.1, 1.1, 4.0, 5.0, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1], + [1.1, 1.1, 1.1, 1.1, 1.1], + ]]]; + } + + #[test] + fn padding_asymmetric_integer_test() { + let unpadded_ints = [[[[0, 1], [2, 3], [4, 5]]]]; + + let tensor = TestTensorInt::from(unpadded_ints); + let padded_tensor = tensor.pad((2, 1, 4, 3), 6); + + let padded_primitive_data_expected = [[[ + [6, 6, 6, 6, 6], + [6, 6, 6, 6, 6], + [6, 6, 6, 6, 6], + [6, 6, 6, 6, 6], + [6, 6, 0, 1, 6], + [6, 6, 2, 3, 6], + [6, 6, 4, 5, 6], + [6, 6, 6, 6, 6], + [6, 6, 6, 6, 6], + [6, 6, 6, 6, 6], + ]]]; + + let padded_data_expected = Data::from(padded_primitive_data_expected); + let padded_data_actual = padded_tensor.into_data(); + assert_eq!(padded_data_expected, padded_data_actual); + } +}