mirror of https://github.com/tracel-ai/burn.git
Repeat operation (#2090)
* renaming repeat to repeat_dim * implementing repeat function * renaming repeat files to repeat_dim * renaming part 2 * renaming part 3 * renaming part 4 * renaming part 5 * adding test file * adding unit test * adding rust book documentation * adding function args doc * fixing tests * changing repeat api to match pytorch equivalent * fixing clippy error
This commit is contained in:
parent
bb13729b20
commit
f7639bd35a
|
@ -132,7 +132,7 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
|
|||
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
||||
|
||||
| Burn | PyTorch Equivalent |
|
||||
| ------------------------------------- | ------------------------------------ |
|
||||
| ------------------------------------- | ------------------------------------------------------------------------ |
|
||||
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
|
||||
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
|
||||
| `Tensor::from_primitive(primitive)` | N/A |
|
||||
|
@ -155,7 +155,8 @@ Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
|||
| `tensor.not_equal(other)` | `x != y` |
|
||||
| `tensor.permute(axes)` | `tensor.permute(axes)` |
|
||||
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
|
||||
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
|
||||
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])`|
|
||||
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
|
||||
| `tensor.reshape(shape)` | `tensor.view(shape)` |
|
||||
| `tensor.shape()` | `tensor.shape` |
|
||||
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
||||
|
|
|
@ -132,11 +132,11 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
|||
B::bool_expand(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_repeat<const D: usize>(
|
||||
fn bool_repeat_dim<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> BoolTensor<B, D> {
|
||||
B::bool_repeat(tensor, dim, times)
|
||||
B::bool_repeat_dim(tensor, dim, times)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -162,12 +162,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
|||
B::int_mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_repeat<const D: usize>(
|
||||
fn int_repeat_dim<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> IntTensor<B, D> {
|
||||
B::int_repeat(tensor, dim, times)
|
||||
B::int_repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn int_greater<const D: usize>(lhs: IntTensor<B, D>, rhs: IntTensor<B, D>) -> BoolTensor<B, D> {
|
||||
|
|
|
@ -2418,7 +2418,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
B::float_argsort(tensor.primitive, dim, descending)
|
||||
}
|
||||
|
||||
fn float_repeat<const D: usize>(
|
||||
fn float_repeat_dim<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
|
@ -2437,7 +2437,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> {
|
||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
|
||||
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
|
||||
let out = B::float_repeat(tensor, self.dim, self.times);
|
||||
let out = B::float_repeat_dim(tensor, self.dim, self.times);
|
||||
states.save(out_node, out)
|
||||
}
|
||||
}
|
||||
|
@ -2467,9 +2467,11 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
|||
.stateful()
|
||||
{
|
||||
OpsKind::Tracked(prep) => {
|
||||
prep.finish(dim, B::float_repeat(tensor.primitive, dim, times))
|
||||
prep.finish(dim, B::float_repeat_dim(tensor.primitive, dim, times))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => {
|
||||
prep.finish(B::float_repeat_dim(tensor.primitive, dim, times))
|
||||
}
|
||||
OpsKind::UnTracked(prep) => prep.finish(B::float_repeat(tensor.primitive, dim, times)),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ mod permute;
|
|||
mod pow;
|
||||
mod recip;
|
||||
mod relu;
|
||||
mod repeat;
|
||||
mod repeat_dim;
|
||||
mod reshape;
|
||||
mod select;
|
||||
mod sigmoid;
|
||||
|
@ -133,6 +133,6 @@ macro_rules! testgen_all {
|
|||
burn_autodiff::testgen_ad_sign!();
|
||||
burn_autodiff::testgen_ad_expand!();
|
||||
burn_autodiff::testgen_ad_sort!();
|
||||
burn_autodiff::testgen_ad_repeat!();
|
||||
burn_autodiff::testgen_ad_repeat_dim!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#[burn_tensor_testgen::testgen(ad_repeat)]
|
||||
#[burn_tensor_testgen::testgen(ad_repeat_dim)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{activation, TensorData};
|
||||
|
@ -12,7 +12,7 @@ mod tests {
|
|||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||
|
||||
let tensor_3 = tensor_2.clone().repeat(1, 3);
|
||||
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(tensor_3);
|
||||
let grads = tensor_3.backward();
|
|
@ -94,7 +94,7 @@ mod tests {
|
|||
// burn_tensor::testgen_powf!();
|
||||
|
||||
burn_tensor::testgen_random!();
|
||||
burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_repeat_dim!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_select!();
|
||||
burn_tensor::testgen_sin!();
|
||||
|
|
|
@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
|
|||
mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
|
||||
}
|
||||
|
||||
mask = mask.repeat(0, batch_size);
|
||||
mask = mask.repeat_dim(0, batch_size);
|
||||
|
||||
mask.equal_elem(1_i64.elem::<i64>())
|
||||
}
|
||||
|
|
|
@ -152,7 +152,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
|
|||
* weights
|
||||
.clone()
|
||||
.reshape([1, nr_classes])
|
||||
.repeat(0, batch_size);
|
||||
.repeat_dim(0, batch_size);
|
||||
let weights = weights.clone().gather(0, targets);
|
||||
let tensor = Self::apply_mask_2d(tensor, mask);
|
||||
tensor.sum().neg() / weights.sum()
|
||||
|
@ -224,7 +224,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
|
|||
fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
|
||||
if let Some(mask) = mask {
|
||||
let [batch_size, nr_classes] = tensor.dims();
|
||||
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0);
|
||||
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
|
||||
}
|
||||
|
||||
tensor
|
||||
|
@ -312,7 +312,7 @@ mod tests {
|
|||
* targets_logits
|
||||
* Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
|
||||
.unsqueeze()
|
||||
.repeat(0, 4);
|
||||
.repeat_dim(0, 4);
|
||||
let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
|
||||
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
|
||||
}
|
||||
|
|
|
@ -58,7 +58,7 @@ impl RotaryEncodingConfig {
|
|||
.float()
|
||||
.unsqueeze()
|
||||
.transpose()
|
||||
.repeat(1, self.d_model / 2)
|
||||
.repeat_dim(1, self.d_model / 2)
|
||||
* theta_i.unsqueeze();
|
||||
|
||||
// Convert frequency values to complex numbers (polar form)
|
||||
|
@ -71,7 +71,7 @@ impl RotaryEncodingConfig {
|
|||
.reshape([self.max_sequence_length, 2, self.d_model / 2])
|
||||
.transpose()
|
||||
.unsqueeze_dim::<4>(2)
|
||||
.repeat(2, 2)
|
||||
.repeat_dim(2, 2)
|
||||
.reshape([self.max_sequence_length, self.d_model, 2]);
|
||||
|
||||
RotaryEncoding {
|
||||
|
|
|
@ -14,7 +14,7 @@ use burn_tensor::{
|
|||
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
||||
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||
HandleContainer, OperationDescription, PermuteOperationDescription,
|
||||
RepeatOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
|
||||
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
|
||||
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
|
||||
},
|
||||
Device, Shape,
|
||||
|
@ -575,22 +575,22 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn bool_repeat<const D: usize>(
|
||||
fn bool_repeat_dim<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> BoolTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct RepeatOps<B: FusionBackend, const D: usize> {
|
||||
desc: RepeatOperationDescription,
|
||||
struct RepeatDimOps<B: FusionBackend, const D: usize> {
|
||||
desc: RepeatDimOperationDescription,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_bool_tensor::<B, D>(&self.desc.tensor);
|
||||
|
||||
let output = B::bool_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
let output = B::bool_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
|
||||
handles.register_bool_tensor::<B, D>(&self.desc.out.id, output);
|
||||
}
|
||||
|
@ -601,7 +601,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
shape[dim] *= times;
|
||||
let out = tensor.client.tensor_uninitialized(shape, DType::Bool);
|
||||
|
||||
let desc = RepeatOperationDescription {
|
||||
let desc = RepeatDimOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
times,
|
||||
|
@ -609,8 +609,8 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())),
|
||||
RepeatOps::<B, D>::new(desc),
|
||||
OperationDescription::BaseBool(BaseOperationDescription::RepeatDim(desc.clone())),
|
||||
RepeatDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
|
|
|
@ -1696,22 +1696,22 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn float_repeat<const D: usize>(
|
||||
fn float_repeat_dim<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> FloatTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct RepeatOps<B: FusionBackend, const D: usize> {
|
||||
desc: RepeatOperationDescription,
|
||||
struct RepeatDimOps<B: FusionBackend, const D: usize> {
|
||||
desc: RepeatDimOperationDescription,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_float_tensor::<B, D>(&self.desc.tensor);
|
||||
|
||||
let output = B::float_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
let output = B::float_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
|
||||
handles.register_float_tensor::<B, D>(&self.desc.out.id, output);
|
||||
}
|
||||
|
@ -1724,7 +1724,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(shape, B::FloatElem::dtype());
|
||||
|
||||
let desc = RepeatOperationDescription {
|
||||
let desc = RepeatDimOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
times,
|
||||
|
@ -1732,8 +1732,8 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())),
|
||||
RepeatOps::<B, D>::new(desc),
|
||||
OperationDescription::BaseFloat(BaseOperationDescription::RepeatDim(desc.clone())),
|
||||
RepeatDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
|
|
|
@ -1755,22 +1755,22 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn int_repeat<const D: usize>(
|
||||
fn int_repeat_dim<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> IntTensor<Self, D> {
|
||||
#[derive(new)]
|
||||
struct RepeatOps<B: FusionBackend, const D: usize> {
|
||||
desc: RepeatOperationDescription,
|
||||
struct RepeatDimOps<B: FusionBackend, const D: usize> {
|
||||
desc: RepeatDimOperationDescription,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
|
||||
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
|
||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
||||
let tensor = handles.get_int_tensor::<B, D>(&self.desc.tensor);
|
||||
|
||||
let output = B::int_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
let output = B::int_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
|
||||
|
||||
handles.register_int_tensor::<B, D>(&self.desc.out.id, output);
|
||||
}
|
||||
|
@ -1783,7 +1783,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
.client
|
||||
.tensor_uninitialized(shape, B::IntElem::dtype());
|
||||
|
||||
let desc = RepeatOperationDescription {
|
||||
let desc = RepeatDimOperationDescription {
|
||||
tensor: tensor.into_description(),
|
||||
dim,
|
||||
times,
|
||||
|
@ -1791,8 +1791,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
};
|
||||
out.client.register(
|
||||
vec![stream],
|
||||
OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())),
|
||||
RepeatOps::<B, D>::new(desc),
|
||||
OperationDescription::BaseInt(BaseOperationDescription::RepeatDim(desc.clone())),
|
||||
RepeatDimOps::<B, D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
|
|
|
@ -848,8 +848,8 @@ impl RelativeOps for BaseOperationDescription {
|
|||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
BaseOperationDescription::Repeat(desc) => {
|
||||
BaseOperationDescription::Repeat(RepeatOperationDescription {
|
||||
BaseOperationDescription::RepeatDim(desc) => {
|
||||
BaseOperationDescription::RepeatDim(RepeatDimOperationDescription {
|
||||
tensor: desc.tensor.to_relative(converter),
|
||||
dim: desc.dim,
|
||||
times: desc.times,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
mod flip;
|
||||
mod gather;
|
||||
mod repeat;
|
||||
mod repeat_dim;
|
||||
mod scatter;
|
||||
mod select;
|
||||
mod select_assign;
|
||||
|
@ -8,7 +8,7 @@ mod slice;
|
|||
mod slice_assign;
|
||||
|
||||
pub use flip::*;
|
||||
pub use repeat::*;
|
||||
pub use repeat_dim::*;
|
||||
pub use select::*;
|
||||
pub use select_assign::*;
|
||||
pub use slice::*;
|
||||
|
|
|
@ -95,7 +95,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for RepeatEagerKernel<R, E> {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn repeat<R: JitRuntime, E: JitElement, const D1: usize>(
|
||||
pub(crate) fn repeat_dim<R: JitRuntime, E: JitElement, const D1: usize>(
|
||||
input: JitTensor<R, E, D1>,
|
||||
dim: usize,
|
||||
times: usize,
|
|
@ -94,12 +94,12 @@ where
|
|||
tensor
|
||||
}
|
||||
|
||||
fn bool_repeat<const D: usize>(
|
||||
fn bool_repeat_dim<const D: usize>(
|
||||
tensor: BoolTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> BoolTensor<Self, D> {
|
||||
kernel::repeat(tensor, dim, times)
|
||||
kernel::repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn bool_permute<const D: usize>(
|
||||
|
|
|
@ -467,12 +467,12 @@ where
|
|||
})
|
||||
}
|
||||
|
||||
fn float_repeat<const D: usize>(
|
||||
fn float_repeat_dim<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> FloatTensor<Self, D> {
|
||||
kernel::repeat(tensor, dim, times)
|
||||
kernel::repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn float_powf<const D: usize>(
|
||||
|
|
|
@ -318,12 +318,12 @@ where
|
|||
tensor
|
||||
}
|
||||
|
||||
fn int_repeat<const D: usize>(
|
||||
fn int_repeat_dim<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> IntTensor<Self, D> {
|
||||
kernel::repeat(tensor, dim, times)
|
||||
kernel::repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn int_random<const D: usize>(
|
||||
|
|
|
@ -17,7 +17,7 @@ mod max_pool2d;
|
|||
mod max_pool2d_backward;
|
||||
mod normal;
|
||||
mod reduce;
|
||||
mod repeat;
|
||||
mod repeat_dim;
|
||||
mod scatter;
|
||||
mod select;
|
||||
mod select_assign;
|
||||
|
@ -48,7 +48,7 @@ macro_rules! testgen_all {
|
|||
burn_jit::testgen_conv_transpose2d!();
|
||||
burn_jit::testgen_conv_transpose3d!();
|
||||
|
||||
burn_jit::testgen_repeat!();
|
||||
burn_jit::testgen_repeat_dim!();
|
||||
burn_jit::testgen_gather!();
|
||||
burn_jit::testgen_scatter!();
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
#[burn_tensor_testgen::testgen(repeat)]
|
||||
#[burn_tensor_testgen::testgen(repeat_dim)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Distribution, Tensor};
|
||||
|
@ -12,8 +12,8 @@ mod tests {
|
|||
let tensor_ref =
|
||||
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
||||
|
||||
let actual = tensor.repeat(dim, times);
|
||||
let expected = tensor_ref.repeat(dim, times);
|
||||
let actual = tensor.repeat_dim(dim, times);
|
||||
let expected = tensor_ref.repeat_dim(dim, times);
|
||||
|
||||
expected
|
||||
.into_data()
|
||||
|
@ -29,8 +29,8 @@ mod tests {
|
|||
let tensor_ref =
|
||||
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
||||
|
||||
let actual = tensor.repeat(dim, times);
|
||||
let expected = tensor_ref.repeat(dim, times);
|
||||
let actual = tensor.repeat_dim(dim, times);
|
||||
let expected = tensor_ref.repeat_dim(dim, times);
|
||||
|
||||
expected
|
||||
.into_data()
|
||||
|
@ -46,8 +46,8 @@ mod tests {
|
|||
let tensor_ref =
|
||||
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
||||
|
||||
let actual = tensor.repeat(dim, times);
|
||||
let expected = tensor_ref.repeat(dim, times);
|
||||
let actual = tensor.repeat_dim(dim, times);
|
||||
let expected = tensor_ref.repeat_dim(dim, times);
|
||||
|
||||
expected
|
||||
.into_data()
|
||||
|
@ -66,8 +66,8 @@ mod tests {
|
|||
let tensor_ref =
|
||||
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
||||
|
||||
let actual = tensor.repeat(dim, times);
|
||||
let expected = tensor_ref.repeat(dim, times);
|
||||
let actual = tensor.repeat_dim(dim, times);
|
||||
let expected = tensor_ref.repeat_dim(dim, times);
|
||||
|
||||
expected
|
||||
.into_data()
|
|
@ -33,7 +33,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
|
||||
}
|
||||
|
||||
pub fn repeat<const D: usize>(
|
||||
pub fn repeat_dim<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
|
|
|
@ -15,12 +15,12 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn bool_repeat<const D: usize>(
|
||||
fn bool_repeat_dim<const D: usize>(
|
||||
tensor: TchTensor<bool, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchOps::repeat(tensor, dim, times)
|
||||
TchOps::repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
async fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> TensorData {
|
||||
|
|
|
@ -18,12 +18,12 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
|
|||
tensor.shape()
|
||||
}
|
||||
|
||||
fn int_repeat<const D: usize>(
|
||||
fn int_repeat_dim<const D: usize>(
|
||||
tensor: TchTensor<i64, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> TchTensor<i64, D> {
|
||||
TchOps::repeat(tensor, dim, times)
|
||||
TchOps::repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
async fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> TensorData {
|
||||
|
|
|
@ -42,12 +42,12 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
|
|||
}
|
||||
}
|
||||
|
||||
fn float_repeat<const D: usize>(
|
||||
fn float_repeat_dim<const D: usize>(
|
||||
tensor: TchTensor<E, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> TchTensor<E, D> {
|
||||
TchOps::repeat(tensor, dim, times)
|
||||
TchOps::repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn float_zeros<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> {
|
||||
|
|
|
@ -189,10 +189,10 @@ pub enum BaseOperationDescription {
|
|||
Equal(BinaryOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [repeat](crate::ops::FloatTensorOps::float_repeat).
|
||||
/// Int => [repeat](crate::ops::IntTensorOps::int_repeat).
|
||||
/// Bool => [repeat](crate::ops::BoolTensorOps::bool_repeat).
|
||||
Repeat(RepeatOperationDescription),
|
||||
/// Float => [repeat dim](crate::ops::FloatTensorOps::float_repeat_dim).
|
||||
/// Int => [repeat dim](crate::ops::IntTensorOps::int_repeat_dim).
|
||||
/// Bool => [repeat dim](crate::ops::BoolTensorOps::bool_repeat_dim).
|
||||
RepeatDim(RepeatDimOperationDescription),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [cat](crate::ops::FloatTensorOps::float_cat).
|
||||
|
@ -627,7 +627,7 @@ pub struct ClampOperationDescription<E> {
|
|||
|
||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct RepeatOperationDescription {
|
||||
pub struct RepeatDimOperationDescription {
|
||||
pub tensor: TensorDescription,
|
||||
pub dim: usize,
|
||||
pub times: usize,
|
||||
|
@ -1189,7 +1189,7 @@ impl BaseOperationDescription {
|
|||
BaseOperationDescription::Equal(desc) => {
|
||||
vec![&desc.lhs, &desc.rhs, &desc.out]
|
||||
}
|
||||
BaseOperationDescription::Repeat(desc) => {
|
||||
BaseOperationDescription::RepeatDim(desc) => {
|
||||
vec![&desc.tensor, &desc.out]
|
||||
}
|
||||
BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),
|
||||
|
|
|
@ -722,8 +722,21 @@ where
|
|||
}
|
||||
|
||||
/// Repeat the tensor along the given dimension.
|
||||
pub fn repeat(self, dim: usize, times: usize) -> Self {
|
||||
Self::new(K::repeat(self.primitive, dim, times))
|
||||
pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
|
||||
Self::new(K::repeat_dim(self.primitive, dim, times))
|
||||
}
|
||||
|
||||
/// Repeat the tensor along the given dimensions.
|
||||
/// # Arguments
|
||||
/// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
|
||||
pub fn repeat(self, sizes: &[usize]) -> Self {
|
||||
let mut tensor = self;
|
||||
for (dim, ×) in sizes.iter().enumerate() {
|
||||
if times > 1 {
|
||||
tensor = tensor.repeat_dim(dim, times);
|
||||
}
|
||||
}
|
||||
tensor
|
||||
}
|
||||
|
||||
/// Applies element-wise equal comparison and returns a boolean tensor.
|
||||
|
@ -1504,9 +1517,9 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
|||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||
/// or use this function directly.
|
||||
///
|
||||
/// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function,
|
||||
/// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
|
||||
/// which is more high-level and designed for public use.
|
||||
fn repeat<const D: usize>(
|
||||
fn repeat_dim<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
|
@ -1763,12 +1776,12 @@ impl<B: Backend> BasicOps<B> for Float {
|
|||
}
|
||||
}
|
||||
|
||||
fn repeat<const D: usize>(
|
||||
fn repeat_dim<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> Self::Primitive<D> {
|
||||
TensorPrimitive::Float(B::float_repeat(tensor.tensor(), dim, times))
|
||||
TensorPrimitive::Float(B::float_repeat_dim(tensor.tensor(), dim, times))
|
||||
}
|
||||
|
||||
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
||||
|
@ -1888,12 +1901,12 @@ impl<B: Backend> BasicOps<B> for Int {
|
|||
B::int_from_data(data, device)
|
||||
}
|
||||
|
||||
fn repeat<const D: usize>(
|
||||
fn repeat_dim<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> Self::Primitive<D> {
|
||||
B::int_repeat(tensor, dim, times)
|
||||
B::int_repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn equal<const D: usize>(
|
||||
|
@ -2010,12 +2023,12 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
B::bool_from_data(data, device)
|
||||
}
|
||||
|
||||
fn repeat<const D: usize>(
|
||||
fn repeat_dim<const D: usize>(
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> Self::Primitive<D> {
|
||||
B::bool_repeat(tensor, dim, times)
|
||||
B::bool_repeat_dim(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn equal<const D: usize>(
|
||||
|
|
|
@ -46,7 +46,7 @@ pub fn cartesian_grid<B: Backend, S: Into<Shape<D>>, const D: usize, const D2: u
|
|||
if i == dim {
|
||||
continue;
|
||||
}
|
||||
dim_range = dim_range.repeat(i, item);
|
||||
dim_range = dim_range.repeat_dim(i, item);
|
||||
}
|
||||
|
||||
indices.push(dim_range);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::{
|
||||
cat::cat_with_slice_assign, repeat::repeat_with_slice_assign, BoolTensor, Device, FloatTensor,
|
||||
IntTensor,
|
||||
cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign, BoolTensor, Device,
|
||||
FloatTensor, IntTensor,
|
||||
};
|
||||
use crate::{
|
||||
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
|
||||
|
@ -157,7 +157,7 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The tensor with the dimension repeated.
|
||||
fn bool_repeat<const D: usize>(
|
||||
fn bool_repeat_dim<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::cat::cat_with_slice_assign;
|
||||
use super::repeat::repeat_with_slice_assign;
|
||||
use super::repeat_dim::repeat_with_slice_assign;
|
||||
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
use crate::cast::ToElement;
|
||||
use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData};
|
||||
|
@ -251,7 +251,7 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given dimension repeated the given number of times.
|
||||
fn int_repeat<const D: usize>(
|
||||
fn int_repeat_dim<const D: usize>(
|
||||
tensor: IntTensor<B, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
|
|
|
@ -4,7 +4,7 @@ pub mod conv;
|
|||
/// Module with cat operation
|
||||
pub(crate) mod cat;
|
||||
/// Module with repeat operation
|
||||
pub(crate) mod repeat;
|
||||
pub(crate) mod repeat_dim;
|
||||
/// Module with unfold operations.
|
||||
pub(crate) mod unfold;
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::cat::cat_with_slice_assign;
|
||||
use super::repeat::repeat_with_slice_assign;
|
||||
use super::repeat_dim::repeat_with_slice_assign;
|
||||
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
|
||||
use crate::backend::BackendBridge;
|
||||
use crate::tensor::cast::ToElement;
|
||||
|
@ -174,7 +174,7 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The tensor with the given dimension repeated.
|
||||
fn float_repeat<const D: usize>(
|
||||
fn float_repeat_dim<const D: usize>(
|
||||
tensor: FloatTensor<B, D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
|
|
|
@ -349,7 +349,7 @@ mod tests {
|
|||
clone_invariance_test!(
|
||||
unary: Repeat,
|
||||
ops_float: |tensor: TestTensor<2>| {
|
||||
tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32])
|
||||
tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
|
@ -633,7 +633,7 @@ mod tests {
|
|||
clone_invariance_test!(
|
||||
unary: Repeat,
|
||||
ops_int: |tensor: TestTensorInt<2>| {
|
||||
tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32])
|
||||
tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
|
||||
}
|
||||
);
|
||||
clone_invariance_test!(
|
||||
|
|
|
@ -74,6 +74,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_powf_scalar!();
|
||||
burn_tensor::testgen_random!();
|
||||
burn_tensor::testgen_recip!();
|
||||
burn_tensor::testgen_repeat_dim!();
|
||||
burn_tensor::testgen_repeat!();
|
||||
burn_tensor::testgen_reshape!();
|
||||
burn_tensor::testgen_select!();
|
||||
|
|
|
@ -45,6 +45,7 @@ mod random;
|
|||
mod recip;
|
||||
mod remainder;
|
||||
mod repeat;
|
||||
mod repeat_dim;
|
||||
mod reshape;
|
||||
mod select;
|
||||
mod sign;
|
||||
|
|
|
@ -4,27 +4,27 @@ mod tests {
|
|||
use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_support_repeat_ops() {
|
||||
let data = TensorData::from([[0.0, 1.0, 2.0]]);
|
||||
fn should_support_repeat_ops_one_dimension() {
|
||||
let data = TensorData::from([[0.0f32, 1.0f32, 2.0f32]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat(0, 4);
|
||||
let output = tensor.repeat(&[4, 1, 1]);
|
||||
let expected = TensorData::from([
|
||||
[0.0, 1.0, 2.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[0.0f32, 1.0f32, 2.0f32],
|
||||
[0.0f32, 1.0f32, 2.0f32],
|
||||
[0.0f32, 1.0f32, 2.0f32],
|
||||
[0.0f32, 1.0f32, 2.0f32],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_repeat_ops() {
|
||||
fn should_support_bool_repeat_ops_one_dimension() {
|
||||
let data = TensorData::from([[true, false, false]]);
|
||||
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat(0, 4);
|
||||
let output = tensor.repeat(&[4, 1, 1]);
|
||||
let expected = TensorData::from([
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
|
@ -35,70 +35,226 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_int_repeat_ops() {
|
||||
let data = TensorData::from([[0, 1, 2]]);
|
||||
fn should_support_int_repeat_ops_one_dimension() {
|
||||
let data = TensorData::from([[0i32, 1i32, 2i32]]);
|
||||
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat(0, 4);
|
||||
let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
|
||||
let output = tensor.repeat(&[4, 1, 1]);
|
||||
let expected = TensorData::from([
|
||||
[0i32, 1i32, 2i32],
|
||||
[0i32, 1i32, 2i32],
|
||||
[0i32, 1i32, 2i32],
|
||||
[0i32, 1i32, 2i32],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_float_repeat_on_dims_larger_than_1() {
|
||||
fn should_support_float_repeat_repeating_on_many_dimensions() {
|
||||
let data = TensorData::from([
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
[[13.0, 14.0], [15.0, 16.0]],
|
||||
[[1.0f32, 2.0f32], [3.0f32, 4.0f32]],
|
||||
[[5.0f32, 6.0f32], [7.0f32, 8.0f32]],
|
||||
[[9.0f32, 10.0f32], [11.0f32, 12.0f32]],
|
||||
[[13.0f32, 14.0f32], [15.0f32, 16.0f32]],
|
||||
]);
|
||||
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat(2, 2);
|
||||
let output = tensor.repeat(&[2, 3, 2]);
|
||||
let expected = TensorData::from([
|
||||
[[1.0, 2.0, 1.0, 2.0], [3.0, 4.0, 3.0, 4.0]],
|
||||
[[5.0, 6.0, 5.0, 6.0], [7.0, 8.0, 7.0, 8.0]],
|
||||
[[9.0, 10.0, 9.0, 10.0], [11.0, 12.0, 11.0, 12.0]],
|
||||
[[13.0, 14.0, 13.0, 14.0], [15.0, 16.0, 15.0, 16.0]],
|
||||
[
|
||||
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||
],
|
||||
[
|
||||
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||
],
|
||||
[
|
||||
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||
],
|
||||
[
|
||||
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||
],
|
||||
[
|
||||
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||
],
|
||||
[
|
||||
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||
],
|
||||
[
|
||||
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||
],
|
||||
[
|
||||
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||
],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_int_repeat_on_dims_larger_than_1() {
|
||||
fn should_support_int_repeat_on_many_dims() {
|
||||
let data = TensorData::from([
|
||||
[[1, 2], [3, 4]],
|
||||
[[5, 6], [7, 8]],
|
||||
[[9, 10], [11, 12]],
|
||||
[[13, 14], [15, 16]],
|
||||
[[1i32, 2i32], [3i32, 4i32]],
|
||||
[[5i32, 6i32], [7i32, 8i32]],
|
||||
[[9i32, 10i32], [11i32, 12i32]],
|
||||
[[13i32, 14i32], [15i32, 16i32]],
|
||||
]);
|
||||
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat(2, 3);
|
||||
let output = tensor.repeat(&[2, 3, 2]);
|
||||
|
||||
let expected = TensorData::from([
|
||||
[[1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4]],
|
||||
[[5, 6, 5, 6, 5, 6], [7, 8, 7, 8, 7, 8]],
|
||||
[[9, 10, 9, 10, 9, 10], [11, 12, 11, 12, 11, 12]],
|
||||
[[13, 14, 13, 14, 13, 14], [15, 16, 15, 16, 15, 16]],
|
||||
[
|
||||
[1i32, 2i32, 1i32, 2i32],
|
||||
[3i32, 4i32, 3i32, 4i32],
|
||||
[1i32, 2i32, 1i32, 2i32],
|
||||
[3i32, 4i32, 3i32, 4i32],
|
||||
[1i32, 2i32, 1i32, 2i32],
|
||||
[3i32, 4i32, 3i32, 4i32],
|
||||
],
|
||||
[
|
||||
[5i32, 6i32, 5i32, 6i32],
|
||||
[7i32, 8i32, 7i32, 8i32],
|
||||
[5i32, 6i32, 5i32, 6i32],
|
||||
[7i32, 8i32, 7i32, 8i32],
|
||||
[5i32, 6i32, 5i32, 6i32],
|
||||
[7i32, 8i32, 7i32, 8i32],
|
||||
],
|
||||
[
|
||||
[9i32, 10i32, 9i32, 10i32],
|
||||
[11i32, 12i32, 11i32, 12i32],
|
||||
[9i32, 10i32, 9i32, 10i32],
|
||||
[11i32, 12i32, 11i32, 12i32],
|
||||
[9i32, 10i32, 9i32, 10i32],
|
||||
[11i32, 12i32, 11i32, 12i32],
|
||||
],
|
||||
[
|
||||
[13i32, 14i32, 13i32, 14i32],
|
||||
[15i32, 16i32, 15i32, 16i32],
|
||||
[13i32, 14i32, 13i32, 14i32],
|
||||
[15i32, 16i32, 15i32, 16i32],
|
||||
[13i32, 14i32, 13i32, 14i32],
|
||||
[15i32, 16i32, 15i32, 16i32],
|
||||
],
|
||||
[
|
||||
[1i32, 2i32, 1i32, 2i32],
|
||||
[3i32, 4i32, 3i32, 4i32],
|
||||
[1i32, 2i32, 1i32, 2i32],
|
||||
[3i32, 4i32, 3i32, 4i32],
|
||||
[1i32, 2i32, 1i32, 2i32],
|
||||
[3i32, 4i32, 3i32, 4i32],
|
||||
],
|
||||
[
|
||||
[5i32, 6i32, 5i32, 6i32],
|
||||
[7i32, 8i32, 7i32, 8i32],
|
||||
[5i32, 6i32, 5i32, 6i32],
|
||||
[7i32, 8i32, 7i32, 8i32],
|
||||
[5i32, 6i32, 5i32, 6i32],
|
||||
[7i32, 8i32, 7i32, 8i32],
|
||||
],
|
||||
[
|
||||
[9i32, 10i32, 9i32, 10i32],
|
||||
[11i32, 12i32, 11i32, 12i32],
|
||||
[9i32, 10i32, 9i32, 10i32],
|
||||
[11i32, 12i32, 11i32, 12i32],
|
||||
[9i32, 10i32, 9i32, 10i32],
|
||||
[11i32, 12i32, 11i32, 12i32],
|
||||
],
|
||||
[
|
||||
[13i32, 14i32, 13i32, 14i32],
|
||||
[15i32, 16i32, 15i32, 16i32],
|
||||
[13i32, 14i32, 13i32, 14i32],
|
||||
[15i32, 16i32, 15i32, 16i32],
|
||||
[13i32, 14i32, 13i32, 14i32],
|
||||
[15i32, 16i32, 15i32, 16i32],
|
||||
],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_repeat_on_dims_larger_than_1() {
|
||||
fn should_support_bool_repeat_on_many_dimension() {
|
||||
let data = TensorData::from([
|
||||
[[false, true], [true, false]],
|
||||
[[true, true], [false, false]],
|
||||
]);
|
||||
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat(1, 2);
|
||||
let output = tensor.repeat(&[2, 3, 2]);
|
||||
let expected = TensorData::from([
|
||||
[[false, true], [true, false], [false, true], [true, false]],
|
||||
[[true, true], [false, false], [true, true], [false, false]],
|
||||
[
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
],
|
||||
[
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
],
|
||||
[
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
[false, true, false, true],
|
||||
[true, false, true, false],
|
||||
],
|
||||
[
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
[true, true, true, true],
|
||||
[false, false, false, false],
|
||||
],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, true);
|
||||
|
|
|
@ -0,0 +1,130 @@
|
|||
#[burn_tensor_testgen::testgen(repeat_dim)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};
|
||||
|
||||
#[test]
|
||||
fn should_support_repeat_ops() {
|
||||
let data = TensorData::from([[0.0f64, 1.0f64, 2.0f64]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data.clone(), &Default::default());
|
||||
|
||||
let output = tensor.repeat_dim(0, 4);
|
||||
let expected = TensorData::from([
|
||||
[0.0f32, 1.0f32, 2.0f32],
|
||||
[0.0f32, 1.0f32, 2.0f32],
|
||||
[0.0f32, 1.0f32, 2.0f32],
|
||||
[0.0f32, 1.0f32, 2.0f32],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_repeat_ops() {
|
||||
let data = TensorData::from([[true, false, false]]);
|
||||
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat_dim(0, 4);
|
||||
let expected = TensorData::from([
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
[true, false, false],
|
||||
]);
|
||||
output.into_data().assert_eq(&expected, true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_int_repeat_ops() {
|
||||
let data = TensorData::from([[0, 1, 2]]);
|
||||
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat_dim(0, 4);
|
||||
let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_float_repeat_on_dims_larger_than_1() {
|
||||
let data = TensorData::from([
|
||||
[[1.0f32, 2.0f32], [3.0f32, 4.0f32]],
|
||||
[[5.0f32, 6.0f32], [7.0f32, 8.0f32]],
|
||||
[[9.0f32, 10.0f32], [11.0f32, 12.0f32]],
|
||||
[[13.0f32, 14.0f32], [15.0f32, 16.0f32]],
|
||||
]);
|
||||
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat_dim(2, 2);
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||
],
|
||||
[
|
||||
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||
],
|
||||
[
|
||||
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||
],
|
||||
[
|
||||
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||
],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_int_repeat_on_dims_larger_than_1() {
|
||||
let data = TensorData::from([
|
||||
[[1i32, 2i32], [3i32, 4i32]],
|
||||
[[5i32, 6i32], [7i32, 8i32]],
|
||||
[[9i32, 10i32], [11i32, 12i32]],
|
||||
[[13i32, 14i32], [15i32, 16i32]],
|
||||
]);
|
||||
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat_dim(2, 3);
|
||||
let expected = TensorData::from([
|
||||
[
|
||||
[1i32, 2i32, 1i32, 2i32, 1i32, 2i32],
|
||||
[3i32, 4i32, 3i32, 4i32, 3i32, 4i32],
|
||||
],
|
||||
[
|
||||
[5i32, 6i32, 5i32, 6i32, 5i32, 6i32],
|
||||
[7i32, 8i32, 7i32, 8i32, 7i32, 8i32],
|
||||
],
|
||||
[
|
||||
[9i32, 10i32, 9i32, 10i32, 9i32, 10i32],
|
||||
[11i32, 12i32, 11i32, 12i32, 11i32, 12i32],
|
||||
],
|
||||
[
|
||||
[13i32, 14i32, 13i32, 14i32, 13i32, 14i32],
|
||||
[15i32, 16i32, 15i32, 16i32, 15i32, 16i32],
|
||||
],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, false);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_repeat_on_dims_larger_than_1() {
|
||||
let data = TensorData::from([
|
||||
[[false, true], [true, false]],
|
||||
[[true, true], [false, false]],
|
||||
]);
|
||||
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
|
||||
|
||||
let output = tensor.repeat_dim(1, 2);
|
||||
let expected = TensorData::from([
|
||||
[[false, true], [true, false], [false, true], [true, false]],
|
||||
[[true, true], [false, false], [true, true], [false, false]],
|
||||
]);
|
||||
|
||||
output.into_data().assert_eq(&expected, true);
|
||||
}
|
||||
}
|
|
@ -73,7 +73,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
// Calculate token and position embeddings, and combine them
|
||||
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
||||
.reshape([1, seq_length])
|
||||
.repeat(0, batch_size);
|
||||
.repeat_dim(0, batch_size);
|
||||
let embedding_positions = self.embedding_pos.forward(index_positions);
|
||||
let embedding_tokens = self.embedding_token.forward(tokens);
|
||||
let embedding = (embedding_positions + embedding_tokens) / 2;
|
||||
|
@ -113,7 +113,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
|||
// Calculate token and position embeddings, and combine them
|
||||
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
||||
.reshape([1, seq_length])
|
||||
.repeat(0, batch_size);
|
||||
.repeat_dim(0, batch_size);
|
||||
let embedding_positions = self.embedding_pos.forward(index_positions);
|
||||
let embedding_tokens = self.embedding_token.forward(tokens);
|
||||
let embedding = (embedding_positions + embedding_tokens) / 2;
|
||||
|
|
|
@ -64,7 +64,7 @@ impl<B: Backend> TextGenerationModel<B> {
|
|||
|
||||
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
||||
.reshape([1, seq_length])
|
||||
.repeat(0, batch_size);
|
||||
.repeat_dim(0, batch_size);
|
||||
|
||||
let embedding_positions = self.embedding_pos.forward(index_positions);
|
||||
let embedding_tokens = self.embedding_token.forward(inputs);
|
||||
|
|
Loading…
Reference in New Issue