From ef421f0ae99f938f3a0873ada4be3b6fba307f3a Mon Sep 17 00:00:00 2001 From: Dilshod Tadjibaev <939125+antimora@users.noreply.github.com> Date: Thu, 13 Jul 2023 16:09:57 -0500 Subject: [PATCH] Add arange with steps op for int tensor (#490) --- burn-tensor/src/tensor/api/int.rs | 34 ++++++++++++++++++++++-- burn-tensor/src/tensor/ops/tensor.rs | 27 +++++++++++++++++-- burn-tensor/src/tests/mod.rs | 1 + burn-tensor/src/tests/ops/arange_step.rs | 27 +++++++++++++++++++ burn-tensor/src/tests/ops/mod.rs | 1 + 5 files changed, 86 insertions(+), 4 deletions(-) create mode 100644 burn-tensor/src/tests/ops/arange_step.rs diff --git a/burn-tensor/src/tensor/api/int.rs b/burn-tensor/src/tensor/api/int.rs index b973d9310..308dfd724 100644 --- a/burn-tensor/src/tensor/api/int.rs +++ b/burn-tensor/src/tensor/api/int.rs @@ -5,14 +5,44 @@ impl Tensor where B: Backend, { - /// Returns a new integer tensor on the default device which values are generated from the given range. + /// Returns a new integer tensor on the default device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. pub fn arange(range: Range) -> Self { Tensor::new(B::arange(range, &B::Device::default())) } - /// Returns a new integer tensor on the specified device which values are generated from the given range. + + /// Returns a new integer tensor on the default device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + /// * `step` - The step between each value. + pub fn arange_step(range: Range, step: usize) -> Self { + Tensor::new(B::arange_step(range, step, &B::Device::default())) + } + + /// Returns a new integer tensor on the specified device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + /// * `device` - The device to create the tensor on. pub fn arange_device(range: Range, device: &B::Device) -> Self { Tensor::new(B::arange(range, device)) } + + /// Returns a new integer tensor on the specified device. + /// + /// # Arguments + /// + /// * `range` - The range of values to generate. + /// * `step` - The step between each value. + pub fn arange_step_device(range: Range, step: usize, device: &B::Device) -> Self { + Tensor::new(B::arange_step(range, step, device)) + } } impl Tensor diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index 2321cc242..7ef21830a 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -136,12 +136,35 @@ pub trait TensorOps { /// # Returns /// /// The tensor with the given values. + /// + /// # Remarks + /// + /// Uses `arange_step` with a step size of 1 under the hood. fn arange(range: Range, device: &B::Device) -> B::IntTensorPrimitive<1> { - let shape = Shape::new([range.end - range.start]); + Self::arange_step(range, 1, device) + } + + /// Creates a new tensor with values from the given range with the given step size. + /// + /// # Arguments + /// + /// * `range` - The range of values. + /// * `step` - The step size. + /// * `device` - The device to create the tensor on. + /// + /// # Returns + /// + /// The tensor with the given values. + fn arange_step( + range: Range, + step: usize, + device: &B::Device, + ) -> B::IntTensorPrimitive<1> { let value = range - .into_iter() + .step_by(step) .map(|i| (i as i64).elem()) .collect::>(); + let shape = Shape::new([value.len()]); let data = Data::new(value, shape); B::int_from_data(data, device) } diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index ca4e4e79c..20d26513a 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -28,6 +28,7 @@ macro_rules! testgen_all { burn_tensor::testgen_add!(); burn_tensor::testgen_aggregation!(); burn_tensor::testgen_arange!(); + burn_tensor::testgen_arange_step!(); burn_tensor::testgen_arg!(); burn_tensor::testgen_cat!(); burn_tensor::testgen_cos!(); diff --git a/burn-tensor/src/tests/ops/arange_step.rs b/burn-tensor/src/tests/ops/arange_step.rs new file mode 100644 index 000000000..c75c5b865 --- /dev/null +++ b/burn-tensor/src/tests/ops/arange_step.rs @@ -0,0 +1,27 @@ +#[burn_tensor_testgen::testgen(arange_step)] +mod tests { + use super::*; + use burn_tensor::{Data, Int, Tensor}; + + #[test] + fn test_arange_step() { + // Test correct sequence of numbers when the range is 0..9 and the step is 1 + let tensor = Tensor::::arange_step(0..9, 1); + assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8])); + + // Test correct sequence of numbers when the range is 0..3 and the step is 2 + let tensor = Tensor::::arange_step(0..3, 2); + assert_eq!(tensor.into_data(), Data::from([0, 2])); + + // Test correct sequence of numbers when the range is 0..2 and the step is 5 + let tensor = Tensor::::arange_step(0..2, 5); + assert_eq!(tensor.into_data(), Data::from([0])); + } + + #[test] + #[should_panic] + fn test_arange_step_panic() { + // Test that arange_step panics when the step is 0 + let _tensor = Tensor::::arange_step(0..3, 0); + } +} diff --git a/burn-tensor/src/tests/ops/mod.rs b/burn-tensor/src/tests/ops/mod.rs index d5c01d837..81163657c 100644 --- a/burn-tensor/src/tests/ops/mod.rs +++ b/burn-tensor/src/tests/ops/mod.rs @@ -1,6 +1,7 @@ mod add; mod aggregation; mod arange; +mod arange_step; mod arg; mod cat; mod cos;