mirror of https://github.com/tracel-ai/burn.git
Fix repeat for dims > 1 (#1713)
This commit is contained in:
parent
3a02a54e55
commit
2e4c82fa64
|
@ -564,7 +564,7 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = tensor.stream;
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = times;
|
||||
shape[dim] *= times;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = RepeatOperationDescription {
|
||||
|
|
|
@ -1620,7 +1620,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = tensor.stream;
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = times;
|
||||
shape[dim] *= times;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = RepeatOperationDescription {
|
||||
|
|
|
@ -1665,7 +1665,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
|
||||
let stream = tensor.stream;
|
||||
let mut shape = tensor.shape.clone();
|
||||
shape[dim] = times;
|
||||
shape[dim] *= times;
|
||||
let out = tensor.client.tensor_uninitialized(shape);
|
||||
|
||||
let desc = RepeatOperationDescription {
|
||||
|
|
|
@ -38,19 +38,21 @@ impl RepeatComputeShader {
|
|||
|
||||
let stride_input = scope.create_local(Elem::UInt);
|
||||
let stride_output = scope.create_local(Elem::UInt);
|
||||
let shape_output = scope.create_local(Elem::UInt);
|
||||
let shape = scope.create_local(Elem::UInt);
|
||||
|
||||
for i in 0..self.rank {
|
||||
gpu!(scope, stride_input = stride(input, i));
|
||||
gpu!(scope, stride_output = stride(output, i));
|
||||
if i != self.dim {
|
||||
gpu!(scope, stride_input = stride(input, i));
|
||||
gpu!(scope, stride_output = stride(output, i));
|
||||
gpu!(scope, shape_output = shape(output, i));
|
||||
|
||||
gpu!(scope, offset_local = id / stride_output);
|
||||
gpu!(scope, offset_local = offset_local % shape_output);
|
||||
gpu!(scope, offset_local = offset_local * stride_input);
|
||||
gpu!(scope, offset_input += offset_local);
|
||||
gpu!(scope, shape = shape(output, i));
|
||||
} else {
|
||||
gpu!(scope, shape = shape(input, i));
|
||||
}
|
||||
|
||||
gpu!(scope, offset_local = id / stride_output);
|
||||
gpu!(scope, offset_local = offset_local % shape);
|
||||
gpu!(scope, offset_local = offset_local * stride_input);
|
||||
gpu!(scope, offset_input += offset_local);
|
||||
}
|
||||
|
||||
let result = scope.create_local(input.item());
|
||||
|
@ -108,12 +110,9 @@ pub(crate) fn repeat<R: Runtime, E: JitElement, const D1: usize>(
|
|||
times: usize,
|
||||
) -> JitTensor<R, E, D1> {
|
||||
let mut shape = input.shape.clone();
|
||||
if shape.dims[dim] != 1 {
|
||||
panic!("Can only repeat dimension with dim=1");
|
||||
}
|
||||
|
||||
// Create output handle
|
||||
shape.dims[dim] = times;
|
||||
shape.dims[dim] *= times;
|
||||
let num_elems_output = shape.num_elements();
|
||||
let handle = input
|
||||
.client
|
||||
|
|
|
@ -564,10 +564,6 @@ where
|
|||
}
|
||||
|
||||
/// Repeat the tensor along the given dimension.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the selected dimension more than one item.
|
||||
pub fn repeat(self, dim: usize, times: usize) -> Self {
|
||||
Self::new(K::repeat(self.primitive, dim, times))
|
||||
}
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
use super::{cat::cat_with_slice_assign, BoolTensor, Device, FloatTensor, IntTensor};
|
||||
use super::{
|
||||
cat::cat_with_slice_assign, repeat::repeat_with_slice_assign, BoolTensor, Device, FloatTensor,
|
||||
IntTensor,
|
||||
};
|
||||
use crate::{
|
||||
backend::Backend, chunk, narrow, tensor::Shape, Bool, Data, ElementConversion, Tensor,
|
||||
};
|
||||
|
@ -174,28 +177,12 @@ pub trait BoolTensorOps<B: Backend> {
|
|||
dim: usize,
|
||||
times: usize,
|
||||
) -> BoolTensor<B, D> {
|
||||
let mut shape = Self::bool_shape(&tensor);
|
||||
if shape.dims[dim] != 1 {
|
||||
panic!("Can only repeat dimension with dim=1");
|
||||
}
|
||||
shape.dims[dim] = times;
|
||||
|
||||
let mut i = 0;
|
||||
let ranges_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
start..end
|
||||
});
|
||||
|
||||
let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor));
|
||||
for i in 0..times {
|
||||
let mut ranges = ranges_select_all.clone();
|
||||
ranges[dim] = i..i + 1;
|
||||
tensor_output = Self::bool_slice_assign(tensor_output, ranges, tensor.clone());
|
||||
}
|
||||
|
||||
tensor_output
|
||||
repeat_with_slice_assign::<B, D, Bool>(
|
||||
Tensor::<B, D, Bool>::from_primitive(tensor),
|
||||
dim,
|
||||
times,
|
||||
)
|
||||
.into_primitive()
|
||||
}
|
||||
|
||||
/// Concatenates the tensors along the given dimension.
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::cat::cat_with_slice_assign;
|
||||
use super::repeat::repeat_with_slice_assign;
|
||||
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
use crate::Tensor;
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
|
||||
|
@ -270,28 +271,12 @@ pub trait IntTensorOps<B: Backend> {
|
|||
dim: usize,
|
||||
times: usize,
|
||||
) -> IntTensor<B, D> {
|
||||
let mut shape = Self::int_shape(&tensor);
|
||||
if shape.dims[dim] != 1 {
|
||||
panic!("Can only repeat dimension with dim=1");
|
||||
}
|
||||
shape.dims[dim] = times;
|
||||
|
||||
let mut i = 0;
|
||||
let indices_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
start..end
|
||||
});
|
||||
|
||||
let mut tensor_output = Self::int_empty(shape, &Self::int_device(&tensor));
|
||||
for i in 0..times {
|
||||
let mut indices = indices_select_all.clone();
|
||||
indices[dim] = i..i + 1;
|
||||
tensor_output = Self::int_slice_assign(tensor_output, indices, tensor.clone());
|
||||
}
|
||||
|
||||
tensor_output
|
||||
repeat_with_slice_assign::<B, D, Int>(
|
||||
Tensor::<B, D, Int>::from_primitive(tensor),
|
||||
dim,
|
||||
times,
|
||||
)
|
||||
.into_primitive()
|
||||
}
|
||||
|
||||
/// Concatenates the given tensors along the given dimension.
|
||||
|
|
|
@ -19,10 +19,8 @@ pub(crate) fn cat_with_slice_assign<B: Backend, const D: usize, K: TensorKind<B>
|
|||
|
||||
let mut i = 0;
|
||||
let indices_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
start..end
|
||||
0..shape.dims[i - 1]
|
||||
});
|
||||
|
||||
let mut output_index = 0;
|
||||
|
|
|
@ -3,6 +3,8 @@ pub mod conv;
|
|||
|
||||
/// Module with cat operation
|
||||
pub(crate) mod cat;
|
||||
/// Module with repeat operation
|
||||
pub(crate) mod repeat;
|
||||
/// Module with unfold operations.
|
||||
pub(crate) mod unfold;
|
||||
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
use crate::{backend::Backend, BasicOps, Tensor, TensorKind};
|
||||
|
||||
pub(crate) fn repeat_with_slice_assign<
|
||||
B: Backend,
|
||||
const D: usize,
|
||||
K: TensorKind<B> + BasicOps<B>,
|
||||
>(
|
||||
tensor: Tensor<B, D, K>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> Tensor<B, D, K> {
|
||||
let mut shape = tensor.shape();
|
||||
let device = tensor.device();
|
||||
|
||||
let original_dim_length = shape.dims[dim];
|
||||
shape.dims[dim] *= times;
|
||||
|
||||
let mut tensor_output = Tensor::empty(shape.clone(), &device);
|
||||
|
||||
let mut i = 0;
|
||||
let indices_select_all = [0; D].map(|_| {
|
||||
i += 1;
|
||||
0..shape.dims[i - 1]
|
||||
});
|
||||
|
||||
let mut output_index = 0;
|
||||
for _ in 0..times {
|
||||
let mut indices = indices_select_all.clone();
|
||||
indices[dim] = output_index..output_index + original_dim_length;
|
||||
output_index += original_dim_length;
|
||||
|
||||
tensor_output = tensor_output.slice_assign(indices, tensor.clone());
|
||||
}
|
||||
|
||||
tensor_output
|
||||
}
|
|
@ -1,4 +1,5 @@
|
|||
use super::cat::cat_with_slice_assign;
|
||||
use super::repeat::repeat_with_slice_assign;
|
||||
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
|
||||
use crate::backend::BackendBridge;
|
||||
use crate::Tensor;
|
||||
|
@ -193,28 +194,8 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
dim: usize,
|
||||
times: usize,
|
||||
) -> FloatTensor<B, D> {
|
||||
let mut shape = B::float_shape(&tensor);
|
||||
if shape.dims[dim] != 1 {
|
||||
panic!("Can only repeat dimension with dim=1");
|
||||
}
|
||||
shape.dims[dim] = times;
|
||||
|
||||
let mut i = 0;
|
||||
let indices_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
start..end
|
||||
});
|
||||
|
||||
let mut tensor_output = B::float_empty(shape, &B::float_device(&tensor));
|
||||
for i in 0..times {
|
||||
let mut indices = indices_select_all.clone();
|
||||
indices[dim] = i..i + 1;
|
||||
tensor_output = B::float_slice_assign(tensor_output, indices, tensor.clone());
|
||||
}
|
||||
|
||||
tensor_output
|
||||
repeat_with_slice_assign::<B, D, Float>(Tensor::<B, D>::from_primitive(tensor), dim, times)
|
||||
.into_primitive()
|
||||
}
|
||||
|
||||
/// Adds two tensors together.
|
||||
|
|
|
@ -45,4 +45,66 @@ mod tests {
|
|||
let data_expected = Data::from([[0, 1, 2], [0, 1, 2], [0, 1, 2], [0, 1, 2]]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_float_repeat_on_dims_larger_than_1() {
|
||||
let data = Data::from([
|
||||
[[1.0, 2.0], [3.0, 4.0]],
|
||||
[[5.0, 6.0], [7.0, 8.0]],
|
||||
[[9.0, 10.0], [11.0, 12.0]],
|
||||
[[13.0, 14.0], [15.0, 16.0]],
|
||||
]);
|
||||
let tensor = Tensor::<TestBackend, 3>::from_data(data, &Default::default());
|
||||
|
||||
let data_actual = tensor.repeat(2, 2).into_data();
|
||||
|
||||
let data_expected = Data::from([
|
||||
[[1.0, 2.0, 1.0, 2.0], [3.0, 4.0, 3.0, 4.0]],
|
||||
[[5.0, 6.0, 5.0, 6.0], [7.0, 8.0, 7.0, 8.0]],
|
||||
[[9.0, 10.0, 9.0, 10.0], [11.0, 12.0, 11.0, 12.0]],
|
||||
[[13.0, 14.0, 13.0, 14.0], [15.0, 16.0, 15.0, 16.0]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_int_repeat_on_dims_larger_than_1() {
|
||||
let data = Data::from([
|
||||
[[1, 2], [3, 4]],
|
||||
[[5, 6], [7, 8]],
|
||||
[[9, 10], [11, 12]],
|
||||
[[13, 14], [15, 16]],
|
||||
]);
|
||||
let tensor = Tensor::<TestBackend, 3, Int>::from_data(data, &Default::default());
|
||||
|
||||
let data_actual = tensor.repeat(2, 3).into_data();
|
||||
|
||||
let data_expected = Data::from([
|
||||
[[1, 2, 1, 2, 1, 2], [3, 4, 3, 4, 3, 4]],
|
||||
[[5, 6, 5, 6, 5, 6], [7, 8, 7, 8, 7, 8]],
|
||||
[[9, 10, 9, 10, 9, 10], [11, 12, 11, 12, 11, 12]],
|
||||
[[13, 14, 13, 14, 13, 14], [15, 16, 15, 16, 15, 16]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_bool_repeat_on_dims_larger_than_1() {
|
||||
let data = Data::from([
|
||||
[[false, true], [true, false]],
|
||||
[[true, true], [false, false]],
|
||||
]);
|
||||
let tensor = Tensor::<TestBackend, 3, Bool>::from_data(data, &Default::default());
|
||||
|
||||
let data_actual = tensor.repeat(1, 2).into_data();
|
||||
|
||||
let data_expected = Data::from([
|
||||
[[false, true], [true, false], [false, true], [true, false]],
|
||||
[[true, true], [false, false], [true, true], [false, false]],
|
||||
]);
|
||||
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue