mirror of https://github.com/tracel-ai/burn.git
Add arange with steps op for int tensor (#490)
This commit is contained in:
parent
53c088209d
commit
ef421f0ae9
|
@ -5,14 +5,44 @@ impl<B> Tensor<B, 1, Int>
|
|||
where
|
||||
B: Backend,
|
||||
{
|
||||
/// Returns a new integer tensor on the default device which values are generated from the given range.
|
||||
/// Returns a new integer tensor on the default device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values to generate.
|
||||
pub fn arange(range: Range<usize>) -> Self {
|
||||
Tensor::new(B::arange(range, &B::Device::default()))
|
||||
}
|
||||
/// Returns a new integer tensor on the specified device which values are generated from the given range.
|
||||
|
||||
/// Returns a new integer tensor on the default device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values to generate.
|
||||
/// * `step` - The step between each value.
|
||||
pub fn arange_step(range: Range<usize>, step: usize) -> Self {
|
||||
Tensor::new(B::arange_step(range, step, &B::Device::default()))
|
||||
}
|
||||
|
||||
/// Returns a new integer tensor on the specified device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values to generate.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
pub fn arange_device(range: Range<usize>, device: &B::Device) -> Self {
|
||||
Tensor::new(B::arange(range, device))
|
||||
}
|
||||
|
||||
/// Returns a new integer tensor on the specified device.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values to generate.
|
||||
/// * `step` - The step between each value.
|
||||
pub fn arange_step_device(range: Range<usize>, step: usize, device: &B::Device) -> Self {
|
||||
Tensor::new(B::arange_step(range, step, device))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B> Tensor<B, D, Int>
|
||||
|
|
|
@ -136,12 +136,35 @@ pub trait TensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given values.
|
||||
///
|
||||
/// # Remarks
|
||||
///
|
||||
/// Uses `arange_step` with a step size of 1 under the hood.
|
||||
fn arange(range: Range<usize>, device: &B::Device) -> B::IntTensorPrimitive<1> {
|
||||
let shape = Shape::new([range.end - range.start]);
|
||||
Self::arange_step(range, 1, device)
|
||||
}
|
||||
|
||||
/// Creates a new tensor with values from the given range with the given step size.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `range` - The range of values.
|
||||
/// * `step` - The step size.
|
||||
/// * `device` - The device to create the tensor on.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given values.
|
||||
fn arange_step(
|
||||
range: Range<usize>,
|
||||
step: usize,
|
||||
device: &B::Device,
|
||||
) -> B::IntTensorPrimitive<1> {
|
||||
let value = range
|
||||
.into_iter()
|
||||
.step_by(step)
|
||||
.map(|i| (i as i64).elem())
|
||||
.collect::<Vec<B::IntElem>>();
|
||||
let shape = Shape::new([value.len()]);
|
||||
let data = Data::new(value, shape);
|
||||
B::int_from_data(data, device)
|
||||
}
|
||||
|
|
|
@ -28,6 +28,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_add!();
|
||||
burn_tensor::testgen_aggregation!();
|
||||
burn_tensor::testgen_arange!();
|
||||
burn_tensor::testgen_arange_step!();
|
||||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_cat!();
|
||||
burn_tensor::testgen_cos!();
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
#[burn_tensor_testgen::testgen(arange_step)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_arange_step() {
|
||||
// Test correct sequence of numbers when the range is 0..9 and the step is 1
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..9, 1);
|
||||
assert_eq!(tensor.into_data(), Data::from([0, 1, 2, 3, 4, 5, 6, 7, 8]));
|
||||
|
||||
// Test correct sequence of numbers when the range is 0..3 and the step is 2
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 2);
|
||||
assert_eq!(tensor.into_data(), Data::from([0, 2]));
|
||||
|
||||
// Test correct sequence of numbers when the range is 0..2 and the step is 5
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..2, 5);
|
||||
assert_eq!(tensor.into_data(), Data::from([0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_arange_step_panic() {
|
||||
// Test that arange_step panics when the step is 0
|
||||
let _tensor = Tensor::<TestBackend, 1, Int>::arange_step(0..3, 0);
|
||||
}
|
||||
}
|
|
@ -1,6 +1,7 @@
|
|||
mod add;
|
||||
mod aggregation;
|
||||
mod arange;
|
||||
mod arange_step;
|
||||
mod arg;
|
||||
mod cat;
|
||||
mod cos;
|
||||
|
|
Loading…
Reference in New Issue