diff --git a/burn-autodiff/src/ops/int_tensor.rs b/burn-autodiff/src/ops/int_tensor.rs index ecdd53f1b..60f32ef29 100644 --- a/burn-autodiff/src/ops/int_tensor.rs +++ b/burn-autodiff/src/ops/int_tensor.rs @@ -195,19 +195,21 @@ impl IntTensorOps> for ADBackendDecorator { B::int_lower_equal_elem(lhs, rhs) } - fn int_index_select( + fn int_gather( + dim: usize, tensor: IntTensor, indexes: IntTensor, ) -> IntTensor { - B::int_index_select(tensor, indexes) + B::int_gather(dim, tensor, indexes) } - fn int_index_select_assign( + fn int_scatter( + dim: usize, tensor: IntTensor, indexes: IntTensor, value: IntTensor, ) -> IntTensor { - B::int_index_select_assign(tensor, indexes, value) + B::int_scatter(dim, tensor, indexes, value) } fn int_index_select_dim( diff --git a/burn-autodiff/src/ops/maxmin.rs b/burn-autodiff/src/ops/maxmin.rs index 93129f144..758d9c7e3 100644 --- a/burn-autodiff/src/ops/maxmin.rs +++ b/burn-autodiff/src/ops/maxmin.rs @@ -14,7 +14,7 @@ impl Backward for MaxMinDim { let device = B::device(&grad); let zeros = B::zeros(shape, &device); - B::index_select_assign(zeros, indexes, grad) + B::scatter(D - 1, zeros, indexes, grad) }); } } diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index b7a1a107c..b5910b91c 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -438,55 +438,55 @@ impl TensorOps> for ADBackendDecorator { } } - fn index_select( + fn gather( + dim: usize, tensor: ADTensor, indexes: IntTensor, ) -> ADTensor { #[derive(Debug)] - struct IndexSelect; + struct Gather; - impl Backward for IndexSelect { - type State = (IntTensor, Shape, B::Device); + impl Backward for Gather { + type State = (usize, IntTensor, Shape, B::Device); fn backward(self, ops: Ops, grads: &mut Gradients) { - let (indexes, shape, device) = ops.state; + let (dim, indexes, shape, device) = ops.state; unary::(ops.parents, ops.node, grads, |grad| { let zeros = B::zeros(shape, &device); - B::index_select_assign(zeros, indexes, grad) + B::scatter(dim, zeros, indexes, grad) }); } } - match IndexSelect - .prepare([tensor.node], [tensor.graph]) - .statefull() - { + match Gather.prepare([tensor.node], [tensor.graph]).statefull() { OpsKind::Tracked(prep) => prep.finish( ( + dim, indexes.clone(), B::shape(&tensor.primitive), B::device(&tensor.primitive), ), - B::index_select(tensor.primitive, indexes), + B::gather(dim, tensor.primitive, indexes), ), - OpsKind::UnTracked(prep) => prep.finish(B::index_select(tensor.primitive, indexes)), + OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indexes)), } } - fn index_select_assign( + fn scatter( + dim: usize, tensor: ADTensor, indexes: IntTensor, value: ADTensor, ) -> ADTensor { #[derive(Debug)] - struct IndexSelectAssign; + struct Scatter; - impl Backward for IndexSelectAssign { - type State = (IntTensor, Shape, Shape, B::Device); + impl Backward for Scatter { + type State = (usize, IntTensor, Shape, Shape, B::Device); fn backward(self, ops: Ops, grads: &mut Gradients) { - let (indexes, shape_lhs, shape_rhs, device) = ops.state; + let (dim, indexes, shape_lhs, shape_rhs, device) = ops.state; let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes)); binary::( @@ -495,38 +495,37 @@ impl TensorOps> for ADBackendDecorator { grads, |grad| { let zeros = B::zeros(shape_lhs, &device); - B::index_select_assign(grad, indexes_4lhs.unwrap(), zeros) + B::scatter(dim, grad, indexes_4lhs.unwrap(), zeros) }, |grad| { let zeros = B::zeros(shape_rhs, &device); - B::index_select_assign(zeros, indexes_4rhs.unwrap(), grad) + B::scatter(dim, zeros, indexes_4rhs.unwrap(), grad) }, ); } } - match IndexSelectAssign + match Scatter .prepare([tensor.node, value.node], [tensor.graph, value.graph]) .statefull() { OpsKind::Tracked(prep) => prep.finish( ( + dim, indexes.clone(), B::shape(&tensor.primitive), B::shape(&value.primitive), B::device(&value.primitive), ), - B::index_select_assign(tensor.primitive, indexes, value.primitive), + B::scatter(dim, tensor.primitive, indexes, value.primitive), ), - OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign( - tensor.primitive, - indexes, - value.primitive, - )), + OpsKind::UnTracked(prep) => { + prep.finish(B::scatter(dim, tensor.primitive, indexes, value.primitive)) + } } } - fn index_select_dim( + fn index_select( tensor: ADTensor, dim: usize, indexes: IntTensor, @@ -542,7 +541,7 @@ impl TensorOps> for ADBackendDecorator { unary::(ops.parents, ops.node, grads, |grad| { let zeros = B::zeros(shape, &device); - B::index_select_dim_assign(zeros, dim, indexes, grad) + B::index_select_assign(zeros, dim, indexes, grad) }); } } @@ -558,15 +557,15 @@ impl TensorOps> for ADBackendDecorator { B::shape(&tensor.primitive), B::device(&tensor.primitive), ), - B::index_select_dim(tensor.primitive, dim, indexes), + B::index_select(tensor.primitive, dim, indexes), ), OpsKind::UnTracked(prep) => { - prep.finish(B::index_select_dim(tensor.primitive, dim, indexes)) + prep.finish(B::index_select(tensor.primitive, dim, indexes)) } } } - fn index_select_dim_assign( + fn index_select_assign( tensor: ADTensor, dim: usize, indexes: IntTensor, @@ -588,11 +587,11 @@ impl TensorOps> for ADBackendDecorator { grads, |grad| { let zeros = B::zeros(shape_lhs, &device); - B::index_select_dim_assign(grad, dim, indexes_4lhs.unwrap(), zeros) + B::index_select_assign(grad, dim, indexes_4lhs.unwrap(), zeros) }, |grad| { let zeros = B::zeros(shape_rhs, &device); - B::index_select_dim_assign(zeros, dim, indexes_4rhs.unwrap(), grad) + B::index_select_assign(zeros, dim, indexes_4rhs.unwrap(), grad) }, ); } @@ -610,9 +609,9 @@ impl TensorOps> for ADBackendDecorator { B::shape(&value.primitive), B::device(&value.primitive), ), - B::index_select_dim_assign(tensor.primitive, dim, indexes, value.primitive), + B::index_select_assign(tensor.primitive, dim, indexes, value.primitive), ), - OpsKind::UnTracked(prep) => prep.finish(B::index_select_dim_assign( + OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign( tensor.primitive, dim, indexes, diff --git a/burn-autodiff/src/tests/index_select_dim.rs b/burn-autodiff/src/tests/gather_scatter.rs similarity index 64% rename from burn-autodiff/src/tests/index_select_dim.rs rename to burn-autodiff/src/tests/gather_scatter.rs index dc31d418f..cb5da6a7d 100644 --- a/burn-autodiff/src/tests/index_select_dim.rs +++ b/burn-autodiff/src/tests/gather_scatter.rs @@ -1,16 +1,16 @@ -#[burn_tensor_testgen::testgen(ad_index_select_dim)] +#[burn_tensor_testgen::testgen(ad_gather_scatter)] mod tests { use super::*; use burn_tensor::Data; #[test] - fn test_index_select_dim_grad() { + fn test_gather_grad() { let tensor_1 = TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); - let indexes = TestADTensor::from_data(Data::from([1, 0])); + let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]])); let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().index_select_dim(0, indexes); + let tensor_3 = tensor_1.clone().gather(1, indexes); let tensor_4 = tensor_2.matmul(tensor_3); let grads = tensor_4.backward(); @@ -19,22 +19,20 @@ mod tests { assert_eq!( grad_1.into_data(), - Data::from([[109., 148., 187.], [37., 58., 79.]]) + Data::from([[94., 150., 187.], [242., 305., 304.]]) ); } #[test] - fn test_index_select_dim_assign_grad() { + fn test_scatter_grad() { let tensor_1 = TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); let values = TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad(); - let indexes = TestADTensor::from_data(Data::from([1, 0])); + let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]])); let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1 - .clone() - .index_select_dim_assign(0, indexes, values.clone()); + let tensor_3 = tensor_1.clone().scatter(1, indexes, values.clone()); let tensor_4 = tensor_2.matmul(tensor_3); let grads = tensor_4.backward(); @@ -44,11 +42,11 @@ mod tests { assert_eq!( grad_1.into_data(), - Data::from([[127., 199., 271.], [172., 244., 316.]]) + Data::from([[127., 181., 235.], [226., 316., 406.]]) ); assert_eq!( grad_2.into_data(), - Data::from([[64., 64., 64.], [19., 19., 19.]]) + Data::from([[19., 19., 19.], [64., 64., 64.]]) ); } } diff --git a/burn-autodiff/src/tests/index_select.rs b/burn-autodiff/src/tests/index_select.rs index 89d7f1e2e..58bcb83c3 100644 --- a/burn-autodiff/src/tests/index_select.rs +++ b/burn-autodiff/src/tests/index_select.rs @@ -7,10 +7,10 @@ mod tests { fn test_index_select_grad() { let tensor_1 = TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); - let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]])); + let indexes = TestADTensor::from_data(Data::from([1, 0])); let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); - let tensor_3 = tensor_1.clone().index_select(indexes); + let tensor_3 = tensor_1.clone().index_select(0, indexes); let tensor_4 = tensor_2.matmul(tensor_3); let grads = tensor_4.backward(); @@ -19,7 +19,7 @@ mod tests { assert_eq!( grad_1.into_data(), - Data::from([[94., 150., 187.], [242., 305., 304.]]) + Data::from([[109., 148., 187.], [37., 58., 79.]]) ); } @@ -29,12 +29,12 @@ mod tests { TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); let values = TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad(); - let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]])); + let indexes = TestADTensor::from_data(Data::from([1, 0])); let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose()); let tensor_3 = tensor_1 .clone() - .index_select_assign(indexes, values.clone()); + .index_select_assign(0, indexes, values.clone()); let tensor_4 = tensor_2.matmul(tensor_3); let grads = tensor_4.backward(); @@ -44,11 +44,11 @@ mod tests { assert_eq!( grad_1.into_data(), - Data::from([[127., 181., 235.], [226., 316., 406.]]) + Data::from([[127., 199., 271.], [172., 244., 316.]]) ); assert_eq!( grad_2.into_data(), - Data::from([[19., 19., 19.], [64., 64., 64.]]) + Data::from([[64., 64., 64.], [19., 19., 19.]]) ); } } diff --git a/burn-autodiff/src/tests/mod.rs b/burn-autodiff/src/tests/mod.rs index b11bc17bb..903b5c343 100644 --- a/burn-autodiff/src/tests/mod.rs +++ b/burn-autodiff/src/tests/mod.rs @@ -13,10 +13,10 @@ mod cross_entropy; mod div; mod erf; mod exp; +mod gather_scatter; mod gelu; mod index; mod index_select; -mod index_select_dim; mod log; mod log1p; mod mask; @@ -70,8 +70,8 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_erf!(); burn_autodiff::testgen_ad_exp!(); burn_autodiff::testgen_ad_index!(); + burn_autodiff::testgen_ad_gather_scatter!(); burn_autodiff::testgen_ad_index_select!(); - burn_autodiff::testgen_ad_index_select_dim!(); burn_autodiff::testgen_ad_log!(); burn_autodiff::testgen_ad_log1p!(); burn_autodiff::testgen_ad_mask!(); diff --git a/burn-core/src/nn/loss/cross_entropy.rs b/burn-core/src/nn/loss/cross_entropy.rs index c06e4b475..64dd40486 100644 --- a/burn-core/src/nn/loss/cross_entropy.rs +++ b/burn-core/src/nn/loss/cross_entropy.rs @@ -29,7 +29,7 @@ impl CrossEntropyLoss { let mask = self.padding_mask(&targets); let tensor = activation::log_softmax(logits, 1); - let tensor = tensor.index_select(targets.reshape([batch_size, 1])); + let tensor = tensor.gather(1, targets.reshape([batch_size, 1])); let tensor = self.apply_mask(tensor.reshape([batch_size]), mask); tensor.mean().neg() diff --git a/burn-ndarray/src/ops/base.rs b/burn-ndarray/src/ops/base.rs index 02c3d8642..a9b4bad9e 100644 --- a/burn-ndarray/src/ops/base.rs +++ b/burn-ndarray/src/ops/base.rs @@ -208,13 +208,18 @@ where } } - pub fn index_select( - tensor: NdArrayTensor, - indexes: NdArrayTensor, + pub fn gather( + dim: usize, + mut tensor: NdArrayTensor, + mut indexes: NdArrayTensor, ) -> NdArrayTensor { + if dim != D - 1 { + tensor.array.swap_axes(D - 1, dim); + indexes.array.swap_axes(D - 1, dim); + } let (shape_tensor, shape_indexes) = (tensor.shape(), indexes.shape()); let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indexes.dims[D - 1]); - let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes); + let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes); let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array; let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array; @@ -228,17 +233,30 @@ where } } - NdArrayOps::reshape( + let mut output = NdArrayOps::reshape( NdArrayTensor::::new(output.into_shared().into_dyn()), shape_indexes, - ) + ); + + if dim != D - 1 { + output.array.swap_axes(D - 1, dim); + } + + output } - pub fn index_select_assign( - tensor: NdArrayTensor, - indexes: NdArrayTensor, - value: NdArrayTensor, + pub fn scatter( + dim: usize, + mut tensor: NdArrayTensor, + mut indexes: NdArrayTensor, + mut value: NdArrayTensor, ) -> NdArrayTensor { + if dim != D - 1 { + tensor.array.swap_axes(D - 1, dim); + indexes.array.swap_axes(D - 1, dim); + value.array.swap_axes(D - 1, dim); + } + let (shape_tensor, shape_indexes, shape_value) = (tensor.shape(), indexes.shape(), value.shape()); let (size_tensor, size_index, size_value) = ( @@ -246,7 +264,7 @@ where shape_indexes.dims[D - 1], shape_value.dims[D - 1], ); - let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes); + let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes); if shape_value != shape_indexes { panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indexes.dims, shape_value.dims); @@ -265,10 +283,14 @@ where } } - NdArrayOps::reshape( + let mut output = NdArrayOps::reshape( NdArrayTensor::::new(tensor.into_shared().into_dyn()), shape_tensor, - ) + ); + if dim != D - 1 { + output.array.swap_axes(D - 1, dim); + } + output } pub fn mask_scatter( @@ -307,7 +329,7 @@ where NdArrayTensor::new(array) } - fn index_select_batch_size( + fn gather_batch_size( shape_tensor: &Shape, shape_indexes: &Shape, ) -> usize { @@ -323,7 +345,7 @@ where batch_size } - pub fn index_select_dim( + pub fn index_select( tensor: NdArrayTensor, dim: usize, indexes: NdArrayTensor, @@ -340,7 +362,7 @@ where NdArrayTensor::new(array.into_shared()) } - pub fn index_select_dim_assign( + pub fn index_select_assign( tensor: NdArrayTensor, dim: usize, indexes: NdArrayTensor, diff --git a/burn-ndarray/src/ops/int_tensor.rs b/burn-ndarray/src/ops/int_tensor.rs index 0ac8a29f5..8eaed2de6 100644 --- a/burn-ndarray/src/ops/int_tensor.rs +++ b/burn-ndarray/src/ops/int_tensor.rs @@ -280,19 +280,21 @@ impl IntTensorOps> for NdArrayBackend< NdArrayMathOps::mean_dim(tensor, dim) } - fn int_index_select( + fn int_gather( + dim: usize, tensor: NdArrayTensor, indexes: NdArrayTensor, ) -> NdArrayTensor { - NdArrayMathOps::index_select(tensor, indexes) + NdArrayMathOps::gather(dim, tensor, indexes) } - fn int_index_select_assign( + fn int_scatter( + dim: usize, tensor: NdArrayTensor, indexes: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { - NdArrayMathOps::index_select_assign(tensor, indexes, value) + NdArrayMathOps::scatter(dim, tensor, indexes, value) } fn int_index_select_dim( @@ -300,7 +302,7 @@ impl IntTensorOps> for NdArrayBackend< dim: usize, indexes: NdArrayTensor, ) -> NdArrayTensor { - NdArrayMathOps::index_select_dim(tensor, dim, indexes) + NdArrayMathOps::index_select(tensor, dim, indexes) } fn int_index_select_dim_assign( @@ -309,7 +311,7 @@ impl IntTensorOps> for NdArrayBackend< indexes: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { - NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value) + NdArrayMathOps::index_select_assign(tensor, dim, indexes, value) } fn int_argmax( tensor: NdArrayTensor, diff --git a/burn-ndarray/src/ops/tensor.rs b/burn-ndarray/src/ops/tensor.rs index f8893d8c1..6390e7689 100644 --- a/burn-ndarray/src/ops/tensor.rs +++ b/burn-ndarray/src/ops/tensor.rs @@ -150,36 +150,38 @@ impl TensorOps> for NdArrayBackend NdArrayOps::reshape(tensor, shape) } - fn index_select( + fn gather( + dim: usize, tensor: NdArrayTensor, indexes: NdArrayTensor, ) -> NdArrayTensor { - NdArrayMathOps::index_select(tensor, indexes) + NdArrayMathOps::gather(dim, tensor, indexes) } - fn index_select_assign( + fn scatter( + dim: usize, tensor: NdArrayTensor, indexes: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { - NdArrayMathOps::index_select_assign(tensor, indexes, value) + NdArrayMathOps::scatter(dim, tensor, indexes, value) } - fn index_select_dim( + fn index_select( tensor: NdArrayTensor, dim: usize, indexes: NdArrayTensor, ) -> NdArrayTensor { - NdArrayMathOps::index_select_dim(tensor, dim, indexes) + NdArrayMathOps::index_select(tensor, dim, indexes) } - fn index_select_dim_assign( + fn index_select_assign( tensor: NdArrayTensor, dim: usize, indexes: NdArrayTensor, value: NdArrayTensor, ) -> NdArrayTensor { - NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value) + NdArrayMathOps::index_select_assign(tensor, dim, indexes, value) } fn index( diff --git a/burn-tch/src/ops/base.rs b/burn-tch/src/ops/base.rs index 66e1d05c1..920e34587 100644 --- a/burn-tch/src/ops/base.rs +++ b/burn-tch/src/ops/base.rs @@ -56,17 +56,19 @@ impl TchOps { TchTensor::new(tensor_original) } - pub fn index_select( + pub fn gather( + dim: usize, tensor: TchTensor, indexes: TchTensor, ) -> TchTensor { let storage = tensor.storage.clone(); - let tensor = tensor.tensor.gather((D - 1) as i64, &indexes.tensor, false); + let tensor = tensor.tensor.gather(dim as i64, &indexes.tensor, false); TchTensor::from_existing(tensor, storage) } - pub fn index_select_assign( + pub fn scatter( + dim: usize, tensor: TchTensor, indexes: TchTensor, value: TchTensor, @@ -74,7 +76,7 @@ impl TchOps { let storage = tensor.storage.clone(); let tensor = tensor .tensor - .scatter_add((D - 1) as i64, &indexes.tensor, &value.tensor); + .scatter_add(dim as i64, &indexes.tensor, &value.tensor); TchTensor::from_existing(tensor, storage) } diff --git a/burn-tch/src/ops/int_tensor.rs b/burn-tch/src/ops/int_tensor.rs index 9360179e3..2a05cb76e 100644 --- a/burn-tch/src/ops/int_tensor.rs +++ b/burn-tch/src/ops/int_tensor.rs @@ -231,19 +231,21 @@ impl IntTensorOps> for TchBackend { fn int_mean_dim(tensor: TchTensor, dim: usize) -> TchTensor { TchOps::mean_dim(tensor, dim) } - fn int_index_select( + fn int_gather( + dim: usize, tensor: TchTensor, indexes: TchTensor, ) -> TchTensor { - TchOps::index_select(tensor, indexes) + TchOps::gather(dim, tensor, indexes) } - fn int_index_select_assign( + fn int_scatter( + dim: usize, tensor: TchTensor, indexes: TchTensor, value: TchTensor, ) -> TchTensor { - TchOps::index_select_assign(tensor, indexes, value) + TchOps::scatter(dim, tensor, indexes, value) } fn int_index_select_dim( diff --git a/burn-tch/src/ops/tensor.rs b/burn-tch/src/ops/tensor.rs index fb4775a8a..44b41613c 100644 --- a/burn-tch/src/ops/tensor.rs +++ b/burn-tch/src/ops/tensor.rs @@ -176,22 +176,24 @@ impl TensorOps> for TchBackend { TchOps::reshape(tensor, shape) } - fn index_select( + fn gather( + dim: usize, tensor: TchTensor, indexes: TchTensor, ) -> TchTensor { - TchOps::index_select(tensor, indexes) + TchOps::gather(dim, tensor, indexes) } - fn index_select_assign( + fn scatter( + dim: usize, tensor: TchTensor, indexes: TchTensor, value: TchTensor, ) -> TchTensor { - TchOps::index_select_assign(tensor, indexes, value) + TchOps::scatter(dim, tensor, indexes, value) } - fn index_select_dim( + fn index_select( tensor: TchTensor, dim: usize, indexes: TchTensor, @@ -199,7 +201,7 @@ impl TensorOps> for TchBackend { TchOps::index_select_dim(tensor, dim, indexes) } - fn index_select_dim_assign( + fn index_select_assign( tensor: TchTensor, dim: usize, indexes: TchTensor, diff --git a/burn-tensor/src/tensor/api/check.rs b/burn-tensor/src/tensor/api/check.rs index d8904be39..24ec25f85 100644 --- a/burn-tensor/src/tensor/api/check.rs +++ b/burn-tensor/src/tensor/api/check.rs @@ -7,7 +7,7 @@ use core::ops::Range; /// The struct should always be used with the [check](crate::check) macro. /// -/// This is a simple public crate data structure that efficiently checks tensor operations and +/// This is a simple pub(crate) data structure that efficiently checks tensor operations and /// formats clear error messages. It's crucial that the checks are really fast, but it doesn't matter /// when a failed check is discovered since the program will panic. /// @@ -32,14 +32,14 @@ use core::ops::Range; /// such as the `index_select` operation. The downside of that approach is that all backend /// implementation might re-implement the same checks, which may result in uncessary code /// duplication. Maybe a combination of both strategies could help to cover all usecases. -pub enum TensorCheck { +pub(crate) enum TensorCheck { Ok, Failed(FailedTensorCheck), } impl TensorCheck { /// Checks device and shape compatibility for element wise binary operations. - pub fn binary_ops_ew>( + pub(crate) fn binary_ops_ew>( ops: &str, lhs: &Tensor, rhs: &Tensor, @@ -49,7 +49,7 @@ impl TensorCheck { .binary_ops_ew_shape(ops, &lhs.shape(), &rhs.shape()) } - pub fn into_scalar(shape: &Shape) -> Self { + pub(crate) fn into_scalar(shape: &Shape) -> Self { let mut check = Self::Ok; if shape.num_elements() != 1 { @@ -66,7 +66,7 @@ impl TensorCheck { check } - pub fn dim_ops(ops: &str, dim: usize) -> Self { + pub(crate) fn dim_ops(ops: &str, dim: usize) -> Self { let mut check = Self::Ok; if dim >= D { @@ -80,7 +80,7 @@ impl TensorCheck { check } - pub fn reshape( + pub(crate) fn reshape( original: &Shape, target: &Shape, ) -> Self { @@ -99,7 +99,10 @@ impl TensorCheck { check } - pub fn flatten(start_dim: usize, end_dim: usize) -> Self { + pub(crate) fn flatten( + start_dim: usize, + end_dim: usize, + ) -> Self { let mut check = Self::Ok; if start_dim > end_dim { @@ -130,7 +133,7 @@ impl TensorCheck { check } - pub fn unsqueeze() -> Self { + pub(crate) fn unsqueeze() -> Self { let mut check = Self::Ok; if D2 < D1 { check = check.register( @@ -144,7 +147,7 @@ impl TensorCheck { check } - pub fn swap_dims(dim1: usize, dim2: usize) -> Self { + pub(crate) fn swap_dims(dim1: usize, dim2: usize) -> Self { let mut check = Self::Ok; if dim1 > D || dim2 > D { @@ -160,7 +163,10 @@ impl TensorCheck { check } - pub fn matmul(lhs: &Tensor, rhs: &Tensor) -> Self { + pub(crate) fn matmul( + lhs: &Tensor, + rhs: &Tensor, + ) -> Self { let mut check = Self::Ok; check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device()); @@ -191,7 +197,7 @@ impl TensorCheck { check } - pub fn cat>( + pub(crate) fn cat>( tensors: &[Tensor], dim: usize, ) -> Self { @@ -241,7 +247,7 @@ impl TensorCheck { check } - pub fn index( + pub(crate) fn index( shape: &Shape, indexes: &[Range; D2], ) -> Self { @@ -298,7 +304,7 @@ impl TensorCheck { check } - pub fn index_assign( + pub(crate) fn index_assign( shape: &Shape, shape_value: &Shape, indexes: &[Range; D2], @@ -374,8 +380,104 @@ impl TensorCheck { check } + pub(crate) fn gather( + dim: usize, + shape: &Shape, + shape_indexes: &Shape, + ) -> Self { + Self::check_gather_scatter_indexes(Self::Ok, "Gather", dim, shape, shape_indexes) + } + + pub(crate) fn scatter( + dim: usize, + shape: &Shape, + shape_indexes: &Shape, + shape_value: &Shape, + ) -> Self { + let ops = "Scatter"; + let mut check = + Self::check_gather_scatter_indexes(Self::Ok, ops, dim, shape, shape_indexes); + + if shape_indexes != shape_value { + check = check.register( + ops, + TensorError::new( + "Indexes tensor shape should be the same as the value tensor shape." + .to_string(), + ) + .details(format!( + "The shape differs: {:?} != {:?}", + shape_indexes.dims, shape_value.dims + )), + ); + } + + check + } + + pub(crate) fn index_select(dim: usize) -> Self { + Self::check_index_select_basic::(Self::Ok, "index_select", dim) + } + + pub(crate) fn index_select_assign(dim: usize) -> Self { + Self::check_index_select_basic::(Self::Ok, "index_select_assign", dim) + } + + fn check_index_select_basic(mut check: Self, ops: &str, dim: usize) -> Self { + if dim > D { + check = check.register( + ops, + TensorError::new(format!( + "Can't index a tensor with ({D}) dimensions on axis ({dim})" + )), + ); + } + + check + } + fn check_gather_scatter_indexes( + mut check: Self, + ops: &str, + dim: usize, + shape: &Shape, + shape_indexes: &Shape, + ) -> Self { + if dim > D { + check = check.register( + ops, + TensorError::new(format!( + "Can't index a tensor with ({D}) dimensions on axis ({dim})" + )), + ); + } + + for i in 0..D { + if i == dim { + continue; + } + + let tensor_dim_i = shape.dims[i]; + let indexes_dim_i = shape_indexes.dims[i]; + + if tensor_dim_i != indexes_dim_i { + check = check.register( + ops, + TensorError::new( + "The tensor shape should be the same as the index tensor shape." + .to_string(), + ) + .details(format!( + "The shape differs at dimension {i}: {tensor_dim_i} != {indexes_dim_i}" + )), + ); + } + } + + check + } + /// Checks aggregate dimension such as mean and sum. - pub fn aggregate_dim(ops: &str, dim: usize) -> Self { + pub(crate) fn aggregate_dim(ops: &str, dim: usize) -> Self { let mut check = Self::Ok; if dim > D { @@ -409,7 +511,7 @@ impl TensorCheck { } /// Checks if shapes are compatible for element wise operations supporting broadcasting. - pub fn binary_ops_ew_shape( + pub(crate) fn binary_ops_ew_shape( self, ops: &str, lhs: &Shape, @@ -464,14 +566,14 @@ impl TensorCheck { } } -pub struct FailedTensorCheck { +pub(crate) struct FailedTensorCheck { ops: String, errors: Vec, } impl FailedTensorCheck { /// Format all the checks into a single message ready to be printed by a [panic](core::panic). - pub fn format(self) -> String { + pub(crate) fn format(self) -> String { self.errors.into_iter().enumerate().fold( format!( "=== Tensor Operation Error ===\n Operation: '{}'\n Reason:", @@ -488,14 +590,14 @@ struct TensorError { } impl TensorError { - pub fn new>(description: S) -> Self { + pub(crate) fn new>(description: S) -> Self { TensorError { description: description.into(), details: None, } } - pub fn details>(mut self, details: S) -> Self { + pub(crate) fn details>(mut self, details: S) -> Self { self.details = Some(details.into()); self } diff --git a/burn-tensor/src/tensor/api/int.rs b/burn-tensor/src/tensor/api/int.rs index 141e922cd..b973d9310 100644 --- a/burn-tensor/src/tensor/api/int.rs +++ b/burn-tensor/src/tensor/api/int.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, Int, Tensor}; +use crate::{backend::Backend, Data, Int, Tensor}; use core::ops::Range; impl Tensor @@ -14,3 +14,25 @@ where Tensor::new(B::arange(range, device)) } } + +impl Tensor +where + B: Backend, +{ + /// Create a tensor from integers (i32). + /// + /// # Example + /// + /// ```rust + /// use burn_tensor::backend::Backend; + /// use burn_tensor::{Tensor, Int}; + /// + /// fn example() { + /// let _x: Tensor = Tensor::from_ints([1, 2]); + /// let _y: Tensor = Tensor::from_ints([[1, 2], [3, 4]]); + /// } + /// ``` + pub fn from_ints>>(ints: A) -> Self { + Self::from_data(ints.into().convert()) + } +} diff --git a/burn-tensor/src/tensor/api/numeric.rs b/burn-tensor/src/tensor/api/numeric.rs index 924223ba2..af1ff3e70 100644 --- a/burn-tensor/src/tensor/api/numeric.rs +++ b/burn-tensor/src/tensor/api/numeric.rs @@ -203,45 +203,83 @@ where Self::new(K::mask_fill(self.primitive, mask, value.elem())) } - /// Select the tensor elements corresponding to the given indexes. + /// Gather tensor elements corresponding to the given indexes from the specified dim. + /// + /// Example using a 3D tensor: + /// + /// `output[i, j, k] = input[indexes[i, j, k], j, k]; // dim = 0` + /// `output[i, j, k] = input[i, indexes[i, j, k], k]; // dim = 1` + /// `output[i, j, k] = input[i, j, indexes[i, j, k]]; // dim = 2` /// /// # Notes /// - /// The index tensor shoud have the same shape as the original tensor except for the last - /// dimension. - pub fn index_select(self, indexes: Tensor) -> Self { - Self::new(K::index_select(self.primitive, indexes)) + /// The index tensor shoud have the same shape as the original tensor except for the dim + /// specified. + pub fn gather(self, dim: usize, indexes: Tensor) -> Self { + check!(TensorCheck::gather::( + dim, + &self.shape(), + &indexes.shape() + )); + + Self::new(K::gather(dim, self.primitive, indexes)) } - /// Assign the selected elements corresponding to the given indexes from the value tensor - /// to the original tensor using sum reduction. + /// Assign the gathered elements corresponding to the given indexes along the speficied dimension + /// from the value tensor to the original tensor using sum reduction. + /// + /// Example using a 3D tensor: + /// + /// `input[indexes[i, j, k], j, k] += values[i, j, k]; // dim = 0` + /// `input[i, indexes[i, j, k], k] += values[i, j, k]; // dim = 1` + /// `input[i, j, indexes[i, j, k]] += values[i, j, k]; // dim = 2` /// /// # Notes /// - /// The index tensor shoud have the same shape as the original tensor except for the last + /// The index tensor shoud have the same shape as the original tensor except for the speficied /// dimension. The value and index tensors should have the same shape. - pub fn index_select_assign(self, indexes: Tensor, values: Self) -> Self { - Self::new(K::index_select_assign( - self.primitive, - indexes, - values.primitive, - )) + /// + /// Other references to the input tensor will not be modified by this operation. + pub fn scatter(self, dim: usize, indexes: Tensor, values: Self) -> Self { + check!(TensorCheck::scatter::( + dim, + &self.shape(), + &indexes.shape(), + &values.shape() + )); + + Self::new(K::scatter(dim, self.primitive, indexes, values.primitive)) } /// Select the tensor elements along the given dimension corresponding to the given indexes. - pub fn index_select_dim(self, dim: usize, indexes: Tensor) -> Self { - Self::new(K::index_select_dim(self.primitive, dim, indexes)) + /// + /// Example using a 3D tensor: + /// + /// `output[i, j, k] = input[indexes[i], j, k]; // dim = 0` + /// `output[i, j, k] = input[i, indexes[j], k]; // dim = 1` + /// `output[i, j, k] = input[i, j, indexes[k]]; // dim = 2` + pub fn index_select(self, dim: usize, indexes: Tensor) -> Self { + check!(TensorCheck::index_select::(dim)); + Self::new(K::index_select(self.primitive, dim, indexes)) } /// Assign the selected elements along the given dimension corresponding to the given indexes /// from the value tensor to the original tensor using sum reduction. - pub fn index_select_dim_assign( + /// + /// Example using a 3D tensor: + /// + /// `input[indexes[i], j, k] += values[i, j, k]; // dim = 0` + /// `input[i, indexes[j], k] += values[i, j, k]; // dim = 1` + /// `input[i, j, indexes[k]] += values[i, j, k]; // dim = 2` + pub fn index_select_assign( self, dim: usize, indexes: Tensor, values: Tensor, ) -> Self { - Self::new(K::index_select_dim_assign( + check!(TensorCheck::index_select_assign::(dim)); + + Self::new(K::index_select_assign( self.primitive, dim, indexes, @@ -412,21 +450,23 @@ where mask: Tensor, value: Self::Elem, ) -> Self::Primitive; - fn index_select( + fn gather( + dim: usize, tensor: Self::Primitive, indexes: Tensor, ) -> Self::Primitive; - fn index_select_assign( + fn scatter( + dim: usize, tensor: Self::Primitive, indexes: Tensor, values: Self::Primitive, ) -> Self::Primitive; - fn index_select_dim( + fn index_select( tensor: Self::Primitive, dim: usize, indexes: Tensor, ) -> Self::Primitive; - fn index_select_dim_assign( + fn index_select_assign( tensor: Self::Primitive, dim: usize, indexes: Tensor, @@ -588,7 +628,7 @@ impl Numeric for Int { B::int_mask_fill(tensor, mask.primitive, value) } - fn index_select_dim( + fn index_select( tensor: Self::Primitive, dim: usize, indexes: Tensor, @@ -596,7 +636,7 @@ impl Numeric for Int { B::int_index_select_dim(tensor, dim, indexes.primitive) } - fn index_select_dim_assign( + fn index_select_assign( tensor: Self::Primitive, dim: usize, indexes: Tensor, @@ -604,19 +644,21 @@ impl Numeric for Int { ) -> Self::Primitive { B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values) } - fn index_select( + fn gather( + dim: usize, tensor: Self::Primitive, indexes: Tensor, ) -> Self::Primitive { - B::int_index_select(tensor, indexes.primitive) + B::int_gather(dim, tensor, indexes.primitive) } - fn index_select_assign( + fn scatter( + dim: usize, tensor: Self::Primitive, indexes: Tensor, values: Self::Primitive, ) -> Self::Primitive { - B::int_index_select_assign(tensor, indexes.primitive, values) + B::int_scatter(dim, tensor, indexes.primitive, values) } fn argmax( @@ -804,36 +846,38 @@ impl Numeric for Float { B::mask_fill(tensor, mask.primitive, value) } - fn index_select_dim( + fn index_select( tensor: Self::Primitive, dim: usize, indexes: Tensor, ) -> Self::Primitive { - B::index_select_dim(tensor, dim, indexes.primitive) + B::index_select(tensor, dim, indexes.primitive) } - fn index_select_dim_assign( + fn index_select_assign( tensor: Self::Primitive, dim: usize, indexes: Tensor, values: Self::Primitive, ) -> Self::Primitive { - B::index_select_dim_assign(tensor, dim, indexes.primitive, values) + B::index_select_assign(tensor, dim, indexes.primitive, values) } - fn index_select( + fn gather( + dim: usize, tensor: Self::Primitive, indexes: Tensor, ) -> Self::Primitive { - B::index_select(tensor, indexes.primitive) + B::gather(dim, tensor, indexes.primitive) } - fn index_select_assign( + fn scatter( + dim: usize, tensor: Self::Primitive, indexes: Tensor, values: Self::Primitive, ) -> Self::Primitive { - B::index_select_assign(tensor, indexes.primitive, values) + B::scatter(dim, tensor, indexes.primitive, values) } fn argmax( diff --git a/burn-tensor/src/tensor/ops/int_tensor.rs b/burn-tensor/src/tensor/ops/int_tensor.rs index 33fa9e15d..9f274daad 100644 --- a/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/burn-tensor/src/tensor/ops/int_tensor.rs @@ -44,11 +44,13 @@ pub trait IntTensorOps { mask: B::BoolTensorPrimitive, value: B::IntElem, ) -> B::IntTensorPrimitive; - fn int_index_select( + fn int_gather( + dim: usize, tensor: B::IntTensorPrimitive, indexes: B::IntTensorPrimitive, ) -> B::IntTensorPrimitive; - fn int_index_select_assign( + fn int_scatter( + dim: usize, tensor: B::IntTensorPrimitive, indexes: B::IntTensorPrimitive, value: B::IntTensorPrimitive, @@ -204,14 +206,14 @@ pub trait IntTensorOps { ) -> B::IntTensorPrimitive { let index = B::int_argmax(tensor.clone(), dim); - B::int_index_select(tensor, index) + B::int_gather(D - 1, tensor, index) } fn int_max_dim_with_indexes( tensor: B::IntTensorPrimitive, dim: usize, ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { let index = B::int_argmax(tensor.clone(), dim); - let values = B::int_index_select(tensor, index.clone()); + let values = B::int_gather(D - 1, tensor, index.clone()); (values, index) } @@ -227,14 +229,14 @@ pub trait IntTensorOps { ) -> B::IntTensorPrimitive { let index = B::int_argmin(tensor.clone(), dim); - B::int_index_select(tensor, index) + B::int_gather(D - 1, tensor, index) } fn int_min_dim_with_indexes( tensor: B::IntTensorPrimitive, dim: usize, ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) { let index = B::int_argmin(tensor.clone(), dim); - let values = B::int_index_select(tensor, index.clone()); + let values = B::int_gather(D - 1, tensor, index.clone()); (values, index) } diff --git a/burn-tensor/src/tensor/ops/tensor.rs b/burn-tensor/src/tensor/ops/tensor.rs index 70e11bbf5..d58f034d0 100644 --- a/burn-tensor/src/tensor/ops/tensor.rs +++ b/burn-tensor/src/tensor/ops/tensor.rs @@ -115,21 +115,23 @@ pub trait TensorOps { tensor: B::TensorPrimitive, shape: Shape, ) -> B::TensorPrimitive; - fn index_select( + fn gather( + dim: usize, tensor: B::TensorPrimitive, indexes: B::IntTensorPrimitive, ) -> B::TensorPrimitive; - fn index_select_assign( + fn scatter( + dim: usize, tensor: B::TensorPrimitive, indexes: B::IntTensorPrimitive, value: B::TensorPrimitive, ) -> B::TensorPrimitive; - fn index_select_dim( + fn index_select( tensor: B::TensorPrimitive, dim: usize, indexes: B::IntTensorPrimitive<1>, ) -> B::TensorPrimitive; - fn index_select_dim_assign( + fn index_select_assign( tensor: B::TensorPrimitive, dim: usize, indexes: B::IntTensorPrimitive<1>, @@ -251,14 +253,14 @@ pub trait TensorOps { fn max_dim(tensor: B::TensorPrimitive, dim: usize) -> B::TensorPrimitive { let index = B::argmax(tensor.clone(), dim); - B::index_select(tensor, index) + B::gather(D - 1, tensor, index) } fn max_dim_with_indexes( tensor: B::TensorPrimitive, dim: usize, ) -> (B::TensorPrimitive, B::IntTensorPrimitive) { let index = B::argmax(tensor.clone(), dim); - let values = B::index_select(tensor, index.clone()); + let values = B::gather(D - 1, tensor, index.clone()); (values, index) } @@ -271,14 +273,14 @@ pub trait TensorOps { fn min_dim(tensor: B::TensorPrimitive, dim: usize) -> B::TensorPrimitive { let index = B::argmin(tensor.clone(), dim); - B::index_select(tensor, index) + B::gather(D - 1, tensor, index) } fn min_dim_with_indexes( tensor: B::TensorPrimitive, dim: usize, ) -> (B::TensorPrimitive, B::IntTensorPrimitive) { let index = B::argmin(tensor.clone(), dim); - let values = B::index_select(tensor, index.clone()); + let values = B::gather(D - 1, tensor, index.clone()); (values, index) } diff --git a/burn-tensor/src/tests/mod.rs b/burn-tensor/src/tests/mod.rs index 2206b2450..4e97ac6c8 100644 --- a/burn-tensor/src/tests/mod.rs +++ b/burn-tensor/src/tests/mod.rs @@ -34,8 +34,8 @@ macro_rules! testgen_all { burn_tensor::testgen_log!(); burn_tensor::testgen_log1p!(); burn_tensor::testgen_index!(); + burn_tensor::testgen_gather_scatter!(); burn_tensor::testgen_index_select!(); - burn_tensor::testgen_index_select_dim!(); burn_tensor::testgen_map_comparison!(); burn_tensor::testgen_mask!(); burn_tensor::testgen_matmul!(); diff --git a/burn-tensor/src/tests/ops/gather_scatter.rs b/burn-tensor/src/tests/ops/gather_scatter.rs new file mode 100644 index 000000000..f2f514efc --- /dev/null +++ b/burn-tensor/src/tests/ops/gather_scatter.rs @@ -0,0 +1,132 @@ +#[burn_tensor_testgen::testgen(gather_scatter)] +mod tests { + use super::*; + use burn_tensor::{Data, Tensor}; + + #[test] + fn should_gather_1d_dim0() { + let tensor = TestTensor::from_floats([0.0, 1.0, 2.0]); + let indexes = TestTensorInt::from_ints([1, 1, 0, 1, 2]); + + let output = tensor.gather(0, indexes); + + assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); + } + + #[test] + fn should_gather_2d_dim0() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indexes = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]); + + let output = tensor.gather(0, indexes); + + assert_eq!( + output.into_data(), + Data::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]]) + ); + } + + #[test] + fn should_gather_2d_dim1() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indexes = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]); + + let output = tensor.gather(1, indexes); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]) + ); + } + + #[test] + fn should_gather_3d_dim1() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); + + let output = tensor.gather(1, indexes); + + assert_eq!( + output.into_data(), + Data::from([ + [[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]], + [[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]] + ]) + ); + } + + #[test] + fn should_gather_2d_only_1dim() { + let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let indexes = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]); + + let output = tensor.gather(1, indexes); + + assert_eq!(output.into_data(), Data::from([[1.0], [5.0]])); + } + + #[test] + fn should_scatter_1d() { + let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]); + let values = TestTensor::from_floats([5.0, 4.0, 3.0]); + let indexes = TestTensorInt::from_ints([1, 0, 2]); + + let output = tensor.scatter(0, indexes, values); + + assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0])); + } + + #[test] + fn should_scatter_2d_dim0() { + let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); + let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indexes = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]); + + let output = tensor.scatter(0, indexes, values); + + assert_eq!( + output.into_data(), + Data::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]]) + ); + } + + #[test] + fn should_scatter_2d_dim1() { + let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]); + let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]); + let indexes = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]); + + let output = tensor.scatter(1, indexes, values); + + assert_eq!( + output.into_data(), + Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]) + ); + } + + #[test] + fn should_scatter_3d_dim1() { + let tensor = TestTensor::from_floats([ + [[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], + [[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]], + ]); + let values = TestTensor::from_floats([ + [[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]], + [[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]], + ]); + let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]); + + let output = tensor.scatter(1, indexes, values); + + assert_eq!( + output.into_data(), + Data::from([ + [[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]], + [[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]] + ]) + ); + } +} diff --git a/burn-tensor/src/tests/ops/index_select.rs b/burn-tensor/src/tests/ops/index_select.rs index 285a6b579..9e0a84dec 100644 --- a/burn-tensor/src/tests/ops/index_select.rs +++ b/burn-tensor/src/tests/ops/index_select.rs @@ -8,56 +8,77 @@ mod tests { let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0])); let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); - let output = tensor.index_select(indexes); + let output = tensor.index_select(0, indexes); assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); } #[test] - fn should_select_2d() { + fn should_select_2d_dim0_same_num_dim() { let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); - let indexes = TestTensorInt::from_data(Data::from([[2, 1, 0, 0], [2, 0, 1, 2]])); + let indexes = TestTensorInt::from_data(Data::from([1, 0])); - let output = tensor.index_select(indexes); + let output = tensor.index_select(0, indexes); assert_eq!( output.into_data(), - Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]]) + Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]) ); } #[test] - fn should_select_2d_only_1dim() { + fn should_select_2d_dim0_more_num_dim() { let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); - let indexes = TestTensorInt::from_data(Data::from([[1, 2]])).reshape([2, 1]); + let indexes = TestTensorInt::from_data(Data::from([1, 0, 1, 1])); - let output = tensor.index_select(indexes); + let output = tensor.index_select(0, indexes); - assert_eq!(output.into_data(), Data::from([[1.0], [5.0]])); + assert_eq!( + output.into_data(), + Data::from([ + [3.0, 4.0, 5.0], + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [3.0, 4.0, 5.0] + ]) + ); + } + + #[test] + fn should_select_2d_dim1() { + let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); + let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); + + let output = tensor.index_select(1, indexes); + + assert_eq!( + output.into_data(), + Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]) + ); } #[test] fn should_select_assign_1d() { - let tensor = TestTensor::from_data(Data::from([0.0, 0.0, 0.0])); - let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0])); - let indexes = TestTensorInt::from_data(Data::from([1, 0, 2])); + let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0])); + let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0, 2.0, 1.0])); + let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); - let output = tensor.index_select_assign(indexes, values); + let output = tensor.index_select_assign(0, indexes, values); - assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0])); + assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0])); } #[test] - fn should_select_assign_2d() { - let tensor = TestTensor::from_data(Data::from([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])); + fn should_select_assign_2d_dim0() { + let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])); - let indexes = TestTensorInt::from_data(Data::from([[1, 0, 2], [1, 2, 0]])); + let indexes = TestTensorInt::from_data(Data::from([1, 0])); - let output = tensor.index_select_assign(indexes, values); + let output = tensor.index_select_assign(0, indexes, values); assert_eq!( output.into_data(), - Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]]) + Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]]) ); } } diff --git a/burn-tensor/src/tests/ops/index_select_dim.rs b/burn-tensor/src/tests/ops/index_select_dim.rs deleted file mode 100644 index 6f284b970..000000000 --- a/burn-tensor/src/tests/ops/index_select_dim.rs +++ /dev/null @@ -1,84 +0,0 @@ -#[burn_tensor_testgen::testgen(index_select_dim)] -mod tests { - use super::*; - use burn_tensor::{Data, Tensor}; - - #[test] - fn should_select_1d() { - let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0])); - let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); - - let output = tensor.index_select_dim(0, indexes); - - assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0])); - } - - #[test] - fn should_select_2d_dim0_same_num_dim() { - let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); - let indexes = TestTensorInt::from_data(Data::from([1, 0])); - - let output = tensor.index_select_dim(0, indexes); - - assert_eq!( - output.into_data(), - Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]]) - ); - } - - #[test] - fn should_select_2d_dim0_more_num_dim() { - let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); - let indexes = TestTensorInt::from_data(Data::from([1, 0, 1, 1])); - - let output = tensor.index_select_dim(0, indexes); - - assert_eq!( - output.into_data(), - Data::from([ - [3.0, 4.0, 5.0], - [0.0, 1.0, 2.0], - [3.0, 4.0, 5.0], - [3.0, 4.0, 5.0] - ]) - ); - } - - #[test] - fn should_select_2d_dim1() { - let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); - let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); - - let output = tensor.index_select_dim(1, indexes); - - assert_eq!( - output.into_data(), - Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]]) - ); - } - - #[test] - fn should_select_assign_1d() { - let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0])); - let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0, 2.0, 1.0])); - let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2])); - - let output = tensor.index_select_dim_assign(0, indexes, values); - - assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0])); - } - - #[test] - fn should_select_assign_2d_dim0() { - let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])); - let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])); - let indexes = TestTensorInt::from_data(Data::from([1, 0])); - - let output = tensor.index_select_dim_assign(0, indexes, values); - - assert_eq!( - output.into_data(), - Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]]) - ); - } -} diff --git a/burn-tensor/src/tests/ops/mod.rs b/burn-tensor/src/tests/ops/mod.rs index cbdf0d36c..72df84cd6 100644 --- a/burn-tensor/src/tests/ops/mod.rs +++ b/burn-tensor/src/tests/ops/mod.rs @@ -6,9 +6,9 @@ mod div; mod erf; mod exp; mod flatten; +mod gather_scatter; mod index; mod index_select; -mod index_select_dim; mod log; mod log1p; mod map_comparison;