mirror of https://github.com/tracel-ai/burn.git
Repeat operation (#2090)
* renaming repeat to repeat_dim * implementing repeat function * renaming repeat files to repeat_dim * renaming part 2 * renaming part 3 * renaming part 4 * renaming part 5 * adding test file * adding unit test * adding rust book documentation * adding function args doc * fixing tests * changing repeat api to match pytorch equivalent * fixing clippy error
This commit is contained in:
parent
bb13729b20
commit
f7639bd35a
|
@ -131,40 +131,41 @@ for the sake of simplicity, we ignore type signatures. For more details, refer t
|
||||||
|
|
||||||
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
Those operations are available for all tensor kinds: `Int`, `Float`, and `Bool`.
|
||||||
|
|
||||||
| Burn | PyTorch Equivalent |
|
| Burn | PyTorch Equivalent |
|
||||||
| ------------------------------------- | ------------------------------------ |
|
| ------------------------------------- | ------------------------------------------------------------------------ |
|
||||||
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
|
| `Tensor::cat(tensors, dim)` | `torch.cat(tensors, dim)` |
|
||||||
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
|
| `Tensor::empty(shape, device)` | `torch.empty(shape, device=device)` |
|
||||||
| `Tensor::from_primitive(primitive)` | N/A |
|
| `Tensor::from_primitive(primitive)` | N/A |
|
||||||
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
|
| `Tensor::stack(tensors, dim)` | `torch.stack(tensors, dim)` |
|
||||||
| `tensor.all()` | `tensor.all()` |
|
| `tensor.all()` | `tensor.all()` |
|
||||||
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
|
| `tensor.all_dim(dim)` | `tensor.all(dim)` |
|
||||||
| `tensor.any()` | `tensor.any()` |
|
| `tensor.any()` | `tensor.any()` |
|
||||||
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
|
| `tensor.any_dim(dim)` | `tensor.any(dim)` |
|
||||||
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
|
| `tensor.chunk(num_chunks, dim)` | `tensor.chunk(num_chunks, dim)` |
|
||||||
| `tensor.device()` | `tensor.device` |
|
| `tensor.device()` | `tensor.device` |
|
||||||
| `tensor.dims()` | `tensor.size()` |
|
| `tensor.dims()` | `tensor.size()` |
|
||||||
| `tensor.equal(other)` | `x == y` |
|
| `tensor.equal(other)` | `x == y` |
|
||||||
| `tensor.expand(shape)` | `tensor.expand(shape)` |
|
| `tensor.expand(shape)` | `tensor.expand(shape)` |
|
||||||
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
| `tensor.flatten(start_dim, end_dim)` | `tensor.flatten(start_dim, end_dim)` |
|
||||||
| `tensor.flip(axes)` | `tensor.flip(axes)` |
|
| `tensor.flip(axes)` | `tensor.flip(axes)` |
|
||||||
| `tensor.into_data()` | N/A |
|
| `tensor.into_data()` | N/A |
|
||||||
| `tensor.into_primitive()` | N/A |
|
| `tensor.into_primitive()` | N/A |
|
||||||
| `tensor.into_scalar()` | `tensor.item()` |
|
| `tensor.into_scalar()` | `tensor.item()` |
|
||||||
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
|
| `tensor.narrow(dim, start, length)` | `tensor.narrow(dim, start, length)` |
|
||||||
| `tensor.not_equal(other)` | `x != y` |
|
| `tensor.not_equal(other)` | `x != y` |
|
||||||
| `tensor.permute(axes)` | `tensor.permute(axes)` |
|
| `tensor.permute(axes)` | `tensor.permute(axes)` |
|
||||||
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
|
| `tensor.movedim(src, dst)` | `tensor.movedim(src, dst)` |
|
||||||
| `tensor.repeat(2, 4)` | `tensor.repeat([1, 1, 4])` |
|
| `tensor.repeat_dim(dim, times)` | `tensor.repeat(*[times if i == dim else 1 for i in range(tensor.dim())])`|
|
||||||
| `tensor.reshape(shape)` | `tensor.view(shape)` |
|
| `tensor.repeat(sizes)` | `tensor.repeat(sizes)` |
|
||||||
| `tensor.shape()` | `tensor.shape` |
|
| `tensor.reshape(shape)` | `tensor.view(shape)` |
|
||||||
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
| `tensor.shape()` | `tensor.shape` |
|
||||||
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
|
| `tensor.slice(ranges)` | `tensor[(*ranges,)]` |
|
||||||
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
|
| `tensor.slice_assign(ranges, values)` | `tensor[(*ranges,)] = values` |
|
||||||
| `tensor.to_data()` | N/A |
|
| `tensor.squeeze(dim)` | `tensor.squeeze(dim)` |
|
||||||
| `tensor.to_device(device)` | `tensor.to(device)` |
|
| `tensor.to_data()` | N/A |
|
||||||
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
|
| `tensor.to_device(device)` | `tensor.to(device)` |
|
||||||
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
|
| `tensor.unsqueeze()` | `tensor.unsqueeze(0)` |
|
||||||
|
| `tensor.unsqueeze_dim(dim)` | `tensor.unsqueeze(dim)` |
|
||||||
|
|
||||||
### Numeric Operations
|
### Numeric Operations
|
||||||
|
|
||||||
|
|
|
@ -132,11 +132,11 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
|
||||||
B::bool_expand(tensor, shape)
|
B::bool_expand(tensor, shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bool_repeat<const D: usize>(
|
fn bool_repeat_dim<const D: usize>(
|
||||||
tensor: BoolTensor<B, D>,
|
tensor: BoolTensor<B, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> BoolTensor<B, D> {
|
) -> BoolTensor<B, D> {
|
||||||
B::bool_repeat(tensor, dim, times)
|
B::bool_repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -162,12 +162,12 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
|
||||||
B::int_mean_dim(tensor, dim)
|
B::int_mean_dim(tensor, dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_repeat<const D: usize>(
|
fn int_repeat_dim<const D: usize>(
|
||||||
tensor: IntTensor<B, D>,
|
tensor: IntTensor<B, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> IntTensor<B, D> {
|
) -> IntTensor<B, D> {
|
||||||
B::int_repeat(tensor, dim, times)
|
B::int_repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_greater<const D: usize>(lhs: IntTensor<B, D>, rhs: IntTensor<B, D>) -> BoolTensor<B, D> {
|
fn int_greater<const D: usize>(lhs: IntTensor<B, D>, rhs: IntTensor<B, D>) -> BoolTensor<B, D> {
|
||||||
|
|
|
@ -2418,7 +2418,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
||||||
B::float_argsort(tensor.primitive, dim, descending)
|
B::float_argsort(tensor.primitive, dim, descending)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn float_repeat<const D: usize>(
|
fn float_repeat_dim<const D: usize>(
|
||||||
tensor: FloatTensor<Self, D>,
|
tensor: FloatTensor<Self, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
|
@ -2437,7 +2437,7 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
||||||
impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> {
|
impl<B: Backend, const D: usize> RetroForward for RetroRepeat<B, D> {
|
||||||
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
|
fn forward(&self, states: &mut BackwardStates, out_node: NodeID) {
|
||||||
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
|
let tensor = states.get_state::<B::FloatTensorPrimitive<D>>(&self.tensor_id);
|
||||||
let out = B::float_repeat(tensor, self.dim, self.times);
|
let out = B::float_repeat_dim(tensor, self.dim, self.times);
|
||||||
states.save(out_node, out)
|
states.save(out_node, out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2467,9 +2467,11 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
||||||
.stateful()
|
.stateful()
|
||||||
{
|
{
|
||||||
OpsKind::Tracked(prep) => {
|
OpsKind::Tracked(prep) => {
|
||||||
prep.finish(dim, B::float_repeat(tensor.primitive, dim, times))
|
prep.finish(dim, B::float_repeat_dim(tensor.primitive, dim, times))
|
||||||
|
}
|
||||||
|
OpsKind::UnTracked(prep) => {
|
||||||
|
prep.finish(B::float_repeat_dim(tensor.primitive, dim, times))
|
||||||
}
|
}
|
||||||
OpsKind::UnTracked(prep) => prep.finish(B::float_repeat(tensor.primitive, dim, times)),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,7 @@ mod permute;
|
||||||
mod pow;
|
mod pow;
|
||||||
mod recip;
|
mod recip;
|
||||||
mod relu;
|
mod relu;
|
||||||
mod repeat;
|
mod repeat_dim;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
mod select;
|
mod select;
|
||||||
mod sigmoid;
|
mod sigmoid;
|
||||||
|
@ -133,6 +133,6 @@ macro_rules! testgen_all {
|
||||||
burn_autodiff::testgen_ad_sign!();
|
burn_autodiff::testgen_ad_sign!();
|
||||||
burn_autodiff::testgen_ad_expand!();
|
burn_autodiff::testgen_ad_expand!();
|
||||||
burn_autodiff::testgen_ad_sort!();
|
burn_autodiff::testgen_ad_sort!();
|
||||||
burn_autodiff::testgen_ad_repeat!();
|
burn_autodiff::testgen_ad_repeat_dim!();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#[burn_tensor_testgen::testgen(ad_repeat)]
|
#[burn_tensor_testgen::testgen(ad_repeat_dim)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use burn_tensor::{activation, TensorData};
|
use burn_tensor::{activation, TensorData};
|
||||||
|
@ -12,7 +12,7 @@ mod tests {
|
||||||
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
let tensor_1 = TestAutodiffTensor::<2>::from_data(data_1, &device).require_grad();
|
||||||
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
let tensor_2 = TestAutodiffTensor::from_data(data_2, &device).require_grad();
|
||||||
|
|
||||||
let tensor_3 = tensor_2.clone().repeat(1, 3);
|
let tensor_3 = tensor_2.clone().repeat_dim(1, 3);
|
||||||
|
|
||||||
let tensor_3 = tensor_1.matmul(tensor_3);
|
let tensor_3 = tensor_1.matmul(tensor_3);
|
||||||
let grads = tensor_3.backward();
|
let grads = tensor_3.backward();
|
|
@ -94,7 +94,7 @@ mod tests {
|
||||||
// burn_tensor::testgen_powf!();
|
// burn_tensor::testgen_powf!();
|
||||||
|
|
||||||
burn_tensor::testgen_random!();
|
burn_tensor::testgen_random!();
|
||||||
burn_tensor::testgen_repeat!();
|
burn_tensor::testgen_repeat_dim!();
|
||||||
burn_tensor::testgen_reshape!();
|
burn_tensor::testgen_reshape!();
|
||||||
burn_tensor::testgen_select!();
|
burn_tensor::testgen_select!();
|
||||||
burn_tensor::testgen_sin!();
|
burn_tensor::testgen_sin!();
|
||||||
|
|
|
@ -19,7 +19,7 @@ pub fn generate_autoregressive_mask<B: Backend>(
|
||||||
mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
|
mask = mask.slice_assign([0..1, i..i + 1, i + 1..seq_length], values);
|
||||||
}
|
}
|
||||||
|
|
||||||
mask = mask.repeat(0, batch_size);
|
mask = mask.repeat_dim(0, batch_size);
|
||||||
|
|
||||||
mask.equal_elem(1_i64.elem::<i64>())
|
mask.equal_elem(1_i64.elem::<i64>())
|
||||||
}
|
}
|
||||||
|
|
|
@ -152,7 +152,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
|
||||||
* weights
|
* weights
|
||||||
.clone()
|
.clone()
|
||||||
.reshape([1, nr_classes])
|
.reshape([1, nr_classes])
|
||||||
.repeat(0, batch_size);
|
.repeat_dim(0, batch_size);
|
||||||
let weights = weights.clone().gather(0, targets);
|
let weights = weights.clone().gather(0, targets);
|
||||||
let tensor = Self::apply_mask_2d(tensor, mask);
|
let tensor = Self::apply_mask_2d(tensor, mask);
|
||||||
tensor.sum().neg() / weights.sum()
|
tensor.sum().neg() / weights.sum()
|
||||||
|
@ -224,7 +224,7 @@ impl<B: Backend> CrossEntropyLoss<B> {
|
||||||
fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
|
fn apply_mask_2d(mut tensor: Tensor<B, 2>, mask: Option<Tensor<B, 1, Bool>>) -> Tensor<B, 2> {
|
||||||
if let Some(mask) = mask {
|
if let Some(mask) = mask {
|
||||||
let [batch_size, nr_classes] = tensor.dims();
|
let [batch_size, nr_classes] = tensor.dims();
|
||||||
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat(1, nr_classes), 0);
|
tensor = tensor.mask_fill(mask.reshape([batch_size, 1]).repeat_dim(1, nr_classes), 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
tensor
|
tensor
|
||||||
|
@ -312,7 +312,7 @@ mod tests {
|
||||||
* targets_logits
|
* targets_logits
|
||||||
* Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
|
* Tensor::<TestBackend, 1>::from_floats(weights.as_slice(), &device)
|
||||||
.unsqueeze()
|
.unsqueeze()
|
||||||
.repeat(0, 4);
|
.repeat_dim(0, 4);
|
||||||
let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
|
let loss_2 = loss_2.sum().neg() / (1. + 2. + 3. + 5.);
|
||||||
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
|
loss_1.into_data().assert_approx_eq(&loss_2.into_data(), 3);
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,7 +58,7 @@ impl RotaryEncodingConfig {
|
||||||
.float()
|
.float()
|
||||||
.unsqueeze()
|
.unsqueeze()
|
||||||
.transpose()
|
.transpose()
|
||||||
.repeat(1, self.d_model / 2)
|
.repeat_dim(1, self.d_model / 2)
|
||||||
* theta_i.unsqueeze();
|
* theta_i.unsqueeze();
|
||||||
|
|
||||||
// Convert frequency values to complex numbers (polar form)
|
// Convert frequency values to complex numbers (polar form)
|
||||||
|
@ -71,7 +71,7 @@ impl RotaryEncodingConfig {
|
||||||
.reshape([self.max_sequence_length, 2, self.d_model / 2])
|
.reshape([self.max_sequence_length, 2, self.d_model / 2])
|
||||||
.transpose()
|
.transpose()
|
||||||
.unsqueeze_dim::<4>(2)
|
.unsqueeze_dim::<4>(2)
|
||||||
.repeat(2, 2)
|
.repeat_dim(2, 2)
|
||||||
.reshape([self.max_sequence_length, self.d_model, 2]);
|
.reshape([self.max_sequence_length, self.d_model, 2]);
|
||||||
|
|
||||||
RotaryEncoding {
|
RotaryEncoding {
|
||||||
|
|
|
@ -14,7 +14,7 @@ use burn_tensor::{
|
||||||
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
BaseOperationDescription, BinaryOperationDescription, BoolOperationDescription,
|
||||||
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
CatOperationDescription, ExpandOperationDescription, FlipOperationDescription,
|
||||||
HandleContainer, OperationDescription, PermuteOperationDescription,
|
HandleContainer, OperationDescription, PermuteOperationDescription,
|
||||||
RepeatOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
|
RepeatDimOperationDescription, ReshapeDescription, SliceAssignOperationDescription,
|
||||||
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
|
SliceOperationDescription, SwapDimsDescription, UnaryOperationDescription,
|
||||||
},
|
},
|
||||||
Device, Shape,
|
Device, Shape,
|
||||||
|
@ -575,22 +575,22 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bool_repeat<const D: usize>(
|
fn bool_repeat_dim<const D: usize>(
|
||||||
tensor: BoolTensor<Self, D>,
|
tensor: BoolTensor<Self, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> BoolTensor<Self, D> {
|
) -> BoolTensor<Self, D> {
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
struct RepeatOps<B: FusionBackend, const D: usize> {
|
struct RepeatDimOps<B: FusionBackend, const D: usize> {
|
||||||
desc: RepeatOperationDescription,
|
desc: RepeatDimOperationDescription,
|
||||||
_b: PhantomData<B>,
|
_b: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
|
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
|
||||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
||||||
let tensor = handles.get_bool_tensor::<B, D>(&self.desc.tensor);
|
let tensor = handles.get_bool_tensor::<B, D>(&self.desc.tensor);
|
||||||
|
|
||||||
let output = B::bool_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
let output = B::bool_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
|
||||||
|
|
||||||
handles.register_bool_tensor::<B, D>(&self.desc.out.id, output);
|
handles.register_bool_tensor::<B, D>(&self.desc.out.id, output);
|
||||||
}
|
}
|
||||||
|
@ -601,7 +601,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
||||||
shape[dim] *= times;
|
shape[dim] *= times;
|
||||||
let out = tensor.client.tensor_uninitialized(shape, DType::Bool);
|
let out = tensor.client.tensor_uninitialized(shape, DType::Bool);
|
||||||
|
|
||||||
let desc = RepeatOperationDescription {
|
let desc = RepeatDimOperationDescription {
|
||||||
tensor: tensor.into_description(),
|
tensor: tensor.into_description(),
|
||||||
dim,
|
dim,
|
||||||
times,
|
times,
|
||||||
|
@ -609,8 +609,8 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
||||||
};
|
};
|
||||||
out.client.register(
|
out.client.register(
|
||||||
vec![stream],
|
vec![stream],
|
||||||
OperationDescription::BaseBool(BaseOperationDescription::Repeat(desc.clone())),
|
OperationDescription::BaseBool(BaseOperationDescription::RepeatDim(desc.clone())),
|
||||||
RepeatOps::<B, D>::new(desc),
|
RepeatDimOps::<B, D>::new(desc),
|
||||||
);
|
);
|
||||||
|
|
||||||
out
|
out
|
||||||
|
|
|
@ -1696,22 +1696,22 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
fn float_repeat<const D: usize>(
|
fn float_repeat_dim<const D: usize>(
|
||||||
tensor: FloatTensor<Self, D>,
|
tensor: FloatTensor<Self, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> FloatTensor<Self, D> {
|
) -> FloatTensor<Self, D> {
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
struct RepeatOps<B: FusionBackend, const D: usize> {
|
struct RepeatDimOps<B: FusionBackend, const D: usize> {
|
||||||
desc: RepeatOperationDescription,
|
desc: RepeatDimOperationDescription,
|
||||||
_b: PhantomData<B>,
|
_b: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
|
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
|
||||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
||||||
let tensor = handles.get_float_tensor::<B, D>(&self.desc.tensor);
|
let tensor = handles.get_float_tensor::<B, D>(&self.desc.tensor);
|
||||||
|
|
||||||
let output = B::float_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
let output = B::float_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
|
||||||
|
|
||||||
handles.register_float_tensor::<B, D>(&self.desc.out.id, output);
|
handles.register_float_tensor::<B, D>(&self.desc.out.id, output);
|
||||||
}
|
}
|
||||||
|
@ -1724,7 +1724,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
||||||
.client
|
.client
|
||||||
.tensor_uninitialized(shape, B::FloatElem::dtype());
|
.tensor_uninitialized(shape, B::FloatElem::dtype());
|
||||||
|
|
||||||
let desc = RepeatOperationDescription {
|
let desc = RepeatDimOperationDescription {
|
||||||
tensor: tensor.into_description(),
|
tensor: tensor.into_description(),
|
||||||
dim,
|
dim,
|
||||||
times,
|
times,
|
||||||
|
@ -1732,8 +1732,8 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
||||||
};
|
};
|
||||||
out.client.register(
|
out.client.register(
|
||||||
vec![stream],
|
vec![stream],
|
||||||
OperationDescription::BaseFloat(BaseOperationDescription::Repeat(desc.clone())),
|
OperationDescription::BaseFloat(BaseOperationDescription::RepeatDim(desc.clone())),
|
||||||
RepeatOps::<B, D>::new(desc),
|
RepeatDimOps::<B, D>::new(desc),
|
||||||
);
|
);
|
||||||
|
|
||||||
out
|
out
|
||||||
|
|
|
@ -1755,22 +1755,22 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_repeat<const D: usize>(
|
fn int_repeat_dim<const D: usize>(
|
||||||
tensor: IntTensor<Self, D>,
|
tensor: IntTensor<Self, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> IntTensor<Self, D> {
|
) -> IntTensor<Self, D> {
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
struct RepeatOps<B: FusionBackend, const D: usize> {
|
struct RepeatDimOps<B: FusionBackend, const D: usize> {
|
||||||
desc: RepeatOperationDescription,
|
desc: RepeatDimOperationDescription,
|
||||||
_b: PhantomData<B>,
|
_b: PhantomData<B>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatOps<B, D> {
|
impl<const D: usize, B: FusionBackend> Operation<B::FusionRuntime> for RepeatDimOps<B, D> {
|
||||||
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
fn execute(self: Box<Self>, handles: &mut HandleContainer<B::Handle>) {
|
||||||
let tensor = handles.get_int_tensor::<B, D>(&self.desc.tensor);
|
let tensor = handles.get_int_tensor::<B, D>(&self.desc.tensor);
|
||||||
|
|
||||||
let output = B::int_repeat::<D>(tensor, self.desc.dim, self.desc.times);
|
let output = B::int_repeat_dim::<D>(tensor, self.desc.dim, self.desc.times);
|
||||||
|
|
||||||
handles.register_int_tensor::<B, D>(&self.desc.out.id, output);
|
handles.register_int_tensor::<B, D>(&self.desc.out.id, output);
|
||||||
}
|
}
|
||||||
|
@ -1783,7 +1783,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
||||||
.client
|
.client
|
||||||
.tensor_uninitialized(shape, B::IntElem::dtype());
|
.tensor_uninitialized(shape, B::IntElem::dtype());
|
||||||
|
|
||||||
let desc = RepeatOperationDescription {
|
let desc = RepeatDimOperationDescription {
|
||||||
tensor: tensor.into_description(),
|
tensor: tensor.into_description(),
|
||||||
dim,
|
dim,
|
||||||
times,
|
times,
|
||||||
|
@ -1791,8 +1791,8 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
||||||
};
|
};
|
||||||
out.client.register(
|
out.client.register(
|
||||||
vec![stream],
|
vec![stream],
|
||||||
OperationDescription::BaseInt(BaseOperationDescription::Repeat(desc.clone())),
|
OperationDescription::BaseInt(BaseOperationDescription::RepeatDim(desc.clone())),
|
||||||
RepeatOps::<B, D>::new(desc),
|
RepeatDimOps::<B, D>::new(desc),
|
||||||
);
|
);
|
||||||
|
|
||||||
out
|
out
|
||||||
|
|
|
@ -848,8 +848,8 @@ impl RelativeOps for BaseOperationDescription {
|
||||||
out: desc.out.to_relative(converter),
|
out: desc.out.to_relative(converter),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
BaseOperationDescription::Repeat(desc) => {
|
BaseOperationDescription::RepeatDim(desc) => {
|
||||||
BaseOperationDescription::Repeat(RepeatOperationDescription {
|
BaseOperationDescription::RepeatDim(RepeatDimOperationDescription {
|
||||||
tensor: desc.tensor.to_relative(converter),
|
tensor: desc.tensor.to_relative(converter),
|
||||||
dim: desc.dim,
|
dim: desc.dim,
|
||||||
times: desc.times,
|
times: desc.times,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
mod flip;
|
mod flip;
|
||||||
mod gather;
|
mod gather;
|
||||||
mod repeat;
|
mod repeat_dim;
|
||||||
mod scatter;
|
mod scatter;
|
||||||
mod select;
|
mod select;
|
||||||
mod select_assign;
|
mod select_assign;
|
||||||
|
@ -8,7 +8,7 @@ mod slice;
|
||||||
mod slice_assign;
|
mod slice_assign;
|
||||||
|
|
||||||
pub use flip::*;
|
pub use flip::*;
|
||||||
pub use repeat::*;
|
pub use repeat_dim::*;
|
||||||
pub use select::*;
|
pub use select::*;
|
||||||
pub use select_assign::*;
|
pub use select_assign::*;
|
||||||
pub use slice::*;
|
pub use slice::*;
|
||||||
|
|
|
@ -95,7 +95,7 @@ impl<R: JitRuntime, E: JitElement> Kernel for RepeatEagerKernel<R, E> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn repeat<R: JitRuntime, E: JitElement, const D1: usize>(
|
pub(crate) fn repeat_dim<R: JitRuntime, E: JitElement, const D1: usize>(
|
||||||
input: JitTensor<R, E, D1>,
|
input: JitTensor<R, E, D1>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
|
@ -94,12 +94,12 @@ where
|
||||||
tensor
|
tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bool_repeat<const D: usize>(
|
fn bool_repeat_dim<const D: usize>(
|
||||||
tensor: BoolTensor<Self, D>,
|
tensor: BoolTensor<Self, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> BoolTensor<Self, D> {
|
) -> BoolTensor<Self, D> {
|
||||||
kernel::repeat(tensor, dim, times)
|
kernel::repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bool_permute<const D: usize>(
|
fn bool_permute<const D: usize>(
|
||||||
|
|
|
@ -467,12 +467,12 @@ where
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn float_repeat<const D: usize>(
|
fn float_repeat_dim<const D: usize>(
|
||||||
tensor: FloatTensor<Self, D>,
|
tensor: FloatTensor<Self, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> FloatTensor<Self, D> {
|
) -> FloatTensor<Self, D> {
|
||||||
kernel::repeat(tensor, dim, times)
|
kernel::repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn float_powf<const D: usize>(
|
fn float_powf<const D: usize>(
|
||||||
|
|
|
@ -318,12 +318,12 @@ where
|
||||||
tensor
|
tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_repeat<const D: usize>(
|
fn int_repeat_dim<const D: usize>(
|
||||||
tensor: IntTensor<Self, D>,
|
tensor: IntTensor<Self, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> IntTensor<Self, D> {
|
) -> IntTensor<Self, D> {
|
||||||
kernel::repeat(tensor, dim, times)
|
kernel::repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_random<const D: usize>(
|
fn int_random<const D: usize>(
|
||||||
|
|
|
@ -17,7 +17,7 @@ mod max_pool2d;
|
||||||
mod max_pool2d_backward;
|
mod max_pool2d_backward;
|
||||||
mod normal;
|
mod normal;
|
||||||
mod reduce;
|
mod reduce;
|
||||||
mod repeat;
|
mod repeat_dim;
|
||||||
mod scatter;
|
mod scatter;
|
||||||
mod select;
|
mod select;
|
||||||
mod select_assign;
|
mod select_assign;
|
||||||
|
@ -48,7 +48,7 @@ macro_rules! testgen_all {
|
||||||
burn_jit::testgen_conv_transpose2d!();
|
burn_jit::testgen_conv_transpose2d!();
|
||||||
burn_jit::testgen_conv_transpose3d!();
|
burn_jit::testgen_conv_transpose3d!();
|
||||||
|
|
||||||
burn_jit::testgen_repeat!();
|
burn_jit::testgen_repeat_dim!();
|
||||||
burn_jit::testgen_gather!();
|
burn_jit::testgen_gather!();
|
||||||
burn_jit::testgen_scatter!();
|
burn_jit::testgen_scatter!();
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
#[burn_tensor_testgen::testgen(repeat)]
|
#[burn_tensor_testgen::testgen(repeat_dim)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use burn_tensor::{Distribution, Tensor};
|
use burn_tensor::{Distribution, Tensor};
|
||||||
|
@ -12,8 +12,8 @@ mod tests {
|
||||||
let tensor_ref =
|
let tensor_ref =
|
||||||
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
||||||
|
|
||||||
let actual = tensor.repeat(dim, times);
|
let actual = tensor.repeat_dim(dim, times);
|
||||||
let expected = tensor_ref.repeat(dim, times);
|
let expected = tensor_ref.repeat_dim(dim, times);
|
||||||
|
|
||||||
expected
|
expected
|
||||||
.into_data()
|
.into_data()
|
||||||
|
@ -29,8 +29,8 @@ mod tests {
|
||||||
let tensor_ref =
|
let tensor_ref =
|
||||||
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
||||||
|
|
||||||
let actual = tensor.repeat(dim, times);
|
let actual = tensor.repeat_dim(dim, times);
|
||||||
let expected = tensor_ref.repeat(dim, times);
|
let expected = tensor_ref.repeat_dim(dim, times);
|
||||||
|
|
||||||
expected
|
expected
|
||||||
.into_data()
|
.into_data()
|
||||||
|
@ -46,8 +46,8 @@ mod tests {
|
||||||
let tensor_ref =
|
let tensor_ref =
|
||||||
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
||||||
|
|
||||||
let actual = tensor.repeat(dim, times);
|
let actual = tensor.repeat_dim(dim, times);
|
||||||
let expected = tensor_ref.repeat(dim, times);
|
let expected = tensor_ref.repeat_dim(dim, times);
|
||||||
|
|
||||||
expected
|
expected
|
||||||
.into_data()
|
.into_data()
|
||||||
|
@ -66,8 +66,8 @@ mod tests {
|
||||||
let tensor_ref =
|
let tensor_ref =
|
||||||
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
Tensor::<ReferenceBackend, 3>::from_data(tensor.to_data(), &Default::default());
|
||||||
|
|
||||||
let actual = tensor.repeat(dim, times);
|
let actual = tensor.repeat_dim(dim, times);
|
||||||
let expected = tensor_ref.repeat(dim, times);
|
let expected = tensor_ref.repeat_dim(dim, times);
|
||||||
|
|
||||||
expected
|
expected
|
||||||
.into_data()
|
.into_data()
|
|
@ -33,7 +33,7 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
||||||
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
|
TchTensor::from_existing(tensor.tensor.reshape(shape_tch.dims), tensor.storage)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn repeat<const D: usize>(
|
pub fn repeat_dim<const D: usize>(
|
||||||
tensor: TchTensor<E, D>,
|
tensor: TchTensor<E, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
|
|
|
@ -15,12 +15,12 @@ impl<E: TchElement, Q: QuantElement> BoolTensorOps<Self> for LibTorch<E, Q> {
|
||||||
tensor.shape()
|
tensor.shape()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bool_repeat<const D: usize>(
|
fn bool_repeat_dim<const D: usize>(
|
||||||
tensor: TchTensor<bool, D>,
|
tensor: TchTensor<bool, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> TchTensor<bool, D> {
|
) -> TchTensor<bool, D> {
|
||||||
TchOps::repeat(tensor, dim, times)
|
TchOps::repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> TensorData {
|
async fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> TensorData {
|
||||||
|
|
|
@ -18,12 +18,12 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
|
||||||
tensor.shape()
|
tensor.shape()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_repeat<const D: usize>(
|
fn int_repeat_dim<const D: usize>(
|
||||||
tensor: TchTensor<i64, D>,
|
tensor: TchTensor<i64, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> TchTensor<i64, D> {
|
) -> TchTensor<i64, D> {
|
||||||
TchOps::repeat(tensor, dim, times)
|
TchOps::repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> TensorData {
|
async fn int_into_data<const D: usize>(tensor: TchTensor<i64, D>) -> TensorData {
|
||||||
|
|
|
@ -42,12 +42,12 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn float_repeat<const D: usize>(
|
fn float_repeat_dim<const D: usize>(
|
||||||
tensor: TchTensor<E, D>,
|
tensor: TchTensor<E, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> TchTensor<E, D> {
|
) -> TchTensor<E, D> {
|
||||||
TchOps::repeat(tensor, dim, times)
|
TchOps::repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn float_zeros<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> {
|
fn float_zeros<const D: usize>(shape: Shape<D>, device: &LibTorchDevice) -> TchTensor<E, D> {
|
||||||
|
|
|
@ -189,10 +189,10 @@ pub enum BaseOperationDescription {
|
||||||
Equal(BinaryOperationDescription),
|
Equal(BinaryOperationDescription),
|
||||||
/// Operation corresponding to:
|
/// Operation corresponding to:
|
||||||
///
|
///
|
||||||
/// Float => [repeat](crate::ops::FloatTensorOps::float_repeat).
|
/// Float => [repeat dim](crate::ops::FloatTensorOps::float_repeat_dim).
|
||||||
/// Int => [repeat](crate::ops::IntTensorOps::int_repeat).
|
/// Int => [repeat dim](crate::ops::IntTensorOps::int_repeat_dim).
|
||||||
/// Bool => [repeat](crate::ops::BoolTensorOps::bool_repeat).
|
/// Bool => [repeat dim](crate::ops::BoolTensorOps::bool_repeat_dim).
|
||||||
Repeat(RepeatOperationDescription),
|
RepeatDim(RepeatDimOperationDescription),
|
||||||
/// Operation corresponding to:
|
/// Operation corresponding to:
|
||||||
///
|
///
|
||||||
/// Float => [cat](crate::ops::FloatTensorOps::float_cat).
|
/// Float => [cat](crate::ops::FloatTensorOps::float_cat).
|
||||||
|
@ -627,7 +627,7 @@ pub struct ClampOperationDescription<E> {
|
||||||
|
|
||||||
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
#[derive(Clone, Debug, Hash, PartialEq, Serialize, Deserialize)]
|
||||||
#[allow(missing_docs)]
|
#[allow(missing_docs)]
|
||||||
pub struct RepeatOperationDescription {
|
pub struct RepeatDimOperationDescription {
|
||||||
pub tensor: TensorDescription,
|
pub tensor: TensorDescription,
|
||||||
pub dim: usize,
|
pub dim: usize,
|
||||||
pub times: usize,
|
pub times: usize,
|
||||||
|
@ -1189,7 +1189,7 @@ impl BaseOperationDescription {
|
||||||
BaseOperationDescription::Equal(desc) => {
|
BaseOperationDescription::Equal(desc) => {
|
||||||
vec![&desc.lhs, &desc.rhs, &desc.out]
|
vec![&desc.lhs, &desc.rhs, &desc.out]
|
||||||
}
|
}
|
||||||
BaseOperationDescription::Repeat(desc) => {
|
BaseOperationDescription::RepeatDim(desc) => {
|
||||||
vec![&desc.tensor, &desc.out]
|
vec![&desc.tensor, &desc.out]
|
||||||
}
|
}
|
||||||
BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),
|
BaseOperationDescription::Cat(desc) => desc.tensors.iter().collect(),
|
||||||
|
|
|
@ -722,8 +722,21 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Repeat the tensor along the given dimension.
|
/// Repeat the tensor along the given dimension.
|
||||||
pub fn repeat(self, dim: usize, times: usize) -> Self {
|
pub fn repeat_dim(self, dim: usize, times: usize) -> Self {
|
||||||
Self::new(K::repeat(self.primitive, dim, times))
|
Self::new(K::repeat_dim(self.primitive, dim, times))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Repeat the tensor along the given dimensions.
|
||||||
|
/// # Arguments
|
||||||
|
/// - `sizes`: Borrowed slice of the number of times to repeat each dimension.
|
||||||
|
pub fn repeat(self, sizes: &[usize]) -> Self {
|
||||||
|
let mut tensor = self;
|
||||||
|
for (dim, ×) in sizes.iter().enumerate() {
|
||||||
|
if times > 1 {
|
||||||
|
tensor = tensor.repeat_dim(dim, times);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tensor
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Applies element-wise equal comparison and returns a boolean tensor.
|
/// Applies element-wise equal comparison and returns a boolean tensor.
|
||||||
|
@ -1504,9 +1517,9 @@ pub trait BasicOps<B: Backend>: TensorKind<B> {
|
||||||
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
/// with static dispatch. It is not designed for direct usage by users, and not recommended to import
|
||||||
/// or use this function directly.
|
/// or use this function directly.
|
||||||
///
|
///
|
||||||
/// For repeating a tensor, users should prefer the [Tensor::repeat](Tensor::repeat) function,
|
/// For repeating a tensor, users should prefer the [Tensor::repeat_dim](Tensor::repeat_dim) function,
|
||||||
/// which is more high-level and designed for public use.
|
/// which is more high-level and designed for public use.
|
||||||
fn repeat<const D: usize>(
|
fn repeat_dim<const D: usize>(
|
||||||
tensor: Self::Primitive<D>,
|
tensor: Self::Primitive<D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
|
@ -1763,12 +1776,12 @@ impl<B: Backend> BasicOps<B> for Float {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat<const D: usize>(
|
fn repeat_dim<const D: usize>(
|
||||||
tensor: Self::Primitive<D>,
|
tensor: Self::Primitive<D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> Self::Primitive<D> {
|
) -> Self::Primitive<D> {
|
||||||
TensorPrimitive::Float(B::float_repeat(tensor.tensor(), dim, times))
|
TensorPrimitive::Float(B::float_repeat_dim(tensor.tensor(), dim, times))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
||||||
|
@ -1888,12 +1901,12 @@ impl<B: Backend> BasicOps<B> for Int {
|
||||||
B::int_from_data(data, device)
|
B::int_from_data(data, device)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat<const D: usize>(
|
fn repeat_dim<const D: usize>(
|
||||||
tensor: Self::Primitive<D>,
|
tensor: Self::Primitive<D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> Self::Primitive<D> {
|
) -> Self::Primitive<D> {
|
||||||
B::int_repeat(tensor, dim, times)
|
B::int_repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn equal<const D: usize>(
|
fn equal<const D: usize>(
|
||||||
|
@ -2010,12 +2023,12 @@ impl<B: Backend> BasicOps<B> for Bool {
|
||||||
B::bool_from_data(data, device)
|
B::bool_from_data(data, device)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn repeat<const D: usize>(
|
fn repeat_dim<const D: usize>(
|
||||||
tensor: Self::Primitive<D>,
|
tensor: Self::Primitive<D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
) -> Self::Primitive<D> {
|
) -> Self::Primitive<D> {
|
||||||
B::bool_repeat(tensor, dim, times)
|
B::bool_repeat_dim(tensor, dim, times)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn equal<const D: usize>(
|
fn equal<const D: usize>(
|
||||||
|
|
|
@ -46,7 +46,7 @@ pub fn cartesian_grid<B: Backend, S: Into<Shape<D>>, const D: usize, const D2: u
|
||||||
if i == dim {
|
if i == dim {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
dim_range = dim_range.repeat(i, item);
|
dim_range = dim_range.repeat_dim(i, item);
|
||||||
}
|
}
|
||||||
|
|
||||||
indices.push(dim_range);
|
indices.push(dim_range);
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use super::{
|
use super::{
|
||||||
cat::cat_with_slice_assign, repeat::repeat_with_slice_assign, BoolTensor, Device, FloatTensor,
|
cat::cat_with_slice_assign, repeat_dim::repeat_with_slice_assign, BoolTensor, Device,
|
||||||
IntTensor,
|
FloatTensor, IntTensor,
|
||||||
};
|
};
|
||||||
use crate::{
|
use crate::{
|
||||||
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
|
argwhere_data, backend::Backend, chunk, narrow, tensor::Shape, Bool, ElementConversion, Tensor,
|
||||||
|
@ -157,7 +157,7 @@ pub trait BoolTensorOps<B: Backend> {
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// The tensor with the dimension repeated.
|
/// The tensor with the dimension repeated.
|
||||||
fn bool_repeat<const D: usize>(
|
fn bool_repeat_dim<const D: usize>(
|
||||||
tensor: BoolTensor<B, D>,
|
tensor: BoolTensor<B, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use super::cat::cat_with_slice_assign;
|
use super::cat::cat_with_slice_assign;
|
||||||
use super::repeat::repeat_with_slice_assign;
|
use super::repeat_dim::repeat_with_slice_assign;
|
||||||
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||||
use crate::cast::ToElement;
|
use crate::cast::ToElement;
|
||||||
use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData};
|
use crate::{backend::Backend, tensor::Shape, Distribution, ElementConversion, Int, TensorData};
|
||||||
|
@ -251,7 +251,7 @@ pub trait IntTensorOps<B: Backend> {
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// The tensor with the given dimension repeated the given number of times.
|
/// The tensor with the given dimension repeated the given number of times.
|
||||||
fn int_repeat<const D: usize>(
|
fn int_repeat_dim<const D: usize>(
|
||||||
tensor: IntTensor<B, D>,
|
tensor: IntTensor<B, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
|
|
|
@ -4,7 +4,7 @@ pub mod conv;
|
||||||
/// Module with cat operation
|
/// Module with cat operation
|
||||||
pub(crate) mod cat;
|
pub(crate) mod cat;
|
||||||
/// Module with repeat operation
|
/// Module with repeat operation
|
||||||
pub(crate) mod repeat;
|
pub(crate) mod repeat_dim;
|
||||||
/// Module with unfold operations.
|
/// Module with unfold operations.
|
||||||
pub(crate) mod unfold;
|
pub(crate) mod unfold;
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
use super::cat::cat_with_slice_assign;
|
use super::cat::cat_with_slice_assign;
|
||||||
use super::repeat::repeat_with_slice_assign;
|
use super::repeat_dim::repeat_with_slice_assign;
|
||||||
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
|
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
|
||||||
use crate::backend::BackendBridge;
|
use crate::backend::BackendBridge;
|
||||||
use crate::tensor::cast::ToElement;
|
use crate::tensor::cast::ToElement;
|
||||||
|
@ -174,7 +174,7 @@ pub trait FloatTensorOps<B: Backend> {
|
||||||
/// # Returns
|
/// # Returns
|
||||||
///
|
///
|
||||||
/// The tensor with the given dimension repeated.
|
/// The tensor with the given dimension repeated.
|
||||||
fn float_repeat<const D: usize>(
|
fn float_repeat_dim<const D: usize>(
|
||||||
tensor: FloatTensor<B, D>,
|
tensor: FloatTensor<B, D>,
|
||||||
dim: usize,
|
dim: usize,
|
||||||
times: usize,
|
times: usize,
|
||||||
|
|
|
@ -349,7 +349,7 @@ mod tests {
|
||||||
clone_invariance_test!(
|
clone_invariance_test!(
|
||||||
unary: Repeat,
|
unary: Repeat,
|
||||||
ops_float: |tensor: TestTensor<2>| {
|
ops_float: |tensor: TestTensor<2>| {
|
||||||
tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32])
|
tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
clone_invariance_test!(
|
clone_invariance_test!(
|
||||||
|
@ -633,7 +633,7 @@ mod tests {
|
||||||
clone_invariance_test!(
|
clone_invariance_test!(
|
||||||
unary: Repeat,
|
unary: Repeat,
|
||||||
ops_int: |tensor: TestTensorInt<2>| {
|
ops_int: |tensor: TestTensorInt<2>| {
|
||||||
tensor.reshape([1, 32, 32]).repeat(0, 4).reshape([4 * 32, 32])
|
tensor.reshape([1, 32, 32]).repeat_dim(0, 4).reshape([4 * 32, 32])
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
clone_invariance_test!(
|
clone_invariance_test!(
|
||||||
|
|
|
@ -74,6 +74,7 @@ macro_rules! testgen_all {
|
||||||
burn_tensor::testgen_powf_scalar!();
|
burn_tensor::testgen_powf_scalar!();
|
||||||
burn_tensor::testgen_random!();
|
burn_tensor::testgen_random!();
|
||||||
burn_tensor::testgen_recip!();
|
burn_tensor::testgen_recip!();
|
||||||
|
burn_tensor::testgen_repeat_dim!();
|
||||||
burn_tensor::testgen_repeat!();
|
burn_tensor::testgen_repeat!();
|
||||||
burn_tensor::testgen_reshape!();
|
burn_tensor::testgen_reshape!();
|
||||||
burn_tensor::testgen_select!();
|
burn_tensor::testgen_select!();
|
||||||
|
|
|
@ -45,6 +45,7 @@ mod random;
|
||||||
mod recip;
|
mod recip;
|
||||||
mod remainder;
|
mod remainder;
|
||||||
mod repeat;
|
mod repeat;
|
||||||
|
mod repeat_dim;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
mod select;
|
mod select;
|
||||||
mod sign;
|
mod sign;
|
||||||
|
|
|
@ -4,27 +4,27 @@ mod tests {
|
||||||
use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};
|
use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_repeat_ops() {
|
fn should_support_repeat_ops_one_dimension() {
|
||||||
let data = TensorData::from([[0.0, 1.0, 2.0]]);
|
let data = TensorData::from([[0.0f32, 1.0f32, 2.0f32]]);
|
||||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
||||||
|
|
||||||
let output = tensor.repeat(0, 4);
|
let output = tensor.repeat(&[4, 1, 1]);
|
||||||
let expected = TensorData::from([
|
let expected = TensorData::from([
|
||||||
[0.0, 1.0, 2.0],
|
[0.0f32, 1.0f32, 2.0f32],
|
||||||
[0.0, 1.0, 2.0],
|
[0.0f32, 1.0f32, 2.0f32],
|
||||||
[0.0, 1.0, 2.0],
|
[0.0f32, 1.0f32, 2.0f32],
|
||||||
[0.0, 1.0, 2.0],
|
[0.0f32, 1.0f32, 2.0f32],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
output.into_data().assert_eq(&expected, false);
|
output.into_data().assert_eq(&expected, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_bool_repeat_ops() {
|
fn should_support_bool_repeat_ops_one_dimension() {
|
||||||
let data = TensorData::from([[true, false, false]]);
|
let data = TensorData::from([[true, false, false]]);
|
||||||
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default());
|
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default());
|
||||||
|
|
||||||
let output = tensor.repeat(0, 4);
|
let output = tensor.repeat(&[4, 1, 1]);
|
||||||
let expected = TensorData::from([
|
let expected = TensorData::from([
|
||||||
[true, false, false],
|
[true, false, false],
|
||||||
[true, false, false],
|
[true, false, false],
|
||||||
|
@ -35,70 +35,226 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_int_repeat_ops() {
|
fn should_support_int_repeat_ops_one_dimension() {
|
||||||
let data = TensorData::from([[0, 1, 2]]);
|
let data = TensorData::from([[0i32, 1i32, 2i32]]);
|
||||||
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
|
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
|
||||||
|
|
||||||
let output = tensor.repeat(0, 4);
|
let output = tensor.repeat(&[4, 1, 1]);
|
||||||
let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
|
let expected = TensorData::from([
|
||||||
|
[0i32, 1i32, 2i32],
|
||||||
|
[0i32, 1i32, 2i32],
|
||||||
|
[0i32, 1i32, 2i32],
|
||||||
|
[0i32, 1i32, 2i32],
|
||||||
|
]);
|
||||||
|
|
||||||
output.into_data().assert_eq(&expected, false);
|
output.into_data().assert_eq(&expected, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_float_repeat_on_dims_larger_than_1() {
|
fn should_support_float_repeat_repeating_on_many_dimensions() {
|
||||||
let data = TensorData::from([
|
let data = TensorData::from([
|
||||||
[[1.0, 2.0], [3.0, 4.0]],
|
[[1.0f32, 2.0f32], [3.0f32, 4.0f32]],
|
||||||
[[5.0, 6.0], [7.0, 8.0]],
|
[[5.0f32, 6.0f32], [7.0f32, 8.0f32]],
|
||||||
[[9.0, 10.0], [11.0, 12.0]],
|
[[9.0f32, 10.0f32], [11.0f32, 12.0f32]],
|
||||||
[[13.0, 14.0], [15.0, 16.0]],
|
[[13.0f32, 14.0f32], [15.0f32, 16.0f32]],
|
||||||
]);
|
]);
|
||||||
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
|
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
|
||||||
|
|
||||||
let output = tensor.repeat(2, 2);
|
let output = tensor.repeat(&[2, 3, 2]);
|
||||||
let expected = TensorData::from([
|
let expected = TensorData::from([
|
||||||
[[1.0, 2.0, 1.0, 2.0], [3.0, 4.0, 3.0, 4.0]],
|
[
|
||||||
[[5.0, 6.0, 5.0, 6.0], [7.0, 8.0, 7.0, 8.0]],
|
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||||
[[9.0, 10.0, 9.0, 10.0], [11.0, 12.0, 11.0, 12.0]],
|
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||||
[[13.0, 14.0, 13.0, 14.0], [15.0, 16.0, 15.0, 16.0]],
|
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||||
|
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||||
|
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||||
|
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||||
|
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||||
|
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||||
|
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||||
|
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||||
|
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||||
|
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||||
|
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||||
|
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||||
|
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||||
|
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||||
|
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||||
|
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||||
|
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||||
|
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||||
|
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||||
|
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||||
|
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||||
|
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||||
|
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||||
|
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||||
|
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||||
|
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||||
|
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||||
|
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||||
|
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||||
|
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||||
|
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||||
|
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||||
|
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||||
|
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||||
|
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||||
|
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||||
|
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||||
|
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||||
|
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||||
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
output.into_data().assert_eq(&expected, false);
|
output.into_data().assert_eq(&expected, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_int_repeat_on_dims_larger_than_1() {
|
fn should_support_int_repeat_on_many_dims() {
|
||||||
let data = TensorData::from([
|
let data = TensorData::from([
|
||||||
[[1, 2], [3, 4]],
|
[[1i32, 2i32], [3i32, 4i32]],
|
||||||
[[5, 6], [7, 8]],
|
[[5i32, 6i32], [7i32, 8i32]],
|
||||||
[[9, 10], [11, 12]],
|
[[9i32, 10i32], [11i32, 12i32]],
|
||||||
[[13, 14], [15, 16]],
|
[[13i32, 14i32], [15i32, 16i32]],
|
||||||
]);
|
]);
|
||||||
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
|
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
|
||||||
|
|
||||||
let output = tensor.repeat(2, 3);
|
let output = tensor.repeat(&[2, 3, 2]);
|
||||||
|
|
||||||
let expected = TensorData::from([
|
let expected = TensorData::from([
|
||||||
[[1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4]],
|
[
|
||||||
[[5, 6, 5, 6, 5, 6], [7, 8, 7, 8, 7, 8]],
|
[1i32, 2i32, 1i32, 2i32],
|
||||||
[[9, 10, 9, 10, 9, 10], [11, 12, 11, 12, 11, 12]],
|
[3i32, 4i32, 3i32, 4i32],
|
||||||
[[13, 14, 13, 14, 13, 14], [15, 16, 15, 16, 15, 16]],
|
[1i32, 2i32, 1i32, 2i32],
|
||||||
|
[3i32, 4i32, 3i32, 4i32],
|
||||||
|
[1i32, 2i32, 1i32, 2i32],
|
||||||
|
[3i32, 4i32, 3i32, 4i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[5i32, 6i32, 5i32, 6i32],
|
||||||
|
[7i32, 8i32, 7i32, 8i32],
|
||||||
|
[5i32, 6i32, 5i32, 6i32],
|
||||||
|
[7i32, 8i32, 7i32, 8i32],
|
||||||
|
[5i32, 6i32, 5i32, 6i32],
|
||||||
|
[7i32, 8i32, 7i32, 8i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[9i32, 10i32, 9i32, 10i32],
|
||||||
|
[11i32, 12i32, 11i32, 12i32],
|
||||||
|
[9i32, 10i32, 9i32, 10i32],
|
||||||
|
[11i32, 12i32, 11i32, 12i32],
|
||||||
|
[9i32, 10i32, 9i32, 10i32],
|
||||||
|
[11i32, 12i32, 11i32, 12i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[13i32, 14i32, 13i32, 14i32],
|
||||||
|
[15i32, 16i32, 15i32, 16i32],
|
||||||
|
[13i32, 14i32, 13i32, 14i32],
|
||||||
|
[15i32, 16i32, 15i32, 16i32],
|
||||||
|
[13i32, 14i32, 13i32, 14i32],
|
||||||
|
[15i32, 16i32, 15i32, 16i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[1i32, 2i32, 1i32, 2i32],
|
||||||
|
[3i32, 4i32, 3i32, 4i32],
|
||||||
|
[1i32, 2i32, 1i32, 2i32],
|
||||||
|
[3i32, 4i32, 3i32, 4i32],
|
||||||
|
[1i32, 2i32, 1i32, 2i32],
|
||||||
|
[3i32, 4i32, 3i32, 4i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[5i32, 6i32, 5i32, 6i32],
|
||||||
|
[7i32, 8i32, 7i32, 8i32],
|
||||||
|
[5i32, 6i32, 5i32, 6i32],
|
||||||
|
[7i32, 8i32, 7i32, 8i32],
|
||||||
|
[5i32, 6i32, 5i32, 6i32],
|
||||||
|
[7i32, 8i32, 7i32, 8i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[9i32, 10i32, 9i32, 10i32],
|
||||||
|
[11i32, 12i32, 11i32, 12i32],
|
||||||
|
[9i32, 10i32, 9i32, 10i32],
|
||||||
|
[11i32, 12i32, 11i32, 12i32],
|
||||||
|
[9i32, 10i32, 9i32, 10i32],
|
||||||
|
[11i32, 12i32, 11i32, 12i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[13i32, 14i32, 13i32, 14i32],
|
||||||
|
[15i32, 16i32, 15i32, 16i32],
|
||||||
|
[13i32, 14i32, 13i32, 14i32],
|
||||||
|
[15i32, 16i32, 15i32, 16i32],
|
||||||
|
[13i32, 14i32, 13i32, 14i32],
|
||||||
|
[15i32, 16i32, 15i32, 16i32],
|
||||||
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
output.into_data().assert_eq(&expected, false);
|
output.into_data().assert_eq(&expected, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_bool_repeat_on_dims_larger_than_1() {
|
fn should_support_bool_repeat_on_many_dimension() {
|
||||||
let data = TensorData::from([
|
let data = TensorData::from([
|
||||||
[[false, true], [true, false]],
|
[[false, true], [true, false]],
|
||||||
[[true, true], [false, false]],
|
[[true, true], [false, false]],
|
||||||
]);
|
]);
|
||||||
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
|
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
|
||||||
|
|
||||||
let output = tensor.repeat(1, 2);
|
let output = tensor.repeat(&[2, 3, 2]);
|
||||||
let expected = TensorData::from([
|
let expected = TensorData::from([
|
||||||
[[false, true], [true, false], [false, true], [true, false]],
|
[
|
||||||
[[true, true], [false, false], [true, true], [false, false]],
|
[false, true, false, true],
|
||||||
|
[true, false, true, false],
|
||||||
|
[false, true, false, true],
|
||||||
|
[true, false, true, false],
|
||||||
|
[false, true, false, true],
|
||||||
|
[true, false, true, false],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[true, true, true, true],
|
||||||
|
[false, false, false, false],
|
||||||
|
[true, true, true, true],
|
||||||
|
[false, false, false, false],
|
||||||
|
[true, true, true, true],
|
||||||
|
[false, false, false, false],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[false, true, false, true],
|
||||||
|
[true, false, true, false],
|
||||||
|
[false, true, false, true],
|
||||||
|
[true, false, true, false],
|
||||||
|
[false, true, false, true],
|
||||||
|
[true, false, true, false],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[true, true, true, true],
|
||||||
|
[false, false, false, false],
|
||||||
|
[true, true, true, true],
|
||||||
|
[false, false, false, false],
|
||||||
|
[true, true, true, true],
|
||||||
|
[false, false, false, false],
|
||||||
|
],
|
||||||
]);
|
]);
|
||||||
|
|
||||||
output.into_data().assert_eq(&expected, true);
|
output.into_data().assert_eq(&expected, true);
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
#[burn_tensor_testgen::testgen(repeat_dim)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{backend::Backend, Bool, Int, Tensor, TensorData};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_repeat_ops() {
|
||||||
|
let data = TensorData::from([[0.0f64, 1.0f64, 2.0f64]]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2>::from_data(data.clone(), &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.repeat_dim(0, 4);
|
||||||
|
let expected = TensorData::from([
|
||||||
|
[0.0f32, 1.0f32, 2.0f32],
|
||||||
|
[0.0f32, 1.0f32, 2.0f32],
|
||||||
|
[0.0f32, 1.0f32, 2.0f32],
|
||||||
|
[0.0f32, 1.0f32, 2.0f32],
|
||||||
|
]);
|
||||||
|
|
||||||
|
output.into_data().assert_eq(&expected, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_bool_repeat_ops() {
|
||||||
|
let data = TensorData::from([[true, false, false]]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Bool>::from_data(data, &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.repeat_dim(0, 4);
|
||||||
|
let expected = TensorData::from([
|
||||||
|
[true, false, false],
|
||||||
|
[true, false, false],
|
||||||
|
[true, false, false],
|
||||||
|
[true, false, false],
|
||||||
|
]);
|
||||||
|
output.into_data().assert_eq(&expected, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_int_repeat_ops() {
|
||||||
|
let data = TensorData::from([[0, 1, 2]]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2, Int>::from_data(data, &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.repeat_dim(0, 4);
|
||||||
|
let expected = TensorData::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
|
||||||
|
|
||||||
|
output.into_data().assert_eq(&expected, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_float_repeat_on_dims_larger_than_1() {
|
||||||
|
let data = TensorData::from([
|
||||||
|
[[1.0f32, 2.0f32], [3.0f32, 4.0f32]],
|
||||||
|
[[5.0f32, 6.0f32], [7.0f32, 8.0f32]],
|
||||||
|
[[9.0f32, 10.0f32], [11.0f32, 12.0f32]],
|
||||||
|
[[13.0f32, 14.0f32], [15.0f32, 16.0f32]],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.repeat_dim(2, 2);
|
||||||
|
let expected = TensorData::from([
|
||||||
|
[
|
||||||
|
[1.0f32, 2.0f32, 1.0f32, 2.0f32],
|
||||||
|
[3.0f32, 4.0f32, 3.0f32, 4.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[5.0f32, 6.0f32, 5.0f32, 6.0f32],
|
||||||
|
[7.0f32, 8.0f32, 7.0f32, 8.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[9.0f32, 10.0f32, 9.0f32, 10.0f32],
|
||||||
|
[11.0f32, 12.0f32, 11.0f32, 12.0f32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[13.0f32, 14.0f32, 13.0f32, 14.0f32],
|
||||||
|
[15.0f32, 16.0f32, 15.0f32, 16.0f32],
|
||||||
|
],
|
||||||
|
]);
|
||||||
|
|
||||||
|
output.into_data().assert_eq(&expected, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_int_repeat_on_dims_larger_than_1() {
|
||||||
|
let data = TensorData::from([
|
||||||
|
[[1i32, 2i32], [3i32, 4i32]],
|
||||||
|
[[5i32, 6i32], [7i32, 8i32]],
|
||||||
|
[[9i32, 10i32], [11i32, 12i32]],
|
||||||
|
[[13i32, 14i32], [15i32, 16i32]],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.repeat_dim(2, 3);
|
||||||
|
let expected = TensorData::from([
|
||||||
|
[
|
||||||
|
[1i32, 2i32, 1i32, 2i32, 1i32, 2i32],
|
||||||
|
[3i32, 4i32, 3i32, 4i32, 3i32, 4i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[5i32, 6i32, 5i32, 6i32, 5i32, 6i32],
|
||||||
|
[7i32, 8i32, 7i32, 8i32, 7i32, 8i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[9i32, 10i32, 9i32, 10i32, 9i32, 10i32],
|
||||||
|
[11i32, 12i32, 11i32, 12i32, 11i32, 12i32],
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[13i32, 14i32, 13i32, 14i32, 13i32, 14i32],
|
||||||
|
[15i32, 16i32, 15i32, 16i32, 15i32, 16i32],
|
||||||
|
],
|
||||||
|
]);
|
||||||
|
|
||||||
|
output.into_data().assert_eq(&expected, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_bool_repeat_on_dims_larger_than_1() {
|
||||||
|
let data = TensorData::from([
|
||||||
|
[[false, true], [true, false]],
|
||||||
|
[[true, true], [false, false]],
|
||||||
|
]);
|
||||||
|
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.repeat_dim(1, 2);
|
||||||
|
let expected = TensorData::from([
|
||||||
|
[[false, true], [true, false], [false, true], [true, false]],
|
||||||
|
[[true, true], [false, false], [true, true], [false, false]],
|
||||||
|
]);
|
||||||
|
|
||||||
|
output.into_data().assert_eq(&expected, true);
|
||||||
|
}
|
||||||
|
}
|
|
@ -73,7 +73,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
||||||
// Calculate token and position embeddings, and combine them
|
// Calculate token and position embeddings, and combine them
|
||||||
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
||||||
.reshape([1, seq_length])
|
.reshape([1, seq_length])
|
||||||
.repeat(0, batch_size);
|
.repeat_dim(0, batch_size);
|
||||||
let embedding_positions = self.embedding_pos.forward(index_positions);
|
let embedding_positions = self.embedding_pos.forward(index_positions);
|
||||||
let embedding_tokens = self.embedding_token.forward(tokens);
|
let embedding_tokens = self.embedding_token.forward(tokens);
|
||||||
let embedding = (embedding_positions + embedding_tokens) / 2;
|
let embedding = (embedding_positions + embedding_tokens) / 2;
|
||||||
|
@ -113,7 +113,7 @@ impl<B: Backend> TextClassificationModel<B> {
|
||||||
// Calculate token and position embeddings, and combine them
|
// Calculate token and position embeddings, and combine them
|
||||||
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
||||||
.reshape([1, seq_length])
|
.reshape([1, seq_length])
|
||||||
.repeat(0, batch_size);
|
.repeat_dim(0, batch_size);
|
||||||
let embedding_positions = self.embedding_pos.forward(index_positions);
|
let embedding_positions = self.embedding_pos.forward(index_positions);
|
||||||
let embedding_tokens = self.embedding_token.forward(tokens);
|
let embedding_tokens = self.embedding_token.forward(tokens);
|
||||||
let embedding = (embedding_positions + embedding_tokens) / 2;
|
let embedding = (embedding_positions + embedding_tokens) / 2;
|
||||||
|
|
|
@ -64,7 +64,7 @@ impl<B: Backend> TextGenerationModel<B> {
|
||||||
|
|
||||||
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
let index_positions = Tensor::arange(0..seq_length as i64, device)
|
||||||
.reshape([1, seq_length])
|
.reshape([1, seq_length])
|
||||||
.repeat(0, batch_size);
|
.repeat_dim(0, batch_size);
|
||||||
|
|
||||||
let embedding_positions = self.embedding_pos.forward(index_positions);
|
let embedding_positions = self.embedding_pos.forward(index_positions);
|
||||||
let embedding_tokens = self.embedding_token.forward(inputs);
|
let embedding_tokens = self.embedding_token.forward(inputs);
|
||||||
|
|
Loading…
Reference in New Issue