Feat/gather scatter (#367)

This commit is contained in:
Nathaniel Simard 2023-05-27 11:40:04 -04:00 committed by GitHub
parent 41e54f741b
commit 2a4ba5a6ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 553 additions and 281 deletions

View File

@ -195,19 +195,21 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::int_lower_equal_elem(lhs, rhs)
}
fn int_index_select<const D: usize>(
fn int_gather<const D: usize>(
dim: usize,
tensor: IntTensor<B, D>,
indexes: IntTensor<B, D>,
) -> IntTensor<B, D> {
B::int_index_select(tensor, indexes)
B::int_gather(dim, tensor, indexes)
}
fn int_index_select_assign<const D: usize>(
fn int_scatter<const D: usize>(
dim: usize,
tensor: IntTensor<B, D>,
indexes: IntTensor<B, D>,
value: IntTensor<B, D>,
) -> IntTensor<B, D> {
B::int_index_select_assign(tensor, indexes, value)
B::int_scatter(dim, tensor, indexes, value)
}
fn int_index_select_dim<const D: usize>(

View File

@ -14,7 +14,7 @@ impl<B: Backend, const D: usize> Backward<B, D, 1> for MaxMinDim {
let device = B::device(&grad);
let zeros = B::zeros(shape, &device);
B::index_select_assign(zeros, indexes, grad)
B::scatter(D - 1, zeros, indexes, grad)
});
}
}

View File

@ -438,55 +438,55 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
}
}
fn index_select<const D: usize>(
fn gather<const D: usize>(
dim: usize,
tensor: ADTensor<B, D>,
indexes: IntTensor<B, D>,
) -> ADTensor<B, D> {
#[derive(Debug)]
struct IndexSelect;
struct Gather;
impl<B: Backend, const D: usize> Backward<B, D, 1> for IndexSelect {
type State = (IntTensor<B, D>, Shape<D>, B::Device);
impl<B: Backend, const D: usize> Backward<B, D, 1> for Gather {
type State = (usize, IntTensor<B, D>, Shape<D>, B::Device);
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
let (indexes, shape, device) = ops.state;
let (dim, indexes, shape, device) = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let zeros = B::zeros(shape, &device);
B::index_select_assign(zeros, indexes, grad)
B::scatter(dim, zeros, indexes, grad)
});
}
}
match IndexSelect
.prepare([tensor.node], [tensor.graph])
.statefull()
{
match Gather.prepare([tensor.node], [tensor.graph]).statefull() {
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indexes.clone(),
B::shape(&tensor.primitive),
B::device(&tensor.primitive),
),
B::index_select(tensor.primitive, indexes),
B::gather(dim, tensor.primitive, indexes),
),
OpsKind::UnTracked(prep) => prep.finish(B::index_select(tensor.primitive, indexes)),
OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indexes)),
}
}
fn index_select_assign<const D: usize>(
fn scatter<const D: usize>(
dim: usize,
tensor: ADTensor<B, D>,
indexes: IntTensor<B, D>,
value: ADTensor<B, D>,
) -> ADTensor<B, D> {
#[derive(Debug)]
struct IndexSelectAssign;
struct Scatter;
impl<B: Backend, const D: usize> Backward<B, D, 2> for IndexSelectAssign {
type State = (IntTensor<B, D>, Shape<D>, Shape<D>, B::Device);
impl<B: Backend, const D: usize> Backward<B, D, 2> for Scatter {
type State = (usize, IntTensor<B, D>, Shape<D>, Shape<D>, B::Device);
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
let (indexes, shape_lhs, shape_rhs, device) = ops.state;
let (dim, indexes, shape_lhs, shape_rhs, device) = ops.state;
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
binary::<B, D, D, D, _, _>(
@ -495,38 +495,37 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
grads,
|grad| {
let zeros = B::zeros(shape_lhs, &device);
B::index_select_assign(grad, indexes_4lhs.unwrap(), zeros)
B::scatter(dim, grad, indexes_4lhs.unwrap(), zeros)
},
|grad| {
let zeros = B::zeros(shape_rhs, &device);
B::index_select_assign(zeros, indexes_4rhs.unwrap(), grad)
B::scatter(dim, zeros, indexes_4rhs.unwrap(), grad)
},
);
}
}
match IndexSelectAssign
match Scatter
.prepare([tensor.node, value.node], [tensor.graph, value.graph])
.statefull()
{
OpsKind::Tracked(prep) => prep.finish(
(
dim,
indexes.clone(),
B::shape(&tensor.primitive),
B::shape(&value.primitive),
B::device(&value.primitive),
),
B::index_select_assign(tensor.primitive, indexes, value.primitive),
B::scatter(dim, tensor.primitive, indexes, value.primitive),
),
OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign(
tensor.primitive,
indexes,
value.primitive,
)),
OpsKind::UnTracked(prep) => {
prep.finish(B::scatter(dim, tensor.primitive, indexes, value.primitive))
}
}
}
fn index_select_dim<const D: usize>(
fn index_select<const D: usize>(
tensor: ADTensor<B, D>,
dim: usize,
indexes: IntTensor<B, 1>,
@ -542,7 +541,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
let zeros = B::zeros(shape, &device);
B::index_select_dim_assign(zeros, dim, indexes, grad)
B::index_select_assign(zeros, dim, indexes, grad)
});
}
}
@ -558,15 +557,15 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::shape(&tensor.primitive),
B::device(&tensor.primitive),
),
B::index_select_dim(tensor.primitive, dim, indexes),
B::index_select(tensor.primitive, dim, indexes),
),
OpsKind::UnTracked(prep) => {
prep.finish(B::index_select_dim(tensor.primitive, dim, indexes))
prep.finish(B::index_select(tensor.primitive, dim, indexes))
}
}
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: ADTensor<B, D1>,
dim: usize,
indexes: IntTensor<B, 1>,
@ -588,11 +587,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
grads,
|grad| {
let zeros = B::zeros(shape_lhs, &device);
B::index_select_dim_assign(grad, dim, indexes_4lhs.unwrap(), zeros)
B::index_select_assign(grad, dim, indexes_4lhs.unwrap(), zeros)
},
|grad| {
let zeros = B::zeros(shape_rhs, &device);
B::index_select_dim_assign(zeros, dim, indexes_4rhs.unwrap(), grad)
B::index_select_assign(zeros, dim, indexes_4rhs.unwrap(), grad)
},
);
}
@ -610,9 +609,9 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::shape(&value.primitive),
B::device(&value.primitive),
),
B::index_select_dim_assign(tensor.primitive, dim, indexes, value.primitive),
B::index_select_assign(tensor.primitive, dim, indexes, value.primitive),
),
OpsKind::UnTracked(prep) => prep.finish(B::index_select_dim_assign(
OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign(
tensor.primitive,
dim,
indexes,

View File

@ -1,16 +1,16 @@
#[burn_tensor_testgen::testgen(ad_index_select_dim)]
#[burn_tensor_testgen::testgen(ad_gather_scatter)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn test_index_select_dim_grad() {
fn test_gather_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([1, 0]));
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().index_select_dim(0, indexes);
let tensor_3 = tensor_1.clone().gather(1, indexes);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
@ -19,22 +19,20 @@ mod tests {
assert_eq!(
grad_1.into_data(),
Data::from([[109., 148., 187.], [37., 58., 79.]])
Data::from([[94., 150., 187.], [242., 305., 304.]])
);
}
#[test]
fn test_index_select_dim_assign_grad() {
fn test_scatter_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let values =
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([1, 0]));
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1
.clone()
.index_select_dim_assign(0, indexes, values.clone());
let tensor_3 = tensor_1.clone().scatter(1, indexes, values.clone());
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
@ -44,11 +42,11 @@ mod tests {
assert_eq!(
grad_1.into_data(),
Data::from([[127., 199., 271.], [172., 244., 316.]])
Data::from([[127., 181., 235.], [226., 316., 406.]])
);
assert_eq!(
grad_2.into_data(),
Data::from([[64., 64., 64.], [19., 19., 19.]])
Data::from([[19., 19., 19.], [64., 64., 64.]])
);
}
}

View File

@ -7,10 +7,10 @@ mod tests {
fn test_index_select_grad() {
let tensor_1 =
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
let indexes = TestADTensor::from_data(Data::from([1, 0]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1.clone().index_select(indexes);
let tensor_3 = tensor_1.clone().index_select(0, indexes);
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
@ -19,7 +19,7 @@ mod tests {
assert_eq!(
grad_1.into_data(),
Data::from([[94., 150., 187.], [242., 305., 304.]])
Data::from([[109., 148., 187.], [37., 58., 79.]])
);
}
@ -29,12 +29,12 @@ mod tests {
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
let values =
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
let indexes = TestADTensor::from_data(Data::from([1, 0]));
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
let tensor_3 = tensor_1
.clone()
.index_select_assign(indexes, values.clone());
.index_select_assign(0, indexes, values.clone());
let tensor_4 = tensor_2.matmul(tensor_3);
let grads = tensor_4.backward();
@ -44,11 +44,11 @@ mod tests {
assert_eq!(
grad_1.into_data(),
Data::from([[127., 181., 235.], [226., 316., 406.]])
Data::from([[127., 199., 271.], [172., 244., 316.]])
);
assert_eq!(
grad_2.into_data(),
Data::from([[19., 19., 19.], [64., 64., 64.]])
Data::from([[64., 64., 64.], [19., 19., 19.]])
);
}
}

View File

@ -13,10 +13,10 @@ mod cross_entropy;
mod div;
mod erf;
mod exp;
mod gather_scatter;
mod gelu;
mod index;
mod index_select;
mod index_select_dim;
mod log;
mod log1p;
mod mask;
@ -70,8 +70,8 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_erf!();
burn_autodiff::testgen_ad_exp!();
burn_autodiff::testgen_ad_index!();
burn_autodiff::testgen_ad_gather_scatter!();
burn_autodiff::testgen_ad_index_select!();
burn_autodiff::testgen_ad_index_select_dim!();
burn_autodiff::testgen_ad_log!();
burn_autodiff::testgen_ad_log1p!();
burn_autodiff::testgen_ad_mask!();

View File

@ -29,7 +29,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
let mask = self.padding_mask(&targets);
let tensor = activation::log_softmax(logits, 1);
let tensor = tensor.index_select(targets.reshape([batch_size, 1]));
let tensor = tensor.gather(1, targets.reshape([batch_size, 1]));
let tensor = self.apply_mask(tensor.reshape([batch_size]), mask);
tensor.mean().neg()

View File

@ -208,13 +208,18 @@ where
}
}
pub fn index_select<const D: usize>(
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
pub fn gather<const D: usize>(
dim: usize,
mut tensor: NdArrayTensor<E, D>,
mut indexes: NdArrayTensor<i64, D>,
) -> NdArrayTensor<E, D> {
if dim != D - 1 {
tensor.array.swap_axes(D - 1, dim);
indexes.array.swap_axes(D - 1, dim);
}
let (shape_tensor, shape_indexes) = (tensor.shape(), indexes.shape());
let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indexes.dims[D - 1]);
let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes);
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes);
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
@ -228,17 +233,30 @@ where
}
}
NdArrayOps::reshape(
let mut output = NdArrayOps::reshape(
NdArrayTensor::<E, 2>::new(output.into_shared().into_dyn()),
shape_indexes,
)
);
if dim != D - 1 {
output.array.swap_axes(D - 1, dim);
}
output
}
pub fn index_select_assign<const D: usize>(
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
value: NdArrayTensor<E, D>,
pub fn scatter<const D: usize>(
dim: usize,
mut tensor: NdArrayTensor<E, D>,
mut indexes: NdArrayTensor<i64, D>,
mut value: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
if dim != D - 1 {
tensor.array.swap_axes(D - 1, dim);
indexes.array.swap_axes(D - 1, dim);
value.array.swap_axes(D - 1, dim);
}
let (shape_tensor, shape_indexes, shape_value) =
(tensor.shape(), indexes.shape(), value.shape());
let (size_tensor, size_index, size_value) = (
@ -246,7 +264,7 @@ where
shape_indexes.dims[D - 1],
shape_value.dims[D - 1],
);
let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes);
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes);
if shape_value != shape_indexes {
panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indexes.dims, shape_value.dims);
@ -265,10 +283,14 @@ where
}
}
NdArrayOps::reshape(
let mut output = NdArrayOps::reshape(
NdArrayTensor::<E, 2>::new(tensor.into_shared().into_dyn()),
shape_tensor,
)
);
if dim != D - 1 {
output.array.swap_axes(D - 1, dim);
}
output
}
pub fn mask_scatter<const D: usize>(
@ -307,7 +329,7 @@ where
NdArrayTensor::new(array)
}
fn index_select_batch_size<const D: usize>(
fn gather_batch_size<const D: usize>(
shape_tensor: &Shape<D>,
shape_indexes: &Shape<D>,
) -> usize {
@ -323,7 +345,7 @@ where
batch_size
}
pub fn index_select_dim<const D: usize>(
pub fn index_select<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
@ -340,7 +362,7 @@ where
NdArrayTensor::new(array.into_shared())
}
pub fn index_select_dim_assign<const D1: usize, const D2: usize>(
pub fn index_select_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,

View File

@ -280,19 +280,21 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
NdArrayMathOps::mean_dim(tensor, dim)
}
fn int_index_select<const D: usize>(
fn int_gather<const D: usize>(
dim: usize,
tensor: NdArrayTensor<i64, D>,
indexes: NdArrayTensor<i64, D>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::index_select(tensor, indexes)
NdArrayMathOps::gather(dim, tensor, indexes)
}
fn int_index_select_assign<const D: usize>(
fn int_scatter<const D: usize>(
dim: usize,
tensor: NdArrayTensor<i64, D>,
indexes: NdArrayTensor<i64, D>,
value: NdArrayTensor<i64, D>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::index_select_assign(tensor, indexes, value)
NdArrayMathOps::scatter(dim, tensor, indexes, value)
}
fn int_index_select_dim<const D: usize>(
@ -300,7 +302,7 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
dim: usize,
indexes: NdArrayTensor<i64, 1>,
) -> NdArrayTensor<i64, D> {
NdArrayMathOps::index_select_dim(tensor, dim, indexes)
NdArrayMathOps::index_select(tensor, dim, indexes)
}
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
@ -309,7 +311,7 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
indexes: NdArrayTensor<i64, 1>,
value: NdArrayTensor<i64, D2>,
) -> NdArrayTensor<i64, D1> {
NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value)
NdArrayMathOps::index_select_assign(tensor, dim, indexes, value)
}
fn int_argmax<const D: usize>(
tensor: NdArrayTensor<i64, D>,

View File

@ -150,36 +150,38 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
NdArrayOps::reshape(tensor, shape)
}
fn index_select<const D: usize>(
fn gather<const D: usize>(
dim: usize,
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::index_select(tensor, indexes)
NdArrayMathOps::gather(dim, tensor, indexes)
}
fn index_select_assign<const D: usize>(
fn scatter<const D: usize>(
dim: usize,
tensor: NdArrayTensor<E, D>,
indexes: NdArrayTensor<i64, D>,
value: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::index_select_assign(tensor, indexes, value)
NdArrayMathOps::scatter(dim, tensor, indexes, value)
}
fn index_select_dim<const D: usize>(
fn index_select<const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::index_select_dim(tensor, dim, indexes)
NdArrayMathOps::index_select(tensor, dim, indexes)
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
dim: usize,
indexes: NdArrayTensor<i64, 1>,
value: NdArrayTensor<E, D2>,
) -> NdArrayTensor<E, D1> {
NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value)
NdArrayMathOps::index_select_assign(tensor, dim, indexes, value)
}
fn index<const D1: usize, const D2: usize>(

View File

@ -56,17 +56,19 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::new(tensor_original)
}
pub fn index_select<const D: usize>(
pub fn gather<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
) -> TchTensor<E, D> {
let storage = tensor.storage.clone();
let tensor = tensor.tensor.gather((D - 1) as i64, &indexes.tensor, false);
let tensor = tensor.tensor.gather(dim as i64, &indexes.tensor, false);
TchTensor::from_existing(tensor, storage)
}
pub fn index_select_assign<const D: usize>(
pub fn scatter<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
value: TchTensor<E, D>,
@ -74,7 +76,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
let storage = tensor.storage.clone();
let tensor = tensor
.tensor
.scatter_add((D - 1) as i64, &indexes.tensor, &value.tensor);
.scatter_add(dim as i64, &indexes.tensor, &value.tensor);
TchTensor::from_existing(tensor, storage)
}

View File

@ -231,19 +231,21 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
fn int_mean_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
TchOps::mean_dim(tensor, dim)
}
fn int_index_select<const D: usize>(
fn int_gather<const D: usize>(
dim: usize,
tensor: TchTensor<i64, D>,
indexes: TchTensor<i64, D>,
) -> TchTensor<i64, D> {
TchOps::index_select(tensor, indexes)
TchOps::gather(dim, tensor, indexes)
}
fn int_index_select_assign<const D: usize>(
fn int_scatter<const D: usize>(
dim: usize,
tensor: TchTensor<i64, D>,
indexes: TchTensor<i64, D>,
value: TchTensor<i64, D>,
) -> TchTensor<i64, D> {
TchOps::index_select_assign(tensor, indexes, value)
TchOps::scatter(dim, tensor, indexes, value)
}
fn int_index_select_dim<const D: usize>(

View File

@ -176,22 +176,24 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
TchOps::reshape(tensor, shape)
}
fn index_select<const D: usize>(
fn gather<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
) -> TchTensor<E, D> {
TchOps::index_select(tensor, indexes)
TchOps::gather(dim, tensor, indexes)
}
fn index_select_assign<const D: usize>(
fn scatter<const D: usize>(
dim: usize,
tensor: TchTensor<E, D>,
indexes: TchTensor<i64, D>,
value: TchTensor<E, D>,
) -> TchTensor<E, D> {
TchOps::index_select_assign(tensor, indexes, value)
TchOps::scatter(dim, tensor, indexes, value)
}
fn index_select_dim<const D: usize>(
fn index_select<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
indexes: TchTensor<i64, 1>,
@ -199,7 +201,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
TchOps::index_select_dim(tensor, dim, indexes)
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
dim: usize,
indexes: TchTensor<i64, 1>,

View File

@ -7,7 +7,7 @@ use core::ops::Range;
/// The struct should always be used with the [check](crate::check) macro.
///
/// This is a simple public crate data structure that efficiently checks tensor operations and
/// This is a simple pub(crate) data structure that efficiently checks tensor operations and
/// formats clear error messages. It's crucial that the checks are really fast, but it doesn't matter
/// when a failed check is discovered since the program will panic.
///
@ -32,14 +32,14 @@ use core::ops::Range;
/// such as the `index_select` operation. The downside of that approach is that all backend
/// implementation might re-implement the same checks, which may result in uncessary code
/// duplication. Maybe a combination of both strategies could help to cover all usecases.
pub enum TensorCheck {
pub(crate) enum TensorCheck {
Ok,
Failed(FailedTensorCheck),
}
impl TensorCheck {
/// Checks device and shape compatibility for element wise binary operations.
pub fn binary_ops_ew<B: Backend, const D: usize, K: BasicOps<B>>(
pub(crate) fn binary_ops_ew<B: Backend, const D: usize, K: BasicOps<B>>(
ops: &str,
lhs: &Tensor<B, D, K>,
rhs: &Tensor<B, D, K>,
@ -49,7 +49,7 @@ impl TensorCheck {
.binary_ops_ew_shape(ops, &lhs.shape(), &rhs.shape())
}
pub fn into_scalar<const D: usize>(shape: &Shape<D>) -> Self {
pub(crate) fn into_scalar<const D: usize>(shape: &Shape<D>) -> Self {
let mut check = Self::Ok;
if shape.num_elements() != 1 {
@ -66,7 +66,7 @@ impl TensorCheck {
check
}
pub fn dim_ops<const D: usize>(ops: &str, dim: usize) -> Self {
pub(crate) fn dim_ops<const D: usize>(ops: &str, dim: usize) -> Self {
let mut check = Self::Ok;
if dim >= D {
@ -80,7 +80,7 @@ impl TensorCheck {
check
}
pub fn reshape<const D1: usize, const D2: usize>(
pub(crate) fn reshape<const D1: usize, const D2: usize>(
original: &Shape<D1>,
target: &Shape<D2>,
) -> Self {
@ -99,7 +99,10 @@ impl TensorCheck {
check
}
pub fn flatten<const D1: usize, const D2: usize>(start_dim: usize, end_dim: usize) -> Self {
pub(crate) fn flatten<const D1: usize, const D2: usize>(
start_dim: usize,
end_dim: usize,
) -> Self {
let mut check = Self::Ok;
if start_dim > end_dim {
@ -130,7 +133,7 @@ impl TensorCheck {
check
}
pub fn unsqueeze<const D1: usize, const D2: usize>() -> Self {
pub(crate) fn unsqueeze<const D1: usize, const D2: usize>() -> Self {
let mut check = Self::Ok;
if D2 < D1 {
check = check.register(
@ -144,7 +147,7 @@ impl TensorCheck {
check
}
pub fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
pub(crate) fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
let mut check = Self::Ok;
if dim1 > D || dim2 > D {
@ -160,7 +163,10 @@ impl TensorCheck {
check
}
pub fn matmul<B: Backend, const D: usize>(lhs: &Tensor<B, D>, rhs: &Tensor<B, D>) -> Self {
pub(crate) fn matmul<B: Backend, const D: usize>(
lhs: &Tensor<B, D>,
rhs: &Tensor<B, D>,
) -> Self {
let mut check = Self::Ok;
check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device());
@ -191,7 +197,7 @@ impl TensorCheck {
check
}
pub fn cat<B: Backend, const D: usize, K: BasicOps<B>>(
pub(crate) fn cat<B: Backend, const D: usize, K: BasicOps<B>>(
tensors: &[Tensor<B, D, K>],
dim: usize,
) -> Self {
@ -241,7 +247,7 @@ impl TensorCheck {
check
}
pub fn index<const D1: usize, const D2: usize>(
pub(crate) fn index<const D1: usize, const D2: usize>(
shape: &Shape<D1>,
indexes: &[Range<usize>; D2],
) -> Self {
@ -298,7 +304,7 @@ impl TensorCheck {
check
}
pub fn index_assign<const D1: usize, const D2: usize>(
pub(crate) fn index_assign<const D1: usize, const D2: usize>(
shape: &Shape<D1>,
shape_value: &Shape<D1>,
indexes: &[Range<usize>; D2],
@ -374,8 +380,104 @@ impl TensorCheck {
check
}
pub(crate) fn gather<const D: usize>(
dim: usize,
shape: &Shape<D>,
shape_indexes: &Shape<D>,
) -> Self {
Self::check_gather_scatter_indexes(Self::Ok, "Gather", dim, shape, shape_indexes)
}
pub(crate) fn scatter<const D: usize>(
dim: usize,
shape: &Shape<D>,
shape_indexes: &Shape<D>,
shape_value: &Shape<D>,
) -> Self {
let ops = "Scatter";
let mut check =
Self::check_gather_scatter_indexes(Self::Ok, ops, dim, shape, shape_indexes);
if shape_indexes != shape_value {
check = check.register(
ops,
TensorError::new(
"Indexes tensor shape should be the same as the value tensor shape."
.to_string(),
)
.details(format!(
"The shape differs: {:?} != {:?}",
shape_indexes.dims, shape_value.dims
)),
);
}
check
}
pub(crate) fn index_select<const D: usize>(dim: usize) -> Self {
Self::check_index_select_basic::<D>(Self::Ok, "index_select", dim)
}
pub(crate) fn index_select_assign<const D: usize>(dim: usize) -> Self {
Self::check_index_select_basic::<D>(Self::Ok, "index_select_assign", dim)
}
fn check_index_select_basic<const D: usize>(mut check: Self, ops: &str, dim: usize) -> Self {
if dim > D {
check = check.register(
ops,
TensorError::new(format!(
"Can't index a tensor with ({D}) dimensions on axis ({dim})"
)),
);
}
check
}
fn check_gather_scatter_indexes<const D: usize>(
mut check: Self,
ops: &str,
dim: usize,
shape: &Shape<D>,
shape_indexes: &Shape<D>,
) -> Self {
if dim > D {
check = check.register(
ops,
TensorError::new(format!(
"Can't index a tensor with ({D}) dimensions on axis ({dim})"
)),
);
}
for i in 0..D {
if i == dim {
continue;
}
let tensor_dim_i = shape.dims[i];
let indexes_dim_i = shape_indexes.dims[i];
if tensor_dim_i != indexes_dim_i {
check = check.register(
ops,
TensorError::new(
"The tensor shape should be the same as the index tensor shape."
.to_string(),
)
.details(format!(
"The shape differs at dimension {i}: {tensor_dim_i} != {indexes_dim_i}"
)),
);
}
}
check
}
/// Checks aggregate dimension such as mean and sum.
pub fn aggregate_dim<const D: usize>(ops: &str, dim: usize) -> Self {
pub(crate) fn aggregate_dim<const D: usize>(ops: &str, dim: usize) -> Self {
let mut check = Self::Ok;
if dim > D {
@ -409,7 +511,7 @@ impl TensorCheck {
}
/// Checks if shapes are compatible for element wise operations supporting broadcasting.
pub fn binary_ops_ew_shape<const D: usize>(
pub(crate) fn binary_ops_ew_shape<const D: usize>(
self,
ops: &str,
lhs: &Shape<D>,
@ -464,14 +566,14 @@ impl TensorCheck {
}
}
pub struct FailedTensorCheck {
pub(crate) struct FailedTensorCheck {
ops: String,
errors: Vec<TensorError>,
}
impl FailedTensorCheck {
/// Format all the checks into a single message ready to be printed by a [panic](core::panic).
pub fn format(self) -> String {
pub(crate) fn format(self) -> String {
self.errors.into_iter().enumerate().fold(
format!(
"=== Tensor Operation Error ===\n Operation: '{}'\n Reason:",
@ -488,14 +590,14 @@ struct TensorError {
}
impl TensorError {
pub fn new<S: Into<String>>(description: S) -> Self {
pub(crate) fn new<S: Into<String>>(description: S) -> Self {
TensorError {
description: description.into(),
details: None,
}
}
pub fn details<S: Into<String>>(mut self, details: S) -> Self {
pub(crate) fn details<S: Into<String>>(mut self, details: S) -> Self {
self.details = Some(details.into());
self
}

View File

@ -1,4 +1,4 @@
use crate::{backend::Backend, Int, Tensor};
use crate::{backend::Backend, Data, Int, Tensor};
use core::ops::Range;
impl<B> Tensor<B, 1, Int>
@ -14,3 +14,25 @@ where
Tensor::new(B::arange(range, device))
}
}
impl<const D: usize, B> Tensor<B, D, Int>
where
B: Backend,
{
/// Create a tensor from integers (i32).
///
/// # Example
///
/// ```rust
/// use burn_tensor::backend::Backend;
/// use burn_tensor::{Tensor, Int};
///
/// fn example<B: Backend>() {
/// let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2]);
/// let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]]);
/// }
/// ```
pub fn from_ints<A: Into<Data<i32, D>>>(ints: A) -> Self {
Self::from_data(ints.into().convert())
}
}

View File

@ -203,45 +203,83 @@ where
Self::new(K::mask_fill(self.primitive, mask, value.elem()))
}
/// Select the tensor elements corresponding to the given indexes.
/// Gather tensor elements corresponding to the given indexes from the specified dim.
///
/// Example using a 3D tensor:
///
/// `output[i, j, k] = input[indexes[i, j, k], j, k]; // dim = 0`
/// `output[i, j, k] = input[i, indexes[i, j, k], k]; // dim = 1`
/// `output[i, j, k] = input[i, j, indexes[i, j, k]]; // dim = 2`
///
/// # Notes
///
/// The index tensor shoud have the same shape as the original tensor except for the last
/// dimension.
pub fn index_select(self, indexes: Tensor<B, D, Int>) -> Self {
Self::new(K::index_select(self.primitive, indexes))
/// The index tensor shoud have the same shape as the original tensor except for the dim
/// specified.
pub fn gather(self, dim: usize, indexes: Tensor<B, D, Int>) -> Self {
check!(TensorCheck::gather::<D>(
dim,
&self.shape(),
&indexes.shape()
));
Self::new(K::gather(dim, self.primitive, indexes))
}
/// Assign the selected elements corresponding to the given indexes from the value tensor
/// to the original tensor using sum reduction.
/// Assign the gathered elements corresponding to the given indexes along the speficied dimension
/// from the value tensor to the original tensor using sum reduction.
///
/// Example using a 3D tensor:
///
/// `input[indexes[i, j, k], j, k] += values[i, j, k]; // dim = 0`
/// `input[i, indexes[i, j, k], k] += values[i, j, k]; // dim = 1`
/// `input[i, j, indexes[i, j, k]] += values[i, j, k]; // dim = 2`
///
/// # Notes
///
/// The index tensor shoud have the same shape as the original tensor except for the last
/// The index tensor shoud have the same shape as the original tensor except for the speficied
/// dimension. The value and index tensors should have the same shape.
pub fn index_select_assign(self, indexes: Tensor<B, D, Int>, values: Self) -> Self {
Self::new(K::index_select_assign(
self.primitive,
indexes,
values.primitive,
))
///
/// Other references to the input tensor will not be modified by this operation.
pub fn scatter(self, dim: usize, indexes: Tensor<B, D, Int>, values: Self) -> Self {
check!(TensorCheck::scatter::<D>(
dim,
&self.shape(),
&indexes.shape(),
&values.shape()
));
Self::new(K::scatter(dim, self.primitive, indexes, values.primitive))
}
/// Select the tensor elements along the given dimension corresponding to the given indexes.
pub fn index_select_dim(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
Self::new(K::index_select_dim(self.primitive, dim, indexes))
///
/// Example using a 3D tensor:
///
/// `output[i, j, k] = input[indexes[i], j, k]; // dim = 0`
/// `output[i, j, k] = input[i, indexes[j], k]; // dim = 1`
/// `output[i, j, k] = input[i, j, indexes[k]]; // dim = 2`
pub fn index_select(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
check!(TensorCheck::index_select::<D>(dim));
Self::new(K::index_select(self.primitive, dim, indexes))
}
/// Assign the selected elements along the given dimension corresponding to the given indexes
/// from the value tensor to the original tensor using sum reduction.
pub fn index_select_dim_assign<const D2: usize>(
///
/// Example using a 3D tensor:
///
/// `input[indexes[i], j, k] += values[i, j, k]; // dim = 0`
/// `input[i, indexes[j], k] += values[i, j, k]; // dim = 1`
/// `input[i, j, indexes[k]] += values[i, j, k]; // dim = 2`
pub fn index_select_assign<const D2: usize>(
self,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Tensor<B, D2, K>,
) -> Self {
Self::new(K::index_select_dim_assign(
check!(TensorCheck::index_select_assign::<D>(dim));
Self::new(K::index_select_assign(
self.primitive,
dim,
indexes,
@ -412,21 +450,23 @@ where
mask: Tensor<B, D, Bool>,
value: Self::Elem,
) -> Self::Primitive<D>;
fn index_select<const D: usize>(
fn gather<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
) -> Self::Primitive<D>;
fn index_select_assign<const D: usize>(
fn scatter<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D>;
fn index_select_dim<const D: usize>(
fn index_select<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
) -> Self::Primitive<D>;
fn index_select_dim_assign<const D1: usize, const D2: usize>(
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
dim: usize,
indexes: Tensor<B, 1, Int>,
@ -588,7 +628,7 @@ impl<B: Backend> Numeric<B> for Int {
B::int_mask_fill(tensor, mask.primitive, value)
}
fn index_select_dim<const D: usize>(
fn index_select<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
@ -596,7 +636,7 @@ impl<B: Backend> Numeric<B> for Int {
B::int_index_select_dim(tensor, dim, indexes.primitive)
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
dim: usize,
indexes: Tensor<B, 1, Int>,
@ -604,19 +644,21 @@ impl<B: Backend> Numeric<B> for Int {
) -> Self::Primitive<D1> {
B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values)
}
fn index_select<const D: usize>(
fn gather<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
) -> Self::Primitive<D> {
B::int_index_select(tensor, indexes.primitive)
B::int_gather(dim, tensor, indexes.primitive)
}
fn index_select_assign<const D: usize>(
fn scatter<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::int_index_select_assign(tensor, indexes.primitive, values)
B::int_scatter(dim, tensor, indexes.primitive, values)
}
fn argmax<const D: usize>(
@ -804,36 +846,38 @@ impl<B: Backend> Numeric<B> for Float {
B::mask_fill(tensor, mask.primitive, value)
}
fn index_select_dim<const D: usize>(
fn index_select<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
indexes: Tensor<B, 1, Int>,
) -> Self::Primitive<D> {
B::index_select_dim(tensor, dim, indexes.primitive)
B::index_select(tensor, dim, indexes.primitive)
}
fn index_select_dim_assign<const D1: usize, const D2: usize>(
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: Self::Primitive<D1>,
dim: usize,
indexes: Tensor<B, 1, Int>,
values: Self::Primitive<D2>,
) -> Self::Primitive<D1> {
B::index_select_dim_assign(tensor, dim, indexes.primitive, values)
B::index_select_assign(tensor, dim, indexes.primitive, values)
}
fn index_select<const D: usize>(
fn gather<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
) -> Self::Primitive<D> {
B::index_select(tensor, indexes.primitive)
B::gather(dim, tensor, indexes.primitive)
}
fn index_select_assign<const D: usize>(
fn scatter<const D: usize>(
dim: usize,
tensor: Self::Primitive<D>,
indexes: Tensor<B, D, Int>,
values: Self::Primitive<D>,
) -> Self::Primitive<D> {
B::index_select_assign(tensor, indexes.primitive, values)
B::scatter(dim, tensor, indexes.primitive, values)
}
fn argmax<const D: usize>(

View File

@ -44,11 +44,13 @@ pub trait IntTensorOps<B: Backend> {
mask: B::BoolTensorPrimitive<D>,
value: B::IntElem,
) -> B::IntTensorPrimitive<D>;
fn int_index_select<const D: usize>(
fn int_gather<const D: usize>(
dim: usize,
tensor: B::IntTensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
) -> B::IntTensorPrimitive<D>;
fn int_index_select_assign<const D: usize>(
fn int_scatter<const D: usize>(
dim: usize,
tensor: B::IntTensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
value: B::IntTensorPrimitive<D>,
@ -204,14 +206,14 @@ pub trait IntTensorOps<B: Backend> {
) -> B::IntTensorPrimitive<D> {
let index = B::int_argmax(tensor.clone(), dim);
B::int_index_select(tensor, index)
B::int_gather(D - 1, tensor, index)
}
fn int_max_dim_with_indexes<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
let index = B::int_argmax(tensor.clone(), dim);
let values = B::int_index_select(tensor, index.clone());
let values = B::int_gather(D - 1, tensor, index.clone());
(values, index)
}
@ -227,14 +229,14 @@ pub trait IntTensorOps<B: Backend> {
) -> B::IntTensorPrimitive<D> {
let index = B::int_argmin(tensor.clone(), dim);
B::int_index_select(tensor, index)
B::int_gather(D - 1, tensor, index)
}
fn int_min_dim_with_indexes<const D: usize>(
tensor: B::IntTensorPrimitive<D>,
dim: usize,
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
let index = B::int_argmin(tensor.clone(), dim);
let values = B::int_index_select(tensor, index.clone());
let values = B::int_gather(D - 1, tensor, index.clone());
(values, index)
}

View File

@ -115,21 +115,23 @@ pub trait TensorOps<B: Backend> {
tensor: B::TensorPrimitive<D1>,
shape: Shape<D2>,
) -> B::TensorPrimitive<D2>;
fn index_select<const D: usize>(
fn gather<const D: usize>(
dim: usize,
tensor: B::TensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
fn index_select_assign<const D: usize>(
fn scatter<const D: usize>(
dim: usize,
tensor: B::TensorPrimitive<D>,
indexes: B::IntTensorPrimitive<D>,
value: B::TensorPrimitive<D>,
) -> B::TensorPrimitive<D>;
fn index_select_dim<const D: usize>(
fn index_select<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
) -> B::TensorPrimitive<D>;
fn index_select_dim_assign<const D1: usize, const D2: usize>(
fn index_select_assign<const D1: usize, const D2: usize>(
tensor: B::TensorPrimitive<D1>,
dim: usize,
indexes: B::IntTensorPrimitive<1>,
@ -251,14 +253,14 @@ pub trait TensorOps<B: Backend> {
fn max_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
let index = B::argmax(tensor.clone(), dim);
B::index_select(tensor, index)
B::gather(D - 1, tensor, index)
}
fn max_dim_with_indexes<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
let index = B::argmax(tensor.clone(), dim);
let values = B::index_select(tensor, index.clone());
let values = B::gather(D - 1, tensor, index.clone());
(values, index)
}
@ -271,14 +273,14 @@ pub trait TensorOps<B: Backend> {
fn min_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
let index = B::argmin(tensor.clone(), dim);
B::index_select(tensor, index)
B::gather(D - 1, tensor, index)
}
fn min_dim_with_indexes<const D: usize>(
tensor: B::TensorPrimitive<D>,
dim: usize,
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
let index = B::argmin(tensor.clone(), dim);
let values = B::index_select(tensor, index.clone());
let values = B::gather(D - 1, tensor, index.clone());
(values, index)
}

View File

@ -34,8 +34,8 @@ macro_rules! testgen_all {
burn_tensor::testgen_log!();
burn_tensor::testgen_log1p!();
burn_tensor::testgen_index!();
burn_tensor::testgen_gather_scatter!();
burn_tensor::testgen_index_select!();
burn_tensor::testgen_index_select_dim!();
burn_tensor::testgen_map_comparison!();
burn_tensor::testgen_mask!();
burn_tensor::testgen_matmul!();

View File

@ -0,0 +1,132 @@
#[burn_tensor_testgen::testgen(gather_scatter)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_gather_1d_dim0() {
let tensor = TestTensor::from_floats([0.0, 1.0, 2.0]);
let indexes = TestTensorInt::from_ints([1, 1, 0, 1, 2]);
let output = tensor.gather(0, indexes);
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
}
#[test]
fn should_gather_2d_dim0() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]);
let output = tensor.gather(0, indexes);
assert_eq!(
output.into_data(),
Data::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]])
);
}
#[test]
fn should_gather_2d_dim1() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]);
let output = tensor.gather(1, indexes);
assert_eq!(
output.into_data(),
Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]])
);
}
#[test]
fn should_gather_3d_dim1() {
let tensor = TestTensor::from_floats([
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
]);
let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
let output = tensor.gather(1, indexes);
assert_eq!(
output.into_data(),
Data::from([
[[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]],
[[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]]
])
);
}
#[test]
fn should_gather_2d_only_1dim() {
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let indexes = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]);
let output = tensor.gather(1, indexes);
assert_eq!(output.into_data(), Data::from([[1.0], [5.0]]));
}
#[test]
fn should_scatter_1d() {
let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]);
let values = TestTensor::from_floats([5.0, 4.0, 3.0]);
let indexes = TestTensorInt::from_ints([1, 0, 2]);
let output = tensor.scatter(0, indexes, values);
assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0]));
}
#[test]
fn should_scatter_2d_dim0() {
let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let indexes = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]);
let output = tensor.scatter(0, indexes, values);
assert_eq!(
output.into_data(),
Data::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]])
);
}
#[test]
fn should_scatter_2d_dim1() {
let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
let indexes = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]);
let output = tensor.scatter(1, indexes, values);
assert_eq!(
output.into_data(),
Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]])
);
}
#[test]
fn should_scatter_3d_dim1() {
let tensor = TestTensor::from_floats([
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
]);
let values = TestTensor::from_floats([
[[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]],
[[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
]);
let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
let output = tensor.scatter(1, indexes, values);
assert_eq!(
output.into_data(),
Data::from([
[[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]],
[[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]]
])
);
}
}

View File

@ -8,56 +8,77 @@ mod tests {
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select(indexes);
let output = tensor.index_select(0, indexes);
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
}
#[test]
fn should_select_2d() {
fn should_select_2d_dim0_same_num_dim() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([[2, 1, 0, 0], [2, 0, 1, 2]]));
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
let output = tensor.index_select(indexes);
let output = tensor.index_select(0, indexes);
assert_eq!(
output.into_data(),
Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]])
Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
);
}
#[test]
fn should_select_2d_only_1dim() {
fn should_select_2d_dim0_more_num_dim() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([[1, 2]])).reshape([2, 1]);
let indexes = TestTensorInt::from_data(Data::from([1, 0, 1, 1]));
let output = tensor.index_select(indexes);
let output = tensor.index_select(0, indexes);
assert_eq!(output.into_data(), Data::from([[1.0], [5.0]]));
assert_eq!(
output.into_data(),
Data::from([
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[3.0, 4.0, 5.0]
])
);
}
#[test]
fn should_select_2d_dim1() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select(1, indexes);
assert_eq!(
output.into_data(),
Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]])
);
}
#[test]
fn should_select_assign_1d() {
let tensor = TestTensor::from_data(Data::from([0.0, 0.0, 0.0]));
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0]));
let indexes = TestTensorInt::from_data(Data::from([1, 0, 2]));
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0, 2.0, 1.0]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select_assign(indexes, values);
let output = tensor.index_select_assign(0, indexes, values);
assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0]));
assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0]));
}
#[test]
fn should_select_assign_2d() {
let tensor = TestTensor::from_data(Data::from([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]));
fn should_select_assign_2d_dim0() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]));
let indexes = TestTensorInt::from_data(Data::from([[1, 0, 2], [1, 2, 0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
let output = tensor.index_select_assign(indexes, values);
let output = tensor.index_select_assign(0, indexes, values);
assert_eq!(
output.into_data(),
Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]])
Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]])
);
}
}

View File

@ -1,84 +0,0 @@
#[burn_tensor_testgen::testgen(index_select_dim)]
mod tests {
use super::*;
use burn_tensor::{Data, Tensor};
#[test]
fn should_select_1d() {
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select_dim(0, indexes);
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
}
#[test]
fn should_select_2d_dim0_same_num_dim() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
let output = tensor.index_select_dim(0, indexes);
assert_eq!(
output.into_data(),
Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
);
}
#[test]
fn should_select_2d_dim0_more_num_dim() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 0, 1, 1]));
let output = tensor.index_select_dim(0, indexes);
assert_eq!(
output.into_data(),
Data::from([
[3.0, 4.0, 5.0],
[0.0, 1.0, 2.0],
[3.0, 4.0, 5.0],
[3.0, 4.0, 5.0]
])
);
}
#[test]
fn should_select_2d_dim1() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select_dim(1, indexes);
assert_eq!(
output.into_data(),
Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]])
);
}
#[test]
fn should_select_assign_1d() {
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0, 2.0, 1.0]));
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
let output = tensor.index_select_dim_assign(0, indexes, values);
assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0]));
}
#[test]
fn should_select_assign_2d_dim0() {
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]));
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
let output = tensor.index_select_dim_assign(0, indexes, values);
assert_eq!(
output.into_data(),
Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]])
);
}
}

View File

@ -6,9 +6,9 @@ mod div;
mod erf;
mod exp;
mod flatten;
mod gather_scatter;
mod index;
mod index_select;
mod index_select_dim;
mod log;
mod log1p;
mod map_comparison;