mirror of https://github.com/tracel-ai/burn.git
Feat/index_select (#227)
This commit is contained in:
parent
9655b74b22
commit
d09ab44979
|
@ -195,6 +195,21 @@ impl<B: Backend> IntTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||||
B::int_lower_equal_elem(lhs, rhs)
|
B::int_lower_equal_elem(lhs, rhs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn int_index_select<const D: usize>(
|
||||||
|
tensor: IntTensor<B, D>,
|
||||||
|
indexes: IntTensor<B, D>,
|
||||||
|
) -> IntTensor<B, D> {
|
||||||
|
B::int_index_select(tensor, indexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_index_select_assign<const D: usize>(
|
||||||
|
tensor: IntTensor<B, D>,
|
||||||
|
indexes: IntTensor<B, D>,
|
||||||
|
value: IntTensor<B, D>,
|
||||||
|
) -> IntTensor<B, D> {
|
||||||
|
B::int_index_select_assign(tensor, indexes, value)
|
||||||
|
}
|
||||||
|
|
||||||
fn int_index_select_dim<const D: usize>(
|
fn int_index_select_dim<const D: usize>(
|
||||||
tensor: IntTensor<B, D>,
|
tensor: IntTensor<B, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
|
|
@ -420,6 +420,94 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn index_select<const D: usize>(
|
||||||
|
tensor: ADTensor<B, D>,
|
||||||
|
indexes: IntTensor<B, D>,
|
||||||
|
) -> ADTensor<B, D> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct IndexSelect;
|
||||||
|
|
||||||
|
impl<B: Backend, const D: usize> Backward<B, D, 1> for IndexSelect {
|
||||||
|
type State = (IntTensor<B, D>, Shape<D>, B::Device);
|
||||||
|
|
||||||
|
fn backward(self, ops: Ops<Self::State, 1>, grads: &mut Gradients) {
|
||||||
|
let (indexes, shape, device) = ops.state;
|
||||||
|
|
||||||
|
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||||
|
let zeros = B::zeros(shape, &device);
|
||||||
|
B::index_select_assign(zeros, indexes, grad)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match IndexSelect
|
||||||
|
.prepare([tensor.node], [tensor.graph])
|
||||||
|
.statefull()
|
||||||
|
{
|
||||||
|
OpsKind::Tracked(prep) => prep.finish(
|
||||||
|
(
|
||||||
|
indexes.clone(),
|
||||||
|
B::shape(&tensor.primitive),
|
||||||
|
B::device(&tensor.primitive),
|
||||||
|
),
|
||||||
|
B::index_select(tensor.primitive, indexes),
|
||||||
|
),
|
||||||
|
OpsKind::UnTracked(prep) => prep.finish(B::index_select(tensor.primitive, indexes)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select_assign<const D: usize>(
|
||||||
|
tensor: ADTensor<B, D>,
|
||||||
|
indexes: IntTensor<B, D>,
|
||||||
|
value: ADTensor<B, D>,
|
||||||
|
) -> ADTensor<B, D> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct IndexSelectAssign;
|
||||||
|
|
||||||
|
impl<B: Backend, const D: usize> Backward<B, D, 2> for IndexSelectAssign {
|
||||||
|
type State = (IntTensor<B, D>, Shape<D>, Shape<D>, B::Device);
|
||||||
|
|
||||||
|
fn backward(self, ops: Ops<Self::State, 2>, grads: &mut Gradients) {
|
||||||
|
let (indexes, shape_lhs, shape_rhs, device) = ops.state;
|
||||||
|
let [indexes_4lhs, indexes_4rhs] = duplicate(&ops.parents, Some(indexes));
|
||||||
|
|
||||||
|
binary::<B, D, D, D, _, _>(
|
||||||
|
ops.parents,
|
||||||
|
ops.node,
|
||||||
|
grads,
|
||||||
|
|grad| {
|
||||||
|
let zeros = B::zeros(shape_lhs, &device);
|
||||||
|
B::index_select_assign(grad, indexes_4lhs.unwrap(), zeros)
|
||||||
|
},
|
||||||
|
|grad| {
|
||||||
|
let zeros = B::zeros(shape_rhs, &device);
|
||||||
|
B::index_select_assign(zeros, indexes_4rhs.unwrap(), grad)
|
||||||
|
},
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match IndexSelectAssign
|
||||||
|
.prepare([tensor.node, value.node], [tensor.graph, value.graph])
|
||||||
|
.statefull()
|
||||||
|
{
|
||||||
|
OpsKind::Tracked(prep) => prep.finish(
|
||||||
|
(
|
||||||
|
indexes.clone(),
|
||||||
|
B::shape(&tensor.primitive),
|
||||||
|
B::shape(&value.primitive),
|
||||||
|
B::device(&value.primitive),
|
||||||
|
),
|
||||||
|
B::index_select_assign(tensor.primitive, indexes, value.primitive),
|
||||||
|
),
|
||||||
|
OpsKind::UnTracked(prep) => prep.finish(B::index_select_assign(
|
||||||
|
tensor.primitive,
|
||||||
|
indexes,
|
||||||
|
value.primitive,
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn index_select_dim<const D: usize>(
|
fn index_select_dim<const D: usize>(
|
||||||
tensor: ADTensor<B, D>,
|
tensor: ADTensor<B, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
#[burn_tensor_testgen::testgen(ad_index_select)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::Data;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_index_select_grad() {
|
||||||
|
let tensor_1 =
|
||||||
|
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||||
|
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0, 1, 2], [1, 0, 2, 1, 0]]));
|
||||||
|
|
||||||
|
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||||
|
let tensor_3 = tensor_1.clone().index_select(indexes);
|
||||||
|
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||||
|
|
||||||
|
let grads = tensor_4.backward();
|
||||||
|
|
||||||
|
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
grad_1.into_data(),
|
||||||
|
Data::from([[94., 150., 187.], [242., 305., 304.]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_index_select_assign_grad() {
|
||||||
|
let tensor_1 =
|
||||||
|
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||||
|
let values =
|
||||||
|
TestADTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])).require_grad();
|
||||||
|
let indexes = TestADTensor::from_data(Data::from([[2, 1, 0], [2, 0, 1]]));
|
||||||
|
|
||||||
|
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||||
|
let tensor_3 = tensor_1
|
||||||
|
.clone()
|
||||||
|
.index_select_assign(indexes, values.clone());
|
||||||
|
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||||
|
|
||||||
|
let grads = tensor_4.backward();
|
||||||
|
|
||||||
|
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||||
|
let grad_2 = values.grad(&grads).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
grad_1.into_data(),
|
||||||
|
Data::from([[127., 181., 235.], [226., 316., 406.]])
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
grad_2.into_data(),
|
||||||
|
Data::from([[19., 19., 19.], [64., 64., 64.]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -4,7 +4,27 @@ mod tests {
|
||||||
use burn_tensor::Data;
|
use burn_tensor::Data;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_select_grad() {
|
fn test_index_select_dim_grad() {
|
||||||
|
let tensor_1 =
|
||||||
|
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||||
|
let indexes = TestADTensor::from_data(Data::from([1, 0]));
|
||||||
|
|
||||||
|
let tensor_2 = tensor_1.clone().matmul(tensor_1.clone().transpose());
|
||||||
|
let tensor_3 = tensor_1.clone().index_select_dim(0, indexes);
|
||||||
|
let tensor_4 = tensor_2.matmul(tensor_3);
|
||||||
|
|
||||||
|
let grads = tensor_4.backward();
|
||||||
|
|
||||||
|
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
grad_1.into_data(),
|
||||||
|
Data::from([[109., 148., 187.], [37., 58., 79.]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_index_select_dim_assign_grad() {
|
||||||
let tensor_1 =
|
let tensor_1 =
|
||||||
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
TestADTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]])).require_grad();
|
||||||
let values =
|
let values =
|
||||||
|
|
|
@ -11,6 +11,7 @@ mod div;
|
||||||
mod erf;
|
mod erf;
|
||||||
mod exp;
|
mod exp;
|
||||||
mod index;
|
mod index;
|
||||||
|
mod index_select;
|
||||||
mod index_select_dim;
|
mod index_select_dim;
|
||||||
mod log;
|
mod log;
|
||||||
mod log1p;
|
mod log1p;
|
||||||
|
@ -53,8 +54,9 @@ macro_rules! testgen_all {
|
||||||
burn_autodiff::testgen_ad_div!();
|
burn_autodiff::testgen_ad_div!();
|
||||||
burn_autodiff::testgen_ad_erf!();
|
burn_autodiff::testgen_ad_erf!();
|
||||||
burn_autodiff::testgen_ad_exp!();
|
burn_autodiff::testgen_ad_exp!();
|
||||||
burn_autodiff::testgen_ad_index_select_dim!();
|
|
||||||
burn_autodiff::testgen_ad_index!();
|
burn_autodiff::testgen_ad_index!();
|
||||||
|
burn_autodiff::testgen_ad_index_select!();
|
||||||
|
burn_autodiff::testgen_ad_index_select_dim!();
|
||||||
burn_autodiff::testgen_ad_log!();
|
burn_autodiff::testgen_ad_log!();
|
||||||
burn_autodiff::testgen_ad_log1p!();
|
burn_autodiff::testgen_ad_log1p!();
|
||||||
burn_autodiff::testgen_ad_mask!();
|
burn_autodiff::testgen_ad_mask!();
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
use burn_tensor::Data;
|
use burn_tensor::Data;
|
||||||
use core::{marker::PhantomData, ops::Range};
|
use core::{marker::PhantomData, ops::Range};
|
||||||
|
use ndarray::s;
|
||||||
|
use ndarray::Array2;
|
||||||
|
|
||||||
use burn_tensor::Shape;
|
use burn_tensor::Shape;
|
||||||
use ndarray::Axis;
|
use ndarray::Axis;
|
||||||
|
@ -204,6 +206,85 @@ where
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn index_select<const D: usize>(
|
||||||
|
tensor: NdArrayTensor<E, D>,
|
||||||
|
indexes: NdArrayTensor<i64, D>,
|
||||||
|
) -> NdArrayTensor<E, D> {
|
||||||
|
let (shape_tensor, shape_indexes) = (tensor.shape(), indexes.shape());
|
||||||
|
let (size_tensor, size_index) = (shape_tensor.dims[D - 1], shape_indexes.dims[D - 1]);
|
||||||
|
let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes);
|
||||||
|
|
||||||
|
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
|
||||||
|
let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
|
||||||
|
let mut output = Array2::zeros((batch_size, size_index));
|
||||||
|
|
||||||
|
for b in 0..batch_size {
|
||||||
|
let indexes = indexes.slice(s!(b, ..));
|
||||||
|
|
||||||
|
for (i, index) in indexes.iter().enumerate() {
|
||||||
|
output[[b, i]] = tensor[[b, *index as usize]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NdArrayOps::reshape(
|
||||||
|
NdArrayTensor::<E, 2>::new(output.into_shared().into_dyn()),
|
||||||
|
shape_indexes,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn index_select_assign<const D: usize>(
|
||||||
|
tensor: NdArrayTensor<E, D>,
|
||||||
|
indexes: NdArrayTensor<i64, D>,
|
||||||
|
value: NdArrayTensor<E, D>,
|
||||||
|
) -> NdArrayTensor<E, D> {
|
||||||
|
let (shape_tensor, shape_indexes, shape_value) =
|
||||||
|
(tensor.shape(), indexes.shape(), value.shape());
|
||||||
|
let (size_tensor, size_index, size_value) = (
|
||||||
|
shape_tensor.dims[D - 1],
|
||||||
|
shape_indexes.dims[D - 1],
|
||||||
|
shape_value.dims[D - 1],
|
||||||
|
);
|
||||||
|
let batch_size = Self::index_select_batch_size(&shape_tensor, &shape_indexes);
|
||||||
|
|
||||||
|
if shape_value != shape_indexes {
|
||||||
|
panic!("Invalid dimension: the shape of the index tensor should be the same as the value tensor: Index {:?} value {:?}", shape_indexes.dims, shape_value.dims);
|
||||||
|
}
|
||||||
|
|
||||||
|
let indexes = NdArrayOps::reshape(indexes, Shape::new([batch_size, size_index])).array;
|
||||||
|
let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array;
|
||||||
|
let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
|
||||||
|
|
||||||
|
for b in 0..batch_size {
|
||||||
|
let indexes = indexes.slice(s!(b, ..));
|
||||||
|
|
||||||
|
for (i, index) in indexes.iter().enumerate() {
|
||||||
|
let index = *index as usize;
|
||||||
|
tensor[[b, index]] = tensor[[b, index]] + value[[b, i]];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
NdArrayOps::reshape(
|
||||||
|
NdArrayTensor::<E, 2>::new(tensor.into_shared().into_dyn()),
|
||||||
|
shape_tensor,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select_batch_size<const D: usize>(
|
||||||
|
shape_tensor: &Shape<D>,
|
||||||
|
shape_indexes: &Shape<D>,
|
||||||
|
) -> usize {
|
||||||
|
let mut batch_size = 1;
|
||||||
|
|
||||||
|
for i in 0..D - 1 {
|
||||||
|
if shape_tensor.dims[i] != shape_indexes.dims[i] {
|
||||||
|
panic!("Unsupported dimension, only the last dimension can differ: Tensor {:?} Index {:?}", shape_tensor.dims, shape_indexes.dims);
|
||||||
|
}
|
||||||
|
batch_size *= shape_indexes.dims[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
batch_size
|
||||||
|
}
|
||||||
|
|
||||||
pub fn index_select_dim<const D: usize>(
|
pub fn index_select_dim<const D: usize>(
|
||||||
tensor: NdArrayTensor<E, D>,
|
tensor: NdArrayTensor<E, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
|
|
@ -264,6 +264,21 @@ impl<E: FloatNdArrayElement> IntTensorOps<NdArrayBackend<E>> for NdArrayBackend<
|
||||||
NdArrayMathOps::mean_dim(tensor, dim)
|
NdArrayMathOps::mean_dim(tensor, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn int_index_select<const D: usize>(
|
||||||
|
tensor: NdArrayTensor<i64, D>,
|
||||||
|
indexes: NdArrayTensor<i64, D>,
|
||||||
|
) -> NdArrayTensor<i64, D> {
|
||||||
|
NdArrayMathOps::index_select(tensor, indexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_index_select_assign<const D: usize>(
|
||||||
|
tensor: NdArrayTensor<i64, D>,
|
||||||
|
indexes: NdArrayTensor<i64, D>,
|
||||||
|
value: NdArrayTensor<i64, D>,
|
||||||
|
) -> NdArrayTensor<i64, D> {
|
||||||
|
NdArrayMathOps::index_select_assign(tensor, indexes, value)
|
||||||
|
}
|
||||||
|
|
||||||
fn int_index_select_dim<const D: usize>(
|
fn int_index_select_dim<const D: usize>(
|
||||||
tensor: NdArrayTensor<i64, D>,
|
tensor: NdArrayTensor<i64, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
|
|
@ -151,6 +151,21 @@ impl<E: FloatNdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E>
|
||||||
NdArrayOps::reshape(tensor, shape)
|
NdArrayOps::reshape(tensor, shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn index_select<const D: usize>(
|
||||||
|
tensor: NdArrayTensor<E, D>,
|
||||||
|
indexes: NdArrayTensor<i64, D>,
|
||||||
|
) -> NdArrayTensor<E, D> {
|
||||||
|
NdArrayMathOps::index_select(tensor, indexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select_assign<const D: usize>(
|
||||||
|
tensor: NdArrayTensor<E, D>,
|
||||||
|
indexes: NdArrayTensor<i64, D>,
|
||||||
|
value: NdArrayTensor<E, D>,
|
||||||
|
) -> NdArrayTensor<E, D> {
|
||||||
|
NdArrayMathOps::index_select_assign(tensor, indexes, value)
|
||||||
|
}
|
||||||
|
|
||||||
fn index_select_dim<const D: usize>(
|
fn index_select_dim<const D: usize>(
|
||||||
tensor: NdArrayTensor<E, D>,
|
tensor: NdArrayTensor<E, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
|
|
@ -58,7 +58,12 @@ macro_rules! reshape {
|
||||||
) => {{
|
) => {{
|
||||||
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
|
let dim = $crate::to_typed_dims!($n, $shape.dims, justdim);
|
||||||
let safe_into_shape =
|
let safe_into_shape =
|
||||||
$array.is_standard_layout() || $array.raw_view().reversed_axes().is_standard_layout();
|
$array.is_standard_layout() ||
|
||||||
|
(
|
||||||
|
$array.ndim() > 1 &&
|
||||||
|
$array.raw_view().reversed_axes().is_standard_layout()
|
||||||
|
);
|
||||||
|
|
||||||
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape {
|
let array: ndarray::ArcArray<$ty, Dim<[usize; $n]>> = match safe_into_shape {
|
||||||
true => $array
|
true => $array
|
||||||
.into_shape(dim)
|
.into_shape(dim)
|
||||||
|
|
|
@ -45,6 +45,25 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
||||||
TchTensor::new(tensor_original)
|
TchTensor::new(tensor_original)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn index_select<const D: usize>(
|
||||||
|
tensor: TchTensor<E, D>,
|
||||||
|
indexes: TchTensor<i64, D>,
|
||||||
|
) -> TchTensor<E, D> {
|
||||||
|
let tensor = tensor.tensor.gather((D - 1) as i64, &indexes.tensor, false);
|
||||||
|
TchTensor::new(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn index_select_assign<const D: usize>(
|
||||||
|
tensor: TchTensor<E, D>,
|
||||||
|
indexes: TchTensor<i64, D>,
|
||||||
|
value: TchTensor<E, D>,
|
||||||
|
) -> TchTensor<E, D> {
|
||||||
|
let tensor = tensor
|
||||||
|
.tensor
|
||||||
|
.scatter_add((D - 1) as i64, &indexes.tensor, &value.tensor);
|
||||||
|
TchTensor::new(tensor)
|
||||||
|
}
|
||||||
|
|
||||||
pub fn index_select_dim<const D: usize>(
|
pub fn index_select_dim<const D: usize>(
|
||||||
tensor: TchTensor<E, D>,
|
tensor: TchTensor<E, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
|
|
@ -243,6 +243,20 @@ impl<E: TchElement> IntTensorOps<TchBackend<E>> for TchBackend<E> {
|
||||||
fn int_mean_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
|
fn int_mean_dim<const D: usize>(tensor: TchTensor<i64, D>, dim: usize) -> TchTensor<i64, D> {
|
||||||
TchOps::mean_dim(tensor, dim)
|
TchOps::mean_dim(tensor, dim)
|
||||||
}
|
}
|
||||||
|
fn int_index_select<const D: usize>(
|
||||||
|
tensor: TchTensor<i64, D>,
|
||||||
|
indexes: TchTensor<i64, D>,
|
||||||
|
) -> TchTensor<i64, D> {
|
||||||
|
TchOps::index_select(tensor, indexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn int_index_select_assign<const D: usize>(
|
||||||
|
tensor: TchTensor<i64, D>,
|
||||||
|
indexes: TchTensor<i64, D>,
|
||||||
|
value: TchTensor<i64, D>,
|
||||||
|
) -> TchTensor<i64, D> {
|
||||||
|
TchOps::index_select_assign(tensor, indexes, value)
|
||||||
|
}
|
||||||
|
|
||||||
fn int_index_select_dim<const D: usize>(
|
fn int_index_select_dim<const D: usize>(
|
||||||
tensor: TchTensor<i64, D>,
|
tensor: TchTensor<i64, D>,
|
||||||
|
|
|
@ -190,6 +190,21 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||||
TchTensor::new(tensor)
|
TchTensor::new(tensor)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn index_select<const D: usize>(
|
||||||
|
tensor: TchTensor<E, D>,
|
||||||
|
indexes: TchTensor<i64, D>,
|
||||||
|
) -> TchTensor<E, D> {
|
||||||
|
TchOps::index_select(tensor, indexes)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select_assign<const D: usize>(
|
||||||
|
tensor: TchTensor<E, D>,
|
||||||
|
indexes: TchTensor<i64, D>,
|
||||||
|
value: TchTensor<E, D>,
|
||||||
|
) -> TchTensor<E, D> {
|
||||||
|
TchOps::index_select_assign(tensor, indexes, value)
|
||||||
|
}
|
||||||
|
|
||||||
fn index_select_dim<const D: usize>(
|
fn index_select_dim<const D: usize>(
|
||||||
tensor: TchTensor<E, D>,
|
tensor: TchTensor<E, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
|
|
@ -342,27 +342,6 @@ where
|
||||||
self.reshape(shape)
|
self.reshape(shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Index the tensor along the given dimension using the given indexes.
|
|
||||||
pub fn index_select_dim(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
|
|
||||||
Self::new(B::index_select_dim(self.primitive, dim, indexes.primitive))
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Return a new tensor with the same dimension, but with the values added to
|
|
||||||
/// the original tensor using the corresponding indexes provided along the given dimension.
|
|
||||||
pub fn index_select_dim_assign<const D2: usize>(
|
|
||||||
self,
|
|
||||||
dim: usize,
|
|
||||||
indexes: Tensor<B, 1, Int>,
|
|
||||||
values: Tensor<B, D2>,
|
|
||||||
) -> Self {
|
|
||||||
Self::new(B::index_select_dim_assign(
|
|
||||||
self.primitive,
|
|
||||||
dim,
|
|
||||||
indexes.primitive,
|
|
||||||
values.primitive,
|
|
||||||
))
|
|
||||||
}
|
|
||||||
|
|
||||||
pub(crate) fn relu(self) -> Self {
|
pub(crate) fn relu(self) -> Self {
|
||||||
Self::new(B::relu(self.primitive))
|
Self::new(B::relu(self.primitive))
|
||||||
}
|
}
|
||||||
|
|
|
@ -170,6 +170,52 @@ where
|
||||||
pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
|
pub fn lower_equal_elem<E: ElementConversion>(self, other: E) -> Tensor<B, D, Bool> {
|
||||||
K::lower_equal_elem(self.primitive, other.elem())
|
K::lower_equal_elem(self.primitive, other.elem())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Select the tensor elements corresponding to the given indexes.
|
||||||
|
///
|
||||||
|
/// # Notes
|
||||||
|
///
|
||||||
|
/// The index tensor shoud have the same shape as the original tensor except for the last
|
||||||
|
/// dimension.
|
||||||
|
pub fn index_select(self, indexes: Tensor<B, D, Int>) -> Self {
|
||||||
|
Self::new(K::index_select(self.primitive, indexes))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Assign the selected elements corresponding to the given indexes from the value tensor
|
||||||
|
/// to the original tensor using sum reduction.
|
||||||
|
///
|
||||||
|
/// # Notes
|
||||||
|
///
|
||||||
|
/// The index tensor shoud have the same shape as the original tensor except for the last
|
||||||
|
/// dimension. The value and index tensors should have the same shape.
|
||||||
|
pub fn index_select_assign(self, indexes: Tensor<B, D, Int>, values: Self) -> Self {
|
||||||
|
Self::new(K::index_select_assign(
|
||||||
|
self.primitive,
|
||||||
|
indexes,
|
||||||
|
values.primitive,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Select the tensor elements along the given dimension corresponding to the given indexes.
|
||||||
|
pub fn index_select_dim(self, dim: usize, indexes: Tensor<B, 1, Int>) -> Self {
|
||||||
|
Self::new(K::index_select_dim(self.primitive, dim, indexes))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Assign the selected elements along the given dimension corresponding to the given indexes
|
||||||
|
/// from the value tensor to the original tensor using sum reduction.
|
||||||
|
pub fn index_select_dim_assign<const D2: usize>(
|
||||||
|
self,
|
||||||
|
dim: usize,
|
||||||
|
indexes: Tensor<B, 1, Int>,
|
||||||
|
values: Tensor<B, D2, K>,
|
||||||
|
) -> Self {
|
||||||
|
Self::new(K::index_select_dim_assign(
|
||||||
|
self.primitive,
|
||||||
|
dim,
|
||||||
|
indexes,
|
||||||
|
values.primitive,
|
||||||
|
))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trait that list all operations that can be applied on all numerical tensors.
|
/// Trait that list all operations that can be applied on all numerical tensors.
|
||||||
|
@ -234,6 +280,15 @@ pub trait Numeric<B: Backend>: TensorKind<B> {
|
||||||
lhs: Self::Primitive<D>,
|
lhs: Self::Primitive<D>,
|
||||||
rhs: Self::Elem,
|
rhs: Self::Elem,
|
||||||
) -> Tensor<B, D, Bool>;
|
) -> Tensor<B, D, Bool>;
|
||||||
|
fn index_select<const D: usize>(
|
||||||
|
tensor: Self::Primitive<D>,
|
||||||
|
indexes: Tensor<B, D, Int>,
|
||||||
|
) -> Self::Primitive<D>;
|
||||||
|
fn index_select_assign<const D: usize>(
|
||||||
|
tensor: Self::Primitive<D>,
|
||||||
|
indexes: Tensor<B, D, Int>,
|
||||||
|
values: Self::Primitive<D>,
|
||||||
|
) -> Self::Primitive<D>;
|
||||||
fn index_select_dim<const D: usize>(
|
fn index_select_dim<const D: usize>(
|
||||||
tensor: Self::Primitive<D>,
|
tensor: Self::Primitive<D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
@ -389,6 +444,20 @@ impl<B: Backend> Numeric<B> for Int {
|
||||||
) -> Self::Primitive<D1> {
|
) -> Self::Primitive<D1> {
|
||||||
B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values)
|
B::int_index_select_dim_assign(tensor, dim, indexes.primitive, values)
|
||||||
}
|
}
|
||||||
|
fn index_select<const D: usize>(
|
||||||
|
tensor: Self::Primitive<D>,
|
||||||
|
indexes: Tensor<B, D, Int>,
|
||||||
|
) -> Self::Primitive<D> {
|
||||||
|
B::int_index_select(tensor, indexes.primitive)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select_assign<const D: usize>(
|
||||||
|
tensor: Self::Primitive<D>,
|
||||||
|
indexes: Tensor<B, D, Int>,
|
||||||
|
values: Self::Primitive<D>,
|
||||||
|
) -> Self::Primitive<D> {
|
||||||
|
B::int_index_select_assign(tensor, indexes.primitive, values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B: Backend> Numeric<B> for Float {
|
impl<B: Backend> Numeric<B> for Float {
|
||||||
|
@ -533,6 +602,21 @@ impl<B: Backend> Numeric<B> for Float {
|
||||||
) -> Self::Primitive<D1> {
|
) -> Self::Primitive<D1> {
|
||||||
B::index_select_dim_assign(tensor, dim, indexes.primitive, values)
|
B::index_select_dim_assign(tensor, dim, indexes.primitive, values)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn index_select<const D: usize>(
|
||||||
|
tensor: Self::Primitive<D>,
|
||||||
|
indexes: Tensor<B, D, Int>,
|
||||||
|
) -> Self::Primitive<D> {
|
||||||
|
B::index_select(tensor, indexes.primitive)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_select_assign<const D: usize>(
|
||||||
|
tensor: Self::Primitive<D>,
|
||||||
|
indexes: Tensor<B, D, Int>,
|
||||||
|
values: Self::Primitive<D>,
|
||||||
|
) -> Self::Primitive<D> {
|
||||||
|
B::index_select_assign(tensor, indexes.primitive, values)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>
|
impl<B, const D: usize, K> core::ops::Add<Self> for Tensor<B, D, K>
|
||||||
|
|
|
@ -34,6 +34,15 @@ pub trait IntTensorOps<B: Backend> {
|
||||||
indexes: [Range<usize>; D2],
|
indexes: [Range<usize>; D2],
|
||||||
value: B::IntTensorPrimitive<D1>,
|
value: B::IntTensorPrimitive<D1>,
|
||||||
) -> B::IntTensorPrimitive<D1>;
|
) -> B::IntTensorPrimitive<D1>;
|
||||||
|
fn int_index_select<const D: usize>(
|
||||||
|
tensor: B::IntTensorPrimitive<D>,
|
||||||
|
indexes: B::IntTensorPrimitive<D>,
|
||||||
|
) -> B::IntTensorPrimitive<D>;
|
||||||
|
fn int_index_select_assign<const D: usize>(
|
||||||
|
tensor: B::IntTensorPrimitive<D>,
|
||||||
|
indexes: B::IntTensorPrimitive<D>,
|
||||||
|
value: B::IntTensorPrimitive<D>,
|
||||||
|
) -> B::IntTensorPrimitive<D>;
|
||||||
fn int_index_select_dim<const D: usize>(
|
fn int_index_select_dim<const D: usize>(
|
||||||
tensor: B::IntTensorPrimitive<D>,
|
tensor: B::IntTensorPrimitive<D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
|
|
@ -115,6 +115,15 @@ pub trait TensorOps<B: Backend> {
|
||||||
tensor: B::TensorPrimitive<D1>,
|
tensor: B::TensorPrimitive<D1>,
|
||||||
shape: Shape<D2>,
|
shape: Shape<D2>,
|
||||||
) -> B::TensorPrimitive<D2>;
|
) -> B::TensorPrimitive<D2>;
|
||||||
|
fn index_select<const D: usize>(
|
||||||
|
tensor: B::TensorPrimitive<D>,
|
||||||
|
indexes: B::IntTensorPrimitive<D>,
|
||||||
|
) -> B::TensorPrimitive<D>;
|
||||||
|
fn index_select_assign<const D: usize>(
|
||||||
|
tensor: B::TensorPrimitive<D>,
|
||||||
|
indexes: B::IntTensorPrimitive<D>,
|
||||||
|
value: B::TensorPrimitive<D>,
|
||||||
|
) -> B::TensorPrimitive<D>;
|
||||||
fn index_select_dim<const D: usize>(
|
fn index_select_dim<const D: usize>(
|
||||||
tensor: B::TensorPrimitive<D>,
|
tensor: B::TensorPrimitive<D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
|
|
|
@ -28,8 +28,9 @@ macro_rules! testgen_all {
|
||||||
burn_tensor::testgen_exp!();
|
burn_tensor::testgen_exp!();
|
||||||
burn_tensor::testgen_log!();
|
burn_tensor::testgen_log!();
|
||||||
burn_tensor::testgen_log1p!();
|
burn_tensor::testgen_log1p!();
|
||||||
burn_tensor::testgen_index_select_dim!();
|
|
||||||
burn_tensor::testgen_index!();
|
burn_tensor::testgen_index!();
|
||||||
|
burn_tensor::testgen_index_select!();
|
||||||
|
burn_tensor::testgen_index_select_dim!();
|
||||||
burn_tensor::testgen_map_comparison!();
|
burn_tensor::testgen_map_comparison!();
|
||||||
burn_tensor::testgen_mask!();
|
burn_tensor::testgen_mask!();
|
||||||
burn_tensor::testgen_matmul!();
|
burn_tensor::testgen_matmul!();
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
#[burn_tensor_testgen::testgen(index_select)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{Data, Tensor};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_select_1d() {
|
||||||
|
let tensor = TestTensor::from_data(Data::from([0.0, 1.0, 2.0]));
|
||||||
|
let indexes = TestTensorInt::from_data(Data::from([1, 1, 0, 1, 2]));
|
||||||
|
|
||||||
|
let output = tensor.index_select(indexes);
|
||||||
|
|
||||||
|
assert_eq!(output.into_data(), Data::from([1.0, 1.0, 0.0, 1.0, 2.0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_select_2d() {
|
||||||
|
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||||
|
let indexes = TestTensorInt::from_data(Data::from([[2, 1, 0, 0], [2, 0, 1, 2]]));
|
||||||
|
|
||||||
|
let output = tensor.index_select(indexes);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
output.into_data(),
|
||||||
|
Data::from([[2.0, 1.0, 0.0, 0.0], [5.0, 3.0, 4.0, 5.0]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_select_2d_only_1dim() {
|
||||||
|
let tensor = TestTensor::from_data(Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]));
|
||||||
|
let indexes = TestTensorInt::from_data(Data::from([[1, 2]])).reshape([2, 1]);
|
||||||
|
|
||||||
|
let output = tensor.index_select(indexes);
|
||||||
|
|
||||||
|
assert_eq!(output.into_data(), Data::from([[1.0], [5.0]]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_select_assign_1d() {
|
||||||
|
let tensor = TestTensor::from_data(Data::from([0.0, 0.0, 0.0]));
|
||||||
|
let values = TestTensor::from_data(Data::from([5.0, 4.0, 3.0]));
|
||||||
|
let indexes = TestTensorInt::from_data(Data::from([1, 0, 2]));
|
||||||
|
|
||||||
|
let output = tensor.index_select_assign(indexes, values);
|
||||||
|
|
||||||
|
assert_eq!(output.into_data(), Data::from([4.0, 5.0, 3.0]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_select_assign_2d() {
|
||||||
|
let tensor = TestTensor::from_data(Data::from([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]));
|
||||||
|
let values = TestTensor::from_data(Data::from([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]));
|
||||||
|
let indexes = TestTensorInt::from_data(Data::from([[1, 0, 2], [1, 2, 0]]));
|
||||||
|
|
||||||
|
let output = tensor.index_select_assign(indexes, values);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
output.into_data(),
|
||||||
|
Data::from([[2.0, 1.0, 3.0], [6.0, 4.0, 5.0]])
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -6,6 +6,7 @@ mod div;
|
||||||
mod erf;
|
mod erf;
|
||||||
mod exp;
|
mod exp;
|
||||||
mod index;
|
mod index;
|
||||||
|
mod index_select;
|
||||||
mod index_select_dim;
|
mod index_select_dim;
|
||||||
mod log;
|
mod log;
|
||||||
mod log1p;
|
mod log1p;
|
||||||
|
|
Loading…
Reference in New Issue