mirror of https://github.com/tracel-ai/burn.git
Add `flip` tensor operator (#1468)
This commit is contained in:
parent
6e58663cc1
commit
8911093b88
|
@ -149,6 +149,7 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
|||
| `tensor.dims()` | `tensor.size()` |
|
||||
| `tensor.equal(other)` | `x == y` |
|
||||
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
||||
| `tensor.flip(axes)` | `tensor.flip(axes)` |
|
||||
| `tensor.into_data()` | N/A |
|
||||
| `tensor.into_primitive()` | N/A |
|
||||
| `tensor.into_scalar()` | `tensor.item()` |
|
||||
|
|
|
@ -117,6 +117,10 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
|||
B::bool_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_flip<const D: usize>(tensor: BoolTensor<B, D>, axes: &[usize]) -> BoolTensor<B, D> {
|
||||
B::bool_flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_argwhere<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, 2> {
|
||||
B::bool_argwhere(tensor)
|
||||
}
|
||||
|
|
|
@ -353,6 +353,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
|||
B::int_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
|
||||
B::int_flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_sign<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
|
||||
B::int_sign(tensor)
|
||||
}
|
||||
|
|
|
@ -738,6 +738,62 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
}
|
||||
}
|
||||
|
||||
fn float_flip<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> FloatTensor<Self, D> {
|
||||
#[derive(Debug)]
|
||||
struct FlipDim;
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct RetroFlipDims<B: Backend, const D: usize> {
|
||||
input_id: NodeID,
|
||||
axes: Vec<usize>,
|
||||
_backend: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> RetroForward for RetroFlipDims<B, D> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
|
||||
let input = states.get_state::<B::FloatTensorPrimitive<D>>(&self.input_id);
|
||||
let out = B::float_flip(input, &self.axes);
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Backward<B, D, 1> for FlipDim {
|
||||
type State = Vec<usize>;
|
||||
|
||||
fn backward(
|
||||
self,
|
||||
ops: Ops<Self::State, 1>,
|
||||
grads: &mut Gradients,
|
||||
_checkpointer: &mut Checkpointer,
|
||||
) {
|
||||
let axes = ops.state;
|
||||
|
||||
unary::<B, D, D, _>(ops.parents, ops.node, grads, |grad| {
|
||||
B::float_flip(grad, &axes)
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
match FlipDim
|
||||
.prepare::<C>([tensor.node.clone()], [tensor.graph.clone()])
|
||||
.memory_bound()
|
||||
.retro_forward(RetroFlipDims::<B, D>::new(
|
||||
tensor.node.id.clone(),
|
||||
axes.to_vec(),
|
||||
))
|
||||
.parents([&tensor])
|
||||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(prep) => {
|
||||
prep.finish(axes.to_vec(), B::float_flip(tensor.primitive, axes))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::float_flip(tensor.primitive, axes)),
|
||||
}
|
||||
}
|
||||
|
||||
fn float_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: FloatTensor<Self, D1>,
|
||||
shape: Shape<D2>,
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
#[burn_tensor_testgen::testgen(ad_flip)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_flip() {
|
||||
let data_1: Data<f32, 3> = Data::from([[[1.0, 7.0], [2.0, 3.0]]]); // 1x2x2
|
||||
let data_2: Data<f32, 3> = Data::from([[[3.0, 2.0, 7.0], [3.0, 3.2, 1.0]]]); // 1x2x3
|
||||
|
||||
let device = Default::default();
|
||||
let tensor_1 = TestAutodiffTensor::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().flip([1, 2]);
|
||||
let tensor_4 = tensor_1.clone().matmul(tensor_3);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
assert_eq!(grad_1.to_data(), Data::from([[[7.2, 12.0], [7.2, 12.0]]])); // 1x2x2
|
||||
assert_eq!(
|
||||
grad_2.to_data(),
|
||||
Data::from([[[10.0, 10.0, 10.0], [3.0, 3.0, 3.0]]]) // 1x2x3
|
||||
);
|
||||
}
|
||||
}
|
|
@ -21,6 +21,7 @@ mod cross_entropy;
|
|||
mod div;
|
||||
mod erf;
|
||||
mod exp;
|
||||
mod flip;
|
||||
mod gather_scatter;
|
||||
mod gelu;
|
||||
mod gradients;
|
||||
|
@ -114,6 +115,7 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_sigmoid!();
|
||||
burn_autodiff::testgen_ad_transpose!();
|
||||
burn_autodiff::testgen_ad_permute!();
|
||||
burn_autodiff::testgen_ad_flip!();
|
||||
burn_autodiff::testgen_ad_nonzero!();
|
||||
burn_autodiff::testgen_ad_sign!();
|
||||
};
|
||||
|
|
|
@ -81,6 +81,7 @@ mod tests {
|
|||
burn_tensor::testgen_mul!();
|
||||
burn_tensor::testgen_neg!();
|
||||
burn_tensor::testgen_permute!();
|
||||
burn_tensor::testgen_flip!();
|
||||
burn_tensor::testgen_argwhere_nonzero!();
|
||||
burn_tensor::testgen_sign!();
|
||||
|
||||
|
|
|
@ -23,7 +23,6 @@ pub fn from_data<E: CandleElement, const D: usize>(
|
|||
) -> CandleTensor<E, D> {
|
||||
CandleTensor::from_data(data, *device)
|
||||
}
|
||||
|
||||
pub fn into_data<E: CandleElement, const D: usize>(tensor: CandleTensor<E, D>) -> Data<E, D> {
|
||||
Data::new(
|
||||
tensor.tensor.flatten_all().unwrap().to_vec1().unwrap(),
|
||||
|
@ -60,6 +59,26 @@ pub fn permute<E: CandleElement, const D: usize>(
|
|||
CandleTensor::new(tensor.tensor.permute(axes).unwrap())
|
||||
}
|
||||
|
||||
pub fn flip<E: CandleElement, const D: usize>(
|
||||
tensor: CandleTensor<E, D>,
|
||||
axes: &[usize],
|
||||
) -> CandleTensor<E, D> {
|
||||
// FIXME: Replace with an appropriate method when Candle provides one.
|
||||
let mut tensor = tensor.tensor;
|
||||
for &axis in axes {
|
||||
let indexes = candle_core::Tensor::arange_step(
|
||||
tensor.dim(axis).unwrap() as i64 - 1,
|
||||
-1,
|
||||
-1,
|
||||
tensor.device(),
|
||||
)
|
||||
.unwrap();
|
||||
tensor = tensor.index_select(&indexes, axis).unwrap();
|
||||
}
|
||||
|
||||
CandleTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn reshape<E: CandleElement, const D1: usize, const D2: usize>(
|
||||
tensor: CandleTensor<E, D1>,
|
||||
shape: Shape<D2>,
|
||||
|
|
|
@ -8,8 +8,6 @@ use crate::{
|
|||
Candle, CandleTensor,
|
||||
};
|
||||
|
||||
use super::base::permute;
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<F, I> {
|
||||
fn bool_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> BoolTensor<Self, D> {
|
||||
super::base::empty(shape, device)
|
||||
|
@ -133,6 +131,13 @@ impl<F: FloatCandleElement, I: IntCandleElement> BoolTensorOps<Self> for Candle<
|
|||
tensor: BoolTensor<Self, D>,
|
||||
axes: [usize; D],
|
||||
) -> BoolTensor<Self, D> {
|
||||
permute(tensor, axes)
|
||||
super::base::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_flip<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> BoolTensor<Self, D> {
|
||||
super::base::flip(tensor, axes)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,8 +8,6 @@ use crate::{
|
|||
Candle, CandleTensor,
|
||||
};
|
||||
|
||||
use super::base::permute;
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F, I> {
|
||||
fn int_empty<const D: usize>(shape: Shape<D>, device: &Device<Self>) -> IntTensor<Self, D> {
|
||||
super::base::empty(shape, device)
|
||||
|
@ -425,7 +423,11 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
|||
tensor: IntTensor<Self, D>,
|
||||
axes: [usize; D],
|
||||
) -> IntTensor<Self, D> {
|
||||
permute(tensor, axes)
|
||||
super::base::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
|
||||
super::base::flip(tensor, axes)
|
||||
}
|
||||
|
||||
// TODO add sign operator once Candle supports it:
|
||||
|
|
|
@ -11,8 +11,6 @@ use crate::{
|
|||
Candle, CandleTensor,
|
||||
};
|
||||
|
||||
use super::base::permute;
|
||||
|
||||
impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle<F, I> {
|
||||
fn float_from_data<const D: usize>(
|
||||
data: Data<F, D>,
|
||||
|
@ -522,7 +520,14 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
|
|||
tensor: FloatTensor<Self, D>,
|
||||
axes: [usize; D],
|
||||
) -> FloatTensor<Self, D> {
|
||||
permute(tensor, axes)
|
||||
super::base::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_flip<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> FloatTensor<Self, D> {
|
||||
super::base::flip(tensor, axes)
|
||||
}
|
||||
|
||||
// TODO add sign operator once Candle supports it:
|
||||
|
|
|
@ -4,9 +4,9 @@ use crate::{
|
|||
ops::binary::binary_ops_shape,
|
||||
stream::{
|
||||
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
||||
CatOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
|
||||
ReshapeDescription, SliceAssignOperationDescription, SliceOperationDescription, StreamId,
|
||||
SwapDimsDescription, UnaryOperationDescription,
|
||||
CatOperationDescription, FlipOperationDescription, Operation, OperationDescription,
|
||||
PermuteOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
|
||||
SliceOperationDescription, StreamId, SwapDimsDescription, UnaryOperationDescription,
|
||||
},
|
||||
Fusion, FusionBackend,
|
||||
};
|
||||
|
@ -466,4 +466,39 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out
|
||||
}
|
||||
|
||||
fn bool_flip<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> BoolTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct FlipOps<const D: usize> {
|
||||
desc: FlipOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_bool_tensor::<D>(&self.desc.input);
|
||||
let output = B::bool_flip(input, self.desc.axes.as_slice());
|
||||
handles.register_bool_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tensor.stream;
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = FlipOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
out: out.to_description_out(),
|
||||
axes: axes.to_vec(),
|
||||
};
|
||||
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseBool(BaseOperationDescription::Flip(desc.clone())),
|
||||
FlipOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,13 +6,13 @@ use crate::{
|
|||
scalar_float2int_ops, scalar_float_cmp_ops, scalar_float_ops,
|
||||
stream::{
|
||||
BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
|
||||
ClampOperationDescription, FloatOperationDescription, GatherOperationDescription,
|
||||
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
|
||||
Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription,
|
||||
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
|
||||
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
|
||||
UnaryOperationDescription,
|
||||
ClampOperationDescription, FlipOperationDescription, FloatOperationDescription,
|
||||
GatherOperationDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
|
||||
NumericOperationDescription, Operation, OperationDescription, PermuteOperationDescription,
|
||||
RandomOperationDescription, ReduceDimWithIndicesDescription, ReshapeDescription,
|
||||
ScalarOperationDescription, ScatterOperationDescription, SelectAssignOperationDescription,
|
||||
SelectOperationDescription, SliceAssignOperationDescription, SliceOperationDescription,
|
||||
StreamId, SwapDimsDescription, UnaryOperationDescription,
|
||||
},
|
||||
unary_float_ops, Fusion, FusionBackend, TensorDescription,
|
||||
};
|
||||
|
@ -1846,4 +1846,39 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out
|
||||
}
|
||||
|
||||
fn float_flip<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> FloatTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct FlipOps<const D: usize> {
|
||||
desc: FlipOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for FlipOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_float_tensor::<D>(&self.desc.input);
|
||||
let output = B::float_flip(input, &self.desc.axes);
|
||||
handles.register_float_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tensor.stream;
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = FlipOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
axes: axes.to_vec(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())),
|
||||
FlipOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,9 +6,9 @@ use crate::{
|
|||
scalar_int_cmp_ops, scalar_int_ops,
|
||||
stream::{
|
||||
self, BaseOperationDescription, BinaryOperationDescription, CatOperationDescription,
|
||||
ClampOperationDescription, GatherOperationDescription, MaskFillOperationDescription,
|
||||
MaskWhereOperationDescription, NumericOperationDescription, Operation,
|
||||
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
|
||||
ClampOperationDescription, FlipOperationDescription, GatherOperationDescription,
|
||||
MaskFillOperationDescription, MaskWhereOperationDescription, NumericOperationDescription,
|
||||
Operation, OperationDescription, PermuteOperationDescription, RandomOperationDescription,
|
||||
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
|
||||
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceAssignOperationDescription, SliceOperationDescription, StreamId, SwapDimsDescription,
|
||||
|
@ -1552,4 +1552,38 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
out
|
||||
}
|
||||
|
||||
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct FlipDimsOps<const D: usize> {
|
||||
desc: FlipOperationDescription,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B> for FlipDimsOps<D> {
|
||||
fn execute(self: Box<Self>, handles: &mut crate::HandleContainer<B>) {
|
||||
let input = handles.get_int_tensor::<D>(&self.desc.input);
|
||||
let axes = &self.desc.axes;
|
||||
let output = B::int_flip(input, axes);
|
||||
handles.register_int_tensor(&self.desc.out.id, output);
|
||||
}
|
||||
}
|
||||
|
||||
let stream = tensor.stream;
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = FlipOperationDescription {
|
||||
input: tensor.into_description(),
|
||||
axes: axes.to_vec(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseInt(BaseOperationDescription::Flip(desc.clone())),
|
||||
FlipDimsOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,16 +4,16 @@ use super::{
|
|||
AvgPool2dBackwardDescription, AvgPool2dDescription, BaseOperationDescription,
|
||||
BinaryOperationDescription, BoolOperationDescription, ClampOperationDescription,
|
||||
Conv1dDescription, Conv2dDescription, ConvTranspose1dDescription, ConvTranspose2dDescription,
|
||||
EmbeddingBackwardDescription, EmbeddingDescription, FloatOperationDescription,
|
||||
GatherOperationDescription, IntOperationDescription, InterpolateBackwardDescription,
|
||||
InterpolateDescription, MaskFillOperationDescription, MaskWhereOperationDescription,
|
||||
MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription, MaxPool1dWithIndicesDescription,
|
||||
MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription, MaxPool2dWithIndicesDescription,
|
||||
ModuleOperationDescription, NumericOperationDescription, OperationDescription,
|
||||
PermuteOperationDescription, RandomOperationDescription, ReduceDimWithIndicesDescription,
|
||||
ReshapeDescription, ScalarOperationDescription, ScatterOperationDescription,
|
||||
SelectAssignOperationDescription, SelectOperationDescription, SliceOperationDescription,
|
||||
SwapDimsDescription, UnaryOperationDescription,
|
||||
EmbeddingBackwardDescription, EmbeddingDescription, FlipOperationDescription,
|
||||
FloatOperationDescription, GatherOperationDescription, IntOperationDescription,
|
||||
InterpolateBackwardDescription, InterpolateDescription, MaskFillOperationDescription,
|
||||
MaskWhereOperationDescription, MaxPool1dDescription, MaxPool1dWithIndicesBackwardDescription,
|
||||
MaxPool1dWithIndicesDescription, MaxPool2dDescription, MaxPool2dWithIndicesBackwardDescription,
|
||||
MaxPool2dWithIndicesDescription, ModuleOperationDescription, NumericOperationDescription,
|
||||
OperationDescription, PermuteOperationDescription, RandomOperationDescription,
|
||||
ReduceDimWithIndicesDescription, ReshapeDescription, ScalarOperationDescription,
|
||||
ScatterOperationDescription, SelectAssignOperationDescription, SelectOperationDescription,
|
||||
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
|
||||
};
|
||||
use crate::{FusionBackend, HandleContainer, TensorDescription, TensorId};
|
||||
use burn_tensor::{Element, ElementConversion};
|
||||
|
@ -798,6 +798,13 @@ impl BaseOperationDescription {
|
|||
axes: desc.axes.clone(),
|
||||
})
|
||||
}
|
||||
BaseOperationDescription::Flip(desc) => {
|
||||
BaseOperationDescription::Flip(FlipOperationDescription {
|
||||
input: desc.input.to_relative(converter),
|
||||
out: desc.out.to_relative(converter),
|
||||
axes: desc.axes.clone(),
|
||||
})
|
||||
}
|
||||
BaseOperationDescription::Slice(desc) => {
|
||||
BaseOperationDescription::Slice(SliceOperationDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
|
|
|
@ -155,6 +155,12 @@ pub enum BaseOperationDescription {
|
|||
/// Bool => [permute](burn_tensor::ops::BoolTensorOps::bool_permute).
|
||||
Permute(PermuteOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
/// Float => [flip](burn_tensor::ops::FloatTensorOps::float_flip).
|
||||
/// Int => [flip](burn_tensor::ops::IntTensorOps::int_flip).
|
||||
/// Bool => [flip](burn_tensor::ops::BoolTensorOps::bool_flip).
|
||||
Flip(FlipOperationDescription),
|
||||
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [slice](burn_tensor::ops::FloatTensorOps::float_slice).
|
||||
|
@ -456,6 +462,17 @@ pub struct PermuteOperationDescription {
|
|||
pub axes: Vec<usize>,
|
||||
}
|
||||
|
||||
/// Flip operation description.
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
pub struct FlipOperationDescription {
|
||||
/// Input tensor description.
|
||||
pub input: TensorDescription,
|
||||
/// Output tensor description.
|
||||
pub out: TensorDescription,
|
||||
/// The dimensions to flip.
|
||||
pub axes: Vec<usize>,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct RandomOperationDescription {
|
||||
|
@ -1034,10 +1051,12 @@ impl BaseOperationDescription {
|
|||
BaseOperationDescription::SwapDims(desc) => {
|
||||
vec![&desc.input, &desc.out]
|
||||
}
|
||||
|
||||
BaseOperationDescription::Permute(desc) => {
|
||||
vec![&desc.input, &desc.out]
|
||||
}
|
||||
BaseOperationDescription::Flip(desc) => {
|
||||
vec![&desc.input, &desc.out]
|
||||
}
|
||||
BaseOperationDescription::Slice(desc) => {
|
||||
vec![&desc.tensor, &desc.out]
|
||||
}
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu::{gpu, Elem, Scope, Variable, Visibility},
|
||||
execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime, RuntimeInt,
|
||||
};
|
||||
use burn_tensor::ElementConversion;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
#[derive(new)]
|
||||
struct FlipEagerKernel<R: Runtime, E: JitElement> {
|
||||
rank: usize,
|
||||
_runtime: PhantomData<R>,
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
pub struct FlipComputeShader {
|
||||
input: Variable,
|
||||
output: Variable,
|
||||
rank: usize,
|
||||
}
|
||||
|
||||
impl FlipComputeShader {
|
||||
pub fn expand(self, scope: &mut Scope) {
|
||||
let input = self.input;
|
||||
let output = self.output;
|
||||
let id = Variable::Id;
|
||||
|
||||
let offset_input = scope.zero(Elem::UInt);
|
||||
let offset_local = scope.create_local(Elem::UInt);
|
||||
|
||||
let stride = scope.create_local(Elem::UInt);
|
||||
let shape = scope.create_local(Elem::UInt);
|
||||
let flip = scope.create_local(Elem::UInt);
|
||||
let flip_bool = scope.create_local(Elem::Bool);
|
||||
|
||||
for i in 0..self.rank {
|
||||
gpu!(scope, stride = stride(input, i));
|
||||
gpu!(scope, shape = shape(output, i));
|
||||
gpu!(
|
||||
scope,
|
||||
flip = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
|
||||
);
|
||||
gpu!(scope, flip_bool = flip == 1u32);
|
||||
|
||||
gpu!(scope, offset_local = id / stride);
|
||||
gpu!(scope, offset_local = offset_local % shape);
|
||||
|
||||
gpu!(scope, if(flip_bool).then(|scope| {
|
||||
gpu!(scope, offset_local = shape - offset_local);
|
||||
gpu!(scope, offset_local = offset_local - 1u32);
|
||||
}));
|
||||
gpu!(scope, offset_local = offset_local * stride);
|
||||
|
||||
gpu!(scope, offset_input += offset_local);
|
||||
}
|
||||
|
||||
let result = scope.create_local(input.item());
|
||||
gpu!(scope, result = input[offset_input]);
|
||||
gpu!(scope, output[id] = result);
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime, E: JitElement> DynamicKernelSource for FlipEagerKernel<R, E> {
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
let item = E::gpu_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
FlipComputeShader {
|
||||
input,
|
||||
output,
|
||||
rank: self.rank,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let input = InputInfo::Array {
|
||||
item,
|
||||
visibility: Visibility::Read,
|
||||
};
|
||||
let flip_dims = InputInfo::Scalar {
|
||||
elem: Elem::UInt,
|
||||
size: self.rank,
|
||||
};
|
||||
let output = OutputInfo::Array { item };
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![input, flip_dims],
|
||||
outputs: vec![output],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default();
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!("{:?}-rank={:?}", core::any::TypeId::of::<Self>(), self.rank)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn flip<R: Runtime, E: JitElement, const D: usize>(
|
||||
tensor: JitTensor<R, E, D>,
|
||||
indices: &[usize],
|
||||
) -> JitTensor<R, E, D> {
|
||||
let output = empty_device(
|
||||
tensor.client.clone(),
|
||||
tensor.device.clone(),
|
||||
tensor.shape.clone(),
|
||||
);
|
||||
flip_on_output(tensor, output, indices)
|
||||
}
|
||||
|
||||
pub(crate) fn flip_on_output<R: Runtime, E: JitElement, const D: usize>(
|
||||
tensor: JitTensor<R, E, D>,
|
||||
output: JitTensor<R, E, D>,
|
||||
indices: &[usize],
|
||||
) -> JitTensor<R, E, D> {
|
||||
let mut scalars = Vec::with_capacity(D);
|
||||
|
||||
for i in 0..D {
|
||||
scalars.push((indices.contains(&i) as u32).elem());
|
||||
}
|
||||
|
||||
let kernel = FlipEagerKernel::new(D);
|
||||
|
||||
execute_dynamic::<R, FlipEagerKernel<R, E>, RuntimeInt<R>>(
|
||||
&[EagerHandle::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
&[EagerHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
Some(&scalars),
|
||||
kernel,
|
||||
WorkgroupLaunch::Output { pos: 0 },
|
||||
tensor.client,
|
||||
);
|
||||
|
||||
output
|
||||
}
|
|
@ -1,3 +1,4 @@
|
|||
mod flip;
|
||||
mod gather;
|
||||
mod repeat;
|
||||
mod scatter;
|
||||
|
@ -6,6 +7,7 @@ mod select_assign;
|
|||
mod slice;
|
||||
mod slice_assign;
|
||||
|
||||
pub use flip::*;
|
||||
pub use repeat::*;
|
||||
pub use select::*;
|
||||
pub use select_assign::*;
|
||||
|
|
|
@ -113,4 +113,11 @@ impl<R: Runtime> BoolTensorOps<Self> for JitBackend<R> {
|
|||
) -> BoolTensor<Self, D> {
|
||||
permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_flip<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> BoolTensor<Self, D> {
|
||||
kernel::flip(tensor, axes)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -520,4 +520,11 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
) -> FloatTensor<Self, D> {
|
||||
permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_flip<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> FloatTensor<Self, D> {
|
||||
kernel::flip(tensor, axes)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -339,4 +339,8 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
|
|||
) -> IntTensor<Self, D> {
|
||||
permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_flip<const D: usize>(tensor: IntTensor<Self, D>, axes: &[usize]) -> IntTensor<Self, D> {
|
||||
kernel::flip(tensor, axes)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ use burn_tensor::ElementConversion;
|
|||
use core::{marker::PhantomData, ops::Range};
|
||||
use ndarray::s;
|
||||
use ndarray::Array2;
|
||||
use ndarray::SliceInfo;
|
||||
use ndarray::Zip;
|
||||
use num_traits::Signed;
|
||||
|
||||
|
@ -111,6 +112,34 @@ where
|
|||
|
||||
NdArrayTensor::new(array)
|
||||
}
|
||||
|
||||
pub fn flip<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
axes: &[usize],
|
||||
) -> NdArrayTensor<E, D> {
|
||||
let slice_items: Vec<_> = (0..D)
|
||||
.map(|i| {
|
||||
if axes.contains(&i) {
|
||||
SliceInfoElem::Slice {
|
||||
start: 0,
|
||||
end: None,
|
||||
step: -1,
|
||||
}
|
||||
} else {
|
||||
SliceInfoElem::Slice {
|
||||
start: 0,
|
||||
end: None,
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let slice_info =
|
||||
SliceInfo::<Vec<SliceInfoElem>, IxDyn, IxDyn>::try_from(slice_items).unwrap();
|
||||
let array = tensor.array.slice(slice_info).into_owned().into_shared();
|
||||
|
||||
NdArrayTensor::new(array)
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> NdArrayMathOps<E>
|
||||
|
|
|
@ -137,4 +137,11 @@ impl<E: FloatNdArrayElement> BoolTensorOps<Self> for NdArray<E> {
|
|||
let array = tensor.array.permuted_axes(axes.into_dimension());
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn bool_flip<const D: usize>(
|
||||
tensor: burn_tensor::ops::BoolTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> burn_tensor::ops::BoolTensor<Self, D> {
|
||||
NdArrayOps::flip(tensor, axes)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -446,6 +446,13 @@ impl<E: FloatNdArrayElement> IntTensorOps<Self> for NdArray<E> {
|
|||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn int_flip<const D: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> burn_tensor::ops::IntTensor<Self, D> {
|
||||
NdArrayOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_sign<const D: usize>(tensor: NdArrayTensor<i64, D>) -> NdArrayTensor<i64, D> {
|
||||
NdArrayMathOps::sign_op(tensor)
|
||||
}
|
||||
|
|
|
@ -494,6 +494,13 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
|||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn float_flip<const D: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> burn_tensor::ops::FloatTensor<Self, D> {
|
||||
NdArrayOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_sign<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::sign_op(tensor)
|
||||
}
|
||||
|
|
|
@ -446,6 +446,12 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn flip<const D: usize>(tensor: TchTensor<E, D>, axes: &[usize]) -> TchTensor<E, D> {
|
||||
let dims = axes.iter().map(|x| *x as i64).collect::<Vec<_>>();
|
||||
let tensor = tensor.tensor.flip(dims);
|
||||
TchTensor::new(tensor)
|
||||
}
|
||||
|
||||
pub fn narrow<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
|
|
|
@ -139,6 +139,10 @@ impl<E: TchElement> BoolTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_flip<const D: usize>(tensor: TchTensor<bool, D>, axes: &[usize]) -> TchTensor<bool, D> {
|
||||
TchOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn bool_argwhere<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> TchTensor<i64, 2> {
|
||||
|
|
|
@ -474,6 +474,13 @@ impl<E: TchElement> IntTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_flip<const D: usize>(
|
||||
tensor: burn_tensor::ops::IntTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> burn_tensor::ops::IntTensor<Self, D> {
|
||||
TchOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn int_sign<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::IntTensorPrimitive<D>,
|
||||
) -> <LibTorch<E> as Backend>::IntTensorPrimitive<D> {
|
||||
|
|
|
@ -483,6 +483,13 @@ impl<E: TchElement> FloatTensorOps<Self> for LibTorch<E> {
|
|||
TchOps::permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_flip<const D: usize>(
|
||||
tensor: burn_tensor::ops::FloatTensor<Self, D>,
|
||||
axes: &[usize],
|
||||
) -> burn_tensor::ops::FloatTensor<Self, D> {
|
||||
TchOps::flip(tensor, axes)
|
||||
}
|
||||
|
||||
fn float_sign<const D: usize>(
|
||||
tensor: <LibTorch<E> as Backend>::FloatTensorPrimitive<D>,
|
||||
) -> <LibTorch<E> as Backend>::FloatTensorPrimitive<D> {
|
||||
|
|
|
@ -170,6 +170,33 @@ where
|
|||
Tensor::new(K::permute(self.primitive, transformed_axes))
|
||||
}
|
||||
|
||||
/// Reverse the order of elements in the tensor along the given dimensions.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `axes` - The dimensions to reverse. The values must be unique and in the range of the number of dimensions.
|
||||
/// The values can be negative, in which case they are used as an offset from the end.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the axes flipped.
|
||||
pub fn flip<const N: usize>(self, axes: [isize; N]) -> Tensor<B, D, K> {
|
||||
// Convert the axes to usize and handle negative values without using vector
|
||||
let mut transformed_axes: [usize; N] = [0; N];
|
||||
for (i, &x) in axes.iter().enumerate() {
|
||||
transformed_axes[i] = if x < 0 {
|
||||
(D as isize + x) as usize
|
||||
} else {
|
||||
x as usize
|
||||
};
|
||||
}
|
||||
|
||||
// Check if the axes are valid
|
||||
check!(TensorCheck::flip(D, &transformed_axes));
|
||||
|
||||
Tensor::new(K::flip(self.primitive, &transformed_axes))
|
||||
}
|
||||
|
||||
/// Flatten the tensor along a given range of dimensions.
|
||||
///
|
||||
/// This function collapses the specified range of dimensions into a single dimension,
|
||||
|
@ -1130,6 +1157,18 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
/// The tensor with the dimensions permuted.
|
||||
fn permute<const D: usize>(tensor: Self::Primitive<D>, axes: [usize; D]) -> Self::Primitive<D>;
|
||||
|
||||
/// Flips the tensor along the given axes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to flip.
|
||||
/// * `axes` - The axes to flip the tensor along.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The tensor with the axes flipped.
|
||||
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D>;
|
||||
|
||||
/// Select tensor elements corresponding for the given ranges.
|
||||
///
|
||||
/// # Arguments
|
||||
|
@ -1558,6 +1597,10 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
fn permute<const D: usize>(tensor: Self::Primitive<D>, axes: [usize; D]) -> Self::Primitive<D> {
|
||||
B::float_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
|
||||
B::float_flip(tensor, axes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BasicOps<B> for Int {
|
||||
|
@ -1672,6 +1715,10 @@ impl<B: Backend> BasicOps<B> for Int {
|
|||
fn permute<const D: usize>(tensor: Self::Primitive<D>, axes: [usize; D]) -> Self::Primitive<D> {
|
||||
B::int_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
|
||||
B::int_flip(tensor, axes)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> BasicOps<B> for Bool {
|
||||
|
@ -1786,6 +1833,10 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
fn permute<const D: usize>(tensor: Self::Primitive<D>, axes: [usize; D]) -> Self::Primitive<D> {
|
||||
B::bool_permute(tensor, axes)
|
||||
}
|
||||
|
||||
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
|
||||
B::bool_flip(tensor, axes)
|
||||
}
|
||||
}
|
||||
|
||||
/// Trait used for reshape arguments.
|
||||
|
|
|
@ -349,6 +349,34 @@ impl TensorCheck {
|
|||
check
|
||||
}
|
||||
|
||||
pub(crate) fn flip(rank: usize, axes: &[usize]) -> Self {
|
||||
let check = Self::Ok;
|
||||
|
||||
// Check if the axes are within the tensor dimensions
|
||||
if let Some(axis) = axes.iter().find(|&x| *x >= rank) {
|
||||
return check.register(
|
||||
"flip",
|
||||
TensorError::new("The axes must be smaller than the tensor dimension.").details(
|
||||
format!("The '{axis}' axis is greater than {rank} dimensions."),
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
// Check if the axes are unique
|
||||
let mut dedup = axes.to_vec();
|
||||
dedup.sort_unstable();
|
||||
dedup.dedup();
|
||||
if dedup.len() != axes.len() {
|
||||
return check.register(
|
||||
"flip",
|
||||
TensorError::new("The axes must be unique.")
|
||||
.details(format!("The axes '{axes:?}' are not unique.")),
|
||||
);
|
||||
}
|
||||
|
||||
check
|
||||
}
|
||||
|
||||
pub(crate) fn matmul<B: Backend, const D: usize>(
|
||||
lhs: &Tensor<B, D>,
|
||||
rhs: &Tensor<B, D>,
|
||||
|
|
|
@ -300,6 +300,16 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
fn bool_permute<const D: usize>(tensor: BoolTensor<B, D>, axes: [usize; D])
|
||||
-> BoolTensor<B, D>;
|
||||
|
||||
/// Reverse the order of elements in a tensor along the given axes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to reverse.
|
||||
/// * `axes` - The axes to reverse.
|
||||
///
|
||||
/// The tensor with the elements reversed.
|
||||
fn bool_flip<const D: usize>(tensor: BoolTensor<B, D>, axes: &[usize]) -> BoolTensor<B, D>;
|
||||
|
||||
/// Returns a new tensor with the given dimension narrowed to the given range.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -996,6 +996,16 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// The tensor with the dimensions permuted.
|
||||
fn int_permute<const D: usize>(tensor: IntTensor<B, D>, axes: [usize; D]) -> IntTensor<B, D>;
|
||||
|
||||
/// Reverse the order of elements in a tensor along the given axes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to reverse.
|
||||
/// * `axes` - The axes to reverse.
|
||||
///
|
||||
/// The tensor with the elements reversed.
|
||||
fn int_flip<const D: usize>(tensor: IntTensor<B, D>, axes: &[usize]) -> IntTensor<B, D>;
|
||||
|
||||
/// Returns a new tensor with the given dimension narrowed to the given range.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -449,6 +449,16 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
axes: [usize; D],
|
||||
) -> FloatTensor<B, D>;
|
||||
|
||||
/// Reverse the order of elements in a tensor along the given axes.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `tensor` - The tensor to reverse.
|
||||
/// * `axes` - The axes to reverse.
|
||||
///
|
||||
/// The tensor with the elements reversed.
|
||||
fn float_flip<const D: usize>(tensor: FloatTensor<B, D>, axes: &[usize]) -> FloatTensor<B, D>;
|
||||
|
||||
/// Reshapes a tensor.
|
||||
///
|
||||
/// # Arguments
|
||||
|
|
|
@ -86,6 +86,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_any!();
|
||||
burn_tensor::testgen_all_op!();
|
||||
burn_tensor::testgen_permute!();
|
||||
burn_tensor::testgen_flip!();
|
||||
burn_tensor::testgen_bool!();
|
||||
burn_tensor::testgen_argwhere_nonzero!();
|
||||
burn_tensor::testgen_sign!();
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
#[burn_tensor_testgen::testgen(flip)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Device, Int, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn normal_int() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
|
||||
|
||||
let flipped = tensor.clone().flip([0, 2]);
|
||||
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2))
|
||||
let data_expected = Data::from([
|
||||
[[15, 14, 13, 12], [19, 18, 17, 16], [23, 22, 21, 20]],
|
||||
[[3, 2, 1, 0], [7, 6, 5, 4], [11, 10, 9, 8]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, flipped.into_data());
|
||||
|
||||
// Test with no flip
|
||||
let flipped = tensor.clone().flip([]);
|
||||
assert_eq!(tensor.into_data(), flipped.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normal_float() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.float();
|
||||
|
||||
let flipped = tensor.clone().flip([0, 2]);
|
||||
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).float()
|
||||
let data_expected = Data::from([
|
||||
[
|
||||
[15., 14., 13., 12.],
|
||||
[19., 18., 17., 16.],
|
||||
[23., 22., 21., 20.],
|
||||
],
|
||||
[[3., 2., 1., 0.], [7., 6., 5., 4.], [11., 10., 9., 8.]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, flipped.into_data());
|
||||
|
||||
// Test with no flip
|
||||
let flipped = tensor.clone().flip([]);
|
||||
assert_eq!(tensor.into_data(), flipped.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normal_bool() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device)
|
||||
.reshape([2, 3, 4])
|
||||
.greater_elem(10);
|
||||
|
||||
let flipped = tensor.clone().flip([0, 2]);
|
||||
|
||||
// from pytorch:
|
||||
// import torch; torch.arange(0, 24).reshape(2, 3, 4).flip((0, 2)).gt(10)
|
||||
let data_expected = Data::from([
|
||||
[
|
||||
[true, true, true, true],
|
||||
[true, true, true, true],
|
||||
[true, true, true, true],
|
||||
],
|
||||
[
|
||||
[false, false, false, false],
|
||||
[false, false, false, false],
|
||||
[true, false, false, false],
|
||||
],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, flipped.into_data());
|
||||
|
||||
// Test with no flip
|
||||
let flipped = tensor.clone().flip([]);
|
||||
assert_eq!(tensor.into_data(), flipped.into_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn edge_duplicated_axes() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
|
||||
|
||||
// Test with a duplicated axis
|
||||
let _ = tensor.clone().flip([0, 0, 1]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn edge_out_of_bound_axis() {
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1, Int>::arange(0..24, &device).reshape([2, 3, 4]);
|
||||
|
||||
// Test with an out of bound axis
|
||||
let _ = tensor.clone().flip([3, 0, 1]);
|
||||
}
|
||||
}
|
|
@ -19,6 +19,7 @@ mod div;
|
|||
mod erf;
|
||||
mod exp;
|
||||
mod flatten;
|
||||
mod flip;
|
||||
mod full;
|
||||
mod gather_scatter;
|
||||
mod init;
|
||||
|
|
Loading…
Reference in New Issue