feat: repeat (#63)

This commit is contained in:
Nathaniel Simard 2022-10-24 18:25:53 -04:00 committed by GitHub
parent a78886d51e
commit 0c4c657854
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 91 additions and 2 deletions

View File

@ -1,6 +1,9 @@
use super::unary_ops_wrapper;
use crate::{
backend::{autodiff::ADBackendDecorator, Backend},
backend::{
autodiff::{ADBackendDecorator, ADTensor},
Backend,
},
graph::ops::{UnaryOps, UnaryOpsNodeState},
ops::TensorOps,
Data, Shape,
@ -76,4 +79,11 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
unary_ops_wrapper(input, output, ops)
}
fn empty<const D: usize>(
shape: Shape<D>,
device: <ADBackendDecorator<B> as Backend>::Device,
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
ADTensor::from_tensor(B::empty(shape, device))
}
}

View File

@ -55,4 +55,11 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
) -> NdArrayTensor<E, D> {
tensor.clone()
}
fn empty<const D: usize>(
shape: Shape<D>,
device: <NdArrayBackend<E> as Backend>::Device,
) -> <NdArrayBackend<E> as Backend>::TensorPrimitive<D> {
NdArrayBackend::<E>::zeros(shape, device)
}
}

View File

@ -1,4 +1,4 @@
use super::{TchBackend, TchDevice, TchTensor};
use super::{TchBackend, TchDevice, TchKind, TchTensor};
use crate::{backend::Backend, ops::TensorOps, Data, Shape, TchElement};
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
@ -57,4 +57,19 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
shape: tensor.shape,
}
}
fn empty<const D: usize>(
shape: Shape<D>,
device: <TchBackend<E> as Backend>::Device,
) -> <TchBackend<E> as Backend>::TensorPrimitive<D> {
let kind = TchKind::new();
let tensor =
tch::Tensor::empty(&shape.dims.map(|a| a as i64), (kind.kind(), device.into()));
TchTensor {
kind,
tensor,
shape,
}
}
}

View File

@ -559,6 +559,15 @@ where
self.reshape(shape)
}
/// Repeat the tensor along the given dimension.
///
/// # Panics
///
/// If the selected dimension more than one item.
pub fn repeat(&self, dim: usize, times: usize) -> Self {
Self::new(B::repeat(&self.value, dim, times))
}
pub(crate) fn relu(&self) -> Self {
Self::new(self.value.relu())
}

View File

@ -36,6 +36,35 @@ pub trait TensorOps<B: Backend> {
let data = Data::new(value, shape);
<B::IntegerBackend as Backend>::from_data(data, device)
}
fn empty<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D>;
fn repeat<const D: usize>(
tensor: &B::TensorPrimitive<D>,
dim: usize,
times: usize,
) -> B::TensorPrimitive<D> {
let mut shape = *B::shape(tensor);
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}
shape.dims[dim] = times;
let mut i = 0;
let indexes_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
start..end
});
let mut tensor_output = B::empty(shape, B::device(tensor));
for i in 0..times {
let mut indexes = indexes_select_all.clone();
indexes[dim] = i..i + 1;
tensor_output = tensor_output.index_assign(indexes, tensor);
}
tensor_output
}
}
pub trait TensorOpsAdd<E, const D: usize>: std::ops::Add<Self, Output = Self>

View File

@ -11,6 +11,7 @@ mod matmul;
mod mul;
mod neg;
mod powf;
mod repeat;
mod reshape;
mod sub;
mod transpose;

View File

@ -0,0 +1,18 @@
use super::super::TestBackend;
use burn_tensor::{Data, Tensor};
#[test]
fn should_support_repeat_ops() {
let data = Data::from([[0.0, 1.0, 2.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.repeat(0, 4).into_data();
let data_expected = Data::from([
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
[0.0, 1.0, 2.0],
]);
assert_eq!(data_expected, data_actual);
}