Repeat operation (#2090)

* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error
This commit is contained in:
mepatrick73 2024-08-02 20:33:47 -04:00 committed by GitHub
parent bb13729b20
commit f7639bd35a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 478 additions and 174 deletions

View File

@ -131,40 +131,41 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`. Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| Burn | PyTorch Equivalent | | Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------ | | ------------------------------------- | ------------------------------------------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` | | `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` | | `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A | | `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` | | `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` | | `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` | | `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` | | `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` | | `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` | | `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` | | `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` | | `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` | | `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` | | `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` | | `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` | | `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A | | `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A | | `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` | | `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` | | `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` | | `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` | | `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` | | `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` | | `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])`|
| `tensor.reshape(shape)` | `tensor.view(shape)` | | `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
| `tensor.shape()` | `tensor.shape` | | `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` | | `tensor.shape()` | `tensor.shape` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` | | `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` | | `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.to_data()` | N/A | | `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_device(device)` | `tensor.to(device)` | | `tensor.to_data()` | N/A |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` | | `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` | | `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
### Numeric Operations ### Numeric Operations

View File

@ -132,11 +132,11 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_expand(tensor, shape) B::bool_expand(tensor, shape)
} }
fn bool_repeat<const D: usize>( fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<B, D>, tensor: BoolTensor<B, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> BoolTensor<B, D> { ) -> BoolTensor<B, D> {
B::bool_repeat(tensor, dim, times) B::bool_repeat_dim(tensor, dim, times)
} }
} }

View File

@ -162,12 +162,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_mean_dim(tensor, dim) B::int_mean_dim(tensor, dim)
} }
fn int_repeat<const D: usize>( fn int_repeat_dim<const D: usize>(
tensor: IntTensor<B, D>, tensor: IntTensor<B, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> IntTensor<B, D> { ) -> IntTensor<B, D> {
B::int_repeat(tensor, dim, times) B::int_repeat_dim(tensor, dim, times)
} }
fn int_greater<const D: usize>(lhs: IntTensor<B, D>, rhs: IntTensor<B, D>) -> BoolTensor<B, D> { fn int_greater<const D: usize>(lhs: IntTensor<B, D>, rhs: IntTensor<B, D>) -> BoolTensor<B, D> {

View File

@ -2418,7 +2418,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
B::float_argsort(tensor.primitive, dim, descending) B::float_argsort(tensor.primitive, dim, descending)
} }
fn float_repeat<const D: usize>( fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<Self, D>, tensor: FloatTensor<Self, D>,
dim: usize, dim: usize,
times: usize, times: usize,
@ -2437,7 +2437,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> { impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) { fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id); let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
let out = B::float_repeat(tensor, self.dim, self.times); let out = B::float_repeat_dim(tensor, self.dim, self.times);
states.save(out_node, out) states.save(out_node, out)
} }
} }
@ -2467,9 +2467,11 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.stateful() .stateful()
{ {
OpsKind::Tracked(prep) => { OpsKind::Tracked(prep) => {
prep.finish(dim, B::float_repeat(tensor.primitive, dim, times)) prep.finish(dim, B::float_repeat_dim(tensor.primitive, dim, times))
}
OpsKind::UnTracked(prep) => {
prep.finish(B::float_repeat_dim(tensor.primitive, dim, times))
} }
OpsKind::UnTracked(prep) => prep.finish(B::float_repeat(tensor.primitive, dim, times)),
} }
} }

View File

@ -47,7 +47,7 @@ mod permute;
mod pow; mod pow;
mod recip; mod recip;
mod relu; mod relu;
mod repeat; mod repeat_dim;
mod reshape; mod reshape;
mod select; mod select;
mod sigmoid; mod sigmoid;
@ -133,6 +133,6 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_sign!(); burn_autodiff::testgen_ad_sign!();
burn_autodiff::testgen_ad_expand!(); burn_autodiff::testgen_ad_expand!();
burn_autodiff::testgen_ad_sort!(); burn_autodiff::testgen_ad_sort!();
burn_autodiff::testgen_ad_repeat!(); burn_autodiff::testgen_ad_repeat_dim!();
}; };
} }

View File

@ -1,4 +1,4 @@
#[burn_tensor_testgen::testgen(ad_repeat)] #[burn_tensor_testgen::testgen(ad_repeat_dim)]
mod tests { mod tests {
use super::*; use super::*;
use burn_tensor::{activation, TensorData}; use burn_tensor::{activation, TensorData};
@ -12,7 +12,7 @@ mod tests {
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad(); let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad(); let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().repeat(1, 3); let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
let tensor_3 = tensor_1.matmul(tensor_3); let tensor_3 = tensor_1.matmul(tensor_3);
let grads = tensor_3.backward(); let grads = tensor_3.backward();

View File

@ -94,7 +94,7 @@ mod tests {
// burn_tensor::testgen_powf!(); // burn_tensor::testgen_powf!();
burn_tensor::testgen_random!(); burn_tensor::testgen_random!();
burn_tensor::testgen_repeat!(); burn_tensor::testgen_repeat_dim!();
burn_tensor::testgen_reshape!(); burn_tensor::testgen_reshape!();
burn_tensor::testgen_select!(); burn_tensor::testgen_select!();
burn_tensor::testgen_sin!(); burn_tensor::testgen_sin!();

View File

@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values); mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
} }
mask = mask.repeat(0, batch_size); mask = mask.repeat_dim(0, batch_size);
mask.equal_elem(1_i64.elem::<i64>()) mask.equal_elem(1_i64.elem::<i64>())
} }

View File

@ -152,7 +152,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
* weights * weights
.clone() .clone()
.reshape([1, nr_classes]) .reshape([1, nr_classes])
.repeat(0, batch_size); .repeat_dim(0, batch_size);
let weights = weights.clone().gather(0, targets); let weights = weights.clone().gather(0, targets);
let tensor = Self::apply_mask_2d(tensor, mask); let tensor = Self::apply_mask_2d(tensor, mask);
tensor.sum().neg() / weights.sum() tensor.sum().neg() / weights.sum()
@ -224,7 +224,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> { fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
if let Some(mask) = mask { if let Some(mask) = mask {
let [batch_size, nr_classes] = tensor.dims(); let [batch_size, nr_classes] = tensor.dims();
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0); tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
} }
tensor tensor
@ -312,7 +312,7 @@ mod tests {
* targets_logits * targets_logits
* Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device) * Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
.unsqueeze() .unsqueeze()
.repeat(0, 4); .repeat_dim(0, 4);
let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.); let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3); loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
} }

View File

@ -58,7 +58,7 @@ impl RotaryEncodingConfig {
.float() .float()
.unsqueeze() .unsqueeze()
.transpose() .transpose()
.repeat(1, self.d_model / 2) .repeat_dim(1, self.d_model / 2)
* theta_i.unsqueeze(); * theta_i.unsqueeze();
// Convert frequency values to complex numbers (polar form) // Convert frequency values to complex numbers (polar form)
@ -71,7 +71,7 @@ impl RotaryEncodingConfig {
.reshape([self.max_sequence_length, 2, self.d_model / 2]) .reshape([self.max_sequence_length, 2, self.d_model / 2])
.transpose() .transpose()
.unsqueeze_dim::<4>(2) .unsqueeze_dim::<4>(2)
.repeat(2, 2) .repeat_dim(2, 2)
.reshape([self.max_sequence_length, self.d_model, 2]); .reshape([self.max_sequence_length, self.d_model, 2]);
RotaryEncoding { RotaryEncoding {

View File

@ -14,7 +14,7 @@ use burn_tensor::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription, BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription, CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
HandleContainer, OperationDescription, PermuteOperationDescription, HandleContainer, OperationDescription, PermuteOperationDescription,
RepeatOperationDescription, ReshapeDescription, SliceAssignOperationDescription, RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription, SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
}, },
Device, Shape, Device, Shape,
@ -575,22 +575,22 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
out out
} }
fn bool_repeat<const D: usize>( fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<Self, D>, tensor: BoolTensor<Self, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> BoolTensor<Self, D> { ) -> BoolTensor<Self, D> {
#[derive(new)] #[derive(new)]
struct RepeatOps<B: FusionBackend, const D: usize> { struct RepeatDimOps<B: FusionBackend, const D: usize> {
desc: RepeatOperationDescription, desc: RepeatDimOperationDescription,
_b: PhantomData<B>, _b: PhantomData<B>,
} }
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> { impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) { fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B, D>(&self.desc.tensor); let tensor = handles.get_bool_tensor::<B, D>(&self.desc.tensor);
let output = B::bool_repeat::<D>(tensor, self.desc.dim, self.desc.times); let output = B::bool_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
handles.register_bool_tensor::<B, D>(&self.desc.out.id, output); handles.register_bool_tensor::<B, D>(&self.desc.out.id, output);
} }
@ -601,7 +601,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
shape[dim] *= times; shape[dim] *= times;
let out = tensor.client.tensor_uninitialized(shape, DType::Bool); let out = tensor.client.tensor_uninitialized(shape, DType::Bool);
let desc = RepeatOperationDescription { let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(), tensor: tensor.into_description(),
dim, dim,
times, times,
@ -609,8 +609,8 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
}; };
out.client.register( out.client.register(
vec![stream], vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())), OperationDescription::BaseBool(BaseOperationDescription::RepeatDim(desc.clone())),
RepeatOps::<B, D>::new(desc), RepeatDimOps::<B, D>::new(desc),
); );
out out

View File

@ -1696,22 +1696,22 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out out
} }
fn float_repeat<const D: usize>( fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<Self, D>, tensor: FloatTensor<Self, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> FloatTensor<Self, D> { ) -> FloatTensor<Self, D> {
#[derive(new)] #[derive(new)]
struct RepeatOps<B: FusionBackend, const D: usize> { struct RepeatDimOps<B: FusionBackend, const D: usize> {
desc: RepeatOperationDescription, desc: RepeatDimOperationDescription,
_b: PhantomData<B>, _b: PhantomData<B>,
} }
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> { impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) { fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_float_tensor::<B, D>(&self.desc.tensor); let tensor = handles.get_float_tensor::<B, D>(&self.desc.tensor);
let output = B::float_repeat::<D>(tensor, self.desc.dim, self.desc.times); let output = B::float_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
handles.register_float_tensor::<B, D>(&self.desc.out.id, output); handles.register_float_tensor::<B, D>(&self.desc.out.id, output);
} }
@ -1724,7 +1724,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
.client .client
.tensor_uninitialized(shape, B::FloatElem::dtype()); .tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = RepeatOperationDescription { let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(), tensor: tensor.into_description(),
dim, dim,
times, times,
@ -1732,8 +1732,8 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
}; };
out.client.register( out.client.register(
vec![stream], vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())), OperationDescription::BaseFloat(BaseOperationDescription::RepeatDim(desc.clone())),
RepeatOps::<B, D>::new(desc), RepeatDimOps::<B, D>::new(desc),
); );
out out

View File

@ -1755,22 +1755,22 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out out
} }
fn int_repeat<const D: usize>( fn int_repeat_dim<const D: usize>(
tensor: IntTensor<Self, D>, tensor: IntTensor<Self, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> IntTensor<Self, D> { ) -> IntTensor<Self, D> {
#[derive(new)] #[derive(new)]
struct RepeatOps<B: FusionBackend, const D: usize> { struct RepeatDimOps<B: FusionBackend, const D: usize> {
desc: RepeatOperationDescription, desc: RepeatDimOperationDescription,
_b: PhantomData<B>, _b: PhantomData<B>,
} }
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> { impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) { fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_int_tensor::<B, D>(&self.desc.tensor); let tensor = handles.get_int_tensor::<B, D>(&self.desc.tensor);
let output = B::int_repeat::<D>(tensor, self.desc.dim, self.desc.times); let output = B::int_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
handles.register_int_tensor::<B, D>(&self.desc.out.id, output); handles.register_int_tensor::<B, D>(&self.desc.out.id, output);
} }
@ -1783,7 +1783,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client .client
.tensor_uninitialized(shape, B::IntElem::dtype()); .tensor_uninitialized(shape, B::IntElem::dtype());
let desc = RepeatOperationDescription { let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(), tensor: tensor.into_description(),
dim, dim,
times, times,
@ -1791,8 +1791,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
}; };
out.client.register( out.client.register(
vec![stream], vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())), OperationDescription::BaseInt(BaseOperationDescription::RepeatDim(desc.clone())),
RepeatOps::<B, D>::new(desc), RepeatDimOps::<B, D>::new(desc),
); );
out out

View File

@ -848,8 +848,8 @@ impl RelativeOps for BaseOperationDescription {
out: desc.out.to_relative(converter), out: desc.out.to_relative(converter),
}) })
} }
BaseOperationDescription::Repeat(desc) => { BaseOperationDescription::RepeatDim(desc) => {
BaseOperationDescription::Repeat(RepeatOperationDescription { BaseOperationDescription::RepeatDim(RepeatDimOperationDescription {
tensor: desc.tensor.to_relative(converter), tensor: desc.tensor.to_relative(converter),
dim: desc.dim, dim: desc.dim,
times: desc.times, times: desc.times,

View File

@ -1,6 +1,6 @@
mod flip; mod flip;
mod gather; mod gather;
mod repeat; mod repeat_dim;
mod scatter; mod scatter;
mod select; mod select;
mod select_assign; mod select_assign;
@ -8,7 +8,7 @@ mod slice;
mod slice_assign; mod slice_assign;
pub use flip::*; pub use flip::*;
pub use repeat::*; pub use repeat_dim::*;
pub use select::*; pub use select::*;
pub use select_assign::*; pub use select_assign::*;
pub use slice::*; pub use slice::*;

View File

@ -95,7 +95,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for RepeatEagerKernel<R, E> {
} }
} }
pub(crate) fn repeat<R: JitRuntime, E: JitElement, const D1: usize>( pub(crate) fn repeat_dim<R: JitRuntime, E: JitElement, const D1: usize>(
input: JitTensor<R, E, D1>, input: JitTensor<R, E, D1>,
dim: usize, dim: usize,
times: usize, times: usize,

View File

@ -94,12 +94,12 @@ where
tensor tensor
} }
fn bool_repeat<const D: usize>( fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<Self, D>, tensor: BoolTensor<Self, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> BoolTensor<Self, D> { ) -> BoolTensor<Self, D> {
kernel::repeat(tensor, dim, times) kernel::repeat_dim(tensor, dim, times)
} }
fn bool_permute<const D: usize>( fn bool_permute<const D: usize>(

View File

@ -467,12 +467,12 @@ where
}) })
} }
fn float_repeat<const D: usize>( fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<Self, D>, tensor: FloatTensor<Self, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> FloatTensor<Self, D> { ) -> FloatTensor<Self, D> {
kernel::repeat(tensor, dim, times) kernel::repeat_dim(tensor, dim, times)
} }
fn float_powf<const D: usize>( fn float_powf<const D: usize>(

View File

@ -318,12 +318,12 @@ where
tensor tensor
} }
fn int_repeat<const D: usize>( fn int_repeat_dim<const D: usize>(
tensor: IntTensor<Self, D>, tensor: IntTensor<Self, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> IntTensor<Self, D> { ) -> IntTensor<Self, D> {
kernel::repeat(tensor, dim, times) kernel::repeat_dim(tensor, dim, times)
} }
fn int_random<const D: usize>( fn int_random<const D: usize>(

View File

@ -17,7 +17,7 @@ mod max_pool2d;
mod max_pool2d_backward; mod max_pool2d_backward;
mod normal; mod normal;
mod reduce; mod reduce;
mod repeat; mod repeat_dim;
mod scatter; mod scatter;
mod select; mod select;
mod select_assign; mod select_assign;
@ -48,7 +48,7 @@ macro_rules! testgen_all {
burn_jit::testgen_conv_transpose2d!(); burn_jit::testgen_conv_transpose2d!();
burn_jit::testgen_conv_transpose3d!(); burn_jit::testgen_conv_transpose3d!();
burn_jit::testgen_repeat!(); burn_jit::testgen_repeat_dim!();
burn_jit::testgen_gather!(); burn_jit::testgen_gather!();
burn_jit::testgen_scatter!(); burn_jit::testgen_scatter!();

View File

@ -1,4 +1,4 @@
#[burn_tensor_testgen::testgen(repeat)] #[burn_tensor_testgen::testgen(repeat_dim)]
mod tests { mod tests {
use super::*; use super::*;
use burn_tensor::{Distribution, Tensor}; use burn_tensor::{Distribution, Tensor};
@ -12,8 +12,8 @@ mod tests {
let tensor_ref = let tensor_ref =
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default()); Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
let actual = tensor.repeat(dim, times); let actual = tensor.repeat_dim(dim, times);
let expected = tensor_ref.repeat(dim, times); let expected = tensor_ref.repeat_dim(dim, times);
expected expected
.into_data() .into_data()
@ -29,8 +29,8 @@ mod tests {
let tensor_ref = let tensor_ref =
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default()); Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
let actual = tensor.repeat(dim, times); let actual = tensor.repeat_dim(dim, times);
let expected = tensor_ref.repeat(dim, times); let expected = tensor_ref.repeat_dim(dim, times);
expected expected
.into_data() .into_data()
@ -46,8 +46,8 @@ mod tests {
let tensor_ref = let tensor_ref =
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default()); Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
let actual = tensor.repeat(dim, times); let actual = tensor.repeat_dim(dim, times);
let expected = tensor_ref.repeat(dim, times); let expected = tensor_ref.repeat_dim(dim, times);
expected expected
.into_data() .into_data()
@ -66,8 +66,8 @@ mod tests {
let tensor_ref = let tensor_ref =
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default()); Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
let actual = tensor.repeat(dim, times); let actual = tensor.repeat_dim(dim, times);
let expected = tensor_ref.repeat(dim, times); let expected = tensor_ref.repeat_dim(dim, times);
expected expected
.into_data() .into_data()

View File

@ -33,7 +33,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage) TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
} }
pub fn repeat<const D: usize>( pub fn repeat_dim<const D: usize>(
tensor: TchTensor<E, D>, tensor: TchTensor<E, D>,
dim: usize, dim: usize,
times: usize, times: usize,

View File

@ -15,12 +15,12 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
tensor.shape() tensor.shape()
} }
fn bool_repeat<const D: usize>( fn bool_repeat_dim<const D: usize>(
tensor: TchTensor<bool, D>, tensor: TchTensor<bool, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> TchTensor<bool, D> { ) -> TchTensor<bool, D> {
TchOps::repeat(tensor, dim, times) TchOps::repeat_dim(tensor, dim, times)
} }
async fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> TensorData { async fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> TensorData {

View File

@ -18,12 +18,12 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
tensor.shape() tensor.shape()
} }
fn int_repeat<const D: usize>( fn int_repeat_dim<const D: usize>(
tensor: TchTensor<i64, D>, tensor: TchTensor<i64, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> TchTensor<i64, D> { ) -> TchTensor<i64, D> {
TchOps::repeat(tensor, dim, times) TchOps::repeat_dim(tensor, dim, times)
} }
async fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> TensorData { async fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> TensorData {

View File

@ -42,12 +42,12 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
} }
} }
fn float_repeat<const D: usize>( fn float_repeat_dim<const D: usize>(
tensor: TchTensor<E, D>, tensor: TchTensor<E, D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> TchTensor<E, D> { ) -> TchTensor<E, D> {
TchOps::repeat(tensor, dim, times) TchOps::repeat_dim(tensor, dim, times)
} }
fn float_zeros<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> { fn float_zeros<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> {

View File

@ -189,10 +189,10 @@ pub enum BaseOperationDescription {
Equal(BinaryOperationDescription), Equal(BinaryOperationDescription),
/// Operation corresponding to: /// Operation corresponding to:
/// ///
/// Float => [repeat](crate::ops::FloatTensorOps::float_repeat). /// Float => [repeat dim](crate::ops::FloatTensorOps::float_repeat_dim).
/// Int => [repeat](crate::ops::IntTensorOps::int_repeat). /// Int => [repeat dim](crate::ops::IntTensorOps::int_repeat_dim).
/// Bool => [repeat](crate::ops::BoolTensorOps::bool_repeat). /// Bool => [repeat dim](crate::ops::BoolTensorOps::bool_repeat_dim).
Repeat(RepeatOperationDescription), RepeatDim(RepeatDimOperationDescription),
/// Operation corresponding to: /// Operation corresponding to:
/// ///
/// Float => [cat](crate::ops::FloatTensorOps::float_cat). /// Float => [cat](crate::ops::FloatTensorOps::float_cat).
@ -627,7 +627,7 @@ pub struct ClampOperationDescription<E> {
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)] #[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)] #[allow(missing_docs)]
pub struct RepeatOperationDescription { pub struct RepeatDimOperationDescription {
pub tensor: TensorDescription, pub tensor: TensorDescription,
pub dim: usize, pub dim: usize,
pub times: usize, pub times: usize,
@ -1189,7 +1189,7 @@ impl BaseOperationDescription {
BaseOperationDescription::Equal(desc) => { BaseOperationDescription::Equal(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out] vec![&desc.lhs, &desc.rhs, &desc.out]
} }
BaseOperationDescription::Repeat(desc) => { BaseOperationDescription::RepeatDim(desc) => {
vec![&desc.tensor, &desc.out] vec![&desc.tensor, &desc.out]
} }
BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(), BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),

View File

@ -722,8 +722,21 @@ where
} }
/// Repeat the tensor along the given dimension. /// Repeat the tensor along the given dimension.
pub fn repeat(self, dim: usize, times: usize) -> Self { pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
Self::new(K::repeat(self.primitive, dim, times)) Self::new(K::repeat_dim(self.primitive, dim, times))
}
/// Repeat the tensor along the given dimensions.
/// # Arguments
/// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
pub fn repeat(self, sizes: &[usize]) -> Self {
let mut tensor = self;
for (dim, &times) in sizes.iter().enumerate() {
if times > 1 {
tensor = tensor.repeat_dim(dim, times);
}
}
tensor
} }
/// Applies element-wise equal comparison and returns a boolean tensor. /// Applies element-wise equal comparison and returns a boolean tensor.
@ -1504,9 +1517,9 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import /// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly. /// or use this function directly.
/// ///
/// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function, /// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
/// which is more high-level and designed for public use. /// which is more high-level and designed for public use.
fn repeat<const D: usize>( fn repeat_dim<const D: usize>(
tensor: Self::Primitive<D>, tensor: Self::Primitive<D>,
dim: usize, dim: usize,
times: usize, times: usize,
@ -1763,12 +1776,12 @@ impl<B: Backend> BasicOps<B> for Float {
} }
} }
fn repeat<const D: usize>( fn repeat_dim<const D: usize>(
tensor: Self::Primitive<D>, tensor: Self::Primitive<D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> Self::Primitive<D> { ) -> Self::Primitive<D> {
TensorPrimitive::Float(B::float_repeat(tensor.tensor(), dim, times)) TensorPrimitive::Float(B::float_repeat_dim(tensor.tensor(), dim, times))
} }
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> { fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
@ -1888,12 +1901,12 @@ impl<B: Backend> BasicOps<B> for Int {
B::int_from_data(data, device) B::int_from_data(data, device)
} }
fn repeat<const D: usize>( fn repeat_dim<const D: usize>(
tensor: Self::Primitive<D>, tensor: Self::Primitive<D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> Self::Primitive<D> { ) -> Self::Primitive<D> {
B::int_repeat(tensor, dim, times) B::int_repeat_dim(tensor, dim, times)
} }
fn equal<const D: usize>( fn equal<const D: usize>(
@ -2010,12 +2023,12 @@ impl<B: Backend> BasicOps<B> for Bool {
B::bool_from_data(data, device) B::bool_from_data(data, device)
} }
fn repeat<const D: usize>( fn repeat_dim<const D: usize>(
tensor: Self::Primitive<D>, tensor: Self::Primitive<D>,
dim: usize, dim: usize,
times: usize, times: usize,
) -> Self::Primitive<D> { ) -> Self::Primitive<D> {
B::bool_repeat(tensor, dim, times) B::bool_repeat_dim(tensor, dim, times)
} }
fn equal<const D: usize>( fn equal<const D: usize>(

View File

@ -46,7 +46,7 @@ pub fn cartesian_grid<B: Backend, S: Into<Shape<D>>, const D: usize, const D2: u
if i == dim { if i == dim {
continue; continue;
} }
dim_range = dim_range.repeat(i, item); dim_range = dim_range.repeat_dim(i, item);
} }
indices.push(dim_range); indices.push(dim_range);

View File

@ -1,6 +1,6 @@
use super::{ use super::{
cat::cat_with_slice_assign, repeat::repeat_with_slice_assign, BoolTensor, Device, FloatTensor, cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign, BoolTensor, Device,
IntTensor, FloatTensor, IntTensor,
}; };
use crate::{ use crate::{
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor, argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
@ -157,7 +157,7 @@ pub trait BoolTensorOps<B: Backend> {
/// # Returns /// # Returns
/// ///
/// The tensor with the dimension repeated. /// The tensor with the dimension repeated.
fn bool_repeat<const D: usize>( fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<B, D>, tensor: BoolTensor<B, D>,
dim: usize, dim: usize,
times: usize, times: usize,

View File

@ -1,5 +1,5 @@
use super::cat::cat_with_slice_assign; use super::cat::cat_with_slice_assign;
use super::repeat::repeat_with_slice_assign; use super::repeat_dim::repeat_with_slice_assign;
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use crate::cast::ToElement; use crate::cast::ToElement;
use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData}; use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData};
@ -251,7 +251,7 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns /// # Returns
/// ///
/// The tensor with the given dimension repeated the given number of times. /// The tensor with the given dimension repeated the given number of times.
fn int_repeat<const D: usize>( fn int_repeat_dim<const D: usize>(
tensor: IntTensor<B, D>, tensor: IntTensor<B, D>,
dim: usize, dim: usize,
times: usize, times: usize,

View File

@ -4,7 +4,7 @@ pub mod conv;
/// Module with cat operation /// Module with cat operation
pub(crate) mod cat; pub(crate) mod cat;
/// Module with repeat operation /// Module with repeat operation
pub(crate) mod repeat; pub(crate) mod repeat_dim;
/// Module with unfold operations. /// Module with unfold operations.
pub(crate) mod unfold; pub(crate) mod unfold;

View File

@ -1,5 +1,5 @@
use super::cat::cat_with_slice_assign; use super::cat::cat_with_slice_assign;
use super::repeat::repeat_with_slice_assign; use super::repeat_dim::repeat_with_slice_assign;
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor}; use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
use crate::backend::BackendBridge; use crate::backend::BackendBridge;
use crate::tensor::cast::ToElement; use crate::tensor::cast::ToElement;
@ -174,7 +174,7 @@ pub trait FloatTensorOps<B: Backend> {
/// # Returns /// # Returns
/// ///
/// The tensor with the given dimension repeated. /// The tensor with the given dimension repeated.
fn float_repeat<const D: usize>( fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<B, D>, tensor: FloatTensor<B, D>,
dim: usize, dim: usize,
times: usize, times: usize,

View File

@ -349,7 +349,7 @@ mod tests {
clone_invariance_test!( clone_invariance_test!(
unary: Repeat, unary: Repeat,
ops_float: |tensor: TestTensor<2>| { ops_float: |tensor: TestTensor<2>| {
tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
} }
); );
clone_invariance_test!( clone_invariance_test!(
@ -633,7 +633,7 @@ mod tests {
clone_invariance_test!( clone_invariance_test!(
unary: Repeat, unary: Repeat,
ops_int: |tensor: TestTensorInt<2>| { ops_int: |tensor: TestTensorInt<2>| {
tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32]) tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
} }
); );
clone_invariance_test!( clone_invariance_test!(

View File

@ -74,6 +74,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_powf_scalar!(); burn_tensor::testgen_powf_scalar!();
burn_tensor::testgen_random!(); burn_tensor::testgen_random!();
burn_tensor::testgen_recip!(); burn_tensor::testgen_recip!();
burn_tensor::testgen_repeat_dim!();
burn_tensor::testgen_repeat!(); burn_tensor::testgen_repeat!();
burn_tensor::testgen_reshape!(); burn_tensor::testgen_reshape!();
burn_tensor::testgen_select!(); burn_tensor::testgen_select!();

View File

@ -45,6 +45,7 @@ mod random;
mod recip; mod recip;
mod remainder; mod remainder;
mod repeat; mod repeat;
mod repeat_dim;
mod reshape; mod reshape;
mod select; mod select;
mod sign; mod sign;

View File

@ -4,27 +4,27 @@ mod tests {
use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData}; use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};
#[test] #[test]
fn should_support_repeat_ops() { fn should_support_repeat_ops_one_dimension() {
let data = TensorData::from([[0.0, 1.0, 2.0]]); let data = TensorData::from([[0.0f32, 1.0f32, 2.0f32]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default()); let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let output = tensor.repeat(0, 4); let output = tensor.repeat(&[4, 1, 1]);
let expected = TensorData::from([ let expected = TensorData::from([
[0.0, 1.0, 2.0], [0.0f32, 1.0f32, 2.0f32],
[0.0, 1.0, 2.0], [0.0f32, 1.0f32, 2.0f32],
[0.0, 1.0, 2.0], [0.0f32, 1.0f32, 2.0f32],
[0.0, 1.0, 2.0], [0.0f32, 1.0f32, 2.0f32],
]); ]);
output.into_data().assert_eq(&expected, false); output.into_data().assert_eq(&expected, false);
} }
#[test] #[test]
fn should_support_bool_repeat_ops() { fn should_support_bool_repeat_ops_one_dimension() {
let data = TensorData::from([[true, false, false]]); let data = TensorData::from([[true, false, false]]);
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default()); let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default());
let output = tensor.repeat(0, 4); let output = tensor.repeat(&[4, 1, 1]);
let expected = TensorData::from([ let expected = TensorData::from([
[true, false, false], [true, false, false],
[true, false, false], [true, false, false],
@ -35,70 +35,226 @@ mod tests {
} }
#[test] #[test]
fn should_support_int_repeat_ops() { fn should_support_int_repeat_ops_one_dimension() {
let data = TensorData::from([[0, 1, 2]]); let data = TensorData::from([[0i32, 1i32, 2i32]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default()); let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
let output = tensor.repeat(0, 4); let output = tensor.repeat(&[4, 1, 1]);
let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]); let expected = TensorData::from([
[0i32, 1i32, 2i32],
[0i32, 1i32, 2i32],
[0i32, 1i32, 2i32],
[0i32, 1i32, 2i32],
]);
output.into_data().assert_eq(&expected, false); output.into_data().assert_eq(&expected, false);
} }
#[test] #[test]
fn should_support_float_repeat_on_dims_larger_than_1() { fn should_support_float_repeat_repeating_on_many_dimensions() {
let data = TensorData::from([ let data = TensorData::from([
[[1.0, 2.0], [3.0, 4.0]], [[1.0f32, 2.0f32], [3.0f32, 4.0f32]],
[[5.0, 6.0], [7.0, 8.0]], [[5.0f32, 6.0f32], [7.0f32, 8.0f32]],
[[9.0, 10.0], [11.0, 12.0]], [[9.0f32, 10.0f32], [11.0f32, 12.0f32]],
[[13.0, 14.0], [15.0, 16.0]], [[13.0f32, 14.0f32], [15.0f32, 16.0f32]],
]); ]);
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default()); let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
let output = tensor.repeat(2, 2); let output = tensor.repeat(&[2, 3, 2]);
let expected = TensorData::from([ let expected = TensorData::from([
[[1.0, 2.0, 1.0, 2.0], [3.0, 4.0, 3.0, 4.0]], [
[[5.0, 6.0, 5.0, 6.0], [7.0, 8.0, 7.0, 8.0]], [1.0f32, 2.0f32, 1.0f32, 2.0f32],
[[9.0, 10.0, 9.0, 10.0], [11.0, 12.0, 11.0, 12.0]], [3.0f32, 4.0f32, 3.0f32, 4.0f32],
[[13.0, 14.0, 13.0, 14.0], [15.0, 16.0, 15.0, 16.0]], [1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
],
[
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
],
[
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
],
[
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
],
[
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
],
[
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
],
[
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
],
[
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
],
]); ]);
output.into_data().assert_eq(&expected, false); output.into_data().assert_eq(&expected, false);
} }
#[test] #[test]
fn should_support_int_repeat_on_dims_larger_than_1() { fn should_support_int_repeat_on_many_dims() {
let data = TensorData::from([ let data = TensorData::from([
[[1, 2], [3, 4]], [[1i32, 2i32], [3i32, 4i32]],
[[5, 6], [7, 8]], [[5i32, 6i32], [7i32, 8i32]],
[[9, 10], [11, 12]], [[9i32, 10i32], [11i32, 12i32]],
[[13, 14], [15, 16]], [[13i32, 14i32], [15i32, 16i32]],
]); ]);
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default()); let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
let output = tensor.repeat(2, 3); let output = tensor.repeat(&[2, 3, 2]);
let expected = TensorData::from([ let expected = TensorData::from([
[[1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4]], [
[[5, 6, 5, 6, 5, 6], [7, 8, 7, 8, 7, 8]], [1i32, 2i32, 1i32, 2i32],
[[9, 10, 9, 10, 9, 10], [11, 12, 11, 12, 11, 12]], [3i32, 4i32, 3i32, 4i32],
[[13, 14, 13, 14, 13, 14], [15, 16, 15, 16, 15, 16]], [1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
],
[
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
],
[
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
],
[
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
],
[
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
],
[
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
],
[
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
],
[
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
],
]); ]);
output.into_data().assert_eq(&expected, false); output.into_data().assert_eq(&expected, false);
} }
#[test] #[test]
fn should_support_bool_repeat_on_dims_larger_than_1() { fn should_support_bool_repeat_on_many_dimension() {
let data = TensorData::from([ let data = TensorData::from([
[[false, true], [true, false]], [[false, true], [true, false]],
[[true, true], [false, false]], [[true, true], [false, false]],
]); ]);
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default()); let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
let output = tensor.repeat(1, 2); let output = tensor.repeat(&[2, 3, 2]);
let expected = TensorData::from([ let expected = TensorData::from([
[[false, true], [true, false], [false, true], [true, false]], [
[[true, true], [false, false], [true, true], [false, false]], [false, true, false, true],
[true, false, true, false],
[false, true, false, true],
[true, false, true, false],
[false, true, false, true],
[true, false, true, false],
],
[
[true, true, true, true],
[false, false, false, false],
[true, true, true, true],
[false, false, false, false],
[true, true, true, true],
[false, false, false, false],
],
[
[false, true, false, true],
[true, false, true, false],
[false, true, false, true],
[true, false, true, false],
[false, true, false, true],
[true, false, true, false],
],
[
[true, true, true, true],
[false, false, false, false],
[true, true, true, true],
[false, false, false, false],
[true, true, true, true],
[false, false, false, false],
],
]); ]);
output.into_data().assert_eq(&expected, true); output.into_data().assert_eq(&expected, true);

View File

@ -0,0 +1,130 @@
#[burn_tensor_testgen::testgen(repeat_dim)]
mod tests {
use super::*;
use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};
#[test]
fn should_support_repeat_ops() {
let data = TensorData::from([[0.0f64, 1.0f64, 2.0f64]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data.clone(), &Default::default());
let output = tensor.repeat_dim(0, 4);
let expected = TensorData::from([
[0.0f32, 1.0f32, 2.0f32],
[0.0f32, 1.0f32, 2.0f32],
[0.0f32, 1.0f32, 2.0f32],
[0.0f32, 1.0f32, 2.0f32],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_bool_repeat_ops() {
let data = TensorData::from([[true, false, false]]);
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default());
let output = tensor.repeat_dim(0, 4);
let expected = TensorData::from([
[true, false, false],
[true, false, false],
[true, false, false],
[true, false, false],
]);
output.into_data().assert_eq(&expected, true);
}
#[test]
fn should_support_int_repeat_ops() {
let data = TensorData::from([[0, 1, 2]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
let output = tensor.repeat_dim(0, 4);
let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_float_repeat_on_dims_larger_than_1() {
let data = TensorData::from([
[[1.0f32, 2.0f32], [3.0f32, 4.0f32]],
[[5.0f32, 6.0f32], [7.0f32, 8.0f32]],
[[9.0f32, 10.0f32], [11.0f32, 12.0f32]],
[[13.0f32, 14.0f32], [15.0f32, 16.0f32]],
]);
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
let output = tensor.repeat_dim(2, 2);
let expected = TensorData::from([
[
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
],
[
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
],
[
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
],
[
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_int_repeat_on_dims_larger_than_1() {
let data = TensorData::from([
[[1i32, 2i32], [3i32, 4i32]],
[[5i32, 6i32], [7i32, 8i32]],
[[9i32, 10i32], [11i32, 12i32]],
[[13i32, 14i32], [15i32, 16i32]],
]);
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
let output = tensor.repeat_dim(2, 3);
let expected = TensorData::from([
[
[1i32, 2i32, 1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32, 3i32, 4i32],
],
[
[5i32, 6i32, 5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32, 7i32, 8i32],
],
[
[9i32, 10i32, 9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32, 11i32, 12i32],
],
[
[13i32, 14i32, 13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32, 15i32, 16i32],
],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_bool_repeat_on_dims_larger_than_1() {
let data = TensorData::from([
[[false, true], [true, false]],
[[true, true], [false, false]],
]);
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
let output = tensor.repeat_dim(1, 2);
let expected = TensorData::from([
[[false, true], [true, false], [false, true], [true, false]],
[[true, true], [false, false], [true, true], [false, false]],
]);
output.into_data().assert_eq(&expected, true);
}
}

View File

@ -73,7 +73,7 @@ impl<B: Backend> TextClassificationModel<B> {
// Calculate token and position embeddings, and combine them // Calculate token and position embeddings, and combine them
let index_positions = Tensor::arange(0..seq_length as i64, device) let index_positions = Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length]) .reshape([1, seq_length])
.repeat(0, batch_size); .repeat_dim(0, batch_size);
let embedding_positions = self.embedding_pos.forward(index_positions); let embedding_positions = self.embedding_pos.forward(index_positions);
let embedding_tokens = self.embedding_token.forward(tokens); let embedding_tokens = self.embedding_token.forward(tokens);
let embedding = (embedding_positions + embedding_tokens) / 2; let embedding = (embedding_positions + embedding_tokens) / 2;
@ -113,7 +113,7 @@ impl<B: Backend> TextClassificationModel<B> {
// Calculate token and position embeddings, and combine them // Calculate token and position embeddings, and combine them
let index_positions = Tensor::arange(0..seq_length as i64, device) let index_positions = Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length]) .reshape([1, seq_length])
.repeat(0, batch_size); .repeat_dim(0, batch_size);
let embedding_positions = self.embedding_pos.forward(index_positions); let embedding_positions = self.embedding_pos.forward(index_positions);
let embedding_tokens = self.embedding_token.forward(tokens); let embedding_tokens = self.embedding_token.forward(tokens);
let embedding = (embedding_positions + embedding_tokens) / 2; let embedding = (embedding_positions + embedding_tokens) / 2;

View File

@ -64,7 +64,7 @@ impl<B: Backend> TextGenerationModel<B> {
let index_positions = Tensor::arange(0..seq_length as i64, device) let index_positions = Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length]) .reshape([1, seq_length])
.repeat(0, batch_size); .repeat_dim(0, batch_size);
let embedding_positions = self.embedding_pos.forward(index_positions); let embedding_positions = self.embedding_pos.forward(index_positions);
let embedding_tokens = self.embedding_token.forward(inputs); let embedding_tokens = self.embedding_token.forward(inputs);