Refactor/backend bool tensor (#192)

This commit is contained in:
Nathaniel Simard 2023-03-05 11:23:46 -05:00 committed by GitHub
parent ffd3d35176
commit 15ec42dd6f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 622 additions and 338 deletions

View File

@ -0,0 +1,83 @@
use crate::{
tensor::{BoolTensor, IntTensor},
ADBackendDecorator,
};
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn bool_from_data<const D: usize>(data: Data<bool, D>, device: &B::Device) -> BoolTensor<B, D> {
B::bool_from_data(data, device)
}
fn bool_shape<const D: usize>(tensor: &BoolTensor<B, D>) -> Shape<D> {
B::bool_shape(tensor)
}
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Data<bool, D> {
B::bool_to_data(tensor)
}
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Data<bool, D> {
B::bool_into_data(tensor)
}
fn bool_into_int<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, D> {
B::bool_into_int(tensor)
}
fn bool_to_device<const D: usize>(
tensor: BoolTensor<B, D>,
device: &B::Device,
) -> BoolTensor<B, D> {
B::bool_to_device(tensor, device)
}
fn bool_device<const D: usize>(tensor: &BoolTensor<B, D>) -> B::Device {
B::bool_device(tensor)
}
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: BoolTensor<B, D1>,
shape: Shape<D2>,
) -> BoolTensor<B, D2> {
B::bool_reshape(tensor, shape)
}
fn bool_index<const D1: usize, const D2: usize>(
tensor: BoolTensor<B, D1>,
indexes: [std::ops::Range<usize>; D2],
) -> BoolTensor<B, D1> {
B::bool_index(tensor, indexes)
}
fn bool_empty<const D: usize>(
shape: Shape<D>,
device: &<ADBackendDecorator<B> as Backend>::Device,
) -> BoolTensor<B, D> {
B::bool_empty(shape, device)
}
fn bool_index_assign<const D1: usize, const D2: usize>(
tensor: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1>,
indexes: [std::ops::Range<usize>; D2],
value: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1>,
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1> {
B::bool_index_assign(tensor, indexes, value)
}
fn bool_cat<const D: usize>(tensors: Vec<BoolTensor<B, D>>, dim: usize) -> BoolTensor<B, D> {
B::bool_cat(tensors, dim)
}
fn bool_equal<const D: usize>(
lhs: BoolTensor<B, D>,
rhs: BoolTensor<B, D>,
) -> BoolTensor<B, D> {
B::bool_equal(lhs, rhs)
}
fn bool_equal_elem<const D: usize>(lhs: BoolTensor<B, D>, rhs: bool) -> BoolTensor<B, D> {
B::bool_equal_elem(lhs, rhs)
}
}

View File

@ -1,5 +1,6 @@
mod backward;
mod base;
mod bool_tensor;
mod module;
mod tensor;

View File

@ -19,10 +19,6 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
ADTensor::new(B::from_data(data, device))
}
fn from_data_bool<const D: usize>(data: Data<bool, D>, device: &B::Device) -> BoolTensor<B, D> {
B::from_data_bool(data, device)
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: burn_tensor::Distribution<FloatElem<B>>,
@ -51,47 +47,6 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
B::into_data(tensor.primitive)
}
fn bool_shape<const D: usize>(tensor: &BoolTensor<B, D>) -> Shape<D> {
B::bool_shape(tensor)
}
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Data<bool, D> {
B::bool_to_data(tensor)
}
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Data<bool, D> {
B::bool_into_data(tensor)
}
fn bool_into_int<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, D> {
B::bool_into_int(tensor)
}
fn bool_to_device<const D: usize>(
tensor: BoolTensor<B, D>,
device: &B::Device,
) -> BoolTensor<B, D> {
B::bool_to_device(tensor, device)
}
fn bool_device<const D: usize>(tensor: &BoolTensor<B, D>) -> B::Device {
B::bool_device(tensor)
}
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: BoolTensor<B, D1>,
shape: Shape<D2>,
) -> BoolTensor<B, D2> {
B::bool_reshape(tensor, shape)
}
fn bool_index<const D1: usize, const D2: usize>(
tensor: BoolTensor<B, D1>,
indexes: [std::ops::Range<usize>; D2],
) -> BoolTensor<B, D1> {
B::bool_index(tensor, indexes)
}
fn device<const D: usize>(tensor: &ADTensor<B, D>) -> B::Device {
B::device(&tensor.primitive)
}

View File

@ -0,0 +1,92 @@
use alloc::vec::Vec;
use core::{marker::PhantomData, ops::Range};
use burn_tensor::Shape;
use ndarray::Axis;
use ndarray::Dim;
use ndarray::IxDyn;
use ndarray::SliceInfoElem;
use crate::{tensor::NdArrayTensor, to_nd_array_tensor};
pub struct NdArrayOps<E> {
e: PhantomData<E>,
}
impl<E> NdArrayOps<E>
where
E: Copy,
{
pub fn index<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
indexes: [Range<usize>; D2],
) -> NdArrayTensor<E, D1> {
let slices = Self::to_slice_args::<D1, D2>(indexes);
let array = tensor.array.slice_move(slices.as_slice()).into_shared();
NdArrayTensor { array }
}
pub fn index_assign<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
indexes: [Range<usize>; D2],
value: NdArrayTensor<E, D1>,
) -> NdArrayTensor<E, D1> {
let slices = Self::to_slice_args::<D1, D2>(indexes);
let mut array = tensor.array.to_owned();
array.slice_mut(slices.as_slice()).assign(&value.array);
let array = array.into_owned().into_shared();
NdArrayTensor { array }
}
pub fn reshape<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
shape: Shape<D2>,
) -> NdArrayTensor<E, D2> {
match D2 {
1 => to_nd_array_tensor!(1, shape, tensor.array),
2 => to_nd_array_tensor!(2, shape, tensor.array),
3 => to_nd_array_tensor!(3, shape, tensor.array),
4 => to_nd_array_tensor!(4, shape, tensor.array),
5 => to_nd_array_tensor!(5, shape, tensor.array),
6 => to_nd_array_tensor!(6, shape, tensor.array),
_ => panic!("NdArrayTensor support only 6 dimensions."),
}
}
pub fn cat<const D: usize>(
tensors: Vec<NdArrayTensor<E, D>>,
dim: usize,
) -> NdArrayTensor<E, D> {
let arrays: Vec<ndarray::ArrayView<E, IxDyn>> =
tensors.iter().map(|t| t.array.view()).collect();
let array = ndarray::concatenate(Axis(dim), &arrays)
.unwrap()
.into_shared();
NdArrayTensor { array }
}
fn to_slice_args<const D1: usize, const D2: usize>(
indexes: [Range<usize>; D2],
) -> [SliceInfoElem; D1] {
let mut slices = [SliceInfoElem::NewAxis; D1];
for i in 0..D1 {
if i >= D2 {
slices[i] = SliceInfoElem::Slice {
start: 0,
end: None,
step: 1,
}
} else {
slices[i] = SliceInfoElem::Slice {
start: indexes[i].start as isize,
end: Some(indexes[i].end as isize),
step: 1,
}
}
}
slices
}
}

View File

@ -0,0 +1,120 @@
// Language
use alloc::vec;
use alloc::vec::Vec;
use burn_tensor::ops::BoolTensorOps;
use core::ops::Range;
// Current crate
use crate::element::NdArrayElement;
use crate::NdArrayDevice;
use crate::{tensor::NdArrayTensor, NdArrayBackend};
// Workspace crates
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Shape};
use super::NdArrayOps;
impl<E: NdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn bool_from_data<const D: usize>(
data: Data<bool, D>,
_device: &NdArrayDevice,
) -> NdArrayTensor<bool, D> {
NdArrayTensor::from_data(data)
}
fn bool_shape<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Shape<D> {
tensor.shape()
}
fn bool_to_data<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
let values = tensor.array.iter().map(Clone::clone).collect();
Data::new(values, tensor.shape())
}
fn bool_into_data<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
Data::new(values, shape)
}
fn bool_to_device<const D: usize>(
tensor: NdArrayTensor<bool, D>,
_device: &NdArrayDevice,
) -> NdArrayTensor<bool, D> {
tensor
}
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<bool, D1>,
shape: Shape<D2>,
) -> NdArrayTensor<bool, D2> {
NdArrayOps::reshape(tensor, shape)
}
fn bool_index<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<bool, D1>,
indexes: [Range<usize>; D2],
) -> NdArrayTensor<bool, D1> {
NdArrayOps::index(tensor, indexes)
}
fn bool_into_int<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> <<NdArrayBackend<E> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
let data = Self::bool_into_data(tensor);
NdArrayBackend::<i64>::from_data(data.convert(), &NdArrayDevice::Cpu)
}
fn bool_device<const D: usize>(
_tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::Device {
NdArrayDevice::Cpu
}
fn bool_empty<const D: usize>(
shape: Shape<D>,
_device: &<NdArrayBackend<E> as Backend>::Device,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let values = vec![false; shape.num_elements()];
NdArrayTensor::from_data(Data::new(values, shape))
}
fn bool_index_assign<const D1: usize, const D2: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
value: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1>,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1> {
NdArrayOps::index_assign(tensor, indexes, value)
}
fn bool_cat<const D: usize>(
tensors: Vec<<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>>,
dim: usize,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
NdArrayOps::cat(tensors, dim)
}
fn bool_equal<const D: usize>(
lhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
rhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let mut array = lhs.array;
array.zip_mut_with(&rhs.array, |a, b| *a = *a && *b);
NdArrayTensor { array }
}
fn bool_equal_elem<const D: usize>(
lhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
rhs: bool,
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
let array = lhs.array.mapv(|a| a == rhs).into_shared();
NdArrayTensor { array }
}
}

View File

@ -1,6 +1,10 @@
mod base;
mod bool_tensor;
mod module;
mod tensor;
pub(crate) mod conv;
pub(crate) mod maxpool;
pub(crate) mod padding;
pub(crate) use base::*;

View File

@ -6,7 +6,7 @@ use core::ops::Range;
// Current crate
use crate::tensor::BatchMatrix;
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use crate::{to_nd_array_tensor, NdArrayDevice, SEED};
use crate::{NdArrayDevice, SEED};
// Workspace crates
use burn_common::rand::get_seeded_rng;
@ -15,7 +15,9 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Sha
// External crates
use libm::{cos, erf, sin, tanh};
use ndarray::{Axis, Dim, IxDyn, SliceInfoElem};
use ndarray::Axis;
use super::NdArrayOps;
macro_rules! keepdim {
(
@ -47,13 +49,6 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
NdArrayTensor::from_data(data)
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
_device: &NdArrayDevice,
) -> NdArrayTensor<bool, D> {
NdArrayTensor::from_data(data)
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<E>,
@ -91,49 +86,6 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
Data::new(values, shape)
}
fn bool_shape<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Shape<D> {
tensor.shape()
}
fn bool_to_data<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
let values = tensor.array.iter().map(Clone::clone).collect();
Data::new(values, tensor.shape())
}
fn bool_into_data<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
let shape = tensor.shape();
let values = tensor.array.into_iter().collect();
Data::new(values, shape)
}
fn bool_to_device<const D: usize>(
tensor: NdArrayTensor<bool, D>,
_device: &NdArrayDevice,
) -> NdArrayTensor<bool, D> {
tensor
}
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<bool, D1>,
shape: Shape<D2>,
) -> NdArrayTensor<bool, D2> {
match D2 {
1 => to_nd_array_tensor!(bool, 1, shape, tensor.array),
2 => to_nd_array_tensor!(bool, 2, shape, tensor.array),
3 => to_nd_array_tensor!(bool, 3, shape, tensor.array),
4 => to_nd_array_tensor!(bool, 4, shape, tensor.array),
5 => to_nd_array_tensor!(bool, 5, shape, tensor.array),
6 => to_nd_array_tensor!(bool, 6, shape, tensor.array),
_ => panic!("NdArrayTensor support only 6 dimensions."),
}
}
fn device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {
NdArrayDevice::Cpu
}
@ -263,35 +215,14 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
tensor: NdArrayTensor<E, D1>,
shape: Shape<D2>,
) -> NdArrayTensor<E, D2> {
match D2 {
1 => to_nd_array_tensor!(1, shape, tensor.array),
2 => to_nd_array_tensor!(2, shape, tensor.array),
3 => to_nd_array_tensor!(3, shape, tensor.array),
4 => to_nd_array_tensor!(4, shape, tensor.array),
5 => to_nd_array_tensor!(5, shape, tensor.array),
6 => to_nd_array_tensor!(6, shape, tensor.array),
_ => panic!("NdArrayTensor support only 6 dimensions."),
}
}
fn bool_index<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<bool, D1>,
indexes: [Range<usize>; D2],
) -> NdArrayTensor<bool, D1> {
let slices = to_slice_args::<D1, D2>(indexes);
let array = tensor.array.slice_move(slices.as_slice()).into_shared();
NdArrayTensor { array }
NdArrayOps::reshape(tensor, shape)
}
fn index<const D1: usize, const D2: usize>(
tensor: NdArrayTensor<E, D1>,
indexes: [Range<usize>; D2],
) -> NdArrayTensor<E, D1> {
let slices = to_slice_args::<D1, D2>(indexes);
let array = tensor.array.slice_move(slices.as_slice()).into_shared();
NdArrayTensor { array }
NdArrayOps::index(tensor, indexes)
}
fn index_assign<const D1: usize, const D2: usize>(
@ -299,12 +230,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
indexes: [Range<usize>; D2],
value: NdArrayTensor<E, D1>,
) -> NdArrayTensor<E, D1> {
let slices = to_slice_args::<D1, D2>(indexes);
let mut array = tensor.array.to_owned();
array.slice_mut(slices.as_slice()).assign(&value.array);
let array = array.into_owned().into_shared();
NdArrayTensor { array }
NdArrayOps::index_assign(tensor, indexes, value)
}
fn mask_fill<const D: usize>(
@ -532,13 +458,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
}
fn cat<const D: usize>(tensors: Vec<NdArrayTensor<E, D>>, dim: usize) -> NdArrayTensor<E, D> {
let arrays: Vec<ndarray::ArrayView<E, IxDyn>> =
tensors.iter().map(|t| t.array.view()).collect();
let array = ndarray::concatenate(Axis(dim), &arrays)
.unwrap()
.into_shared();
NdArrayTensor { array }
NdArrayOps::cat(tensors, dim)
}
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
@ -553,41 +473,6 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
NdArrayTensor { array }
}
fn bool_into_int<const D: usize>(
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> <<NdArrayBackend<E> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
let data = Self::bool_into_data(tensor);
NdArrayBackend::<i64>::from_data(data.convert(), &NdArrayDevice::Cpu)
}
fn bool_device<const D: usize>(
_tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> <NdArrayBackend<E> as Backend>::Device {
NdArrayDevice::Cpu
}
}
fn to_slice_args<const D1: usize, const D2: usize>(
indexes: [Range<usize>; D2],
) -> [SliceInfoElem; D1] {
let mut slices = [SliceInfoElem::NewAxis; D1];
for i in 0..D1 {
if i >= D2 {
slices[i] = SliceInfoElem::Slice {
start: 0,
end: None,
step: 1,
}
} else {
slices[i] = SliceInfoElem::Slice {
start: indexes[i].start as isize,
end: Some(indexes[i].end as isize),
step: 1,
}
}
}
slices
}
fn mean_dim<E: NdArrayElement, const D1: usize, const D2: usize>(

75
burn-tch/src/ops/base.rs Normal file
View File

@ -0,0 +1,75 @@
use std::{marker::PhantomData, ops::Range, sync::Arc};
use crate::{to_tensor, TchShape, TchTensor};
pub struct TchOps<E: tch::kind::Element + Copy + Default> {
e: PhantomData<E>,
}
impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
pub fn index<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
indexes: [Range<usize>; D2],
) -> TchTensor<E, D1> {
let kind = tensor.kind;
let mut tensor = tensor.tensor.shallow_clone();
for (i, index) in indexes.iter().enumerate().take(D2) {
let start = index.start as i64;
let length = (index.end - index.start) as i64;
tensor = tensor.narrow(i as i64, start, length);
}
let tensor = Arc::new(tensor);
TchTensor { kind, tensor }
}
pub fn index_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
indexes: [Range<usize>; D2],
value: TchTensor<E, D1>,
) -> TchTensor<E, D1> {
let kind = tensor.kind;
let tensor_original = tensor.tensor.copy();
let tch_shape = TchShape::from(tensor.shape());
let mut tensor = tensor_original.view_(&tch_shape.dims);
for (i, index) in indexes.into_iter().enumerate().take(D2) {
let start = index.start as i64;
let length = (index.end - index.start) as i64;
tensor = tensor.narrow(i as i64, start, length);
}
tensor.copy_(&value.tensor);
TchTensor {
kind,
tensor: Arc::new(tensor_original),
}
}
pub fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
let tensors: Vec<tch::Tensor> = tensors
.into_iter()
.map(|t| t.tensor.shallow_clone())
.collect();
let tensor = tch::Tensor::cat(&tensors, dim as i64);
to_tensor(tensor)
}
pub fn equal<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<bool, D> {
let tensor = TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool),
|lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool),
|lhs, rhs| lhs.eq_tensor(rhs),
);
to_tensor(tensor)
}
}

View File

@ -0,0 +1,123 @@
use std::{ops::Range, sync::Arc};
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
use crate::{element::TchElement, to_tensor, TchBackend, TchDevice, TchKind, TchShape, TchTensor};
use super::TchOps;
impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
fn bool_from_data<const D: usize>(
data: Data<bool, D>,
device: &TchDevice,
) -> TchTensor<bool, D> {
TchTensor::from_data(data, (*device).into())
}
fn bool_shape<const D: usize>(tensor: &TchTensor<bool, D>) -> Shape<D> {
tensor.shape()
}
fn bool_to_data<const D: usize>(tensor: &TchTensor<bool, D>) -> Data<bool, D> {
let values: Vec<bool> = tensor.tensor.shallow_clone().into();
Data::new(values, tensor.shape())
}
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Data<bool, D> {
let shape = tensor.shape();
let values: Vec<bool> = tensor.unary_ops(
|tensor| tensor.into(),
|tensor| tensor.shallow_clone().into(),
);
Data::new(values, shape)
}
fn bool_to_device<const D: usize>(
tensor: TchTensor<bool, D>,
device: &TchDevice,
) -> TchTensor<bool, D> {
TchTensor {
kind: tensor.kind,
tensor: Arc::new(tensor.tensor.to((*device).into())),
}
}
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: TchTensor<bool, D1>,
shape: Shape<D2>,
) -> TchTensor<bool, D2> {
let shape_tch: TchShape<D2> = shape.into();
let tensor = tensor.unary_ops(
|mut tensor| tensor.resize_(&shape_tch.dims),
|tensor| tensor.reshape(&shape_tch.dims),
);
TchTensor {
tensor: Arc::new(tensor),
kind: TchKind::new(),
}
}
fn bool_device<const D: usize>(tensor: &TchTensor<bool, D>) -> TchDevice {
tensor.tensor.device().into()
}
fn bool_empty<const D: usize>(
shape: Shape<D>,
device: &<TchBackend<E> as Backend>::Device,
) -> TchTensor<bool, D> {
let tensor = tch::Tensor::empty(
&shape.dims.map(|a| a as i64),
(tch::Kind::Bool, (*device).into()),
);
to_tensor(tensor)
}
fn bool_index<const D1: usize, const D2: usize>(
tensor: TchTensor<bool, D1>,
indexes: [Range<usize>; D2],
) -> TchTensor<bool, D1> {
TchOps::index(tensor, indexes)
}
fn bool_index_assign<const D1: usize, const D2: usize>(
tensor: TchTensor<bool, D1>,
indexes: [std::ops::Range<usize>; D2],
value: TchTensor<bool, D1>,
) -> TchTensor<bool, D1> {
TchOps::index_assign(tensor, indexes, value)
}
fn bool_cat<const D: usize>(
tensors: Vec<TchTensor<bool, D>>,
dim: usize,
) -> TchTensor<bool, D> {
TchOps::cat(tensors, dim)
}
fn bool_equal<const D: usize>(
lhs: TchTensor<bool, D>,
rhs: TchTensor<bool, D>,
) -> TchTensor<bool, D> {
TchOps::equal(lhs, rhs)
}
fn bool_equal_elem<const D: usize>(lhs: TchTensor<bool, D>, rhs: bool) -> TchTensor<bool, D> {
let rhs = match rhs {
true => 1,
false => 0,
};
let tensor = lhs.unary_ops(
|mut tensor| tensor.eq_(rhs).to_kind(tch::Kind::Bool),
|tensor| tensor.eq(rhs),
);
to_tensor(tensor)
}
fn bool_into_int<const D: usize>(
tensor: TchTensor<bool, D>,
) -> <<TchBackend<E> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
let tensor = tensor.tensor.to_kind(TchKind::<i64>::new().kind());
to_tensor(tensor)
}
}

View File

@ -1,2 +1,6 @@
mod base;
mod bool_tensor;
mod module;
mod tensor;
pub(crate) use base::*;

View File

@ -2,6 +2,8 @@ use crate::{element::TchElement, to_tensor, TchBackend, TchDevice, TchKind, TchS
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Shape};
use std::{ops::Range, sync::Arc};
use super::TchOps;
macro_rules! run_scalar {
(
$scalar:ident,
@ -26,13 +28,6 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
TchTensor::from_data(data, (*device).into())
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: &TchDevice,
) -> TchTensor<bool, D> {
TchTensor::from_data(data, (*device).into())
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<E>,
@ -114,60 +109,6 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
Data::new(values, shape)
}
fn bool_shape<const D: usize>(
tensor: &<TchBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Shape<D> {
tensor.shape()
}
fn bool_to_data<const D: usize>(
tensor: &<TchBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
let values: Vec<bool> = tensor.tensor.shallow_clone().into();
Data::new(values, tensor.shape())
}
fn bool_into_data<const D: usize>(
tensor: <TchBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> Data<bool, D> {
let shape = tensor.shape();
let values: Vec<bool> = tensor.unary_ops(
|tensor| tensor.into(),
|tensor| tensor.shallow_clone().into(),
);
Data::new(values, shape)
}
fn bool_to_device<const D: usize>(
tensor: TchTensor<bool, D>,
device: &TchDevice,
) -> TchTensor<bool, D> {
TchTensor {
kind: tensor.kind,
tensor: Arc::new(tensor.tensor.to((*device).into())),
}
}
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: TchTensor<bool, D1>,
shape: Shape<D2>,
) -> TchTensor<bool, D2> {
let shape_tch: TchShape<D2> = shape.into();
let tensor = tensor.unary_ops(
|mut tensor| tensor.resize_(&shape_tch.dims),
|tensor| tensor.reshape(&shape_tch.dims),
);
TchTensor {
tensor: Arc::new(tensor),
kind: TchKind::new(),
}
}
fn bool_device<const D: usize>(tensor: &TchTensor<bool, D>) -> TchDevice {
tensor.tensor.device().into()
}
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
tensor.tensor.device().into()
}
@ -315,18 +256,11 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
to_tensor(tensor)
}
fn bool_index<const D1: usize, const D2: usize>(
tensor: TchTensor<bool, D1>,
indexes: [Range<usize>; D2],
) -> TchTensor<bool, D1> {
index(&tensor, indexes)
}
fn index<const D1: usize, const D2: usize>(
tensor: TchTensor<E, D1>,
indexes: [Range<usize>; D2],
) -> TchTensor<E, D1> {
index(&tensor, indexes)
TchOps::index(tensor, indexes)
}
fn index_assign<const D1: usize, const D2: usize>(
@ -334,25 +268,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
indexes: [Range<usize>; D2],
value: TchTensor<E, D1>,
) -> <TchBackend<E> as Backend>::TensorPrimitive<D1> {
let kind = tensor.kind;
let tensor_original = tensor.tensor.copy();
let tch_shape = TchShape::from(tensor.shape());
let mut tensor = tensor_original.view_(&tch_shape.dims);
for (i, index) in indexes.into_iter().enumerate().take(D2) {
let start = index.start as i64;
let length = (index.end - index.start) as i64;
tensor = tensor.narrow(i as i64, start, length);
}
tensor.copy_(&value.tensor);
TchTensor {
kind,
tensor: Arc::new(tensor_original),
}
TchOps::index_assign(tensor, indexes, value)
}
fn mask_fill<const D: usize>(
@ -370,15 +286,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
}
fn equal<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<bool, D> {
let tensor = TchTensor::binary_ops_tensor(
lhs,
rhs,
|lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool),
|lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool),
|lhs, rhs| lhs.eq_tensor(rhs),
);
to_tensor(tensor)
TchOps::equal(lhs, rhs)
}
fn equal_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
@ -570,40 +478,10 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
}
fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
let tensors: Vec<tch::Tensor> = tensors
.into_iter()
.map(|t| t.tensor.shallow_clone())
.collect();
let tensor = tch::Tensor::cat(&tensors, dim as i64);
to_tensor(tensor)
TchOps::cat(tensors, dim)
}
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
to_tensor(tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu()))
}
fn bool_into_int<const D: usize>(
tensor: <TchBackend<E> as Backend>::BoolTensorPrimitive<D>,
) -> <<TchBackend<E> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
let tensor = tensor.tensor.to_kind(TchKind::<i64>::new().kind());
to_tensor(tensor)
}
}
fn index<const D1: usize, const D2: usize, E: tch::kind::Element + Copy>(
tensor: &TchTensor<E, D1>,
indexes: [Range<usize>; D2],
) -> TchTensor<E, D1> {
let kind = tensor.kind;
let mut tensor = tensor.tensor.shallow_clone();
for (i, index) in indexes.iter().enumerate().take(D2) {
let start = index.start as i64;
let length = (index.end - index.start) as i64;
tensor = tensor.narrow(i as i64, start, length);
}
let tensor = Arc::new(tensor);
TchTensor { kind, tensor }
}

View File

@ -382,11 +382,11 @@ impl<B: Backend> BasicOps<B> for Bool {
}
fn index_assign<const D1: usize, const D2: usize>(
_tensor: Self::Primitive<D1>,
_indexes: [Range<usize>; D2],
_value: Self::Primitive<D1>,
tensor: Self::Primitive<D1>,
indexes: [Range<usize>; D2],
value: Self::Primitive<D1>,
) -> Self::Primitive<D1> {
todo!("Index assigned is not yet implemented for bool tensor")
B::bool_index_assign(tensor, indexes, value)
}
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
@ -408,31 +408,32 @@ impl<B: Backend> BasicOps<B> for Bool {
data: Data<Self::Elem, D>,
device: &B::Device,
) -> Self::Primitive<D> {
B::from_data_bool(data, device)
B::bool_from_data(data, device)
}
fn repeat<const D: usize>(
_tensor: Self::Primitive<D>,
_dim: usize,
_times: usize,
tensor: Self::Primitive<D>,
dim: usize,
times: usize,
) -> Self::Primitive<D> {
todo!("Repeat operation is not yet implemented for bool tensor");
B::bool_repeat(tensor, dim, times)
}
fn equal<const D: usize>(
_lhs: Self::Primitive<D>,
_rhs: Self::Primitive<D>,
lhs: Self::Primitive<D>,
rhs: Self::Primitive<D>,
) -> Tensor<B, D, Bool> {
todo!("Equal operation is not yet implemented for bool tensor");
Tensor::new(B::bool_equal(lhs, rhs))
}
fn equal_scalar<const D: usize>(
_lhs: Self::Primitive<D>,
_rhs: Self::Elem,
lhs: Self::Primitive<D>,
rhs: Self::Elem,
) -> Tensor<B, D, Bool> {
todo!("Equal scalar operation is not yet implemented for bool tensor");
Tensor::new(B::bool_equal_elem(lhs, rhs))
}
fn cat<const D: usize>(_vectors: Vec<Self::Primitive<D>>, _dim: usize) -> Self::Primitive<D> {
todo!("Cat vectors operation is not yet implemented for bool tensor");
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
B::bool_cat(vectors, dim)
}
}

View File

@ -6,12 +6,12 @@ where
{
/// Create a boolean tensor from data.
pub fn from_bool(data: Data<bool, D>) -> Self {
Self::new(B::from_data_bool(data, &B::Device::default()))
Self::new(B::bool_from_data(data, &B::Device::default()))
}
/// Create a boolean tensor from data on the given device.
pub fn from_bool_device(data: Data<bool, D>, device: &B::Device) -> Self {
Self::new(B::from_data_bool(data, device))
Self::new(B::bool_from_data(data, device))
}
/// Convert the bool tensor into an int tensor.

View File

@ -5,6 +5,7 @@ use crate::tensor::Element;
pub trait Backend:
TensorOps<Self>
+ BoolTensorOps<Self>
+ ModuleOps<Self>
+ Clone
+ Sized

View File

@ -0,0 +1,80 @@
use alloc::vec::Vec;
use core::ops::Range;
use crate::{backend::Backend, tensor::Shape, Data};
pub trait BoolTensorOps<B: Backend> {
/// Bool version of empty, see [tensor](crate::Tensor).
fn bool_empty<const D: usize>(shape: Shape<D>, device: &B::Device)
-> B::BoolTensorPrimitive<D>;
fn bool_shape<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Shape<D>;
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D> {
Self::bool_into_data(tensor.clone())
}
fn bool_from_data<const D: usize>(
data: Data<bool, D>,
device: &B::Device,
) -> B::BoolTensorPrimitive<D>;
fn bool_into_int<const D: usize>(
tensor: B::BoolTensorPrimitive<D>,
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
fn bool_device<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> B::Device;
fn bool_to_device<const D: usize>(
tensor: B::BoolTensorPrimitive<D>,
device: &B::Device,
) -> B::BoolTensorPrimitive<D>;
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
shape: Shape<D2>,
) -> B::BoolTensorPrimitive<D2>;
fn bool_index<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
) -> B::BoolTensorPrimitive<D1>;
fn bool_index_assign<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
value: B::BoolTensorPrimitive<D1>,
) -> B::BoolTensorPrimitive<D1>;
fn bool_repeat<const D: usize>(
tensor: B::BoolTensorPrimitive<D>,
dim: usize,
times: usize,
) -> B::BoolTensorPrimitive<D> {
let mut shape = Self::bool_shape(&tensor);
if shape.dims[dim] != 1 {
panic!("Can only repeat dimension with dim=1");
}
shape.dims[dim] = times;
let mut i = 0;
let indexes_select_all = [0; D].map(|_| {
let start = 0;
let end = shape.dims[i];
i += 1;
start..end
});
let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor));
for i in 0..times {
let mut indexes = indexes_select_all.clone();
indexes[dim] = i..i + 1;
tensor_output = Self::bool_index_assign(tensor_output, indexes, tensor.clone());
}
tensor_output
}
fn bool_cat<const D: usize>(
tensors: Vec<B::BoolTensorPrimitive<D>>,
dim: usize,
) -> B::BoolTensorPrimitive<D>;
fn bool_equal<const D: usize>(
lhs: B::BoolTensorPrimitive<D>,
rhs: B::BoolTensorPrimitive<D>,
) -> B::BoolTensorPrimitive<D>;
fn bool_equal_elem<const D: usize>(
lhs: B::BoolTensorPrimitive<D>,
rhs: bool,
) -> B::BoolTensorPrimitive<D>;
}

View File

@ -0,0 +1 @@

View File

@ -1,5 +1,9 @@
mod bool_tensor;
mod int_tensor;
mod modules;
mod tensor;
pub use bool_tensor::*;
pub use int_tensor::*;
pub use modules::*;
pub use tensor::*;

View File

@ -8,10 +8,6 @@ pub trait TensorOps<B: Backend> {
data: Data<B::FloatElem, D>,
device: &B::Device,
) -> B::TensorPrimitive<D>;
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: &B::Device,
) -> B::BoolTensorPrimitive<D>;
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<B::FloatElem>,
@ -26,25 +22,6 @@ pub trait TensorOps<B: Backend> {
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
fn bool_shape<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Shape<D>;
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
fn bool_into_int<const D: usize>(
tensor: B::BoolTensorPrimitive<D>,
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
fn bool_device<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> B::Device;
fn bool_to_device<const D: usize>(
tensor: B::BoolTensorPrimitive<D>,
device: &B::Device,
) -> B::BoolTensorPrimitive<D>;
fn bool_reshape<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
shape: Shape<D2>,
) -> B::BoolTensorPrimitive<D2>;
fn bool_index<const D1: usize, const D2: usize>(
tensor: B::BoolTensorPrimitive<D1>,
indexes: [Range<usize>; D2],
) -> B::BoolTensorPrimitive<D1>;
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
fn to_device<const D: usize>(
tensor: B::TensorPrimitive<D>,