mirror of https://github.com/tracel-ai/burn.git
feat: repeat (#63)
This commit is contained in:
parent
a78886d51e
commit
0c4c657854
|
@ -1,6 +1,9 @@
|
|||
use super::unary_ops_wrapper;
|
||||
use crate::{
|
||||
backend::{autodiff::ADBackendDecorator, Backend},
|
||||
backend::{
|
||||
autodiff::{ADBackendDecorator, ADTensor},
|
||||
Backend,
|
||||
},
|
||||
graph::ops::{UnaryOps, UnaryOpsNodeState},
|
||||
ops::TensorOps,
|
||||
Data, Shape,
|
||||
|
@ -76,4 +79,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
|
||||
unary_ops_wrapper(input, output, ops)
|
||||
}
|
||||
|
||||
fn empty<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: <ADBackendDecorator<B> as Backend>::Device,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
ADTensor::from_tensor(B::empty(shape, device))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,4 +55,11 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
) -> NdArrayTensor<E, D> {
|
||||
tensor.clone()
|
||||
}
|
||||
|
||||
fn empty<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: <NdArrayBackend<E> as Backend>::Device,
|
||||
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
|
||||
NdArrayBackend::<E>::zeros(shape, device)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{TchBackend, TchDevice, TchTensor};
|
||||
use super::{TchBackend, TchDevice, TchKind, TchTensor};
|
||||
use crate::{backend::Backend, ops::TensorOps, Data, Shape, TchElement};
|
||||
|
||||
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||
|
@ -57,4 +57,19 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
shape: tensor.shape,
|
||||
}
|
||||
}
|
||||
|
||||
fn empty<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: <TchBackend<E> as Backend>::Device,
|
||||
) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
|
||||
let kind = TchKind::new();
|
||||
let tensor =
|
||||
tch::Tensor::empty(&shape.dims.map(|a| a as i64), (kind.kind(), device.into()));
|
||||
|
||||
TchTensor {
|
||||
kind,
|
||||
tensor,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -559,6 +559,15 @@ where
|
|||
self.reshape(shape)
|
||||
}
|
||||
|
||||
/// Repeat the tensor along the given dimension.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// If the selected dimension more than one item.
|
||||
pub fn repeat(&self, dim: usize, times: usize) -> Self {
|
||||
Self::new(B::repeat(&self.value, dim, times))
|
||||
}
|
||||
|
||||
pub(crate) fn relu(&self) -> Self {
|
||||
Self::new(self.value.relu())
|
||||
}
|
||||
|
|
|
@ -36,6 +36,35 @@ pub trait TensorOps<B: Backend> {
|
|||
let data = Data::new(value, shape);
|
||||
<B::IntegerBackend as Backend>::from_data(data, device)
|
||||
}
|
||||
fn empty<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D>;
|
||||
fn repeat<const D: usize>(
|
||||
tensor: &B::TensorPrimitive<D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let mut shape = *B::shape(tensor);
|
||||
if shape.dims[dim] != 1 {
|
||||
panic!("Can only repeat dimension with dim=1");
|
||||
}
|
||||
shape.dims[dim] = times;
|
||||
|
||||
let mut i = 0;
|
||||
let indexes_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
start..end
|
||||
});
|
||||
|
||||
let mut tensor_output = B::empty(shape, B::device(tensor));
|
||||
for i in 0..times {
|
||||
let mut indexes = indexes_select_all.clone();
|
||||
indexes[dim] = i..i + 1;
|
||||
tensor_output = tensor_output.index_assign(indexes, tensor);
|
||||
}
|
||||
|
||||
tensor_output
|
||||
}
|
||||
}
|
||||
|
||||
pub trait TensorOpsAdd<E, const D: usize>: std::ops::Add<Self, Output = Self>
|
||||
|
|
|
@ -11,6 +11,7 @@ mod matmul;
|
|||
mod mul;
|
||||
mod neg;
|
||||
mod powf;
|
||||
mod repeat;
|
||||
mod reshape;
|
||||
mod sub;
|
||||
mod transpose;
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
use super::super::TestBackend;
|
||||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn should_support_repeat_ops() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.repeat(0, 4).into_data();
|
||||
|
||||
let data_expected = Data::from([
|
||||
[0.0, 1.0, 2.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
[0.0, 1.0, 2.0],
|
||||
]);
|
||||
assert_eq!(data_expected, data_actual);
|
||||
}
|
Loading…
Reference in New Issue