Feat/index_select (#227)

This commit is contained in:
Nathaniel Simard 2023-03-12 17:44:22 -04:00 committed by GitHub
parent 9655b74b22
commit d09ab44979
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 514 additions and 25 deletions

View File

@ -195,6 +195,21 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::int_lower_equal_elem(lhs, rhs) B::int_lower_equal_elem(lhs, rhs)
} }
fn int_index_select<const D: usize>(
tensor: IntTensor<B, D>,
indexes: IntTensor<B, D>,
) -> IntTensor<B, D> {
B::int_index_select(tensor, indexes)
}
fn int_index_select_assign<const D: usize>(
tensor: IntTensor<B, D>,
indexes: IntTensor<B, D>,
value: IntTensor<B, D>,
) -> IntTensor<B, D> {
B::int_index_select_assign(tensor, indexes, value)
}
fn int_index_select_dim<const D: usize>( fn int_index_select_dim<const D: usize>(
tensor: IntTensor<B, D>, tensor: IntTensor<B, D>,
dim: usize, dim: usize,

View File

@ -420,6 +420,94 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
} }
} }
fn index_select<const D: usize>(
tensor: ADTensor<B, D>,
indexes: IntTensor<B, D>,
) -> ADTensor<B, D> {
#[derive(Debug)]
struct IndexSelect;
impl<B: Backend, const D: usize> Backward<B, D, 1> for IndexSelect {
type State = (IntTensor<B, D>, Shape<D>, B::Device);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let (indexes, shape, device) = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let zeros = B::zeros(shape, &device);
B::index_select_assign(zeros, indexes, grad)
});
}
}
match IndexSelect
.prepare([tensor.node], [tensor.graph])
.statefull()
{
OpsKind::Tracked(prep) => prep.finish(
(
indexes.clone(),
B::shape(&tensor.primitive),
B::device(&tensor.primitive),
),
B::index_select(tensor.primitive, indexes),
),
OpsKind::UnTracked(prep) => prep.finish(B::index_select(tensor.primitive, indexes)),
}
}
fn index_select_assign<const D: usize>(
tensor: ADTensor<B, D>,
indexes: IntTensor<B, D>,
value: ADTensor<B, D>,
) -> ADTensor<B, D> {
#[derive(Debug)]
struct IndexSelectAssign;
impl<B: Backend, const D: usize> Backward<B, D, 2> for IndexSelectAssign {
type State = (IntTensor<B, D>, Shape<D>, Shape<D>, B::Device);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let (indexes, shape_lhs, shape_rhs, device) = ops.state;
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
binary::<B, D, D, D, _, _>(
ops.parents,
ops.node,
grads,
|grad| {
let zeros = B::zeros(shape_lhs, &device);
B::index_select_assign(grad, indexes_4lhs.unwrap(), zeros)
},
|grad| {
let zeros = B::zeros(shape_rhs, &device);
B::index_select_assign(zeros, indexes_4rhs.unwrap(), grad)
},
);
}
}
match IndexSelectAssign
.prepare([tensor.node, value.node], [tensor.graph, value.graph])
.statefull()
{
OpsKind::Tracked(prep) => prep.finish(
(
indexes.clone(),
B::shape(&tensor.primitive),
B::shape(&value.primitive),
B::device(&value.primitive),
),
B::index_select_assign(tensor.primitive, indexes, value.primitive),
),
OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign(
tensor.primitive,
indexes,
value.primitive,
)),
}
}
fn index_select_dim<const D: usize>( fn index_select_dim<const D: usize>(
tensor: ADTensor<B, D>, tensor: ADTensor<B, D>,
dim: usize, dim: usize,

View File

@ -0,0 +1,54 @@
#[burn_tensor_testgen::testgen(ad_index_select)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn test_index_select_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().index_select(indexes);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
assert_eq!(
grad_1.into_data(),
Data::from([[94., 150., 187.], [242., 305., 304.]])
);
}
#[test]
fn test_index_select_assign_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let values =
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1
.clone()
.index_select_assign(indexes, values.clone());
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = values.grad(&grads).unwrap();
assert_eq!(
grad_1.into_data(),
Data::from([[127., 181., 235.], [226., 316., 406.]])
);
assert_eq!(
grad_2.into_data(),
Data::from([[19., 19., 19.], [64., 64., 64.]])
);
}
}

View File

@ -4,7 +4,27 @@ mod tests {
use burn_tensor::Data; use burn_tensor::Data;
#[test] #[test]
fn test_select_grad() { fn test_index_select_dim_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([1, 0]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().index_select_dim(0, indexes);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
assert_eq!(
grad_1.into_data(),
Data::from([[109., 148., 187.], [37., 58., 79.]])
);
}
#[test]
fn test_index_select_dim_assign_grad() {
let tensor_1 = let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad(); TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let values = let values =

View File

@ -11,6 +11,7 @@ mod div;
mod erf; mod erf;
mod exp; mod exp;
mod index; mod index;
mod index_select;
mod index_select_dim; mod index_select_dim;
mod log; mod log;
mod log1p; mod log1p;
@ -53,8 +54,9 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_div!(); burn_autodiff::testgen_ad_div!();
burn_autodiff::testgen_ad_erf!(); burn_autodiff::testgen_ad_erf!();
burn_autodiff::testgen_ad_exp!(); burn_autodiff::testgen_ad_exp!();
burn_autodiff::testgen_ad_index_select_dim!();
burn_autodiff::testgen_ad_index!(); burn_autodiff::testgen_ad_index!();
burn_autodiff::testgen_ad_index_select!();
burn_autodiff::testgen_ad_index_select_dim!();
burn_autodiff::testgen_ad_log!(); burn_autodiff::testgen_ad_log!();
burn_autodiff::testgen_ad_log1p!(); burn_autodiff::testgen_ad_log1p!();
burn_autodiff::testgen_ad_mask!(); burn_autodiff::testgen_ad_mask!();

View File

@ -1,6 +1,8 @@
use alloc::vec::Vec; use alloc::vec::Vec;
use burn_tensor::Data; use burn_tensor::Data;
use core::{marker::PhantomData, ops::Range}; use core::{marker::PhantomData, ops::Range};
use ndarray::s;
use ndarray::Array2;
use burn_tensor::Shape; use burn_tensor::Shape;
use ndarray::Axis; use ndarray::Axis;
@ -204,6 +206,85 @@ where
} }
} }
pub fn index_select<const D: usize>(
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
) -> NdArrayTensor<E, D> {
let (shape_tensor, shape_indexes) = (tensor.shape(), indexes.shape());
let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indexes.dims[D - 1]);
let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes);
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
let mut output = Array2::zeros((batch_size, size_index));
for b in 0..batch_size {
let indexes = indexes.slice(s!(b, ..));
for (i, index) in indexes.iter().enumerate() {
output[[b, i]] = tensor[[b, *index as usize]];
}
}
NdArrayOps::reshape(
NdArrayTensor::<E, 2>::new(output.into_shared().into_dyn()),
shape_indexes,
)
}
pub fn index_select_assign<const D: usize>(
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
value: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
let (shape_tensor, shape_indexes, shape_value) =
(tensor.shape(), indexes.shape(), value.shape());
let (size_tensor, size_index, size_value) = (
shape_tensor.dims[D - 1],
shape_indexes.dims[D - 1],
shape_value.dims[D - 1],
);
let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes);
if shape_value != shape_indexes {
panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indexes.dims, shape_value.dims);
}
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array;
let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
for b in 0..batch_size {
let indexes = indexes.slice(s!(b, ..));
for (i, index) in indexes.iter().enumerate() {
let index = *index as usize;
tensor[[b, index]] = tensor[[b, index]] + value[[b, i]];
}
}
NdArrayOps::reshape(
NdArrayTensor::<E, 2>::new(tensor.into_shared().into_dyn()),
shape_tensor,
)
}
fn index_select_batch_size<const D: usize>(
shape_tensor: &Shape<D>,
shape_indexes: &Shape<D>,
) -> usize {
let mut batch_size = 1;
for i in 0..D - 1 {
if shape_tensor.dims[i] != shape_indexes.dims[i] {
panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indexes.dims);
}
batch_size *= shape_indexes.dims[i];
}
batch_size
}
pub fn index_select_dim<const D: usize>( pub fn index_select_dim<const D: usize>(
tensor: NdArrayTensor<E, D>, tensor: NdArrayTensor<E, D>,
dim: usize, dim: usize,

View File

@ -264,6 +264,21 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
NdArrayMathOps::mean_dim(tensor, dim) NdArrayMathOps::mean_dim(tensor, dim)
} }
fn int_index_select<const D: usize>(
tensor: NdArrayTensor<i64, D>,
indexes: NdArrayTensor<i64, D>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::index_select(tensor, indexes)
}
fn int_index_select_assign<const D: usize>(
tensor: NdArrayTensor<i64, D>,
indexes: NdArrayTensor<i64, D>,
value: NdArrayTensor<i64, D>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::index_select_assign(tensor, indexes, value)
}
fn int_index_select_dim<const D: usize>( fn int_index_select_dim<const D: usize>(
tensor: NdArrayTensor<i64, D>, tensor: NdArrayTensor<i64, D>,
dim: usize, dim: usize,

View File

@ -151,6 +151,21 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
NdArrayOps::reshape(tensor, shape) NdArrayOps::reshape(tensor, shape)
} }
fn index_select<const D: usize>(
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::index_select(tensor, indexes)
}
fn index_select_assign<const D: usize>(
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
value: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::index_select_assign(tensor, indexes, value)
}
fn index_select_dim<const D: usize>( fn index_select_dim<const D: usize>(
tensor: NdArrayTensor<E, D>, tensor: NdArrayTensor<E, D>,
dim: usize, dim: usize,

View File

@ -58,7 +58,12 @@ macro_rules! reshape {
) => {{ ) => {{
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim); let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
let safe_into_shape = let safe_into_shape =
$array.is_standard_layout() || $array.raw_view().reversed_axes().is_standard_layout(); $array.is_standard_layout() ||
(
$array.ndim() > 1 &&
$array.raw_view().reversed_axes().is_standard_layout()
);
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape { let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape {
true => $array true => $array
.into_shape(dim) .into_shape(dim)

View File

@ -45,6 +45,25 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::new(tensor_original) TchTensor::new(tensor_original)
} }
pub fn index_select<const D: usize>(
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
) -> TchTensor<E, D> {
let tensor = tensor.tensor.gather((D - 1) as i64, &indexes.tensor, false);
TchTensor::new(tensor)
}
pub fn index_select_assign<const D: usize>(
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
value: TchTensor<E, D>,
) -> TchTensor<E, D> {
let tensor = tensor
.tensor
.scatter_add((D - 1) as i64, &indexes.tensor, &value.tensor);
TchTensor::new(tensor)
}
pub fn index_select_dim<const D: usize>( pub fn index_select_dim<const D: usize>(
tensor: TchTensor<E, D>, tensor: TchTensor<E, D>,
dim: usize, dim: usize,

View File

@ -243,6 +243,20 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
fn int_mean_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> { fn int_mean_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
TchOps::mean_dim(tensor, dim) TchOps::mean_dim(tensor, dim)
} }
fn int_index_select<const D: usize>(
tensor: TchTensor<i64, D>,
indexes: TchTensor<i64, D>,
) -> TchTensor<i64, D> {
TchOps::index_select(tensor, indexes)
}
fn int_index_select_assign<const D: usize>(
tensor: TchTensor<i64, D>,
indexes: TchTensor<i64, D>,
value: TchTensor<i64, D>,
) -> TchTensor<i64, D> {
TchOps::index_select_assign(tensor, indexes, value)
}
fn int_index_select_dim<const D: usize>( fn int_index_select_dim<const D: usize>(
tensor: TchTensor<i64, D>, tensor: TchTensor<i64, D>,

View File

@ -190,6 +190,21 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
TchTensor::new(tensor) TchTensor::new(tensor)
} }
fn index_select<const D: usize>(
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
) -> TchTensor<E, D> {
TchOps::index_select(tensor, indexes)
}
fn index_select_assign<const D: usize>(
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
value: TchTensor<E, D>,
) -> TchTensor<E, D> {
TchOps::index_select_assign(tensor, indexes, value)
}
fn index_select_dim<const D: usize>( fn index_select_dim<const D: usize>(
tensor: TchTensor<E, D>, tensor: TchTensor<E, D>,
dim: usize, dim: usize,

View File

@ -342,27 +342,6 @@ where
self.reshape(shape) self.reshape(shape)
} }
/// Index the tensor along the given dimension using the given indexes.
pub fn index_select_dim(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
Self::new(B::index_select_dim(self.primitive, dim, indexes.primitive))
}
/// Return a new tensor with the same dimension, but with the values added to
/// the original tensor using the corresponding indexes provided along the given dimension.
pub fn index_select_dim_assign<const D2: usize>(
self,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Tensor<B, D2>,
) -> Self {
Self::new(B::index_select_dim_assign(
self.primitive,
dim,
indexes.primitive,
values.primitive,
))
}
pub(crate) fn relu(self) -> Self { pub(crate) fn relu(self) -> Self {
Self::new(B::relu(self.primitive)) Self::new(B::relu(self.primitive))
} }

View File

@ -170,6 +170,52 @@ where
pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> { pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
K::lower_equal_elem(self.primitive, other.elem()) K::lower_equal_elem(self.primitive, other.elem())
} }
/// Select the tensor elements corresponding to the given indexes.
///
/// # Notes
///
/// The index tensor shoud have the same shape as the original tensor except for the last
/// dimension.
pub fn index_select(self, indexes: Tensor<B, D, Int>) -> Self {
Self::new(K::index_select(self.primitive, indexes))
}
/// Assign the selected elements corresponding to the given indexes from the value tensor
/// to the original tensor using sum reduction.
///
/// # Notes
///
/// The index tensor shoud have the same shape as the original tensor except for the last
/// dimension. The value and index tensors should have the same shape.
pub fn index_select_assign(self, indexes: Tensor<B, D, Int>, values: Self) -> Self {
Self::new(K::index_select_assign(
self.primitive,
indexes,
values.primitive,
))
}
/// Select the tensor elements along the given dimension corresponding to the given indexes.
pub fn index_select_dim(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
Self::new(K::index_select_dim(self.primitive, dim, indexes))
}
/// Assign the selected elements along the given dimension corresponding to the given indexes
/// from the value tensor to the original tensor using sum reduction.
pub fn index_select_dim_assign<const D2: usize>(
self,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Tensor<B, D2, K>,
) -> Self {
Self::new(K::index_select_dim_assign(
self.primitive,
dim,
indexes,
values.primitive,
))
}
} }
/// Trait that list all operations that can be applied on all numerical tensors. /// Trait that list all operations that can be applied on all numerical tensors.
@ -234,6 +280,15 @@ pub trait Numeric<B: Backend>: TensorKind<B> {
lhs: Self::Primitive<D>, lhs: Self::Primitive<D>,
rhs: Self::Elem, rhs: Self::Elem,
) -> Tensor<B, D, Bool>; ) -> Tensor<B, D, Bool>;
fn index_select<const D: usize>(
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
) -> Self::Primitive<D>;
fn index_select_assign<const D: usize>(
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D>;
fn index_select_dim<const D: usize>( fn index_select_dim<const D: usize>(
tensor: Self::Primitive<D>, tensor: Self::Primitive<D>,
dim: usize, dim: usize,
@ -389,6 +444,20 @@ impl<B: Backend> Numeric<B> for Int {
) -> Self::Primitive<D1> { ) -> Self::Primitive<D1> {
B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values) B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values)
} }
fn index_select<const D: usize>(
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
) -> Self::Primitive<D> {
B::int_index_select(tensor, indexes.primitive)
}
fn index_select_assign<const D: usize>(
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::int_index_select_assign(tensor, indexes.primitive, values)
}
} }
impl<B: Backend> Numeric<B> for Float { impl<B: Backend> Numeric<B> for Float {
@ -533,6 +602,21 @@ impl<B: Backend> Numeric<B> for Float {
) -> Self::Primitive<D1> { ) -> Self::Primitive<D1> {
B::index_select_dim_assign(tensor, dim, indexes.primitive, values) B::index_select_dim_assign(tensor, dim, indexes.primitive, values)
} }
fn index_select<const D: usize>(
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
) -> Self::Primitive<D> {
B::index_select(tensor, indexes.primitive)
}
fn index_select_assign<const D: usize>(
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::index_select_assign(tensor, indexes.primitive, values)
}
} }
impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K> impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>

View File

@ -34,6 +34,15 @@ pub trait IntTensorOps<B: Backend> {
indexes: [Range<usize>; D2], indexes: [Range<usize>; D2],
value: B::IntTensorPrimitive<D1>, value: B::IntTensorPrimitive<D1>,
) -> B::IntTensorPrimitive<D1>; ) -> B::IntTensorPrimitive<D1>;
fn int_index_select<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
fn int_index_select_assign<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
value: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
fn int_index_select_dim<const D: usize>( fn int_index_select_dim<const D: usize>(
tensor: B::IntTensorPrimitive<D>, tensor: B::IntTensorPrimitive<D>,
dim: usize, dim: usize,

View File

@ -115,6 +115,15 @@ pub trait TensorOps<B: Backend> {
tensor: B::TensorPrimitive<D1>, tensor: B::TensorPrimitive<D1>,
shape: Shape<D2>, shape: Shape<D2>,
) -> B::TensorPrimitive<D2>; ) -> B::TensorPrimitive<D2>;
fn index_select<const D: usize>(
tensor: B::TensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
fn index_select_assign<const D: usize>(
tensor: B::TensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
value: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
fn index_select_dim<const D: usize>( fn index_select_dim<const D: usize>(
tensor: B::TensorPrimitive<D>, tensor: B::TensorPrimitive<D>,
dim: usize, dim: usize,

View File

@ -28,8 +28,9 @@ macro_rules! testgen_all {
burn_tensor::testgen_exp!(); burn_tensor::testgen_exp!();
burn_tensor::testgen_log!(); burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!(); burn_tensor::testgen_log1p!();
burn_tensor::testgen_index_select_dim!();
burn_tensor::testgen_index!(); burn_tensor::testgen_index!();
burn_tensor::testgen_index_select!();
burn_tensor::testgen_index_select_dim!();
burn_tensor::testgen_map_comparison!(); burn_tensor::testgen_map_comparison!();
burn_tensor::testgen_mask!(); burn_tensor::testgen_mask!();
burn_tensor::testgen_matmul!(); burn_tensor::testgen_matmul!();

View File

@ -0,0 +1,63 @@
#[burn_tensor_testgen::testgen(index_select)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_select_1d() {
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select(indexes);
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
}
#[test]
fn should_select_2d() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([[2, 1, 0, 0], [2, 0, 1, 2]]));
let output = tensor.index_select(indexes);
assert_eq!(
output.into_data(),
Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]])
);
}
#[test]
fn should_select_2d_only_1dim() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([[1, 2]])).reshape([2, 1]);
let output = tensor.index_select(indexes);
assert_eq!(output.into_data(), Data::from([[1.0], [5.0]]));
}
#[test]
fn should_select_assign_1d() {
let tensor = TestTensor::from_data(Data::from([0.0, 0.0, 0.0]));
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0]));
let indexes = TestTensorInt::from_data(Data::from([1, 0, 2]));
let output = tensor.index_select_assign(indexes, values);
assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0]));
}
#[test]
fn should_select_assign_2d() {
let tensor = TestTensor::from_data(Data::from([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]));
let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]));
let indexes = TestTensorInt::from_data(Data::from([[1, 0, 2], [1, 2, 0]]));
let output = tensor.index_select_assign(indexes, values);
assert_eq!(
output.into_data(),
Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]])
);
}
}

View File

@ -6,6 +6,7 @@ mod div;
mod erf; mod erf;
mod exp; mod exp;
mod index; mod index;
mod index_select;
mod index_select_dim; mod index_select_dim;
mod log; mod log;
mod log1p; mod log1p;