Add `flip` tensor operator (#1468)

This commit is contained in:
carrotflakes 2024-03-19 10:33:39 +09:00 committed by GitHub
parent 6e58663cc1
commit 8911093b88
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 758 additions and 34 deletions

View File

@ -149,6 +149,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |

View File

@ -117,6 +117,10 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_permute(tensor, axes)
}
fn bool_flip<const D: usize>(tensor: BoolTensor<B, D>, axes: &[usize]) -> BoolTensor<B, D> {
B::bool_flip(tensor, axes)
}
fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
B::bool_argwhere(tensor)
}

View File

@ -353,6 +353,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_permute(tensor, axes)
}
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
B::int_flip(tensor, axes)
}
fn int_sign<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
B::int_sign(tensor)
}

View File

@ -738,6 +738,62 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
}
}
fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],
) -> FloatTensor<Self, D> {
#[derive(Debug)]
struct FlipDim;
#[derive(new, Debug)]
struct RetroFlipDims<B: Backend, const D: usize> {
input_id: NodeID,
axes: Vec<usize>,
_backend: PhantomData<B>,
}
impl<B: Backend, const D: usize> RetroForward for RetroFlipDims<B, D> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let input = states.get_state::<B::FloatTensorPrimitive<D>>(&self.input_id);
let out = B::float_flip(input, &self.axes);
states.save(out_node, out)
}
}
impl<B: Backend, const D: usize> Backward<B, D, 1> for FlipDim {
type State = Vec<usize>;
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let axes = ops.state;
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
B::float_flip(grad, &axes)
});
}
}
match FlipDim
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
.memory_bound()
.retro_forward(RetroFlipDims::<B, D>::new(
tensor.node.id.clone(),
axes.to_vec(),
))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(prep) => {
prep.finish(axes.to_vec(), B::float_flip(tensor.primitive, axes))
}
OpsKind::UnTracked(prep) => prep.finish(B::float_flip(tensor.primitive, axes)),
}
}
fn float_reshape<const D1: usize, const D2: usize>(
tensor: FloatTensor<Self, D1>,
shape: Shape<D2>,

View File

@ -0,0 +1,28 @@
#[burn_tensor_testgen::testgen(ad_flip)]
mod tests {
use super::*;
use burn_tensor::Data;
#[test]
fn should_diff_flip() {
let data_1: Data<f32, 3> = Data::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
let data_2: Data<f32, 3> = Data::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3
let device = Default::default();
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().flip([1, 2]);
let tensor_4 = tensor_1.clone().matmul(tensor_3);
let grads = tensor_4.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let grad_2 = tensor_2.grad(&grads).unwrap();
assert_eq!(grad_1.to_data(), Data::from([[[7.2, 12.0], [7.2, 12.0]]])); // 1x2x2
assert_eq!(
grad_2.to_data(),
Data::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]) // 1x2x3
);
}
}

View File

@ -21,6 +21,7 @@ mod cross_entropy;
mod div;
mod erf;
mod exp;
mod flip;
mod gather_scatter;
mod gelu;
mod gradients;
@ -114,6 +115,7 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_sigmoid!();
burn_autodiff::testgen_ad_transpose!();
burn_autodiff::testgen_ad_permute!();
burn_autodiff::testgen_ad_flip!();
burn_autodiff::testgen_ad_nonzero!();
burn_autodiff::testgen_ad_sign!();
};

View File

@ -81,6 +81,7 @@ mod tests {
burn_tensor::testgen_mul!();
burn_tensor::testgen_neg!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_flip!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();

View File

@ -23,7 +23,6 @@ pub fn from_data<E: CandleElement, const D: usize>(
) -> CandleTensor<E, D> {
CandleTensor::from_data(data, *device)
}
pub fn into_data<E: CandleElement, const D: usize>(tensor: CandleTensor<E, D>) -> Data<E, D> {
Data::new(
tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(),
@ -60,6 +59,26 @@ pub fn permute<E: CandleElement, const D: usize>(
CandleTensor::new(tensor.tensor.permute(axes).unwrap())
}
pub fn flip<E: CandleElement, const D: usize>(
tensor: CandleTensor<E, D>,
axes: &[usize],
) -> CandleTensor<E, D> {
// FIXME: Replace with an appropriate method when Candle provides one.
let mut tensor = tensor.tensor;
for &axis in axes {
let indexes = candle_core::Tensor::arange_step(
tensor.dim(axis).unwrap() as i64 - 1,
-1,
-1,
tensor.device(),
)
.unwrap();
tensor = tensor.index_select(&indexes, axis).unwrap();
}
CandleTensor::new(tensor)
}
pub fn reshape<E: CandleElement, const D1: usize, const D2: usize>(
tensor: CandleTensor<E, D1>,
shape: Shape<D2>,

View File

@ -8,8 +8,6 @@ use crate::{
Candle, CandleTensor,
};
use super::base::permute;
impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
super::base::empty(shape, device)
@ -133,6 +131,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
tensor: BoolTensor<Self, D>,
axes: [usize; D],
) -> BoolTensor<Self, D> {
permute(tensor, axes)
super::base::permute(tensor, axes)
}
fn bool_flip<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: &[usize],
) -> BoolTensor<Self, D> {
super::base::flip(tensor, axes)
}
}

View File

@ -8,8 +8,6 @@ use crate::{
Candle, CandleTensor,
};
use super::base::permute;
impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
super::base::empty(shape, device)
@ -425,7 +423,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
tensor: IntTensor<Self, D>,
axes: [usize; D],
) -> IntTensor<Self, D> {
permute(tensor, axes)
super::base::permute(tensor, axes)
}
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
super::base::flip(tensor, axes)
}
// TODO add sign operator once Candle supports it:

View File

@ -11,8 +11,6 @@ use crate::{
Candle, CandleTensor,
};
use super::base::permute;
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
fn float_from_data<const D: usize>(
data: Data<F, D>,
@ -522,7 +520,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
tensor: FloatTensor<Self, D>,
axes: [usize; D],
) -> FloatTensor<Self, D> {
permute(tensor, axes)
super::base::permute(tensor, axes)
}
fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],
) -> FloatTensor<Self, D> {
super::base::flip(tensor, axes)
}
// TODO add sign operator once Candle supports it:

View File

@ -4,9 +4,9 @@ use crate::{
ops::binary::binary_ops_shape,
stream::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId,
SwapDimsDescription, UnaryOperationDescription,
CatOperationDescription, FlipOperationDescription, Operation, OperationDescription,
PermuteOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, StreamId, SwapDimsDescription, UnaryOperationDescription,
},
Fusion, FusionBackend,
};
@ -466,4 +466,39 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
out
}
fn bool_flip<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: &[usize],
) -> BoolTensor<Self, D> {
#[derive(new)]
struct FlipOps<const D: usize> {
desc: FlipOperationDescription,
}
impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_bool_tensor::<D>(&self.desc.input);
let output = B::bool_flip(input, self.desc.axes.as_slice());
handles.register_bool_tensor(&self.desc.out.id, output);
}
}
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = FlipOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
axes: axes.to_vec(),
};
out.client.register(
vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::Flip(desc.clone())),
FlipOps::<D>::new(desc),
);
out
}
}

View File

@ -6,13 +6,13 @@ use crate::{
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
stream::{
BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, FloatOperationDescription, GatherOperationDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
UnaryOperationDescription,
ClampOperationDescription, FlipOperationDescription, FloatOperationDescription,
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
StreamId, SwapDimsDescription, UnaryOperationDescription,
},
unary_float_ops, Fusion, FusionBackend, TensorDescription,
};
@ -1846,4 +1846,39 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out
}
fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],
) -> FloatTensor<Self, D> {
#[derive(new)]
struct FlipOps<const D: usize> {
desc: FlipOperationDescription,
}
impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_float_tensor::<D>(&self.desc.input);
let output = B::float_flip(input, &self.desc.axes);
handles.register_float_tensor(&self.desc.out.id, output);
}
}
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = FlipOperationDescription {
input: tensor.into_description(),
axes: axes.to_vec(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())),
FlipOps::<D>::new(desc),
);
out
}
}

View File

@ -6,9 +6,9 @@ use crate::{
scalar_int_cmp_ops, scalar_int_ops,
stream::{
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
ClampOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
MaskWhereOperationDescription, NumericOperationDescription, Operation,
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ClampOperationDescription, FlipOperationDescription, GatherOperationDescription,
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
@ -1552,4 +1552,38 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
#[derive(new)]
struct FlipDimsOps<const D: usize> {
desc: FlipOperationDescription,
}
impl<const D: usize, B: FusionBackend> Operation<B> for FlipDimsOps<D> {
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
let input = handles.get_int_tensor::<D>(&self.desc.input);
let axes = &self.desc.axes;
let output = B::int_flip(input, axes);
handles.register_int_tensor(&self.desc.out.id, output);
}
}
let stream = tensor.stream;
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
let desc = FlipOperationDescription {
input: tensor.into_description(),
axes: axes.to_vec(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())),
FlipDimsOps::<D>::new(desc),
);
out
}
}

View File

@ -4,16 +4,16 @@ use super::{
AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOperationDescription,
BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription,
Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription,
EmbeddingBackwardDescription, EmbeddingDescription, FloatOperationDescription,
GatherOperationDescription, IntOperationDescription, InterpolateBackwardDescription,
InterpolateDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription,
MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
ModuleOperationDescription, NumericOperationDescription, OperationDescription,
PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
SelectAssignOperationDescription, SelectOperationDescription, SliceOperationDescription,
SwapDimsDescription, UnaryOperationDescription,
EmbeddingBackwardDescription, EmbeddingDescription, FlipOperationDescription,
FloatOperationDescription, GatherOperationDescription, IntOperationDescription,
InterpolateBackwardDescription, InterpolateDescription, MaskFillOperationDescription,
MaskWhereOperationDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
MaxPool2dWithIndicesDescription, ModuleOperationDescription, NumericOperationDescription,
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
};
use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId};
use burn_tensor::{Element, ElementConversion};
@ -798,6 +798,13 @@ impl BaseOperationDescription {
axes: desc.axes.clone(),
})
}
BaseOperationDescription::Flip(desc) => {
BaseOperationDescription::Flip(FlipOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
axes: desc.axes.clone(),
})
}
BaseOperationDescription::Slice(desc) => {
BaseOperationDescription::Slice(SliceOperationDescription {
tensor: desc.tensor.to_relative(converter),

View File

@ -155,6 +155,12 @@ pub enum BaseOperationDescription {
/// Bool => [permute](burn_tensor::ops::BoolTensorOps::bool_permute).
Permute(PermuteOperationDescription),
/// Operation corresponding to:
/// Float => [flip](burn_tensor::ops::FloatTensorOps::float_flip).
/// Int => [flip](burn_tensor::ops::IntTensorOps::int_flip).
/// Bool => [flip](burn_tensor::ops::BoolTensorOps::bool_flip).
Flip(FlipOperationDescription),
/// Operation corresponding to:
///
/// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice).
@ -456,6 +462,17 @@ pub struct PermuteOperationDescription {
pub axes: Vec<usize>,
}
/// Flip operation description.
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
pub struct FlipOperationDescription {
/// Input tensor description.
pub input: TensorDescription,
/// Output tensor description.
pub out: TensorDescription,
/// The dimensions to flip.
pub axes: Vec<usize>,
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct RandomOperationDescription {
@ -1034,10 +1051,12 @@ impl BaseOperationDescription {
BaseOperationDescription::SwapDims(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::Permute(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::Flip(desc) => {
vec![&desc.input, &desc.out]
}
BaseOperationDescription::Slice(desc) => {
vec![&desc.tensor, &desc.out]
}

View File

@ -0,0 +1,157 @@
use crate::{
codegen::{
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
InputInfo, OutputInfo, WorkgroupLaunch,
},
element::JitElement,
kernel::{DynamicKernelSource, SourceTemplate},
ops::numeric::empty_device,
tensor::JitTensor,
Runtime, RuntimeInt,
};
use burn_tensor::ElementConversion;
use std::marker::PhantomData;
#[derive(new)]
struct FlipEagerKernel<R: Runtime, E: JitElement> {
rank: usize,
_runtime: PhantomData<R>,
_elem: PhantomData<E>,
}
pub struct FlipComputeShader {
input: Variable,
output: Variable,
rank: usize,
}
impl FlipComputeShader {
pub fn expand(self, scope: &mut Scope) {
let input = self.input;
let output = self.output;
let id = Variable::Id;
let offset_input = scope.zero(Elem::UInt);
let offset_local = scope.create_local(Elem::UInt);
let stride = scope.create_local(Elem::UInt);
let shape = scope.create_local(Elem::UInt);
let flip = scope.create_local(Elem::UInt);
let flip_bool = scope.create_local(Elem::Bool);
for i in 0..self.rank {
gpu!(scope, stride = stride(input, i));
gpu!(scope, shape = shape(output, i));
gpu!(
scope,
flip = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
);
gpu!(scope, flip_bool = flip == 1u32);
gpu!(scope, offset_local = id / stride);
gpu!(scope, offset_local = offset_local % shape);
gpu!(scope, if(flip_bool).then(|scope| {
gpu!(scope, offset_local = shape - offset_local);
gpu!(scope, offset_local = offset_local - 1u32);
}));
gpu!(scope, offset_local = offset_local * stride);
gpu!(scope, offset_input += offset_local);
}
let result = scope.create_local(input.item());
gpu!(scope, result = input[offset_input]);
gpu!(scope, output[id] = result);
}
}
impl<R: Runtime, E: JitElement> DynamicKernelSource for FlipEagerKernel<R, E> {
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
let item = E::gpu_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
scope.write_global_custom(output);
FlipComputeShader {
input,
output,
rank: self.rank,
}
.expand(&mut scope);
let input = InputInfo::Array {
item,
visibility: Visibility::Read,
};
let flip_dims = InputInfo::Scalar {
elem: Elem::UInt,
size: self.rank,
};
let output = OutputInfo::Array { item };
let info = CompilationInfo {
inputs: vec![input, flip_dims],
outputs: vec![output],
scope,
};
let settings = CompilationSettings::default();
let shader = Compilation::new(info).compile(settings);
let shader = <R::Compiler as Compiler>::compile(shader);
SourceTemplate::new(shader.to_string())
}
fn id(&self) -> String {
format!("{:?}-rank={:?}", core::any::TypeId::of::<Self>(), self.rank)
}
}
pub(crate) fn flip<R: Runtime, E: JitElement, const D: usize>(
tensor: JitTensor<R, E, D>,
indices: &[usize],
) -> JitTensor<R, E, D> {
let output = empty_device(
tensor.client.clone(),
tensor.device.clone(),
tensor.shape.clone(),
);
flip_on_output(tensor, output, indices)
}
pub(crate) fn flip_on_output<R: Runtime, E: JitElement, const D: usize>(
tensor: JitTensor<R, E, D>,
output: JitTensor<R, E, D>,
indices: &[usize],
) -> JitTensor<R, E, D> {
let mut scalars = Vec::with_capacity(D);
for i in 0..D {
scalars.push((indices.contains(&i) as u32).elem());
}
let kernel = FlipEagerKernel::new(D);
execute_dynamic::<R, FlipEagerKernel<R, E>, RuntimeInt<R>>(
&[EagerHandle::new(
&tensor.handle,
&tensor.strides,
&tensor.shape.dims,
)],
&[EagerHandle::new(
&output.handle,
&output.strides,
&output.shape.dims,
)],
Some(&scalars),
kernel,
WorkgroupLaunch::Output { pos: 0 },
tensor.client,
);
output
}

View File

@ -1,3 +1,4 @@
mod flip;
mod gather;
mod repeat;
mod scatter;
@ -6,6 +7,7 @@ mod select_assign;
mod slice;
mod slice_assign;
pub use flip::*;
pub use repeat::*;
pub use select::*;
pub use select_assign::*;

View File

@ -113,4 +113,11 @@ impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
) -> BoolTensor<Self, D> {
permute(tensor, axes)
}
fn bool_flip<const D: usize>(
tensor: BoolTensor<Self, D>,
axes: &[usize],
) -> BoolTensor<Self, D> {
kernel::flip(tensor, axes)
}
}

View File

@ -520,4 +520,11 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
) -> FloatTensor<Self, D> {
permute(tensor, axes)
}
fn float_flip<const D: usize>(
tensor: FloatTensor<Self, D>,
axes: &[usize],
) -> FloatTensor<Self, D> {
kernel::flip(tensor, axes)
}
}

View File

@ -339,4 +339,8 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
) -> IntTensor<Self, D> {
permute(tensor, axes)
}
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
kernel::flip(tensor, axes)
}
}

View File

@ -4,6 +4,7 @@ use burn_tensor::ElementConversion;
use core::{marker::PhantomData, ops::Range};
use ndarray::s;
use ndarray::Array2;
use ndarray::SliceInfo;
use ndarray::Zip;
use num_traits::Signed;
@ -111,6 +112,34 @@ where
NdArrayTensor::new(array)
}
pub fn flip<const D: usize>(
tensor: NdArrayTensor<E, D>,
axes: &[usize],
) -> NdArrayTensor<E, D> {
let slice_items: Vec<_> = (0..D)
.map(|i| {
if axes.contains(&i) {
SliceInfoElem::Slice {
start: 0,
end: None,
step: -1,
}
} else {
SliceInfoElem::Slice {
start: 0,
end: None,
step: 1,
}
}
})
.collect();
let slice_info =
SliceInfo::<Vec<SliceInfoElem>, IxDyn, IxDyn>::try_from(slice_items).unwrap();
let array = tensor.array.slice(slice_info).into_owned().into_shared();
NdArrayTensor::new(array)
}
}
impl<E> NdArrayMathOps<E>

View File

@ -137,4 +137,11 @@ impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
let array = tensor.array.permuted_axes(axes.into_dimension());
NdArrayTensor { array }
}
fn bool_flip<const D: usize>(
tensor: burn_tensor::ops::BoolTensor<Self, D>,
axes: &[usize],
) -> burn_tensor::ops::BoolTensor<Self, D> {
NdArrayOps::flip(tensor, axes)
}
}

View File

@ -446,6 +446,13 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
NdArrayTensor { array }
}
fn int_flip<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D>,
axes: &[usize],
) -> burn_tensor::ops::IntTensor<Self, D> {
NdArrayOps::flip(tensor, axes)
}
fn int_sign<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, D> {
NdArrayMathOps::sign_op(tensor)
}

View File

@ -494,6 +494,13 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
NdArrayTensor { array }
}
fn float_flip<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self, D>,
axes: &[usize],
) -> burn_tensor::ops::FloatTensor<Self, D> {
NdArrayOps::flip(tensor, axes)
}
fn float_sign<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
NdArrayMathOps::sign_op(tensor)
}

View File

@ -446,6 +446,12 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::new(tensor)
}
pub fn flip<const D: usize>(tensor: TchTensor<E, D>, axes: &[usize]) -> TchTensor<E, D> {
let dims = axes.iter().map(|x| *x as i64).collect::<Vec<_>>();
let tensor = tensor.tensor.flip(dims);
TchTensor::new(tensor)
}
pub fn narrow<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,

View File

@ -139,6 +139,10 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
TchOps::permute(tensor, axes)
}
fn bool_flip<const D: usize>(tensor: TchTensor<bool, D>, axes: &[usize]) -> TchTensor<bool, D> {
TchOps::flip(tensor, axes)
}
fn bool_argwhere<const D: usize>(
tensor: <LibTorch<E> as Backend>::BoolTensorPrimitive<D>,
) -> TchTensor<i64, 2> {

View File

@ -474,6 +474,13 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
TchOps::permute(tensor, axes)
}
fn int_flip<const D: usize>(
tensor: burn_tensor::ops::IntTensor<Self, D>,
axes: &[usize],
) -> burn_tensor::ops::IntTensor<Self, D> {
TchOps::flip(tensor, axes)
}
fn int_sign<const D: usize>(
tensor: <LibTorch<E> as Backend>::IntTensorPrimitive<D>,
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {

View File

@ -483,6 +483,13 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
TchOps::permute(tensor, axes)
}
fn float_flip<const D: usize>(
tensor: burn_tensor::ops::FloatTensor<Self, D>,
axes: &[usize],
) -> burn_tensor::ops::FloatTensor<Self, D> {
TchOps::flip(tensor, axes)
}
fn float_sign<const D: usize>(
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {

View File

@ -170,6 +170,33 @@ where
Tensor::new(K::permute(self.primitive, transformed_axes))
}
/// Reverse the order of elements in the tensor along the given dimensions.
///
/// # Arguments
///
/// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
/// The values can be negative, in which case they are used as an offset from the end.
///
/// # Returns
///
/// The tensor with the axes flipped.
pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
// Convert the axes to usize and handle negative values without using vector
let mut transformed_axes: [usize; N] = [0; N];
for (i, &x) in axes.iter().enumerate() {
transformed_axes[i] = if x < 0 {
(D as isize + x) as usize
} else {
x as usize
};
}
// Check if the axes are valid
check!(TensorCheck::flip(D, &transformed_axes));
Tensor::new(K::flip(self.primitive, &transformed_axes))
}
/// Flatten the tensor along a given range of dimensions.
///
/// This function collapses the specified range of dimensions into a single dimension,
@ -1130,6 +1157,18 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
/// The tensor with the dimensions permuted.
fn permute<const D: usize>(tensor: Self::Primitive<D>, axes: [usize; D]) -> Self::Primitive<D>;
/// Flips the tensor along the given axes.
///
/// # Arguments
///
/// * `tensor` - The tensor to flip.
/// * `axes` - The axes to flip the tensor along.
///
/// # Returns
///
/// The tensor with the axes flipped.
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D>;
/// Select tensor elements corresponding for the given ranges.
///
/// # Arguments
@ -1558,6 +1597,10 @@ impl<B: Backend> BasicOps<B> for Float {
fn permute<const D: usize>(tensor: Self::Primitive<D>, axes: [usize; D]) -> Self::Primitive<D> {
B::float_permute(tensor, axes)
}
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
B::float_flip(tensor, axes)
}
}
impl<B: Backend> BasicOps<B> for Int {
@ -1672,6 +1715,10 @@ impl<B: Backend> BasicOps<B> for Int {
fn permute<const D: usize>(tensor: Self::Primitive<D>, axes: [usize; D]) -> Self::Primitive<D> {
B::int_permute(tensor, axes)
}
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
B::int_flip(tensor, axes)
}
}
impl<B: Backend> BasicOps<B> for Bool {
@ -1786,6 +1833,10 @@ impl<B: Backend> BasicOps<B> for Bool {
fn permute<const D: usize>(tensor: Self::Primitive<D>, axes: [usize; D]) -> Self::Primitive<D> {
B::bool_permute(tensor, axes)
}
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
B::bool_flip(tensor, axes)
}
}
/// Trait used for reshape arguments.

View File

@ -349,6 +349,34 @@ impl TensorCheck {
check
}
pub(crate) fn flip(rank: usize, axes: &[usize]) -> Self {
let check = Self::Ok;
// Check if the axes are within the tensor dimensions
if let Some(axis) = axes.iter().find(|&x| *x >= rank) {
return check.register(
"flip",
TensorError::new("The axes must be smaller than the tensor dimension.").details(
format!("The '{axis}' axis is greater than {rank} dimensions."),
),
);
}
// Check if the axes are unique
let mut dedup = axes.to_vec();
dedup.sort_unstable();
dedup.dedup();
if dedup.len() != axes.len() {
return check.register(
"flip",
TensorError::new("The axes must be unique.")
.details(format!("The axes '{axes:?}' are not unique.")),
);
}
check
}
pub(crate) fn matmul<B: Backend, const D: usize>(
lhs: &Tensor<B, D>,
rhs: &Tensor<B, D>,

View File

@ -300,6 +300,16 @@ pub trait BoolTensorOps<B: Backend> {
fn bool_permute<const D: usize>(tensor: BoolTensor<B, D>, axes: [usize; D])
-> BoolTensor<B, D>;
/// Reverse the order of elements in a tensor along the given axes.
///
/// # Arguments
///
/// * `tensor` - The tensor to reverse.
/// * `axes` - The axes to reverse.
///
/// The tensor with the elements reversed.
fn bool_flip<const D: usize>(tensor: BoolTensor<B, D>, axes: &[usize]) -> BoolTensor<B, D>;
/// Returns a new tensor with the given dimension narrowed to the given range.
///
/// # Arguments

View File

@ -996,6 +996,16 @@ pub trait IntTensorOps<B: Backend> {
/// The tensor with the dimensions permuted.
fn int_permute<const D: usize>(tensor: IntTensor<B, D>, axes: [usize; D]) -> IntTensor<B, D>;
/// Reverse the order of elements in a tensor along the given axes.
///
/// # Arguments
///
/// * `tensor` - The tensor to reverse.
/// * `axes` - The axes to reverse.
///
/// The tensor with the elements reversed.
fn int_flip<const D: usize>(tensor: IntTensor<B, D>, axes: &[usize]) -> IntTensor<B, D>;
/// Returns a new tensor with the given dimension narrowed to the given range.
///
/// # Arguments

View File

@ -449,6 +449,16 @@ pub trait FloatTensorOps<B: Backend> {
axes: [usize; D],
) -> FloatTensor<B, D>;
/// Reverse the order of elements in a tensor along the given axes.
///
/// # Arguments
///
/// * `tensor` - The tensor to reverse.
/// * `axes` - The axes to reverse.
///
/// The tensor with the elements reversed.
fn float_flip<const D: usize>(tensor: FloatTensor<B, D>, axes: &[usize]) -> FloatTensor<B, D>;
/// Reshapes a tensor.
///
/// # Arguments

View File

@ -86,6 +86,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_any!();
burn_tensor::testgen_all_op!();
burn_tensor::testgen_permute!();
burn_tensor::testgen_flip!();
burn_tensor::testgen_bool!();
burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!();

View File

@ -0,0 +1,105 @@
#[burn_tensor_testgen::testgen(flip)]
mod tests {
use super::*;
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Device, Int, Shape, Tensor};
#[test]
fn normal_int() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
let flipped = tensor.clone().flip([0, 2]);
// from pytorch:
// import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2))
let data_expected = Data::from([
[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]],
[[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]],
]);
assert_eq!(data_expected, flipped.into_data());
// Test with no flip
let flipped = tensor.clone().flip([]);
assert_eq!(tensor.into_data(), flipped.into_data());
}
#[test]
fn normal_float() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
.reshape([2, 3, 4])
.float();
let flipped = tensor.clone().flip([0, 2]);
// from pytorch:
// import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).float()
let data_expected = Data::from([
[
[15., 14., 13., 12.],
[19., 18., 17., 16.],
[23., 22., 21., 20.],
],
[[3., 2., 1., 0.], [7., 6., 5., 4.], [11., 10., 9., 8.]],
]);
assert_eq!(data_expected, flipped.into_data());
// Test with no flip
let flipped = tensor.clone().flip([]);
assert_eq!(tensor.into_data(), flipped.into_data());
}
#[test]
fn normal_bool() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
.reshape([2, 3, 4])
.greater_elem(10);
let flipped = tensor.clone().flip([0, 2]);
// from pytorch:
// import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).gt(10)
let data_expected = Data::from([
[
[true, true, true, true],
[true, true, true, true],
[true, true, true, true],
],
[
[false, false, false, false],
[false, false, false, false],
[true, false, false, false],
],
]);
assert_eq!(data_expected, flipped.into_data());
// Test with no flip
let flipped = tensor.clone().flip([]);
assert_eq!(tensor.into_data(), flipped.into_data());
}
#[test]
#[should_panic]
fn edge_duplicated_axes() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
// Test with a duplicated axis
let _ = tensor.clone().flip([0, 0, 1]);
}
#[test]
#[should_panic]
fn edge_out_of_bound_axis() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
// Test with an out of bound axis
let _ = tensor.clone().flip([3, 0, 1]);
}
}

View File

@ -19,6 +19,7 @@ mod div;
mod erf;
mod exp;
mod flatten;
mod flip;
mod full;
mod gather_scatter;
mod init;