mirror of https://github.com/tracel-ai/burn.git
Feat/gather scatter (#367)
This commit is contained in:
parent
41e54f741b
commit
2a4ba5a6ab
|
@ -195,19 +195,21 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::int_lower_equal_elem(lhs, rhs)
|
||||
}
|
||||
|
||||
fn int_index_select<const D: usize>(
|
||||
fn int_gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: IntTensor<B, D>,
|
||||
indexes: IntTensor<B, D>,
|
||||
) -> IntTensor<B, D> {
|
||||
B::int_index_select(tensor, indexes)
|
||||
B::int_gather(dim, tensor, indexes)
|
||||
}
|
||||
|
||||
fn int_index_select_assign<const D: usize>(
|
||||
fn int_scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: IntTensor<B, D>,
|
||||
indexes: IntTensor<B, D>,
|
||||
value: IntTensor<B, D>,
|
||||
) -> IntTensor<B, D> {
|
||||
B::int_index_select_assign(tensor, indexes, value)
|
||||
B::int_scatter(dim, tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
|
|
|
@ -14,7 +14,7 @@ impl<B: Backend, const D: usize> Backward<B, D, 1> for MaxMinDim {
|
|||
let device = B::device(&grad);
|
||||
let zeros = B::zeros(shape, &device);
|
||||
|
||||
B::index_select_assign(zeros, indexes, grad)
|
||||
B::scatter(D - 1, zeros, indexes, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -438,55 +438,55 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
}
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: ADTensor<B, D>,
|
||||
indexes: IntTensor<B, D>,
|
||||
) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
struct IndexSelect;
|
||||
struct Gather;
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for IndexSelect {
|
||||
type State = (IntTensor<B, D>, Shape<D>, B::Device);
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for Gather {
|
||||
type State = (usize, IntTensor<B, D>, Shape<D>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||
let (indexes, shape, device) = ops.state;
|
||||
let (dim, indexes, shape, device) = ops.state;
|
||||
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let zeros = B::zeros(shape, &device);
|
||||
B::index_select_assign(zeros, indexes, grad)
|
||||
B::scatter(dim, zeros, indexes, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match IndexSelect
|
||||
.prepare([tensor.node], [tensor.graph])
|
||||
.statefull()
|
||||
{
|
||||
match Gather.prepare([tensor.node], [tensor.graph]).statefull() {
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
dim,
|
||||
indexes.clone(),
|
||||
B::shape(&tensor.primitive),
|
||||
B::device(&tensor.primitive),
|
||||
),
|
||||
B::index_select(tensor.primitive, indexes),
|
||||
B::gather(dim, tensor.primitive, indexes),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::index_select(tensor.primitive, indexes)),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::gather(dim, tensor.primitive, indexes)),
|
||||
}
|
||||
}
|
||||
|
||||
fn index_select_assign<const D: usize>(
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: ADTensor<B, D>,
|
||||
indexes: IntTensor<B, D>,
|
||||
value: ADTensor<B, D>,
|
||||
) -> ADTensor<B, D> {
|
||||
#[derive(Debug)]
|
||||
struct IndexSelectAssign;
|
||||
struct Scatter;
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 2> for IndexSelectAssign {
|
||||
type State = (IntTensor<B, D>, Shape<D>, Shape<D>, B::Device);
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 2> for Scatter {
|
||||
type State = (usize, IntTensor<B, D>, Shape<D>, Shape<D>, B::Device);
|
||||
|
||||
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||
let (indexes, shape_lhs, shape_rhs, device) = ops.state;
|
||||
let (dim, indexes, shape_lhs, shape_rhs, device) = ops.state;
|
||||
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
|
||||
|
||||
binary::<B, D, D, D, _, _>(
|
||||
|
@ -495,38 +495,37 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
grads,
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_lhs, &device);
|
||||
B::index_select_assign(grad, indexes_4lhs.unwrap(), zeros)
|
||||
B::scatter(dim, grad, indexes_4lhs.unwrap(), zeros)
|
||||
},
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_rhs, &device);
|
||||
B::index_select_assign(zeros, indexes_4rhs.unwrap(), grad)
|
||||
B::scatter(dim, zeros, indexes_4rhs.unwrap(), grad)
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match IndexSelectAssign
|
||||
match Scatter
|
||||
.prepare([tensor.node, value.node], [tensor.graph, value.graph])
|
||||
.statefull()
|
||||
{
|
||||
OpsKind::Tracked(prep) => prep.finish(
|
||||
(
|
||||
dim,
|
||||
indexes.clone(),
|
||||
B::shape(&tensor.primitive),
|
||||
B::shape(&value.primitive),
|
||||
B::device(&value.primitive),
|
||||
),
|
||||
B::index_select_assign(tensor.primitive, indexes, value.primitive),
|
||||
B::scatter(dim, tensor.primitive, indexes, value.primitive),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign(
|
||||
tensor.primitive,
|
||||
indexes,
|
||||
value.primitive,
|
||||
)),
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::scatter(dim, tensor.primitive, indexes, value.primitive))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
fn index_select<const D: usize>(
|
||||
tensor: ADTensor<B, D>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
|
@ -542,7 +541,7 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
let zeros = B::zeros(shape, &device);
|
||||
B::index_select_dim_assign(zeros, dim, indexes, grad)
|
||||
B::index_select_assign(zeros, dim, indexes, grad)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
@ -558,15 +557,15 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::shape(&tensor.primitive),
|
||||
B::device(&tensor.primitive),
|
||||
),
|
||||
B::index_select_dim(tensor.primitive, dim, indexes),
|
||||
B::index_select(tensor.primitive, dim, indexes),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::index_select_dim(tensor.primitive, dim, indexes))
|
||||
prep.finish(B::index_select(tensor.primitive, dim, indexes))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: ADTensor<B, D1>,
|
||||
dim: usize,
|
||||
indexes: IntTensor<B, 1>,
|
||||
|
@ -588,11 +587,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
grads,
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_lhs, &device);
|
||||
B::index_select_dim_assign(grad, dim, indexes_4lhs.unwrap(), zeros)
|
||||
B::index_select_assign(grad, dim, indexes_4lhs.unwrap(), zeros)
|
||||
},
|
||||
|grad| {
|
||||
let zeros = B::zeros(shape_rhs, &device);
|
||||
B::index_select_dim_assign(zeros, dim, indexes_4rhs.unwrap(), grad)
|
||||
B::index_select_assign(zeros, dim, indexes_4rhs.unwrap(), grad)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -610,9 +609,9 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::shape(&value.primitive),
|
||||
B::device(&value.primitive),
|
||||
),
|
||||
B::index_select_dim_assign(tensor.primitive, dim, indexes, value.primitive),
|
||||
B::index_select_assign(tensor.primitive, dim, indexes, value.primitive),
|
||||
),
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::index_select_dim_assign(
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign(
|
||||
tensor.primitive,
|
||||
dim,
|
||||
indexes,
|
||||
|
|
|
@ -1,16 +1,16 @@
|
|||
#[burn_tensor_testgen::testgen(ad_index_select_dim)]
|
||||
#[burn_tensor_testgen::testgen(ad_gather_scatter)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn test_index_select_dim_grad() {
|
||||
fn test_gather_grad() {
|
||||
let tensor_1 =
|
||||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||
let indexes = TestADTensor::from_data(Data::from([1, 0]));
|
||||
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1.clone().index_select_dim(0, indexes);
|
||||
let tensor_3 = tensor_1.clone().gather(1, indexes);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
@ -19,22 +19,20 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
grad_1.into_data(),
|
||||
Data::from([[109., 148., 187.], [37., 58., 79.]])
|
||||
Data::from([[94., 150., 187.], [242., 305., 304.]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_index_select_dim_assign_grad() {
|
||||
fn test_scatter_grad() {
|
||||
let tensor_1 =
|
||||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||
let values =
|
||||
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
|
||||
let indexes = TestADTensor::from_data(Data::from([1, 0]));
|
||||
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1
|
||||
.clone()
|
||||
.index_select_dim_assign(0, indexes, values.clone());
|
||||
let tensor_3 = tensor_1.clone().scatter(1, indexes, values.clone());
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
@ -44,11 +42,11 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
grad_1.into_data(),
|
||||
Data::from([[127., 199., 271.], [172., 244., 316.]])
|
||||
Data::from([[127., 181., 235.], [226., 316., 406.]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.into_data(),
|
||||
Data::from([[64., 64., 64.], [19., 19., 19.]])
|
||||
Data::from([[19., 19., 19.], [64., 64., 64.]])
|
||||
);
|
||||
}
|
||||
}
|
|
@ -7,10 +7,10 @@ mod tests {
|
|||
fn test_index_select_grad() {
|
||||
let tensor_1 =
|
||||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
|
||||
let indexes = TestADTensor::from_data(Data::from([1, 0]));
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1.clone().index_select(indexes);
|
||||
let tensor_3 = tensor_1.clone().index_select(0, indexes);
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
@ -19,7 +19,7 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
grad_1.into_data(),
|
||||
Data::from([[94., 150., 187.], [242., 305., 304.]])
|
||||
Data::from([[109., 148., 187.], [37., 58., 79.]])
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -29,12 +29,12 @@ mod tests {
|
|||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||
let values =
|
||||
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
|
||||
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
|
||||
let indexes = TestADTensor::from_data(Data::from([1, 0]));
|
||||
|
||||
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||
let tensor_3 = tensor_1
|
||||
.clone()
|
||||
.index_select_assign(indexes, values.clone());
|
||||
.index_select_assign(0, indexes, values.clone());
|
||||
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||
|
||||
let grads = tensor_4.backward();
|
||||
|
@ -44,11 +44,11 @@ mod tests {
|
|||
|
||||
assert_eq!(
|
||||
grad_1.into_data(),
|
||||
Data::from([[127., 181., 235.], [226., 316., 406.]])
|
||||
Data::from([[127., 199., 271.], [172., 244., 316.]])
|
||||
);
|
||||
assert_eq!(
|
||||
grad_2.into_data(),
|
||||
Data::from([[19., 19., 19.], [64., 64., 64.]])
|
||||
Data::from([[64., 64., 64.], [19., 19., 19.]])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,10 +13,10 @@ mod cross_entropy;
|
|||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod gather_scatter;
|
||||
mod gelu;
|
||||
mod index;
|
||||
mod index_select;
|
||||
mod index_select_dim;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod mask;
|
||||
|
@ -70,8 +70,8 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_erf!();
|
||||
burn_autodiff::testgen_ad_exp!();
|
||||
burn_autodiff::testgen_ad_index!();
|
||||
burn_autodiff::testgen_ad_gather_scatter!();
|
||||
burn_autodiff::testgen_ad_index_select!();
|
||||
burn_autodiff::testgen_ad_index_select_dim!();
|
||||
burn_autodiff::testgen_ad_log!();
|
||||
burn_autodiff::testgen_ad_log1p!();
|
||||
burn_autodiff::testgen_ad_mask!();
|
||||
|
|
|
@ -29,7 +29,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
|
|||
|
||||
let mask = self.padding_mask(&targets);
|
||||
let tensor = activation::log_softmax(logits, 1);
|
||||
let tensor = tensor.index_select(targets.reshape([batch_size, 1]));
|
||||
let tensor = tensor.gather(1, targets.reshape([batch_size, 1]));
|
||||
let tensor = self.apply_mask(tensor.reshape([batch_size]), mask);
|
||||
|
||||
tensor.mean().neg()
|
||||
|
|
|
@ -208,13 +208,18 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn index_select<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
pub fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
mut indexes: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
if dim != D - 1 {
|
||||
tensor.array.swap_axes(D - 1, dim);
|
||||
indexes.array.swap_axes(D - 1, dim);
|
||||
}
|
||||
let (shape_tensor, shape_indexes) = (tensor.shape(), indexes.shape());
|
||||
let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indexes.dims[D - 1]);
|
||||
let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes);
|
||||
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes);
|
||||
|
||||
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
|
||||
let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
|
||||
|
@ -228,17 +233,30 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
NdArrayOps::reshape(
|
||||
let mut output = NdArrayOps::reshape(
|
||||
NdArrayTensor::<E, 2>::new(output.into_shared().into_dyn()),
|
||||
shape_indexes,
|
||||
)
|
||||
);
|
||||
|
||||
if dim != D - 1 {
|
||||
output.array.swap_axes(D - 1, dim);
|
||||
}
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn index_select_assign<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
value: NdArrayTensor<E, D>,
|
||||
pub fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
mut indexes: NdArrayTensor<i64, D>,
|
||||
mut value: NdArrayTensor<E, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
if dim != D - 1 {
|
||||
tensor.array.swap_axes(D - 1, dim);
|
||||
indexes.array.swap_axes(D - 1, dim);
|
||||
value.array.swap_axes(D - 1, dim);
|
||||
}
|
||||
|
||||
let (shape_tensor, shape_indexes, shape_value) =
|
||||
(tensor.shape(), indexes.shape(), value.shape());
|
||||
let (size_tensor, size_index, size_value) = (
|
||||
|
@ -246,7 +264,7 @@ where
|
|||
shape_indexes.dims[D - 1],
|
||||
shape_value.dims[D - 1],
|
||||
);
|
||||
let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes);
|
||||
let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indexes);
|
||||
|
||||
if shape_value != shape_indexes {
|
||||
panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indexes.dims, shape_value.dims);
|
||||
|
@ -265,10 +283,14 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
NdArrayOps::reshape(
|
||||
let mut output = NdArrayOps::reshape(
|
||||
NdArrayTensor::<E, 2>::new(tensor.into_shared().into_dyn()),
|
||||
shape_tensor,
|
||||
)
|
||||
);
|
||||
if dim != D - 1 {
|
||||
output.array.swap_axes(D - 1, dim);
|
||||
}
|
||||
output
|
||||
}
|
||||
|
||||
pub fn mask_scatter<const D: usize>(
|
||||
|
@ -307,7 +329,7 @@ where
|
|||
NdArrayTensor::new(array)
|
||||
}
|
||||
|
||||
fn index_select_batch_size<const D: usize>(
|
||||
fn gather_batch_size<const D: usize>(
|
||||
shape_tensor: &Shape<D>,
|
||||
shape_indexes: &Shape<D>,
|
||||
) -> usize {
|
||||
|
@ -323,7 +345,7 @@ where
|
|||
batch_size
|
||||
}
|
||||
|
||||
pub fn index_select_dim<const D: usize>(
|
||||
pub fn index_select<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
|
@ -340,7 +362,7 @@ where
|
|||
NdArrayTensor::new(array.into_shared())
|
||||
}
|
||||
|
||||
pub fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
pub fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
|
|
|
@ -280,19 +280,21 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
|
|||
NdArrayMathOps::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_index_select<const D: usize>(
|
||||
fn int_gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::index_select(tensor, indexes)
|
||||
NdArrayMathOps::gather(dim, tensor, indexes)
|
||||
}
|
||||
|
||||
fn int_index_select_assign<const D: usize>(
|
||||
fn int_scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
value: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::index_select_assign(tensor, indexes, value)
|
||||
NdArrayMathOps::scatter(dim, tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
|
@ -300,7 +302,7 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
|
|||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::index_select_dim(tensor, dim, indexes)
|
||||
NdArrayMathOps::index_select(tensor, dim, indexes)
|
||||
}
|
||||
|
||||
fn int_index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
|
@ -309,7 +311,7 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
|
|||
indexes: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<i64, D2>,
|
||||
) -> NdArrayTensor<i64, D1> {
|
||||
NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value)
|
||||
NdArrayMathOps::index_select_assign(tensor, dim, indexes, value)
|
||||
}
|
||||
fn int_argmax<const D: usize>(
|
||||
tensor: NdArrayTensor<i64, D>,
|
||||
|
|
|
@ -150,36 +150,38 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
|||
NdArrayOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::index_select(tensor, indexes)
|
||||
NdArrayMathOps::gather(dim, tensor, indexes)
|
||||
}
|
||||
|
||||
fn index_select_assign<const D: usize>(
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
indexes: NdArrayTensor<i64, D>,
|
||||
value: NdArrayTensor<E, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::index_select_assign(tensor, indexes, value)
|
||||
NdArrayMathOps::scatter(dim, tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
fn index_select<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::index_select_dim(tensor, dim, indexes)
|
||||
NdArrayMathOps::index_select(tensor, dim, indexes)
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
dim: usize,
|
||||
indexes: NdArrayTensor<i64, 1>,
|
||||
value: NdArrayTensor<E, D2>,
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
NdArrayMathOps::index_select_dim_assign(tensor, dim, indexes, value)
|
||||
NdArrayMathOps::index_select_assign(tensor, dim, indexes, value)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
|
|
|
@ -56,17 +56,19 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
TchTensor::new(tensor_original)
|
||||
}
|
||||
|
||||
pub fn index_select<const D: usize>(
|
||||
pub fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<E, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor.tensor.gather((D - 1) as i64, &indexes.tensor, false);
|
||||
let tensor = tensor.tensor.gather(dim as i64, &indexes.tensor, false);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
||||
pub fn index_select_assign<const D: usize>(
|
||||
pub fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<E, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
value: TchTensor<E, D>,
|
||||
|
@ -74,7 +76,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
let storage = tensor.storage.clone();
|
||||
let tensor = tensor
|
||||
.tensor
|
||||
.scatter_add((D - 1) as i64, &indexes.tensor, &value.tensor);
|
||||
.scatter_add(dim as i64, &indexes.tensor, &value.tensor);
|
||||
|
||||
TchTensor::from_existing(tensor, storage)
|
||||
}
|
||||
|
|
|
@ -231,19 +231,21 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn int_mean_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
|
||||
TchOps::mean_dim(tensor, dim)
|
||||
}
|
||||
fn int_index_select<const D: usize>(
|
||||
fn int_gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<i64, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::index_select(tensor, indexes)
|
||||
TchOps::gather(dim, tensor, indexes)
|
||||
}
|
||||
|
||||
fn int_index_select_assign<const D: usize>(
|
||||
fn int_scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<i64, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
value: TchTensor<i64, D>,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::index_select_assign(tensor, indexes, value)
|
||||
TchOps::scatter(dim, tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn int_index_select_dim<const D: usize>(
|
||||
|
|
|
@ -176,22 +176,24 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<E, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
TchOps::index_select(tensor, indexes)
|
||||
TchOps::gather(dim, tensor, indexes)
|
||||
}
|
||||
|
||||
fn index_select_assign<const D: usize>(
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: TchTensor<E, D>,
|
||||
indexes: TchTensor<i64, D>,
|
||||
value: TchTensor<E, D>,
|
||||
) -> TchTensor<E, D> {
|
||||
TchOps::index_select_assign(tensor, indexes, value)
|
||||
TchOps::scatter(dim, tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
fn index_select<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
|
@ -199,7 +201,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchOps::index_select_dim(tensor, dim, indexes)
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
dim: usize,
|
||||
indexes: TchTensor<i64, 1>,
|
||||
|
|
|
@ -7,7 +7,7 @@ use core::ops::Range;
|
|||
|
||||
/// The struct should always be used with the [check](crate::check) macro.
|
||||
///
|
||||
/// This is a simple public crate data structure that efficiently checks tensor operations and
|
||||
/// This is a simple pub(crate) data structure that efficiently checks tensor operations and
|
||||
/// formats clear error messages. It's crucial that the checks are really fast, but it doesn't matter
|
||||
/// when a failed check is discovered since the program will panic.
|
||||
///
|
||||
|
@ -32,14 +32,14 @@ use core::ops::Range;
|
|||
/// such as the `index_select` operation. The downside of that approach is that all backend
|
||||
/// implementation might re-implement the same checks, which may result in uncessary code
|
||||
/// duplication. Maybe a combination of both strategies could help to cover all usecases.
|
||||
pub enum TensorCheck {
|
||||
pub(crate) enum TensorCheck {
|
||||
Ok,
|
||||
Failed(FailedTensorCheck),
|
||||
}
|
||||
|
||||
impl TensorCheck {
|
||||
/// Checks device and shape compatibility for element wise binary operations.
|
||||
pub fn binary_ops_ew<B: Backend, const D: usize, K: BasicOps<B>>(
|
||||
pub(crate) fn binary_ops_ew<B: Backend, const D: usize, K: BasicOps<B>>(
|
||||
ops: &str,
|
||||
lhs: &Tensor<B, D, K>,
|
||||
rhs: &Tensor<B, D, K>,
|
||||
|
@ -49,7 +49,7 @@ impl TensorCheck {
|
|||
.binary_ops_ew_shape(ops, &lhs.shape(), &rhs.shape())
|
||||
}
|
||||
|
||||
pub fn into_scalar<const D: usize>(shape: &Shape<D>) -> Self {
|
||||
pub(crate) fn into_scalar<const D: usize>(shape: &Shape<D>) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if shape.num_elements() != 1 {
|
||||
|
@ -66,7 +66,7 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub fn dim_ops<const D: usize>(ops: &str, dim: usize) -> Self {
|
||||
pub(crate) fn dim_ops<const D: usize>(ops: &str, dim: usize) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if dim >= D {
|
||||
|
@ -80,7 +80,7 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub fn reshape<const D1: usize, const D2: usize>(
|
||||
pub(crate) fn reshape<const D1: usize, const D2: usize>(
|
||||
original: &Shape<D1>,
|
||||
target: &Shape<D2>,
|
||||
) -> Self {
|
||||
|
@ -99,7 +99,10 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub fn flatten<const D1: usize, const D2: usize>(start_dim: usize, end_dim: usize) -> Self {
|
||||
pub(crate) fn flatten<const D1: usize, const D2: usize>(
|
||||
start_dim: usize,
|
||||
end_dim: usize,
|
||||
) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if start_dim > end_dim {
|
||||
|
@ -130,7 +133,7 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub fn unsqueeze<const D1: usize, const D2: usize>() -> Self {
|
||||
pub(crate) fn unsqueeze<const D1: usize, const D2: usize>() -> Self {
|
||||
let mut check = Self::Ok;
|
||||
if D2 < D1 {
|
||||
check = check.register(
|
||||
|
@ -144,7 +147,7 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
|
||||
pub(crate) fn swap_dims<const D: usize>(dim1: usize, dim2: usize) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if dim1 > D || dim2 > D {
|
||||
|
@ -160,7 +163,10 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub fn matmul<B: Backend, const D: usize>(lhs: &Tensor<B, D>, rhs: &Tensor<B, D>) -> Self {
|
||||
pub(crate) fn matmul<B: Backend, const D: usize>(
|
||||
lhs: &Tensor<B, D>,
|
||||
rhs: &Tensor<B, D>,
|
||||
) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
check = check.binary_ops_device("Matmul", &lhs.device(), &rhs.device());
|
||||
|
@ -191,7 +197,7 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub fn cat<B: Backend, const D: usize, K: BasicOps<B>>(
|
||||
pub(crate) fn cat<B: Backend, const D: usize, K: BasicOps<B>>(
|
||||
tensors: &[Tensor<B, D, K>],
|
||||
dim: usize,
|
||||
) -> Self {
|
||||
|
@ -241,7 +247,7 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub fn index<const D1: usize, const D2: usize>(
|
||||
pub(crate) fn index<const D1: usize, const D2: usize>(
|
||||
shape: &Shape<D1>,
|
||||
indexes: &[Range<usize>; D2],
|
||||
) -> Self {
|
||||
|
@ -298,7 +304,7 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub fn index_assign<const D1: usize, const D2: usize>(
|
||||
pub(crate) fn index_assign<const D1: usize, const D2: usize>(
|
||||
shape: &Shape<D1>,
|
||||
shape_value: &Shape<D1>,
|
||||
indexes: &[Range<usize>; D2],
|
||||
|
@ -374,8 +380,104 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
shape: &Shape<D>,
|
||||
shape_indexes: &Shape<D>,
|
||||
) -> Self {
|
||||
Self::check_gather_scatter_indexes(Self::Ok, "Gather", dim, shape, shape_indexes)
|
||||
}
|
||||
|
||||
pub(crate) fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
shape: &Shape<D>,
|
||||
shape_indexes: &Shape<D>,
|
||||
shape_value: &Shape<D>,
|
||||
) -> Self {
|
||||
let ops = "Scatter";
|
||||
let mut check =
|
||||
Self::check_gather_scatter_indexes(Self::Ok, ops, dim, shape, shape_indexes);
|
||||
|
||||
if shape_indexes != shape_value {
|
||||
check = check.register(
|
||||
ops,
|
||||
TensorError::new(
|
||||
"Indexes tensor shape should be the same as the value tensor shape."
|
||||
.to_string(),
|
||||
)
|
||||
.details(format!(
|
||||
"The shape differs: {:?} != {:?}",
|
||||
shape_indexes.dims, shape_value.dims
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
pub(crate) fn index_select<const D: usize>(dim: usize) -> Self {
|
||||
Self::check_index_select_basic::<D>(Self::Ok, "index_select", dim)
|
||||
}
|
||||
|
||||
pub(crate) fn index_select_assign<const D: usize>(dim: usize) -> Self {
|
||||
Self::check_index_select_basic::<D>(Self::Ok, "index_select_assign", dim)
|
||||
}
|
||||
|
||||
fn check_index_select_basic<const D: usize>(mut check: Self, ops: &str, dim: usize) -> Self {
|
||||
if dim > D {
|
||||
check = check.register(
|
||||
ops,
|
||||
TensorError::new(format!(
|
||||
"Can't index a tensor with ({D}) dimensions on axis ({dim})"
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
fn check_gather_scatter_indexes<const D: usize>(
|
||||
mut check: Self,
|
||||
ops: &str,
|
||||
dim: usize,
|
||||
shape: &Shape<D>,
|
||||
shape_indexes: &Shape<D>,
|
||||
) -> Self {
|
||||
if dim > D {
|
||||
check = check.register(
|
||||
ops,
|
||||
TensorError::new(format!(
|
||||
"Can't index a tensor with ({D}) dimensions on axis ({dim})"
|
||||
)),
|
||||
);
|
||||
}
|
||||
|
||||
for i in 0..D {
|
||||
if i == dim {
|
||||
continue;
|
||||
}
|
||||
|
||||
let tensor_dim_i = shape.dims[i];
|
||||
let indexes_dim_i = shape_indexes.dims[i];
|
||||
|
||||
if tensor_dim_i != indexes_dim_i {
|
||||
check = check.register(
|
||||
ops,
|
||||
TensorError::new(
|
||||
"The tensor shape should be the same as the index tensor shape."
|
||||
.to_string(),
|
||||
)
|
||||
.details(format!(
|
||||
"The shape differs at dimension {i}: {tensor_dim_i} != {indexes_dim_i}"
|
||||
)),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
/// Checks aggregate dimension such as mean and sum.
|
||||
pub fn aggregate_dim<const D: usize>(ops: &str, dim: usize) -> Self {
|
||||
pub(crate) fn aggregate_dim<const D: usize>(ops: &str, dim: usize) -> Self {
|
||||
let mut check = Self::Ok;
|
||||
|
||||
if dim > D {
|
||||
|
@ -409,7 +511,7 @@ impl TensorCheck {
|
|||
}
|
||||
|
||||
/// Checks if shapes are compatible for element wise operations supporting broadcasting.
|
||||
pub fn binary_ops_ew_shape<const D: usize>(
|
||||
pub(crate) fn binary_ops_ew_shape<const D: usize>(
|
||||
self,
|
||||
ops: &str,
|
||||
lhs: &Shape<D>,
|
||||
|
@ -464,14 +566,14 @@ impl TensorCheck {
|
|||
}
|
||||
}
|
||||
|
||||
pub struct FailedTensorCheck {
|
||||
pub(crate) struct FailedTensorCheck {
|
||||
ops: String,
|
||||
errors: Vec<TensorError>,
|
||||
}
|
||||
|
||||
impl FailedTensorCheck {
|
||||
/// Format all the checks into a single message ready to be printed by a [panic](core::panic).
|
||||
pub fn format(self) -> String {
|
||||
pub(crate) fn format(self) -> String {
|
||||
self.errors.into_iter().enumerate().fold(
|
||||
format!(
|
||||
"=== Tensor Operation Error ===\n Operation: '{}'\n Reason:",
|
||||
|
@ -488,14 +590,14 @@ struct TensorError {
|
|||
}
|
||||
|
||||
impl TensorError {
|
||||
pub fn new<S: Into<String>>(description: S) -> Self {
|
||||
pub(crate) fn new<S: Into<String>>(description: S) -> Self {
|
||||
TensorError {
|
||||
description: description.into(),
|
||||
details: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn details<S: Into<String>>(mut self, details: S) -> Self {
|
||||
pub(crate) fn details<S: Into<String>>(mut self, details: S) -> Self {
|
||||
self.details = Some(details.into());
|
||||
self
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{backend::Backend, Int, Tensor};
|
||||
use crate::{backend::Backend, Data, Int, Tensor};
|
||||
use core::ops::Range;
|
||||
|
||||
impl<B> Tensor<B, 1, Int>
|
||||
|
@ -14,3 +14,25 @@ where
|
|||
Tensor::new(B::arange(range, device))
|
||||
}
|
||||
}
|
||||
|
||||
impl<const D: usize, B> Tensor<B, D, Int>
|
||||
where
|
||||
B: Backend,
|
||||
{
|
||||
/// Create a tensor from integers (i32).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use burn_tensor::backend::Backend;
|
||||
/// use burn_tensor::{Tensor, Int};
|
||||
///
|
||||
/// fn example<B: Backend>() {
|
||||
/// let _x: Tensor<B, 1, Int> = Tensor::from_ints([1, 2]);
|
||||
/// let _y: Tensor<B, 2, Int> = Tensor::from_ints([[1, 2], [3, 4]]);
|
||||
/// }
|
||||
/// ```
|
||||
pub fn from_ints<A: Into<Data<i32, D>>>(ints: A) -> Self {
|
||||
Self::from_data(ints.into().convert())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -203,45 +203,83 @@ where
|
|||
Self::new(K::mask_fill(self.primitive, mask, value.elem()))
|
||||
}
|
||||
|
||||
/// Select the tensor elements corresponding to the given indexes.
|
||||
/// Gather tensor elements corresponding to the given indexes from the specified dim.
|
||||
///
|
||||
/// Example using a 3D tensor:
|
||||
///
|
||||
/// `output[i, j, k] = input[indexes[i, j, k], j, k]; // dim = 0`
|
||||
/// `output[i, j, k] = input[i, indexes[i, j, k], k]; // dim = 1`
|
||||
/// `output[i, j, k] = input[i, j, indexes[i, j, k]]; // dim = 2`
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The index tensor shoud have the same shape as the original tensor except for the last
|
||||
/// dimension.
|
||||
pub fn index_select(self, indexes: Tensor<B, D, Int>) -> Self {
|
||||
Self::new(K::index_select(self.primitive, indexes))
|
||||
/// The index tensor shoud have the same shape as the original tensor except for the dim
|
||||
/// specified.
|
||||
pub fn gather(self, dim: usize, indexes: Tensor<B, D, Int>) -> Self {
|
||||
check!(TensorCheck::gather::<D>(
|
||||
dim,
|
||||
&self.shape(),
|
||||
&indexes.shape()
|
||||
));
|
||||
|
||||
Self::new(K::gather(dim, self.primitive, indexes))
|
||||
}
|
||||
|
||||
/// Assign the selected elements corresponding to the given indexes from the value tensor
|
||||
/// to the original tensor using sum reduction.
|
||||
/// Assign the gathered elements corresponding to the given indexes along the speficied dimension
|
||||
/// from the value tensor to the original tensor using sum reduction.
|
||||
///
|
||||
/// Example using a 3D tensor:
|
||||
///
|
||||
/// `input[indexes[i, j, k], j, k] += values[i, j, k]; // dim = 0`
|
||||
/// `input[i, indexes[i, j, k], k] += values[i, j, k]; // dim = 1`
|
||||
/// `input[i, j, indexes[i, j, k]] += values[i, j, k]; // dim = 2`
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The index tensor shoud have the same shape as the original tensor except for the last
|
||||
/// The index tensor shoud have the same shape as the original tensor except for the speficied
|
||||
/// dimension. The value and index tensors should have the same shape.
|
||||
pub fn index_select_assign(self, indexes: Tensor<B, D, Int>, values: Self) -> Self {
|
||||
Self::new(K::index_select_assign(
|
||||
self.primitive,
|
||||
indexes,
|
||||
values.primitive,
|
||||
))
|
||||
///
|
||||
/// Other references to the input tensor will not be modified by this operation.
|
||||
pub fn scatter(self, dim: usize, indexes: Tensor<B, D, Int>, values: Self) -> Self {
|
||||
check!(TensorCheck::scatter::<D>(
|
||||
dim,
|
||||
&self.shape(),
|
||||
&indexes.shape(),
|
||||
&values.shape()
|
||||
));
|
||||
|
||||
Self::new(K::scatter(dim, self.primitive, indexes, values.primitive))
|
||||
}
|
||||
|
||||
/// Select the tensor elements along the given dimension corresponding to the given indexes.
|
||||
pub fn index_select_dim(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
|
||||
Self::new(K::index_select_dim(self.primitive, dim, indexes))
|
||||
///
|
||||
/// Example using a 3D tensor:
|
||||
///
|
||||
/// `output[i, j, k] = input[indexes[i], j, k]; // dim = 0`
|
||||
/// `output[i, j, k] = input[i, indexes[j], k]; // dim = 1`
|
||||
/// `output[i, j, k] = input[i, j, indexes[k]]; // dim = 2`
|
||||
pub fn index_select(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
|
||||
check!(TensorCheck::index_select::<D>(dim));
|
||||
Self::new(K::index_select(self.primitive, dim, indexes))
|
||||
}
|
||||
|
||||
/// Assign the selected elements along the given dimension corresponding to the given indexes
|
||||
/// from the value tensor to the original tensor using sum reduction.
|
||||
pub fn index_select_dim_assign<const D2: usize>(
|
||||
///
|
||||
/// Example using a 3D tensor:
|
||||
///
|
||||
/// `input[indexes[i], j, k] += values[i, j, k]; // dim = 0`
|
||||
/// `input[i, indexes[j], k] += values[i, j, k]; // dim = 1`
|
||||
/// `input[i, j, indexes[k]] += values[i, j, k]; // dim = 2`
|
||||
pub fn index_select_assign<const D2: usize>(
|
||||
self,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Tensor<B, D2, K>,
|
||||
) -> Self {
|
||||
Self::new(K::index_select_dim_assign(
|
||||
check!(TensorCheck::index_select_assign::<D>(dim));
|
||||
|
||||
Self::new(K::index_select_assign(
|
||||
self.primitive,
|
||||
dim,
|
||||
indexes,
|
||||
|
@ -412,21 +450,23 @@ where
|
|||
mask: Tensor<B, D, Bool>,
|
||||
value: Self::Elem,
|
||||
) -> Self::Primitive<D>;
|
||||
fn index_select<const D: usize>(
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
) -> Self::Primitive<D>;
|
||||
fn index_select_assign<const D: usize>(
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D>;
|
||||
fn index_select_dim<const D: usize>(
|
||||
fn index_select<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
) -> Self::Primitive<D>;
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
|
@ -588,7 +628,7 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
B::int_mask_fill(tensor, mask.primitive, value)
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
fn index_select<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
|
@ -596,7 +636,7 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
B::int_index_select_dim(tensor, dim, indexes.primitive)
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
|
@ -604,19 +644,21 @@ impl<B: Backend> Numeric<B> for Int {
|
|||
) -> Self::Primitive<D1> {
|
||||
B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values)
|
||||
}
|
||||
fn index_select<const D: usize>(
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::int_index_select(tensor, indexes.primitive)
|
||||
B::int_gather(dim, tensor, indexes.primitive)
|
||||
}
|
||||
|
||||
fn index_select_assign<const D: usize>(
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::int_index_select_assign(tensor, indexes.primitive, values)
|
||||
B::int_scatter(dim, tensor, indexes.primitive, values)
|
||||
}
|
||||
|
||||
fn argmax<const D: usize>(
|
||||
|
@ -804,36 +846,38 @@ impl<B: Backend> Numeric<B> for Float {
|
|||
B::mask_fill(tensor, mask.primitive, value)
|
||||
}
|
||||
|
||||
fn index_select_dim<const D: usize>(
|
||||
fn index_select<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::index_select_dim(tensor, dim, indexes.primitive)
|
||||
B::index_select(tensor, dim, indexes.primitive)
|
||||
}
|
||||
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: Self::Primitive<D1>,
|
||||
dim: usize,
|
||||
indexes: Tensor<B, 1, Int>,
|
||||
values: Self::Primitive<D2>,
|
||||
) -> Self::Primitive<D1> {
|
||||
B::index_select_dim_assign(tensor, dim, indexes.primitive, values)
|
||||
B::index_select_assign(tensor, dim, indexes.primitive, values)
|
||||
}
|
||||
|
||||
fn index_select<const D: usize>(
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::index_select(tensor, indexes.primitive)
|
||||
B::gather(dim, tensor, indexes.primitive)
|
||||
}
|
||||
|
||||
fn index_select_assign<const D: usize>(
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
indexes: Tensor<B, D, Int>,
|
||||
values: Self::Primitive<D>,
|
||||
) -> Self::Primitive<D> {
|
||||
B::index_select_assign(tensor, indexes.primitive, values)
|
||||
B::scatter(dim, tensor, indexes.primitive, values)
|
||||
}
|
||||
|
||||
fn argmax<const D: usize>(
|
||||
|
|
|
@ -44,11 +44,13 @@ pub trait IntTensorOps<B: Backend> {
|
|||
mask: B::BoolTensorPrimitive<D>,
|
||||
value: B::IntElem,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
fn int_index_select<const D: usize>(
|
||||
fn int_gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
) -> B::IntTensorPrimitive<D>;
|
||||
fn int_index_select_assign<const D: usize>(
|
||||
fn int_scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
value: B::IntTensorPrimitive<D>,
|
||||
|
@ -204,14 +206,14 @@ pub trait IntTensorOps<B: Backend> {
|
|||
) -> B::IntTensorPrimitive<D> {
|
||||
let index = B::int_argmax(tensor.clone(), dim);
|
||||
|
||||
B::int_index_select(tensor, index)
|
||||
B::int_gather(D - 1, tensor, index)
|
||||
}
|
||||
fn int_max_dim_with_indexes<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
let index = B::int_argmax(tensor.clone(), dim);
|
||||
let values = B::int_index_select(tensor, index.clone());
|
||||
let values = B::int_gather(D - 1, tensor, index.clone());
|
||||
|
||||
(values, index)
|
||||
}
|
||||
|
@ -227,14 +229,14 @@ pub trait IntTensorOps<B: Backend> {
|
|||
) -> B::IntTensorPrimitive<D> {
|
||||
let index = B::int_argmin(tensor.clone(), dim);
|
||||
|
||||
B::int_index_select(tensor, index)
|
||||
B::int_gather(D - 1, tensor, index)
|
||||
}
|
||||
fn int_min_dim_with_indexes<const D: usize>(
|
||||
tensor: B::IntTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::IntTensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
let index = B::int_argmin(tensor.clone(), dim);
|
||||
let values = B::int_index_select(tensor, index.clone());
|
||||
let values = B::int_gather(D - 1, tensor, index.clone());
|
||||
|
||||
(values, index)
|
||||
}
|
||||
|
|
|
@ -115,21 +115,23 @@ pub trait TensorOps<B: Backend> {
|
|||
tensor: B::TensorPrimitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> B::TensorPrimitive<D2>;
|
||||
fn index_select<const D: usize>(
|
||||
fn gather<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn index_select_assign<const D: usize>(
|
||||
fn scatter<const D: usize>(
|
||||
dim: usize,
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
indexes: B::IntTensorPrimitive<D>,
|
||||
value: B::TensorPrimitive<D>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn index_select_dim<const D: usize>(
|
||||
fn index_select<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn index_select_dim_assign<const D1: usize, const D2: usize>(
|
||||
fn index_select_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::TensorPrimitive<D1>,
|
||||
dim: usize,
|
||||
indexes: B::IntTensorPrimitive<1>,
|
||||
|
@ -251,14 +253,14 @@ pub trait TensorOps<B: Backend> {
|
|||
fn max_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
|
||||
let index = B::argmax(tensor.clone(), dim);
|
||||
|
||||
B::index_select(tensor, index)
|
||||
B::gather(D - 1, tensor, index)
|
||||
}
|
||||
fn max_dim_with_indexes<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
let index = B::argmax(tensor.clone(), dim);
|
||||
let values = B::index_select(tensor, index.clone());
|
||||
let values = B::gather(D - 1, tensor, index.clone());
|
||||
|
||||
(values, index)
|
||||
}
|
||||
|
@ -271,14 +273,14 @@ pub trait TensorOps<B: Backend> {
|
|||
fn min_dim<const D: usize>(tensor: B::TensorPrimitive<D>, dim: usize) -> B::TensorPrimitive<D> {
|
||||
let index = B::argmin(tensor.clone(), dim);
|
||||
|
||||
B::index_select(tensor, index)
|
||||
B::gather(D - 1, tensor, index)
|
||||
}
|
||||
fn min_dim_with_indexes<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
) -> (B::TensorPrimitive<D>, B::IntTensorPrimitive<D>) {
|
||||
let index = B::argmin(tensor.clone(), dim);
|
||||
let values = B::index_select(tensor, index.clone());
|
||||
let values = B::gather(D - 1, tensor, index.clone());
|
||||
|
||||
(values, index)
|
||||
}
|
||||
|
|
|
@ -34,8 +34,8 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_log!();
|
||||
burn_tensor::testgen_log1p!();
|
||||
burn_tensor::testgen_index!();
|
||||
burn_tensor::testgen_gather_scatter!();
|
||||
burn_tensor::testgen_index_select!();
|
||||
burn_tensor::testgen_index_select_dim!();
|
||||
burn_tensor::testgen_map_comparison!();
|
||||
burn_tensor::testgen_mask!();
|
||||
burn_tensor::testgen_matmul!();
|
||||
|
|
|
@ -0,0 +1,132 @@
|
|||
#[burn_tensor_testgen::testgen(gather_scatter)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_gather_1d_dim0() {
|
||||
let tensor = TestTensor::from_floats([0.0, 1.0, 2.0]);
|
||||
let indexes = TestTensorInt::from_ints([1, 1, 0, 1, 2]);
|
||||
|
||||
let output = tensor.gather(0, indexes);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_gather_2d_dim0() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[0, 1, 0], [1, 0, 1]]);
|
||||
|
||||
let output = tensor.gather(0, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[0.0, 4.0, 2.0], [3.0, 1.0, 5.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_gather_2d_dim1() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[2, 1, 0, 0], [2, 0, 1, 2]]);
|
||||
|
||||
let output = tensor.gather(1, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_gather_3d_dim1() {
|
||||
let tensor = TestTensor::from_floats([
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
]);
|
||||
let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
|
||||
|
||||
let output = tensor.gather(1, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([
|
||||
[[3.0, 1.0, 2.0], [0.0, 4.0, 2.0]],
|
||||
[[6.0, 7.0, 11.0], [6.0, 10.0, 11.0]]
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_gather_2d_only_1dim() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[1, 2]]).reshape([2, 1]);
|
||||
|
||||
let output = tensor.gather(1, indexes);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([[1.0], [5.0]]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_scatter_1d() {
|
||||
let tensor = TestTensor::from_floats([0.0, 0.0, 0.0]);
|
||||
let values = TestTensor::from_floats([5.0, 4.0, 3.0]);
|
||||
let indexes = TestTensorInt::from_ints([1, 0, 2]);
|
||||
|
||||
let output = tensor.scatter(0, indexes, values);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_scatter_2d_dim0() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
|
||||
let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[1, 0, 1], [1, 1, 0]]);
|
||||
|
||||
let output = tensor.scatter(0, indexes, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[0.0, 2.0, 6.0], [5.0, 5.0, 3.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_scatter_2d_dim1() {
|
||||
let tensor = TestTensor::from_floats([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]);
|
||||
let values = TestTensor::from_floats([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]);
|
||||
let indexes = TestTensorInt::from_ints([[1, 0, 2], [1, 2, 0]]);
|
||||
|
||||
let output = tensor.scatter(1, indexes, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_scatter_3d_dim1() {
|
||||
let tensor = TestTensor::from_floats([
|
||||
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]],
|
||||
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0]],
|
||||
]);
|
||||
let values = TestTensor::from_floats([
|
||||
[[12.0, 13.0, 14.0], [15.0, 16.0, 17.0]],
|
||||
[[18.0, 19.0, 20.0], [21.0, 22.0, 23.0]],
|
||||
]);
|
||||
let indexes = TestTensorInt::from_ints([[[1, 0, 0], [0, 1, 0]], [[0, 0, 1], [0, 1, 1]]]);
|
||||
|
||||
let output = tensor.scatter(1, indexes, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([
|
||||
[[15.0, 14.0, 33.0], [15.0, 20.0, 5.0]],
|
||||
[[45.0, 26.0, 8.0], [9.0, 32.0, 54.0]]
|
||||
])
|
||||
);
|
||||
}
|
||||
}
|
|
@ -8,56 +8,77 @@ mod tests {
|
|||
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select(indexes);
|
||||
let output = tensor.index_select(0, indexes);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_2d() {
|
||||
fn should_select_2d_dim0_same_num_dim() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([[2, 1, 0, 0], [2, 0, 1, 2]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
|
||||
|
||||
let output = tensor.index_select(indexes);
|
||||
let output = tensor.index_select(0, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]])
|
||||
Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_2d_only_1dim() {
|
||||
fn should_select_2d_dim0_more_num_dim() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([[1, 2]])).reshape([2, 1]);
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0, 1, 1]));
|
||||
|
||||
let output = tensor.index_select(indexes);
|
||||
let output = tensor.index_select(0, indexes);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([[1.0], [5.0]]));
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[3.0, 4.0, 5.0]
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_2d_dim1() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select(1, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_assign_1d() {
|
||||
let tensor = TestTensor::from_data(Data::from([0.0, 0.0, 0.0]));
|
||||
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0, 2]));
|
||||
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
|
||||
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0, 2.0, 1.0]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select_assign(indexes, values);
|
||||
let output = tensor.index_select_assign(0, indexes, values);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0]));
|
||||
assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_assign_2d() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]));
|
||||
fn should_select_assign_2d_dim0() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([[1, 0, 2], [1, 2, 0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
|
||||
|
||||
let output = tensor.index_select_assign(indexes, values);
|
||||
let output = tensor.index_select_assign(0, indexes, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]])
|
||||
Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,84 +0,0 @@
|
|||
#[burn_tensor_testgen::testgen(index_select_dim)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_select_1d() {
|
||||
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select_dim(0, indexes);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_2d_dim0_same_num_dim() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
|
||||
|
||||
let output = tensor.index_select_dim(0, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[3.0, 4.0, 5.0], [0.0, 1.0, 2.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_2d_dim0_more_num_dim() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0, 1, 1]));
|
||||
|
||||
let output = tensor.index_select_dim(0, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([
|
||||
[3.0, 4.0, 5.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[3.0, 4.0, 5.0],
|
||||
[3.0, 4.0, 5.0]
|
||||
])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_2d_dim1() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select_dim(1, indexes);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[1.0, 1.0, 0.0, 1.0, 2.0], [4.0, 4.0, 3.0, 4.0, 5.0]])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_assign_1d() {
|
||||
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
|
||||
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0, 2.0, 1.0]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||
|
||||
let output = tensor.index_select_dim_assign(0, indexes, values);
|
||||
|
||||
assert_eq!(output.into_data(), Data::from([3.0, 12.0, 3.0]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_select_assign_2d_dim0() {
|
||||
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||
let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]));
|
||||
let indexes = TestTensorInt::from_data(Data::from([1, 0]));
|
||||
|
||||
let output = tensor.index_select_dim_assign(0, indexes, values);
|
||||
|
||||
assert_eq!(
|
||||
output.into_data(),
|
||||
Data::from([[4.0, 6.0, 8.0], [4.0, 6.0, 8.0]])
|
||||
);
|
||||
}
|
||||
}
|
|
@ -6,9 +6,9 @@ mod div;
|
|||
mod erf;
|
||||
mod exp;
|
||||
mod flatten;
|
||||
mod gather_scatter;
|
||||
mod index;
|
||||
mod index_select;
|
||||
mod index_select_dim;
|
||||
mod log;
|
||||
mod log1p;
|
||||
mod map_comparison;
|
||||
|
|
Loading…
Reference in New Issue