Repeat operation (#2090)

* renaming repeat to repeat_dim

* implementing repeat function

* renaming repeat files to repeat_dim

* renaming part 2

* renaming part 3

* renaming part 4

* renaming part 5

* adding test file

* adding unit test

* adding rust book documentation

* adding function args doc

* fixing tests

* changing repeat api to match pytorch equivalent

* fixing clippy error
This commit is contained in:
mepatrick73 2024-08-02 20:33:47 -04:00 committed by GitHub
parent bb13729b20
commit f7639bd35a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 478 additions and 174 deletions

View File

@ -131,40 +131,41 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
| Burn | PyTorch Equivalent |
| ------------------------------------- | ------------------------------------------------------------------------ |
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
| `Tensor::from_primitive(primitive)` | N/A |
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
| `tensor.all()` | `tensor.all()` |
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
| `tensor.any()` | `tensor.any()` |
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
| `tensor.device()` | `tensor.device` |
| `tensor.dims()` | `tensor.size()` |
| `tensor.equal(other)` | `x == y` |
| `tensor.expand(shape)` | `tensor.expand(shape)` |
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
| `tensor.flip(axes)` | `tensor.flip(axes)` |
| `tensor.into_data()` | N/A |
| `tensor.into_primitive()` | N/A |
| `tensor.into_scalar()` | `tensor.item()` |
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
| `tensor.not_equal(other)` | `x != y` |
| `tensor.permute(axes)` | `tensor.permute(axes)` |
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])`|
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
| `tensor.reshape(shape)` | `tensor.view(shape)` |
| `tensor.shape()` | `tensor.shape` |
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
| `tensor.to_data()` | N/A |
| `tensor.to_device(device)` | `tensor.to(device)` |
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
### Numeric Operations

View File

@ -132,11 +132,11 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
B::bool_expand(tensor, shape)
}
fn bool_repeat<const D: usize>(
fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<B, D>,
dim: usize,
times: usize,
) -> BoolTensor<B, D> {
B::bool_repeat(tensor, dim, times)
B::bool_repeat_dim(tensor, dim, times)
}
}

View File

@ -162,12 +162,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
B::int_mean_dim(tensor, dim)
}
fn int_repeat<const D: usize>(
fn int_repeat_dim<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
times: usize,
) -> IntTensor<B, D> {
B::int_repeat(tensor, dim, times)
B::int_repeat_dim(tensor, dim, times)
}
fn int_greater<const D: usize>(lhs: IntTensor<B, D>, rhs: IntTensor<B, D>) -> BoolTensor<B, D> {

View File

@ -2418,7 +2418,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
B::float_argsort(tensor.primitive, dim, descending)
}
fn float_repeat<const D: usize>(
fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
@ -2437,7 +2437,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> {
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
let out = B::float_repeat(tensor, self.dim, self.times);
let out = B::float_repeat_dim(tensor, self.dim, self.times);
states.save(out_node, out)
}
}
@ -2467,9 +2467,11 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
.stateful()
{
OpsKind::Tracked(prep) => {
prep.finish(dim, B::float_repeat(tensor.primitive, dim, times))
prep.finish(dim, B::float_repeat_dim(tensor.primitive, dim, times))
}
OpsKind::UnTracked(prep) => {
prep.finish(B::float_repeat_dim(tensor.primitive, dim, times))
}
OpsKind::UnTracked(prep) => prep.finish(B::float_repeat(tensor.primitive, dim, times)),
}
}

View File

@ -47,7 +47,7 @@ mod permute;
mod pow;
mod recip;
mod relu;
mod repeat;
mod repeat_dim;
mod reshape;
mod select;
mod sigmoid;
@ -133,6 +133,6 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_sign!();
burn_autodiff::testgen_ad_expand!();
burn_autodiff::testgen_ad_sort!();
burn_autodiff::testgen_ad_repeat!();
burn_autodiff::testgen_ad_repeat_dim!();
};
}

View File

@ -1,4 +1,4 @@
#[burn_tensor_testgen::testgen(ad_repeat)]
#[burn_tensor_testgen::testgen(ad_repeat_dim)]
mod tests {
use super::*;
use burn_tensor::{activation, TensorData};
@ -12,7 +12,7 @@ mod tests {
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
let tensor_3 = tensor_2.clone().repeat(1, 3);
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
let tensor_3 = tensor_1.matmul(tensor_3);
let grads = tensor_3.backward();

View File

@ -94,7 +94,7 @@ mod tests {
// burn_tensor::testgen_powf!();
burn_tensor::testgen_random!();
burn_tensor::testgen_repeat!();
burn_tensor::testgen_repeat_dim!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_select!();
burn_tensor::testgen_sin!();

View File

@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
}
mask = mask.repeat(0, batch_size);
mask = mask.repeat_dim(0, batch_size);
mask.equal_elem(1_i64.elem::<i64>())
}

View File

@ -152,7 +152,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
* weights
.clone()
.reshape([1, nr_classes])
.repeat(0, batch_size);
.repeat_dim(0, batch_size);
let weights = weights.clone().gather(0, targets);
let tensor = Self::apply_mask_2d(tensor, mask);
tensor.sum().neg() / weights.sum()
@ -224,7 +224,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
if let Some(mask) = mask {
let [batch_size, nr_classes] = tensor.dims();
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0);
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
}
tensor
@ -312,7 +312,7 @@ mod tests {
* targets_logits
* Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
.unsqueeze()
.repeat(0, 4);
.repeat_dim(0, 4);
let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
}

View File

@ -58,7 +58,7 @@ impl RotaryEncodingConfig {
.float()
.unsqueeze()
.transpose()
.repeat(1, self.d_model / 2)
.repeat_dim(1, self.d_model / 2)
* theta_i.unsqueeze();
// Convert frequency values to complex numbers (polar form)
@ -71,7 +71,7 @@ impl RotaryEncodingConfig {
.reshape([self.max_sequence_length, 2, self.d_model / 2])
.transpose()
.unsqueeze_dim::<4>(2)
.repeat(2, 2)
.repeat_dim(2, 2)
.reshape([self.max_sequence_length, self.d_model, 2]);
RotaryEncoding {

View File

@ -14,7 +14,7 @@ use burn_tensor::{
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
HandleContainer, OperationDescription, PermuteOperationDescription,
RepeatOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
},
Device, Shape,
@ -575,22 +575,22 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
out
}
fn bool_repeat<const D: usize>(
fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<Self, D>,
dim: usize,
times: usize,
) -> BoolTensor<Self, D> {
#[derive(new)]
struct RepeatOps<B: FusionBackend, const D: usize> {
desc: RepeatOperationDescription,
struct RepeatDimOps<B: FusionBackend, const D: usize> {
desc: RepeatDimOperationDescription,
_b: PhantomData<B>,
}
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_bool_tensor::<B, D>(&self.desc.tensor);
let output = B::bool_repeat::<D>(tensor, self.desc.dim, self.desc.times);
let output = B::bool_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
handles.register_bool_tensor::<B, D>(&self.desc.out.id, output);
}
@ -601,7 +601,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
shape[dim] *= times;
let out = tensor.client.tensor_uninitialized(shape, DType::Bool);
let desc = RepeatOperationDescription {
let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(),
dim,
times,
@ -609,8 +609,8 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())),
RepeatOps::<B, D>::new(desc),
OperationDescription::BaseBool(BaseOperationDescription::RepeatDim(desc.clone())),
RepeatDimOps::<B, D>::new(desc),
);
out

View File

@ -1696,22 +1696,22 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out
}
fn float_repeat<const D: usize>(
fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
) -> FloatTensor<Self, D> {
#[derive(new)]
struct RepeatOps<B: FusionBackend, const D: usize> {
desc: RepeatOperationDescription,
struct RepeatDimOps<B: FusionBackend, const D: usize> {
desc: RepeatDimOperationDescription,
_b: PhantomData<B>,
}
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_float_tensor::<B, D>(&self.desc.tensor);
let output = B::float_repeat::<D>(tensor, self.desc.dim, self.desc.times);
let output = B::float_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
handles.register_float_tensor::<B, D>(&self.desc.out.id, output);
}
@ -1724,7 +1724,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(shape, B::FloatElem::dtype());
let desc = RepeatOperationDescription {
let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(),
dim,
times,
@ -1732,8 +1732,8 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())),
RepeatOps::<B, D>::new(desc),
OperationDescription::BaseFloat(BaseOperationDescription::RepeatDim(desc.clone())),
RepeatDimOps::<B, D>::new(desc),
);
out

View File

@ -1755,22 +1755,22 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
out
}
fn int_repeat<const D: usize>(
fn int_repeat_dim<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
times: usize,
) -> IntTensor<Self, D> {
#[derive(new)]
struct RepeatOps<B: FusionBackend, const D: usize> {
desc: RepeatOperationDescription,
struct RepeatDimOps<B: FusionBackend, const D: usize> {
desc: RepeatDimOperationDescription,
_b: PhantomData<B>,
}
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
let tensor = handles.get_int_tensor::<B, D>(&self.desc.tensor);
let output = B::int_repeat::<D>(tensor, self.desc.dim, self.desc.times);
let output = B::int_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
handles.register_int_tensor::<B, D>(&self.desc.out.id, output);
}
@ -1783,7 +1783,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
.client
.tensor_uninitialized(shape, B::IntElem::dtype());
let desc = RepeatOperationDescription {
let desc = RepeatDimOperationDescription {
tensor: tensor.into_description(),
dim,
times,
@ -1791,8 +1791,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
};
out.client.register(
vec![stream],
OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())),
RepeatOps::<B, D>::new(desc),
OperationDescription::BaseInt(BaseOperationDescription::RepeatDim(desc.clone())),
RepeatDimOps::<B, D>::new(desc),
);
out

View File

@ -848,8 +848,8 @@ impl RelativeOps for BaseOperationDescription {
out: desc.out.to_relative(converter),
})
}
BaseOperationDescription::Repeat(desc) => {
BaseOperationDescription::Repeat(RepeatOperationDescription {
BaseOperationDescription::RepeatDim(desc) => {
BaseOperationDescription::RepeatDim(RepeatDimOperationDescription {
tensor: desc.tensor.to_relative(converter),
dim: desc.dim,
times: desc.times,

View File

@ -1,6 +1,6 @@
mod flip;
mod gather;
mod repeat;
mod repeat_dim;
mod scatter;
mod select;
mod select_assign;
@ -8,7 +8,7 @@ mod slice;
mod slice_assign;
pub use flip::*;
pub use repeat::*;
pub use repeat_dim::*;
pub use select::*;
pub use select_assign::*;
pub use slice::*;

View File

@ -95,7 +95,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for RepeatEagerKernel<R, E> {
}
}
pub(crate) fn repeat<R: JitRuntime, E: JitElement, const D1: usize>(
pub(crate) fn repeat_dim<R: JitRuntime, E: JitElement, const D1: usize>(
input: JitTensor<R, E, D1>,
dim: usize,
times: usize,

View File

@ -94,12 +94,12 @@ where
tensor
}
fn bool_repeat<const D: usize>(
fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<Self, D>,
dim: usize,
times: usize,
) -> BoolTensor<Self, D> {
kernel::repeat(tensor, dim, times)
kernel::repeat_dim(tensor, dim, times)
}
fn bool_permute<const D: usize>(

View File

@ -467,12 +467,12 @@ where
})
}
fn float_repeat<const D: usize>(
fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<Self, D>,
dim: usize,
times: usize,
) -> FloatTensor<Self, D> {
kernel::repeat(tensor, dim, times)
kernel::repeat_dim(tensor, dim, times)
}
fn float_powf<const D: usize>(

View File

@ -318,12 +318,12 @@ where
tensor
}
fn int_repeat<const D: usize>(
fn int_repeat_dim<const D: usize>(
tensor: IntTensor<Self, D>,
dim: usize,
times: usize,
) -> IntTensor<Self, D> {
kernel::repeat(tensor, dim, times)
kernel::repeat_dim(tensor, dim, times)
}
fn int_random<const D: usize>(

View File

@ -17,7 +17,7 @@ mod max_pool2d;
mod max_pool2d_backward;
mod normal;
mod reduce;
mod repeat;
mod repeat_dim;
mod scatter;
mod select;
mod select_assign;
@ -48,7 +48,7 @@ macro_rules! testgen_all {
burn_jit::testgen_conv_transpose2d!();
burn_jit::testgen_conv_transpose3d!();
burn_jit::testgen_repeat!();
burn_jit::testgen_repeat_dim!();
burn_jit::testgen_gather!();
burn_jit::testgen_scatter!();

View File

@ -1,4 +1,4 @@
#[burn_tensor_testgen::testgen(repeat)]
#[burn_tensor_testgen::testgen(repeat_dim)]
mod tests {
use super::*;
use burn_tensor::{Distribution, Tensor};
@ -12,8 +12,8 @@ mod tests {
let tensor_ref =
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
let actual = tensor.repeat(dim, times);
let expected = tensor_ref.repeat(dim, times);
let actual = tensor.repeat_dim(dim, times);
let expected = tensor_ref.repeat_dim(dim, times);
expected
.into_data()
@ -29,8 +29,8 @@ mod tests {
let tensor_ref =
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
let actual = tensor.repeat(dim, times);
let expected = tensor_ref.repeat(dim, times);
let actual = tensor.repeat_dim(dim, times);
let expected = tensor_ref.repeat_dim(dim, times);
expected
.into_data()
@ -46,8 +46,8 @@ mod tests {
let tensor_ref =
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
let actual = tensor.repeat(dim, times);
let expected = tensor_ref.repeat(dim, times);
let actual = tensor.repeat_dim(dim, times);
let expected = tensor_ref.repeat_dim(dim, times);
expected
.into_data()
@ -66,8 +66,8 @@ mod tests {
let tensor_ref =
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
let actual = tensor.repeat(dim, times);
let expected = tensor_ref.repeat(dim, times);
let actual = tensor.repeat_dim(dim, times);
let expected = tensor_ref.repeat_dim(dim, times);
expected
.into_data()

View File

@ -33,7 +33,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
}
pub fn repeat<const D: usize>(
pub fn repeat_dim<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
times: usize,

View File

@ -15,12 +15,12 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
tensor.shape()
}
fn bool_repeat<const D: usize>(
fn bool_repeat_dim<const D: usize>(
tensor: TchTensor<bool, D>,
dim: usize,
times: usize,
) -> TchTensor<bool, D> {
TchOps::repeat(tensor, dim, times)
TchOps::repeat_dim(tensor, dim, times)
}
async fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> TensorData {

View File

@ -18,12 +18,12 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
tensor.shape()
}
fn int_repeat<const D: usize>(
fn int_repeat_dim<const D: usize>(
tensor: TchTensor<i64, D>,
dim: usize,
times: usize,
) -> TchTensor<i64, D> {
TchOps::repeat(tensor, dim, times)
TchOps::repeat_dim(tensor, dim, times)
}
async fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> TensorData {

View File

@ -42,12 +42,12 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
}
}
fn float_repeat<const D: usize>(
fn float_repeat_dim<const D: usize>(
tensor: TchTensor<E, D>,
dim: usize,
times: usize,
) -> TchTensor<E, D> {
TchOps::repeat(tensor, dim, times)
TchOps::repeat_dim(tensor, dim, times)
}
fn float_zeros<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> {

View File

@ -189,10 +189,10 @@ pub enum BaseOperationDescription {
Equal(BinaryOperationDescription),
/// Operation corresponding to:
///
/// Float => [repeat](crate::ops::FloatTensorOps::float_repeat).
/// Int => [repeat](crate::ops::IntTensorOps::int_repeat).
/// Bool => [repeat](crate::ops::BoolTensorOps::bool_repeat).
Repeat(RepeatOperationDescription),
/// Float => [repeat dim](crate::ops::FloatTensorOps::float_repeat_dim).
/// Int => [repeat dim](crate::ops::IntTensorOps::int_repeat_dim).
/// Bool => [repeat dim](crate::ops::BoolTensorOps::bool_repeat_dim).
RepeatDim(RepeatDimOperationDescription),
/// Operation corresponding to:
///
/// Float => [cat](crate::ops::FloatTensorOps::float_cat).
@ -627,7 +627,7 @@ pub struct ClampOperationDescription<E> {
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
#[allow(missing_docs)]
pub struct RepeatOperationDescription {
pub struct RepeatDimOperationDescription {
pub tensor: TensorDescription,
pub dim: usize,
pub times: usize,
@ -1189,7 +1189,7 @@ impl BaseOperationDescription {
BaseOperationDescription::Equal(desc) => {
vec![&desc.lhs, &desc.rhs, &desc.out]
}
BaseOperationDescription::Repeat(desc) => {
BaseOperationDescription::RepeatDim(desc) => {
vec![&desc.tensor, &desc.out]
}
BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),

View File

@ -722,8 +722,21 @@ where
}
/// Repeat the tensor along the given dimension.
pub fn repeat(self, dim: usize, times: usize) -> Self {
Self::new(K::repeat(self.primitive, dim, times))
pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
Self::new(K::repeat_dim(self.primitive, dim, times))
}
/// Repeat the tensor along the given dimensions.
/// # Arguments
/// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
pub fn repeat(self, sizes: &[usize]) -> Self {
let mut tensor = self;
for (dim, &times) in sizes.iter().enumerate() {
if times > 1 {
tensor = tensor.repeat_dim(dim, times);
}
}
tensor
}
/// Applies element-wise equal comparison and returns a boolean tensor.
@ -1504,9 +1517,9 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
/// or use this function directly.
///
/// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function,
/// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
/// which is more high-level and designed for public use.
fn repeat<const D: usize>(
fn repeat_dim<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
times: usize,
@ -1763,12 +1776,12 @@ impl<B: Backend> BasicOps<B> for Float {
}
}
fn repeat<const D: usize>(
fn repeat_dim<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
times: usize,
) -> Self::Primitive<D> {
TensorPrimitive::Float(B::float_repeat(tensor.tensor(), dim, times))
TensorPrimitive::Float(B::float_repeat_dim(tensor.tensor(), dim, times))
}
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
@ -1888,12 +1901,12 @@ impl<B: Backend> BasicOps<B> for Int {
B::int_from_data(data, device)
}
fn repeat<const D: usize>(
fn repeat_dim<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
times: usize,
) -> Self::Primitive<D> {
B::int_repeat(tensor, dim, times)
B::int_repeat_dim(tensor, dim, times)
}
fn equal<const D: usize>(
@ -2010,12 +2023,12 @@ impl<B: Backend> BasicOps<B> for Bool {
B::bool_from_data(data, device)
}
fn repeat<const D: usize>(
fn repeat_dim<const D: usize>(
tensor: Self::Primitive<D>,
dim: usize,
times: usize,
) -> Self::Primitive<D> {
B::bool_repeat(tensor, dim, times)
B::bool_repeat_dim(tensor, dim, times)
}
fn equal<const D: usize>(

View File

@ -46,7 +46,7 @@ pub fn cartesian_grid<B: Backend, S: Into<Shape<D>>, const D: usize, const D2: u
if i == dim {
continue;
}
dim_range = dim_range.repeat(i, item);
dim_range = dim_range.repeat_dim(i, item);
}
indices.push(dim_range);

View File

@ -1,6 +1,6 @@
use super::{
cat::cat_with_slice_assign, repeat::repeat_with_slice_assign, BoolTensor, Device, FloatTensor,
IntTensor,
cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign, BoolTensor, Device,
FloatTensor, IntTensor,
};
use crate::{
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
@ -157,7 +157,7 @@ pub trait BoolTensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the dimension repeated.
fn bool_repeat<const D: usize>(
fn bool_repeat_dim<const D: usize>(
tensor: BoolTensor<B, D>,
dim: usize,
times: usize,

View File

@ -1,5 +1,5 @@
use super::cat::cat_with_slice_assign;
use super::repeat::repeat_with_slice_assign;
use super::repeat_dim::repeat_with_slice_assign;
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use crate::cast::ToElement;
use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData};
@ -251,7 +251,7 @@ pub trait IntTensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the given dimension repeated the given number of times.
fn int_repeat<const D: usize>(
fn int_repeat_dim<const D: usize>(
tensor: IntTensor<B, D>,
dim: usize,
times: usize,

View File

@ -4,7 +4,7 @@ pub mod conv;
/// Module with cat operation
pub(crate) mod cat;
/// Module with repeat operation
pub(crate) mod repeat;
pub(crate) mod repeat_dim;
/// Module with unfold operations.
pub(crate) mod unfold;

View File

@ -1,5 +1,5 @@
use super::cat::cat_with_slice_assign;
use super::repeat::repeat_with_slice_assign;
use super::repeat_dim::repeat_with_slice_assign;
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
use crate::backend::BackendBridge;
use crate::tensor::cast::ToElement;
@ -174,7 +174,7 @@ pub trait FloatTensorOps<B: Backend> {
/// # Returns
///
/// The tensor with the given dimension repeated.
fn float_repeat<const D: usize>(
fn float_repeat_dim<const D: usize>(
tensor: FloatTensor<B, D>,
dim: usize,
times: usize,

View File

@ -349,7 +349,7 @@ mod tests {
clone_invariance_test!(
unary: Repeat,
ops_float: |tensor: TestTensor<2>| {
tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32])
tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
}
);
clone_invariance_test!(
@ -633,7 +633,7 @@ mod tests {
clone_invariance_test!(
unary: Repeat,
ops_int: |tensor: TestTensorInt<2>| {
tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32])
tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
}
);
clone_invariance_test!(

View File

@ -74,6 +74,7 @@ macro_rules! testgen_all {
burn_tensor::testgen_powf_scalar!();
burn_tensor::testgen_random!();
burn_tensor::testgen_recip!();
burn_tensor::testgen_repeat_dim!();
burn_tensor::testgen_repeat!();
burn_tensor::testgen_reshape!();
burn_tensor::testgen_select!();

View File

@ -45,6 +45,7 @@ mod random;
mod recip;
mod remainder;
mod repeat;
mod repeat_dim;
mod reshape;
mod select;
mod sign;

View File

@ -4,27 +4,27 @@ mod tests {
use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};
#[test]
fn should_support_repeat_ops() {
let data = TensorData::from([[0.0, 1.0, 2.0]]);
fn should_support_repeat_ops_one_dimension() {
let data = TensorData::from([[0.0f32, 1.0f32, 2.0f32]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let output = tensor.repeat(0, 4);
let output = tensor.repeat(&[4, 1, 1]);
let expected = TensorData::from([
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
[0.0f32, 1.0f32, 2.0f32],
[0.0f32, 1.0f32, 2.0f32],
[0.0f32, 1.0f32, 2.0f32],
[0.0f32, 1.0f32, 2.0f32],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_bool_repeat_ops() {
fn should_support_bool_repeat_ops_one_dimension() {
let data = TensorData::from([[true, false, false]]);
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default());
let output = tensor.repeat(0, 4);
let output = tensor.repeat(&[4, 1, 1]);
let expected = TensorData::from([
[true, false, false],
[true, false, false],
@ -35,70 +35,226 @@ mod tests {
}
#[test]
fn should_support_int_repeat_ops() {
let data = TensorData::from([[0, 1, 2]]);
fn should_support_int_repeat_ops_one_dimension() {
let data = TensorData::from([[0i32, 1i32, 2i32]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
let output = tensor.repeat(0, 4);
let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
let output = tensor.repeat(&[4, 1, 1]);
let expected = TensorData::from([
[0i32, 1i32, 2i32],
[0i32, 1i32, 2i32],
[0i32, 1i32, 2i32],
[0i32, 1i32, 2i32],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_float_repeat_on_dims_larger_than_1() {
fn should_support_float_repeat_repeating_on_many_dimensions() {
let data = TensorData::from([
[[1.0, 2.0], [3.0, 4.0]],
[[5.0, 6.0], [7.0, 8.0]],
[[9.0, 10.0], [11.0, 12.0]],
[[13.0, 14.0], [15.0, 16.0]],
[[1.0f32, 2.0f32], [3.0f32, 4.0f32]],
[[5.0f32, 6.0f32], [7.0f32, 8.0f32]],
[[9.0f32, 10.0f32], [11.0f32, 12.0f32]],
[[13.0f32, 14.0f32], [15.0f32, 16.0f32]],
]);
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
let output = tensor.repeat(2, 2);
let output = tensor.repeat(&[2, 3, 2]);
let expected = TensorData::from([
[[1.0, 2.0, 1.0, 2.0], [3.0, 4.0, 3.0, 4.0]],
[[5.0, 6.0, 5.0, 6.0], [7.0, 8.0, 7.0, 8.0]],
[[9.0, 10.0, 9.0, 10.0], [11.0, 12.0, 11.0, 12.0]],
[[13.0, 14.0, 13.0, 14.0], [15.0, 16.0, 15.0, 16.0]],
[
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
],
[
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
],
[
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
],
[
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
],
[
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
],
[
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
],
[
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
],
[
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_int_repeat_on_dims_larger_than_1() {
fn should_support_int_repeat_on_many_dims() {
let data = TensorData::from([
[[1, 2], [3, 4]],
[[5, 6], [7, 8]],
[[9, 10], [11, 12]],
[[13, 14], [15, 16]],
[[1i32, 2i32], [3i32, 4i32]],
[[5i32, 6i32], [7i32, 8i32]],
[[9i32, 10i32], [11i32, 12i32]],
[[13i32, 14i32], [15i32, 16i32]],
]);
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
let output = tensor.repeat(2, 3);
let output = tensor.repeat(&[2, 3, 2]);
let expected = TensorData::from([
[[1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4]],
[[5, 6, 5, 6, 5, 6], [7, 8, 7, 8, 7, 8]],
[[9, 10, 9, 10, 9, 10], [11, 12, 11, 12, 11, 12]],
[[13, 14, 13, 14, 13, 14], [15, 16, 15, 16, 15, 16]],
[
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
],
[
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
],
[
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
],
[
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
],
[
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
[1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32],
],
[
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
[5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32],
],
[
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
[9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32],
],
[
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
[13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32],
],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_bool_repeat_on_dims_larger_than_1() {
fn should_support_bool_repeat_on_many_dimension() {
let data = TensorData::from([
[[false, true], [true, false]],
[[true, true], [false, false]],
]);
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
let output = tensor.repeat(1, 2);
let output = tensor.repeat(&[2, 3, 2]);
let expected = TensorData::from([
[[false, true], [true, false], [false, true], [true, false]],
[[true, true], [false, false], [true, true], [false, false]],
[
[false, true, false, true],
[true, false, true, false],
[false, true, false, true],
[true, false, true, false],
[false, true, false, true],
[true, false, true, false],
],
[
[true, true, true, true],
[false, false, false, false],
[true, true, true, true],
[false, false, false, false],
[true, true, true, true],
[false, false, false, false],
],
[
[false, true, false, true],
[true, false, true, false],
[false, true, false, true],
[true, false, true, false],
[false, true, false, true],
[true, false, true, false],
],
[
[true, true, true, true],
[false, false, false, false],
[true, true, true, true],
[false, false, false, false],
[true, true, true, true],
[false, false, false, false],
],
]);
output.into_data().assert_eq(&expected, true);

View File

@ -0,0 +1,130 @@
#[burn_tensor_testgen::testgen(repeat_dim)]
mod tests {
use super::*;
use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};
#[test]
fn should_support_repeat_ops() {
let data = TensorData::from([[0.0f64, 1.0f64, 2.0f64]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data.clone(), &Default::default());
let output = tensor.repeat_dim(0, 4);
let expected = TensorData::from([
[0.0f32, 1.0f32, 2.0f32],
[0.0f32, 1.0f32, 2.0f32],
[0.0f32, 1.0f32, 2.0f32],
[0.0f32, 1.0f32, 2.0f32],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_bool_repeat_ops() {
let data = TensorData::from([[true, false, false]]);
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default());
let output = tensor.repeat_dim(0, 4);
let expected = TensorData::from([
[true, false, false],
[true, false, false],
[true, false, false],
[true, false, false],
]);
output.into_data().assert_eq(&expected, true);
}
#[test]
fn should_support_int_repeat_ops() {
let data = TensorData::from([[0, 1, 2]]);
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
let output = tensor.repeat_dim(0, 4);
let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_float_repeat_on_dims_larger_than_1() {
let data = TensorData::from([
[[1.0f32, 2.0f32], [3.0f32, 4.0f32]],
[[5.0f32, 6.0f32], [7.0f32, 8.0f32]],
[[9.0f32, 10.0f32], [11.0f32, 12.0f32]],
[[13.0f32, 14.0f32], [15.0f32, 16.0f32]],
]);
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
let output = tensor.repeat_dim(2, 2);
let expected = TensorData::from([
[
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
],
[
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
],
[
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
],
[
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_int_repeat_on_dims_larger_than_1() {
let data = TensorData::from([
[[1i32, 2i32], [3i32, 4i32]],
[[5i32, 6i32], [7i32, 8i32]],
[[9i32, 10i32], [11i32, 12i32]],
[[13i32, 14i32], [15i32, 16i32]],
]);
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
let output = tensor.repeat_dim(2, 3);
let expected = TensorData::from([
[
[1i32, 2i32, 1i32, 2i32, 1i32, 2i32],
[3i32, 4i32, 3i32, 4i32, 3i32, 4i32],
],
[
[5i32, 6i32, 5i32, 6i32, 5i32, 6i32],
[7i32, 8i32, 7i32, 8i32, 7i32, 8i32],
],
[
[9i32, 10i32, 9i32, 10i32, 9i32, 10i32],
[11i32, 12i32, 11i32, 12i32, 11i32, 12i32],
],
[
[13i32, 14i32, 13i32, 14i32, 13i32, 14i32],
[15i32, 16i32, 15i32, 16i32, 15i32, 16i32],
],
]);
output.into_data().assert_eq(&expected, false);
}
#[test]
fn should_support_bool_repeat_on_dims_larger_than_1() {
let data = TensorData::from([
[[false, true], [true, false]],
[[true, true], [false, false]],
]);
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
let output = tensor.repeat_dim(1, 2);
let expected = TensorData::from([
[[false, true], [true, false], [false, true], [true, false]],
[[true, true], [false, false], [true, true], [false, false]],
]);
output.into_data().assert_eq(&expected, true);
}
}

View File

@ -73,7 +73,7 @@ impl<B: Backend> TextClassificationModel<B> {
// Calculate token and position embeddings, and combine them
let index_positions = Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length])
.repeat(0, batch_size);
.repeat_dim(0, batch_size);
let embedding_positions = self.embedding_pos.forward(index_positions);
let embedding_tokens = self.embedding_token.forward(tokens);
let embedding = (embedding_positions + embedding_tokens) / 2;
@ -113,7 +113,7 @@ impl<B: Backend> TextClassificationModel<B> {
// Calculate token and position embeddings, and combine them
let index_positions = Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length])
.repeat(0, batch_size);
.repeat_dim(0, batch_size);
let embedding_positions = self.embedding_pos.forward(index_positions);
let embedding_tokens = self.embedding_token.forward(tokens);
let embedding = (embedding_positions + embedding_tokens) / 2;

View File

@ -64,7 +64,7 @@ impl<B: Backend> TextGenerationModel<B> {
let index_positions = Tensor::arange(0..seq_length as i64, device)
.reshape([1, seq_length])
.repeat(0, batch_size);
.repeat_dim(0, batch_size);
let embedding_positions = self.embedding_pos.forward(index_positions);
let embedding_tokens = self.embedding_token.forward(inputs);