diff --git a/burn-autodiff/src/ops/int_tensor.rs b/burn-autodiff/src/ops/int_tensor.rs index 983e429f2..1e031bd6e 100644 --- a/burn-autodiff/src/ops/int_tensor.rs +++ b/burn-autodiff/src/ops/int_tensor.rs @@ -195,6 +195,21 @@ impl IntTensorOps> for ADBackendDecorator { B::int_lower_equal_elem(lhs, rhs) } + fn int_index_select( + tensor: IntTensor, + indexes: IntTensor, + ) -> IntTensor { + B::int_index_select(tensor, indexes) + } + + fn int_index_select_assign( + tensor: IntTensor, + indexes: IntTensor, + value: IntTensor, + ) -> IntTensor { + B::int_index_select_assign(tensor, indexes, value) + } + fn int_index_select_dim( tensor: IntTensor, dim: usize, diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index aa0a4f6a1..45410a251 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -420,6 +420,94 @@ impl TensorOps> for ADBackendDecorator { } } + fn index_select( + tensor: ADTensor, + indexes: IntTensor, + ) -> ADTensor { + #[derive(Debug)] + struct IndexSelect; + + impl Backward for IndexSelect { + type State = (IntTensor, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (indexes, shape, device) = ops.state; + + unary::(ops.parents, ops.node, grads, |grad| { + let zeros = B::zeros(shape, &device); + B::index_select_assign(zeros, indexes, grad) + }); + } + } + + match IndexSelect + .prepare([tensor.node], [tensor.graph]) + .statefull() + { + OpsKind::Tracked(prep) => prep.finish( + ( + indexes.clone(), + B::shape(&tensor.primitive), + B::device(&tensor.primitive), + ), + B::index_select(tensor.primitive, indexes), + ), + OpsKind::UnTracked(prep) => prep.finish(B::index_select(tensor.primitive, indexes)), + } + } + + fn index_select_assign( + tensor: ADTensor, + indexes: IntTensor, + value: ADTensor, + ) -> ADTensor { + #[derive(Debug)] + struct IndexSelectAssign; + + impl Backward for IndexSelectAssign { + type State = (IntTensor, Shape, Shape, B::Device); + + fn backward(self, ops: Ops, grads: &mut Gradients) { + let (indexes, shape_lhs, shape_rhs, device) = ops.state; + let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes)); + + binary::( + ops.parents, + ops.node, + grads, + |grad| { + let zeros = B::zeros(shape_lhs, &device); + B::index_select_assign(grad, indexes_4lhs.unwrap(), zeros) + }, + |grad| { + let zeros = B::zeros(shape_rhs, &device); + B::index_select_assign(zeros, indexes_4rhs.unwrap(), grad) + }, + ); + } + } + + match IndexSelectAssign + .prepare([tensor.node, value.node], [tensor.graph, value.graph]) + .statefull() + { + OpsKind::Tracked(prep) => prep.finish( + ( + indexes.clone(), + B::shape(&tensor.primitive), + B::shape(&value.primitive), + B::device(&value.primitive), + ), + B::index_select_assign(tensor.primitive, indexes, value.primitive), + ), + OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign( + tensor.primitive, + indexes, + value.primitive, + )), + } + } + fn index_select_dim( tensor: ADTensor, dim: usize, diff --git a/burn-autodiff/src/tests/index_select.rs b/burn-autodiff/src/tests/index_select.rs new file mode 100644 index 000000000..89d7f1e2e --- /dev/null +++ b/burn-autodiff/src/tests/index_select.rs @@ -0,0 +1,54 @@ +#[burn_tensor_testgen::testgen(ad_index_select)] +mod tests { + use super::*; + use burn_tensor::Data; + + #[test] + fn test_index_select_grad() { + let tensor_1 = + TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); + let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().index_select(indexes); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[94., 150., 187.], [242., 305., 304.]]) + ); + } + + #[test] + fn test_index_select_assign_grad() { + let tensor_1 = + TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); + let values = + TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad(); + let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1 + .clone() + .index_select_assign(indexes, values.clone()); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = values.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[127., 181., 235.], [226., 316., 406.]]) + ); + assert_eq!( + grad_2.into_data(), + Data::from([[19., 19., 19.], [64., 64., 64.]]) + ); + } +} diff --git a/burn-autodiff/src/tests/index_select_dim.rs b/burn-autodiff/src/tests/index_select_dim.rs index 586dee910..dc31d418f 100644 --- a/burn-autodiff/src/tests/index_select_dim.rs +++ b/burn-autodiff/src/tests/index_select_dim.rs @@ -4,7 +4,27 @@ mod tests { use burn_tensor::Data; #[test] - fn test_select_grad() { + fn test_index_select_dim_grad() { + let tensor_1 = + TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); + let indexes = TestADTensor::from_data(Data::from([1, 0])); + + let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); + let tensor_3 = tensor_1.clone().index_select_dim(0, indexes); + let tensor_4 = tensor_2.matmul(tensor_3); + + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + + assert_eq!( + grad_1.into_data(), + Data::from([[109., 148., 187.], [37., 58., 79.]]) + ); + } + + #[test] + fn test_index_select_dim_assign_grad() { let tensor_1 = TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); let values = diff --git a/burn-autodiff/src/tests/mod.rs b/burn-autodiff/src/tests/mod.rs index daf02d0b8..57534e9c6 100644 --- a/burn-autodiff/src/tests/mod.rs +++ b/burn-autodiff/src/tests/mod.rs @@ -11,6 +11,7 @@ mod div; mod erf; mod exp; mod index; +mod index_select; mod index_select_dim; mod log; mod log1p; @@ -53,8 +54,9 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_div!(); burn_autodiff::testgen_ad_erf!(); burn_autodiff::testgen_ad_exp!(); - burn_autodiff::testgen_ad_index_select_dim!(); burn_autodiff::testgen_ad_index!(); + burn_autodiff::testgen_ad_index_select!(); + burn_autodiff::testgen_ad_index_select_dim!(); burn_autodiff::testgen_ad_log!(); burn_autodiff::testgen_ad_log1p!(); burn_autodiff::testgen_ad_mask!(); diff --git a/burn-ndarray/src/ops/base.rs b/burn-ndarray/src/ops/base.rs index bcd66a96c..d8057a3f6 100644 --- a/burn-ndarray/src/ops/base.rs +++ b/burn-ndarray/src/ops/base.rs @@ -1,6 +1,8 @@ use alloc::vec::Vec; use burn_tensor::Data; use core::{marker::PhantomData, ops::Range}; +use ndarray::s; +use ndarray::Array2; use burn_tensor::Shape; use ndarray::Axis; @@ -204,6 +206,85 @@ where } } + pub fn index_select( + tensor: NdArrayTensor, + indexes: NdArrayTensor, + ) -> NdArrayTensor { + let (shape_tensor, shape_indexes) = (tensor.shape(), indexes.shape()); + let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indexes.dims[D - 1]); + let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes); + + let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array; + let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; + let mut output = Array2::zeros((batch_size, size_index)); + + for b in 0..batch_size { + let indexes = indexes.slice(s!(b, ..)); + + for (i, index) in indexes.iter().enumerate() { + output[[b, i]] = tensor[[b, *index as usize]]; + } + } + + NdArrayOps::reshape( + NdArrayTensor::::new(output.into_shared().into_dyn()), + shape_indexes, + ) + } + + pub fn index_select_assign( + tensor: NdArrayTensor, + indexes: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + let (shape_tensor, shape_indexes, shape_value) = + (tensor.shape(), indexes.shape(), value.shape()); + let (size_tensor, size_index, size_value) = ( + shape_tensor.dims[D - 1], + shape_indexes.dims[D - 1], + shape_value.dims[D - 1], + ); + let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes); + + if shape_value != shape_indexes { + panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indexes.dims, shape_value.dims); + } + + let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array; + let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array; + let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; + + for b in 0..batch_size { + let indexes = indexes.slice(s!(b, ..)); + + for (i, index) in indexes.iter().enumerate() { + let index = *index as usize; + tensor[[b, index]] = tensor[[b, index]] + value[[b, i]]; + } + } + + NdArrayOps::reshape( + NdArrayTensor::::new(tensor.into_shared().into_dyn()), + shape_tensor, + ) + } + + fn index_select_batch_size( + shape_tensor: &Shape, + shape_indexes: &Shape, + ) -> usize { + let mut batch_size = 1; + + for i in 0..D - 1 { + if shape_tensor.dims[i] != shape_indexes.dims[i] { + panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indexes.dims); + } + batch_size *= shape_indexes.dims[i]; + } + + batch_size + } + pub fn index_select_dim( tensor: NdArrayTensor, dim: usize, diff --git a/burn-ndarray/src/ops/int_tensor.rs b/burn-ndarray/src/ops/int_tensor.rs index abed05a49..310711276 100644 --- a/burn-ndarray/src/ops/int_tensor.rs +++ b/burn-ndarray/src/ops/int_tensor.rs @@ -264,6 +264,21 @@ impl IntTensorOps> for NdArrayBackend< NdArrayMathOps::mean_dim(tensor, dim) } + fn int_index_select( + tensor: NdArrayTensor, + indexes: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::index_select(tensor, indexes) + } + + fn int_index_select_assign( + tensor: NdArrayTensor, + indexes: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::index_select_assign(tensor, indexes, value) + } + fn int_index_select_dim( tensor: NdArrayTensor, dim: usize, diff --git a/burn-ndarray/src/ops/tensor.rs b/burn-ndarray/src/ops/tensor.rs index 7a24014e6..2641e6b37 100644 --- a/burn-ndarray/src/ops/tensor.rs +++ b/burn-ndarray/src/ops/tensor.rs @@ -151,6 +151,21 @@ impl TensorOps> for NdArrayBackend NdArrayOps::reshape(tensor, shape) } + fn index_select( + tensor: NdArrayTensor, + indexes: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::index_select(tensor, indexes) + } + + fn index_select_assign( + tensor: NdArrayTensor, + indexes: NdArrayTensor, + value: NdArrayTensor, + ) -> NdArrayTensor { + NdArrayMathOps::index_select_assign(tensor, indexes, value) + } + fn index_select_dim( tensor: NdArrayTensor, dim: usize, diff --git a/burn-ndarray/src/tensor.rs b/burn-ndarray/src/tensor.rs index c923b9c8d..38ba5aa12 100644 --- a/burn-ndarray/src/tensor.rs +++ b/burn-ndarray/src/tensor.rs @@ -58,7 +58,12 @@ macro_rules! reshape { ) => {{ let dim = $crate::to_typed_dims!($n, $shape.dims, justdim); let safe_into_shape = - $array.is_standard_layout() || $array.raw_view().reversed_axes().is_standard_layout(); + $array.is_standard_layout() || + ( + $array.ndim() > 1 && + $array.raw_view().reversed_axes().is_standard_layout() + ); + let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape { true => $array .into_shape(dim) diff --git a/burn-tch/src/ops/base.rs b/burn-tch/src/ops/base.rs index 2cfeadfc8..21267ddc3 100644 --- a/burn-tch/src/ops/base.rs +++ b/burn-tch/src/ops/base.rs @@ -45,6 +45,25 @@ impl TchOps { TchTensor::new(tensor_original) } + pub fn index_select( + tensor: TchTensor, + indexes: TchTensor, + ) -> TchTensor { + let tensor = tensor.tensor.gather((D - 1) as i64, &indexes.tensor, false); + TchTensor::new(tensor) + } + + pub fn index_select_assign( + tensor: TchTensor, + indexes: TchTensor, + value: TchTensor, + ) -> TchTensor { + let tensor = tensor + .tensor + .scatter_add((D - 1) as i64, &indexes.tensor, &value.tensor); + TchTensor::new(tensor) + } + pub fn index_select_dim( tensor: TchTensor, dim: usize, diff --git a/burn-tch/src/ops/int_tensor.rs b/burn-tch/src/ops/int_tensor.rs index a9bf92af7..d66bcfe1d 100644 --- a/burn-tch/src/ops/int_tensor.rs +++ b/burn-tch/src/ops/int_tensor.rs @@ -243,6 +243,20 @@ impl IntTensorOps> for TchBackend { fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::mean_dim(tensor, dim) } + fn int_index_select( + tensor: TchTensor, + indexes: TchTensor, + ) -> TchTensor { + TchOps::index_select(tensor, indexes) + } + + fn int_index_select_assign( + tensor: TchTensor, + indexes: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::index_select_assign(tensor, indexes, value) + } fn int_index_select_dim( tensor: TchTensor, diff --git a/burn-tch/src/ops/tensor.rs b/burn-tch/src/ops/tensor.rs index 860ca4144..4f1fc3ff8 100644 --- a/burn-tch/src/ops/tensor.rs +++ b/burn-tch/src/ops/tensor.rs @@ -190,6 +190,21 @@ impl TensorOps> for TchBackend { TchTensor::new(tensor) } + fn index_select( + tensor: TchTensor, + indexes: TchTensor, + ) -> TchTensor { + TchOps::index_select(tensor, indexes) + } + + fn index_select_assign( + tensor: TchTensor, + indexes: TchTensor, + value: TchTensor, + ) -> TchTensor { + TchOps::index_select_assign(tensor, indexes, value) + } + fn index_select_dim( tensor: TchTensor, dim: usize, diff --git a/burn-tensor/src/tensor/api/float.rs b/burn-tensor/src/tensor/api/float.rs index 69e30c5f1..2e138f513 100644 --- a/burn-tensor/src/tensor/api/float.rs +++ b/burn-tensor/src/tensor/api/float.rs @@ -342,27 +342,6 @@ where self.reshape(shape) } - /// Index the tensor along the given dimension using the given indexes. - pub fn index_select_dim(self, dim: usize, indexes: Tensor) -> Self { - Self::new(B::index_select_dim(self.primitive, dim, indexes.primitive)) - } - - /// Return a new tensor with the same dimension, but with the values added to - /// the original tensor using the corresponding indexes provided along the given dimension. - pub fn index_select_dim_assign( - self, - dim: usize, - indexes: Tensor, - values: Tensor, - ) -> Self { - Self::new(B::index_select_dim_assign( - self.primitive, - dim, - indexes.primitive, - values.primitive, - )) - } - pub(crate) fn relu(self) -> Self { Self::new(B::relu(self.primitive)) } diff --git a/burn-tensor/src/tensor/api/numeric.rs b/burn-tensor/src/tensor/api/numeric.rs index 916af0140..4c90ab753 100644 --- a/burn-tensor/src/tensor/api/numeric.rs +++ b/burn-tensor/src/tensor/api/numeric.rs @@ -170,6 +170,52 @@ where pub fn lower_equal_elem(self, other: E) -> Tensor { K::lower_equal_elem(self.primitive, other.elem()) } + + /// Select the tensor elements corresponding to the given indexes. + /// + /// # Notes + /// + /// The index tensor shoud have the same shape as the original tensor except for the last + /// dimension. + pub fn index_select(self, indexes: Tensor) -> Self { + Self::new(K::index_select(self.primitive, indexes)) + } + + /// Assign the selected elements corresponding to the given indexes from the value tensor + /// to the original tensor using sum reduction. + /// + /// # Notes + /// + /// The index tensor shoud have the same shape as the original tensor except for the last + /// dimension. The value and index tensors should have the same shape. + pub fn index_select_assign(self, indexes: Tensor, values: Self) -> Self { + Self::new(K::index_select_assign( + self.primitive, + indexes, + values.primitive, + )) + } + + /// Select the tensor elements along the given dimension corresponding to the given indexes. + pub fn index_select_dim(self, dim: usize, indexes: Tensor) -> Self { + Self::new(K::index_select_dim(self.primitive, dim, indexes)) + } + + /// Assign the selected elements along the given dimension corresponding to the given indexes + /// from the value tensor to the original tensor using sum reduction. + pub fn index_select_dim_assign( + self, + dim: usize, + indexes: Tensor, + values: Tensor, + ) -> Self { + Self::new(K::index_select_dim_assign( + self.primitive, + dim, + indexes, + values.primitive, + )) + } } /// Trait that list all operations that can be applied on all numerical tensors. @@ -234,6 +280,15 @@ pub trait Numeric: TensorKind { lhs: Self::Primitive, rhs: Self::Elem, ) -> Tensor; + fn index_select( + tensor: Self::Primitive, + indexes: Tensor, + ) -> Self::Primitive; + fn index_select_assign( + tensor: Self::Primitive, + indexes: Tensor, + values: Self::Primitive, + ) -> Self::Primitive; fn index_select_dim( tensor: Self::Primitive, dim: usize, @@ -389,6 +444,20 @@ impl Numeric for Int { ) -> Self::Primitive { B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values) } + fn index_select( + tensor: Self::Primitive, + indexes: Tensor, + ) -> Self::Primitive { + B::int_index_select(tensor, indexes.primitive) + } + + fn index_select_assign( + tensor: Self::Primitive, + indexes: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::int_index_select_assign(tensor, indexes.primitive, values) + } } impl Numeric for Float { @@ -533,6 +602,21 @@ impl Numeric for Float { ) -> Self::Primitive { B::index_select_dim_assign(tensor, dim, indexes.primitive, values) } + + fn index_select( + tensor: Self::Primitive, + indexes: Tensor, + ) -> Self::Primitive { + B::index_select(tensor, indexes.primitive) + } + + fn index_select_assign( + tensor: Self::Primitive, + indexes: Tensor, + values: Self::Primitive, + ) -> Self::Primitive { + B::index_select_assign(tensor, indexes.primitive, values) + } } impl core::ops::Add for Tensor diff --git a/burn-tensor/src/tensor/ops/int_tensor.rs b/burn-tensor/src/tensor/ops/int_tensor.rs index a18b61a0a..485594e36 100644 --- a/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/burn-tensor/src/tensor/ops/int_tensor.rs @@ -34,6 +34,15 @@ pub trait IntTensorOps { indexes: [Range; D2], value: B::IntTensorPrimitive, ) -> B::IntTensorPrimitive; + fn int_index_select( + tensor: B::IntTensorPrimitive, + indexes: B::IntTensorPrimitive, + ) -> B::IntTensorPrimitive; + fn int_index_select_assign( + tensor: B::IntTensorPrimitive, + indexes: B::IntTensorPrimitive, + value: B::IntTensorPrimitive, + ) -> B::IntTensorPrimitive; fn int_index_select_dim( tensor: B::IntTensorPrimitive, dim: usize, diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index 67a5a8520..350ec325e 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -115,6 +115,15 @@ pub trait TensorOps { tensor: B::TensorPrimitive, shape: Shape, ) -> B::TensorPrimitive; + fn index_select( + tensor: B::TensorPrimitive, + indexes: B::IntTensorPrimitive, + ) -> B::TensorPrimitive; + fn index_select_assign( + tensor: B::TensorPrimitive, + indexes: B::IntTensorPrimitive, + value: B::TensorPrimitive, + ) -> B::TensorPrimitive; fn index_select_dim( tensor: B::TensorPrimitive, dim: usize, diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index 9cecf91b9..cee10a725 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -28,8 +28,9 @@ macro_rules! testgen_all { burn_tensor::testgen_exp!(); burn_tensor::testgen_log!(); burn_tensor::testgen_log1p!(); - burn_tensor::testgen_index_select_dim!(); burn_tensor::testgen_index!(); + burn_tensor::testgen_index_select!(); + burn_tensor::testgen_index_select_dim!(); burn_tensor::testgen_map_comparison!(); burn_tensor::testgen_mask!(); burn_tensor::testgen_matmul!(); diff --git a/burn-tensor/src/tests/ops/index_select.rs b/burn-tensor/src/tests/ops/index_select.rs new file mode 100644 index 000000000..285a6b579 --- /dev/null +++ b/burn-tensor/src/tests/ops/index_select.rs @@ -0,0 +1,63 @@ +#[burn_tensor_testgen::testgen(index_select)] +mod tests { + use super::*; + use burn_tensor::{Data, Tensor}; + + #[test] + fn should_select_1d() { + let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0])); + let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); + + let output = tensor.index_select(indexes); + + assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); + } + + #[test] + fn should_select_2d() { + let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); + let indexes = TestTensorInt::from_data(Data::from([[2, 1, 0, 0], [2, 0, 1, 2]])); + + let output = tensor.index_select(indexes); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]) + ); + } + + #[test] + fn should_select_2d_only_1dim() { + let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); + let indexes = TestTensorInt::from_data(Data::from([[1, 2]])).reshape([2, 1]); + + let output = tensor.index_select(indexes); + + assert_eq!(output.into_data(), Data::from([[1.0], [5.0]])); + } + + #[test] + fn should_select_assign_1d() { + let tensor = TestTensor::from_data(Data::from([0.0, 0.0, 0.0])); + let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0])); + let indexes = TestTensorInt::from_data(Data::from([1, 0, 2])); + + let output = tensor.index_select_assign(indexes, values); + + assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0])); + } + + #[test] + fn should_select_assign_2d() { + let tensor = TestTensor::from_data(Data::from([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])); + let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])); + let indexes = TestTensorInt::from_data(Data::from([[1, 0, 2], [1, 2, 0]])); + + let output = tensor.index_select_assign(indexes, values); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]) + ); + } +} diff --git a/burn-tensor/src/tests/ops/mod.rs b/burn-tensor/src/tests/ops/mod.rs index 2c609b9d6..d2a8d4880 100644 --- a/burn-tensor/src/tests/ops/mod.rs +++ b/burn-tensor/src/tests/ops/mod.rs @@ -6,6 +6,7 @@ mod div; mod erf; mod exp; mod index; +mod index_select; mod index_select_dim; mod log; mod log1p;