Add arange with steps op for int tensor (#490)

This commit is contained in:
Dilshod Tadjibaev 2023-07-13 16:09:57 -05:00 committed by GitHub
parent 53c088209d
commit ef421f0ae9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 86 additions and 4 deletions

View File

@ -5,14 +5,44 @@ impl<B> Tensor<B, 1, Int>
where
B: Backend,
{
/// Returns a new integer tensor on the default device which values are generated from the given range.
/// Returns a new integer tensor on the default device.
///
/// # Arguments
///
/// * `range` - The range of values to generate.
pub fn arange(range: Range<usize>) -> Self {
Tensor::new(B::arange(range, &B::Device::default()))
}
/// Returns a new integer tensor on the specified device which values are generated from the given range.
/// Returns a new integer tensor on the default device.
///
/// # Arguments
///
/// * `range` - The range of values to generate.
/// * `step` - The step between each value.
pub fn arange_step(range: Range<usize>, step: usize) -> Self {
Tensor::new(B::arange_step(range, step, &B::Device::default()))
}
/// Returns a new integer tensor on the specified device.
///
/// # Arguments
///
/// * `range` - The range of values to generate.
/// * `device` - The device to create the tensor on.
pub fn arange_device(range: Range<usize>, device: &B::Device) -> Self {
Tensor::new(B::arange(range, device))
}
/// Returns a new integer tensor on the specified device.
///
/// # Arguments
///
/// * `range` - The range of values to generate.
/// * `step` - The step between each value.
pub fn arange_step_device(range: Range<usize>, step: usize, device: &B::Device) -> Self {
Tensor::new(B::arange_step(range, step, device))
}
}
impl<const D: usize, B> Tensor<B, D, Int>

View File

@ -136,12 +136,35 @@ pub trait TensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the given values.
///
/// # Remarks
///
/// Uses `arange_step` with a step size of 1 under the hood.
fn arange(range: Range<usize>, device: &B::Device) -> B::IntTensorPrimitive<1> {
let shape = Shape::new([range.end - range.start]);
Self::arange_step(range, 1, device)
}
/// Creates a new tensor with values from the given range with the given step size.
///
/// # Arguments
///
/// * `range` - The range of values.
/// * `step` - The step size.
/// * `device` - The device to create the tensor on.
///
/// # Returns
///
/// The tensor with the given values.
fn arange_step(
range: Range<usize>,
step: usize,
device: &B::Device,
) -> B::IntTensorPrimitive<1> {
let value = range
.into_iter()
.step_by(step)
.map(|i| (i as i64).elem())
.collect::<Vec<B::IntElem>>();
let shape = Shape::new([value.len()]);
let data = Data::new(value, shape);
B::int_from_data(data, device)
}

View File

@ -28,6 +28,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_add!();
burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arange!();
burn_tensor::testgen_arange_step!();
burn_tensor::testgen_arg!();
burn_tensor::testgen_cat!();
burn_tensor::testgen_cos!();

View File

@ -0,0 +1,27 @@
#[burn_tensor_testgen::testgen(arange_step)]
mod tests {
use super::*;
use burn_tensor::{Data, Int, Tensor};
#[test]
fn test_arange_step() {
// Test correct sequence of numbers when the range is 0..9 and the step is 1
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..9, 1);
assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8]));
// Test correct sequence of numbers when the range is 0..3 and the step is 2
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 2);
assert_eq!(tensor.into_data(), Data::from([0, 2]));
// Test correct sequence of numbers when the range is 0..2 and the step is 5
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..2, 5);
assert_eq!(tensor.into_data(), Data::from([0]));
}
#[test]
#[should_panic]
fn test_arange_step_panic() {
// Test that arange_step panics when the step is 0
let _tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 0);
}
}

View File

@ -1,6 +1,7 @@
mod add;
mod aggregation;
mod arange;
mod arange_step;
mod arg;
mod cat;
mod cos;