mirror of https://github.com/tracel-ai/burn.git
Refactor/backend bool tensor (#192)
This commit is contained in:
parent
ffd3d35176
commit
15ec42dd6f
|
@ -0,0 +1,83 @@
|
|||
use crate::{
|
||||
tensor::{BoolTensor, IntTensor},
|
||||
ADBackendDecorator,
|
||||
};
|
||||
|
||||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
|
||||
|
||||
impl<B: Backend> BoolTensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||
fn bool_from_data<const D: usize>(data: Data<bool, D>, device: &B::Device) -> BoolTensor<B, D> {
|
||||
B::bool_from_data(data, device)
|
||||
}
|
||||
|
||||
fn bool_shape<const D: usize>(tensor: &BoolTensor<B, D>) -> Shape<D> {
|
||||
B::bool_shape(tensor)
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Data<bool, D> {
|
||||
B::bool_to_data(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Data<bool, D> {
|
||||
B::bool_into_data(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, D> {
|
||||
B::bool_into_int(tensor)
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
device: &B::Device,
|
||||
) -> BoolTensor<B, D> {
|
||||
B::bool_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn bool_device<const D: usize>(tensor: &BoolTensor<B, D>) -> B::Device {
|
||||
B::bool_device(tensor)
|
||||
}
|
||||
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<B, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> BoolTensor<B, D2> {
|
||||
B::bool_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<B, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
) -> BoolTensor<B, D1> {
|
||||
B::bool_index(tensor, indexes)
|
||||
}
|
||||
|
||||
fn bool_empty<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &<ADBackendDecorator<B> as Backend>::Device,
|
||||
) -> BoolTensor<B, D> {
|
||||
B::bool_empty(shape, device)
|
||||
}
|
||||
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
value: <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1>,
|
||||
) -> <ADBackendDecorator<B> as Backend>::BoolTensorPrimitive<D1> {
|
||||
B::bool_index_assign(tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn bool_cat<const D: usize>(tensors: Vec<BoolTensor<B, D>>, dim: usize) -> BoolTensor<B, D> {
|
||||
B::bool_cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn bool_equal<const D: usize>(
|
||||
lhs: BoolTensor<B, D>,
|
||||
rhs: BoolTensor<B, D>,
|
||||
) -> BoolTensor<B, D> {
|
||||
B::bool_equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_equal_elem<const D: usize>(lhs: BoolTensor<B, D>, rhs: bool) -> BoolTensor<B, D> {
|
||||
B::bool_equal_elem(lhs, rhs)
|
||||
}
|
||||
}
|
|
@ -1,5 +1,6 @@
|
|||
mod backward;
|
||||
mod base;
|
||||
mod bool_tensor;
|
||||
mod module;
|
||||
mod tensor;
|
||||
|
||||
|
|
|
@ -19,10 +19,6 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
ADTensor::new(B::from_data(data, device))
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(data: Data<bool, D>, device: &B::Device) -> BoolTensor<B, D> {
|
||||
B::from_data_bool(data, device)
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: burn_tensor::Distribution<FloatElem<B>>,
|
||||
|
@ -51,47 +47,6 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
B::into_data(tensor.primitive)
|
||||
}
|
||||
|
||||
fn bool_shape<const D: usize>(tensor: &BoolTensor<B, D>) -> Shape<D> {
|
||||
B::bool_shape(tensor)
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(tensor: &BoolTensor<B, D>) -> Data<bool, D> {
|
||||
B::bool_to_data(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: BoolTensor<B, D>) -> Data<bool, D> {
|
||||
B::bool_into_data(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(tensor: BoolTensor<B, D>) -> IntTensor<B, D> {
|
||||
B::bool_into_int(tensor)
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: BoolTensor<B, D>,
|
||||
device: &B::Device,
|
||||
) -> BoolTensor<B, D> {
|
||||
B::bool_to_device(tensor, device)
|
||||
}
|
||||
|
||||
fn bool_device<const D: usize>(tensor: &BoolTensor<B, D>) -> B::Device {
|
||||
B::bool_device(tensor)
|
||||
}
|
||||
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<B, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> BoolTensor<B, D2> {
|
||||
B::bool_reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
tensor: BoolTensor<B, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
) -> BoolTensor<B, D1> {
|
||||
B::bool_index(tensor, indexes)
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &ADTensor<B, D>) -> B::Device {
|
||||
B::device(&tensor.primitive)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::{marker::PhantomData, ops::Range};
|
||||
|
||||
use burn_tensor::Shape;
|
||||
use ndarray::Axis;
|
||||
use ndarray::Dim;
|
||||
use ndarray::IxDyn;
|
||||
use ndarray::SliceInfoElem;
|
||||
|
||||
use crate::{tensor::NdArrayTensor, to_nd_array_tensor};
|
||||
|
||||
pub struct NdArrayOps<E> {
|
||||
e: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E> NdArrayOps<E>
|
||||
where
|
||||
E: Copy,
|
||||
{
|
||||
pub fn index<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
let slices = Self::to_slice_args::<D1, D2>(indexes);
|
||||
let array = tensor.array.slice_move(slices.as_slice()).into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
pub fn index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
value: NdArrayTensor<E, D1>,
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
let slices = Self::to_slice_args::<D1, D2>(indexes);
|
||||
let mut array = tensor.array.to_owned();
|
||||
array.slice_mut(slices.as_slice()).assign(&value.array);
|
||||
let array = array.into_owned().into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
pub fn reshape<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> NdArrayTensor<E, D2> {
|
||||
match D2 {
|
||||
1 => to_nd_array_tensor!(1, shape, tensor.array),
|
||||
2 => to_nd_array_tensor!(2, shape, tensor.array),
|
||||
3 => to_nd_array_tensor!(3, shape, tensor.array),
|
||||
4 => to_nd_array_tensor!(4, shape, tensor.array),
|
||||
5 => to_nd_array_tensor!(5, shape, tensor.array),
|
||||
6 => to_nd_array_tensor!(6, shape, tensor.array),
|
||||
_ => panic!("NdArrayTensor support only 6 dimensions."),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cat<const D: usize>(
|
||||
tensors: Vec<NdArrayTensor<E, D>>,
|
||||
dim: usize,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
let arrays: Vec<ndarray::ArrayView<E, IxDyn>> =
|
||||
tensors.iter().map(|t| t.array.view()).collect();
|
||||
let array = ndarray::concatenate(Axis(dim), &arrays)
|
||||
.unwrap()
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn to_slice_args<const D1: usize, const D2: usize>(
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> [SliceInfoElem; D1] {
|
||||
let mut slices = [SliceInfoElem::NewAxis; D1];
|
||||
for i in 0..D1 {
|
||||
if i >= D2 {
|
||||
slices[i] = SliceInfoElem::Slice {
|
||||
start: 0,
|
||||
end: None,
|
||||
step: 1,
|
||||
}
|
||||
} else {
|
||||
slices[i] = SliceInfoElem::Slice {
|
||||
start: indexes[i].start as isize,
|
||||
end: Some(indexes[i].end as isize),
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
slices
|
||||
}
|
||||
}
|
|
@ -0,0 +1,120 @@
|
|||
// Language
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use burn_tensor::ops::BoolTensorOps;
|
||||
use core::ops::Range;
|
||||
|
||||
// Current crate
|
||||
use crate::element::NdArrayElement;
|
||||
use crate::NdArrayDevice;
|
||||
use crate::{tensor::NdArrayTensor, NdArrayBackend};
|
||||
|
||||
// Workspace crates
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Shape};
|
||||
|
||||
use super::NdArrayOps;
|
||||
|
||||
impl<E: NdArrayElement> BoolTensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||
fn bool_from_data<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
_device: &NdArrayDevice,
|
||||
) -> NdArrayTensor<bool, D> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn bool_shape<const D: usize>(
|
||||
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Shape<D> {
|
||||
tensor.shape()
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(
|
||||
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Data<bool, D> {
|
||||
let values = tensor.array.iter().map(Clone::clone).collect();
|
||||
Data::new(values, tensor.shape())
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(
|
||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Data<bool, D> {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
Data::new(values, shape)
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: NdArrayTensor<bool, D>,
|
||||
_device: &NdArrayDevice,
|
||||
) -> NdArrayTensor<bool, D> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<bool, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> NdArrayTensor<bool, D2> {
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<bool, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> NdArrayTensor<bool, D1> {
|
||||
NdArrayOps::index(tensor, indexes)
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(
|
||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <<NdArrayBackend<E> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
|
||||
let data = Self::bool_into_data(tensor);
|
||||
NdArrayBackend::<i64>::from_data(data.convert(), &NdArrayDevice::Cpu)
|
||||
}
|
||||
|
||||
fn bool_device<const D: usize>(
|
||||
_tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <NdArrayBackend<E> as Backend>::Device {
|
||||
NdArrayDevice::Cpu
|
||||
}
|
||||
|
||||
fn bool_empty<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
_device: &<NdArrayBackend<E> as Backend>::Device,
|
||||
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
|
||||
let values = vec![false; shape.num_elements()];
|
||||
NdArrayTensor::from_data(Data::new(values, shape))
|
||||
}
|
||||
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
value: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1>,
|
||||
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D1> {
|
||||
NdArrayOps::index_assign(tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn bool_cat<const D: usize>(
|
||||
tensors: Vec<<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>>,
|
||||
dim: usize,
|
||||
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
|
||||
NdArrayOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn bool_equal<const D: usize>(
|
||||
lhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
rhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
|
||||
let mut array = lhs.array;
|
||||
array.zip_mut_with(&rhs.array, |a, b| *a = *a && *b);
|
||||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn bool_equal_elem<const D: usize>(
|
||||
lhs: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
rhs: bool,
|
||||
) -> <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D> {
|
||||
let array = lhs.array.mapv(|a| a == rhs).into_shared();
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
}
|
|
@ -1,6 +1,10 @@
|
|||
mod base;
|
||||
mod bool_tensor;
|
||||
mod module;
|
||||
mod tensor;
|
||||
|
||||
pub(crate) mod conv;
|
||||
pub(crate) mod maxpool;
|
||||
pub(crate) mod padding;
|
||||
|
||||
pub(crate) use base::*;
|
||||
|
|
|
@ -6,7 +6,7 @@ use core::ops::Range;
|
|||
// Current crate
|
||||
use crate::tensor::BatchMatrix;
|
||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||
use crate::{to_nd_array_tensor, NdArrayDevice, SEED};
|
||||
use crate::{NdArrayDevice, SEED};
|
||||
|
||||
// Workspace crates
|
||||
use burn_common::rand::get_seeded_rng;
|
||||
|
@ -15,7 +15,9 @@ use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Sha
|
|||
|
||||
// External crates
|
||||
use libm::{cos, erf, sin, tanh};
|
||||
use ndarray::{Axis, Dim, IxDyn, SliceInfoElem};
|
||||
use ndarray::Axis;
|
||||
|
||||
use super::NdArrayOps;
|
||||
|
||||
macro_rules! keepdim {
|
||||
(
|
||||
|
@ -47,13 +49,6 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
_device: &NdArrayDevice,
|
||||
) -> NdArrayTensor<bool, D> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<E>,
|
||||
|
@ -91,49 +86,6 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
Data::new(values, shape)
|
||||
}
|
||||
|
||||
fn bool_shape<const D: usize>(
|
||||
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Shape<D> {
|
||||
tensor.shape()
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(
|
||||
tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Data<bool, D> {
|
||||
let values = tensor.array.iter().map(Clone::clone).collect();
|
||||
Data::new(values, tensor.shape())
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(
|
||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Data<bool, D> {
|
||||
let shape = tensor.shape();
|
||||
let values = tensor.array.into_iter().collect();
|
||||
Data::new(values, shape)
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: NdArrayTensor<bool, D>,
|
||||
_device: &NdArrayDevice,
|
||||
) -> NdArrayTensor<bool, D> {
|
||||
tensor
|
||||
}
|
||||
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<bool, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> NdArrayTensor<bool, D2> {
|
||||
match D2 {
|
||||
1 => to_nd_array_tensor!(bool, 1, shape, tensor.array),
|
||||
2 => to_nd_array_tensor!(bool, 2, shape, tensor.array),
|
||||
3 => to_nd_array_tensor!(bool, 3, shape, tensor.array),
|
||||
4 => to_nd_array_tensor!(bool, 4, shape, tensor.array),
|
||||
5 => to_nd_array_tensor!(bool, 5, shape, tensor.array),
|
||||
6 => to_nd_array_tensor!(bool, 6, shape, tensor.array),
|
||||
_ => panic!("NdArrayTensor support only 6 dimensions."),
|
||||
}
|
||||
}
|
||||
|
||||
fn device<const D: usize>(_tensor: &NdArrayTensor<E, D>) -> NdArrayDevice {
|
||||
NdArrayDevice::Cpu
|
||||
}
|
||||
|
@ -263,35 +215,14 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
tensor: NdArrayTensor<E, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> NdArrayTensor<E, D2> {
|
||||
match D2 {
|
||||
1 => to_nd_array_tensor!(1, shape, tensor.array),
|
||||
2 => to_nd_array_tensor!(2, shape, tensor.array),
|
||||
3 => to_nd_array_tensor!(3, shape, tensor.array),
|
||||
4 => to_nd_array_tensor!(4, shape, tensor.array),
|
||||
5 => to_nd_array_tensor!(5, shape, tensor.array),
|
||||
6 => to_nd_array_tensor!(6, shape, tensor.array),
|
||||
_ => panic!("NdArrayTensor support only 6 dimensions."),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<bool, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> NdArrayTensor<bool, D1> {
|
||||
let slices = to_slice_args::<D1, D2>(indexes);
|
||||
let array = tensor.array.slice_move(slices.as_slice()).into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
NdArrayOps::reshape(tensor, shape)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
tensor: NdArrayTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
let slices = to_slice_args::<D1, D2>(indexes);
|
||||
let array = tensor.array.slice_move(slices.as_slice()).into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
NdArrayOps::index(tensor, indexes)
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
|
@ -299,12 +230,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
indexes: [Range<usize>; D2],
|
||||
value: NdArrayTensor<E, D1>,
|
||||
) -> NdArrayTensor<E, D1> {
|
||||
let slices = to_slice_args::<D1, D2>(indexes);
|
||||
let mut array = tensor.array.to_owned();
|
||||
array.slice_mut(slices.as_slice()).assign(&value.array);
|
||||
let array = array.into_owned().into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
NdArrayOps::index_assign(tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn mask_fill<const D: usize>(
|
||||
|
@ -532,13 +458,7 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
}
|
||||
|
||||
fn cat<const D: usize>(tensors: Vec<NdArrayTensor<E, D>>, dim: usize) -> NdArrayTensor<E, D> {
|
||||
let arrays: Vec<ndarray::ArrayView<E, IxDyn>> =
|
||||
tensors.iter().map(|t| t.array.view()).collect();
|
||||
let array = ndarray::concatenate(Axis(dim), &arrays)
|
||||
.unwrap()
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
NdArrayOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn relu<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
|
@ -553,41 +473,6 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
|
||||
NdArrayTensor { array }
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(
|
||||
tensor: <NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <<NdArrayBackend<E> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
|
||||
let data = Self::bool_into_data(tensor);
|
||||
NdArrayBackend::<i64>::from_data(data.convert(), &NdArrayDevice::Cpu)
|
||||
}
|
||||
|
||||
fn bool_device<const D: usize>(
|
||||
_tensor: &<NdArrayBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <NdArrayBackend<E> as Backend>::Device {
|
||||
NdArrayDevice::Cpu
|
||||
}
|
||||
}
|
||||
|
||||
fn to_slice_args<const D1: usize, const D2: usize>(
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> [SliceInfoElem; D1] {
|
||||
let mut slices = [SliceInfoElem::NewAxis; D1];
|
||||
for i in 0..D1 {
|
||||
if i >= D2 {
|
||||
slices[i] = SliceInfoElem::Slice {
|
||||
start: 0,
|
||||
end: None,
|
||||
step: 1,
|
||||
}
|
||||
} else {
|
||||
slices[i] = SliceInfoElem::Slice {
|
||||
start: indexes[i].start as isize,
|
||||
end: Some(indexes[i].end as isize),
|
||||
step: 1,
|
||||
}
|
||||
}
|
||||
}
|
||||
slices
|
||||
}
|
||||
|
||||
fn mean_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
|
||||
|
|
|
@ -0,0 +1,75 @@
|
|||
use std::{marker::PhantomData, ops::Range, sync::Arc};
|
||||
|
||||
use crate::{to_tensor, TchShape, TchTensor};
|
||||
|
||||
pub struct TchOps<E: tch::kind::Element + Copy + Default> {
|
||||
e: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
||||
pub fn index<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> TchTensor<E, D1> {
|
||||
let kind = tensor.kind;
|
||||
|
||||
let mut tensor = tensor.tensor.shallow_clone();
|
||||
|
||||
for (i, index) in indexes.iter().enumerate().take(D2) {
|
||||
let start = index.start as i64;
|
||||
let length = (index.end - index.start) as i64;
|
||||
tensor = tensor.narrow(i as i64, start, length);
|
||||
}
|
||||
let tensor = Arc::new(tensor);
|
||||
|
||||
TchTensor { kind, tensor }
|
||||
}
|
||||
|
||||
pub fn index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
value: TchTensor<E, D1>,
|
||||
) -> TchTensor<E, D1> {
|
||||
let kind = tensor.kind;
|
||||
let tensor_original = tensor.tensor.copy();
|
||||
let tch_shape = TchShape::from(tensor.shape());
|
||||
|
||||
let mut tensor = tensor_original.view_(&tch_shape.dims);
|
||||
|
||||
for (i, index) in indexes.into_iter().enumerate().take(D2) {
|
||||
let start = index.start as i64;
|
||||
let length = (index.end - index.start) as i64;
|
||||
|
||||
tensor = tensor.narrow(i as i64, start, length);
|
||||
}
|
||||
|
||||
tensor.copy_(&value.tensor);
|
||||
|
||||
TchTensor {
|
||||
kind,
|
||||
tensor: Arc::new(tensor_original),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
|
||||
let tensors: Vec<tch::Tensor> = tensors
|
||||
.into_iter()
|
||||
.map(|t| t.tensor.shallow_clone())
|
||||
.collect();
|
||||
let tensor = tch::Tensor::cat(&tensors, dim as i64);
|
||||
|
||||
to_tensor(tensor)
|
||||
}
|
||||
|
||||
pub fn equal<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<bool, D> {
|
||||
let tensor = TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| lhs.eq_tensor(rhs),
|
||||
);
|
||||
|
||||
to_tensor(tensor)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,123 @@
|
|||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use burn_tensor::{backend::Backend, ops::BoolTensorOps, Data, Shape};
|
||||
|
||||
use crate::{element::TchElement, to_tensor, TchBackend, TchDevice, TchKind, TchShape, TchTensor};
|
||||
|
||||
use super::TchOps;
|
||||
|
||||
impl<E: TchElement> BoolTensorOps<TchBackend<E>> for TchBackend<E> {
|
||||
fn bool_from_data<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: &TchDevice,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchTensor::from_data(data, (*device).into())
|
||||
}
|
||||
|
||||
fn bool_shape<const D: usize>(tensor: &TchTensor<bool, D>) -> Shape<D> {
|
||||
tensor.shape()
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(tensor: &TchTensor<bool, D>) -> Data<bool, D> {
|
||||
let values: Vec<bool> = tensor.tensor.shallow_clone().into();
|
||||
Data::new(values, tensor.shape())
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(tensor: TchTensor<bool, D>) -> Data<bool, D> {
|
||||
let shape = tensor.shape();
|
||||
let values: Vec<bool> = tensor.unary_ops(
|
||||
|tensor| tensor.into(),
|
||||
|tensor| tensor.shallow_clone().into(),
|
||||
);
|
||||
Data::new(values, shape)
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: TchTensor<bool, D>,
|
||||
device: &TchDevice,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchTensor {
|
||||
kind: tensor.kind,
|
||||
tensor: Arc::new(tensor.tensor.to((*device).into())),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<bool, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> TchTensor<bool, D2> {
|
||||
let shape_tch: TchShape<D2> = shape.into();
|
||||
let tensor = tensor.unary_ops(
|
||||
|mut tensor| tensor.resize_(&shape_tch.dims),
|
||||
|tensor| tensor.reshape(&shape_tch.dims),
|
||||
);
|
||||
|
||||
TchTensor {
|
||||
tensor: Arc::new(tensor),
|
||||
kind: TchKind::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_device<const D: usize>(tensor: &TchTensor<bool, D>) -> TchDevice {
|
||||
tensor.tensor.device().into()
|
||||
}
|
||||
|
||||
fn bool_empty<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
device: &<TchBackend<E> as Backend>::Device,
|
||||
) -> TchTensor<bool, D> {
|
||||
let tensor = tch::Tensor::empty(
|
||||
&shape.dims.map(|a| a as i64),
|
||||
(tch::Kind::Bool, (*device).into()),
|
||||
);
|
||||
|
||||
to_tensor(tensor)
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<bool, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> TchTensor<bool, D1> {
|
||||
TchOps::index(tensor, indexes)
|
||||
}
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<bool, D1>,
|
||||
indexes: [std::ops::Range<usize>; D2],
|
||||
value: TchTensor<bool, D1>,
|
||||
) -> TchTensor<bool, D1> {
|
||||
TchOps::index_assign(tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn bool_cat<const D: usize>(
|
||||
tensors: Vec<TchTensor<bool, D>>,
|
||||
dim: usize,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn bool_equal<const D: usize>(
|
||||
lhs: TchTensor<bool, D>,
|
||||
rhs: TchTensor<bool, D>,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchOps::equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn bool_equal_elem<const D: usize>(lhs: TchTensor<bool, D>, rhs: bool) -> TchTensor<bool, D> {
|
||||
let rhs = match rhs {
|
||||
true => 1,
|
||||
false => 0,
|
||||
};
|
||||
let tensor = lhs.unary_ops(
|
||||
|mut tensor| tensor.eq_(rhs).to_kind(tch::Kind::Bool),
|
||||
|tensor| tensor.eq(rhs),
|
||||
);
|
||||
to_tensor(tensor)
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(
|
||||
tensor: TchTensor<bool, D>,
|
||||
) -> <<TchBackend<E> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
|
||||
let tensor = tensor.tensor.to_kind(TchKind::<i64>::new().kind());
|
||||
to_tensor(tensor)
|
||||
}
|
||||
}
|
|
@ -1,2 +1,6 @@
|
|||
mod base;
|
||||
mod bool_tensor;
|
||||
mod module;
|
||||
mod tensor;
|
||||
|
||||
pub(crate) use base::*;
|
||||
|
|
|
@ -2,6 +2,8 @@ use crate::{element::TchElement, to_tensor, TchBackend, TchDevice, TchKind, TchS
|
|||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Shape};
|
||||
use std::{ops::Range, sync::Arc};
|
||||
|
||||
use super::TchOps;
|
||||
|
||||
macro_rules! run_scalar {
|
||||
(
|
||||
$scalar:ident,
|
||||
|
@ -26,13 +28,6 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
TchTensor::from_data(data, (*device).into())
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: &TchDevice,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchTensor::from_data(data, (*device).into())
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<E>,
|
||||
|
@ -114,60 +109,6 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
Data::new(values, shape)
|
||||
}
|
||||
|
||||
fn bool_shape<const D: usize>(
|
||||
tensor: &<TchBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Shape<D> {
|
||||
tensor.shape()
|
||||
}
|
||||
|
||||
fn bool_to_data<const D: usize>(
|
||||
tensor: &<TchBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Data<bool, D> {
|
||||
let values: Vec<bool> = tensor.tensor.shallow_clone().into();
|
||||
Data::new(values, tensor.shape())
|
||||
}
|
||||
|
||||
fn bool_into_data<const D: usize>(
|
||||
tensor: <TchBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> Data<bool, D> {
|
||||
let shape = tensor.shape();
|
||||
let values: Vec<bool> = tensor.unary_ops(
|
||||
|tensor| tensor.into(),
|
||||
|tensor| tensor.shallow_clone().into(),
|
||||
);
|
||||
Data::new(values, shape)
|
||||
}
|
||||
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: TchTensor<bool, D>,
|
||||
device: &TchDevice,
|
||||
) -> TchTensor<bool, D> {
|
||||
TchTensor {
|
||||
kind: tensor.kind,
|
||||
tensor: Arc::new(tensor.tensor.to((*device).into())),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<bool, D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> TchTensor<bool, D2> {
|
||||
let shape_tch: TchShape<D2> = shape.into();
|
||||
let tensor = tensor.unary_ops(
|
||||
|mut tensor| tensor.resize_(&shape_tch.dims),
|
||||
|tensor| tensor.reshape(&shape_tch.dims),
|
||||
);
|
||||
|
||||
TchTensor {
|
||||
tensor: Arc::new(tensor),
|
||||
kind: TchKind::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_device<const D: usize>(tensor: &TchTensor<bool, D>) -> TchDevice {
|
||||
tensor.tensor.device().into()
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &TchTensor<E, D>) -> TchDevice {
|
||||
tensor.tensor.device().into()
|
||||
}
|
||||
|
@ -315,18 +256,11 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
to_tensor(tensor)
|
||||
}
|
||||
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<bool, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> TchTensor<bool, D1> {
|
||||
index(&tensor, indexes)
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize>(
|
||||
tensor: TchTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> TchTensor<E, D1> {
|
||||
index(&tensor, indexes)
|
||||
TchOps::index(tensor, indexes)
|
||||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
|
@ -334,25 +268,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
indexes: [Range<usize>; D2],
|
||||
value: TchTensor<E, D1>,
|
||||
) -> <TchBackend<E> as Backend>::TensorPrimitive<D1> {
|
||||
let kind = tensor.kind;
|
||||
let tensor_original = tensor.tensor.copy();
|
||||
let tch_shape = TchShape::from(tensor.shape());
|
||||
|
||||
let mut tensor = tensor_original.view_(&tch_shape.dims);
|
||||
|
||||
for (i, index) in indexes.into_iter().enumerate().take(D2) {
|
||||
let start = index.start as i64;
|
||||
let length = (index.end - index.start) as i64;
|
||||
|
||||
tensor = tensor.narrow(i as i64, start, length);
|
||||
}
|
||||
|
||||
tensor.copy_(&value.tensor);
|
||||
|
||||
TchTensor {
|
||||
kind,
|
||||
tensor: Arc::new(tensor_original),
|
||||
}
|
||||
TchOps::index_assign(tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn mask_fill<const D: usize>(
|
||||
|
@ -370,15 +286,7 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
}
|
||||
|
||||
fn equal<const D: usize>(lhs: TchTensor<E, D>, rhs: TchTensor<E, D>) -> TchTensor<bool, D> {
|
||||
let tensor = TchTensor::binary_ops_tensor(
|
||||
lhs,
|
||||
rhs,
|
||||
|lhs, rhs| lhs.eq_tensor_(rhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| rhs.eq_tensor_(lhs).to_kind(tch::Kind::Bool),
|
||||
|lhs, rhs| lhs.eq_tensor(rhs),
|
||||
);
|
||||
|
||||
to_tensor(tensor)
|
||||
TchOps::equal(lhs, rhs)
|
||||
}
|
||||
|
||||
fn equal_scalar<const D: usize>(lhs: TchTensor<E, D>, rhs: E) -> TchTensor<bool, D> {
|
||||
|
@ -570,40 +478,10 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
}
|
||||
|
||||
fn cat<const D: usize>(tensors: Vec<TchTensor<E, D>>, dim: usize) -> TchTensor<E, D> {
|
||||
let tensors: Vec<tch::Tensor> = tensors
|
||||
.into_iter()
|
||||
.map(|t| t.tensor.shallow_clone())
|
||||
.collect();
|
||||
let tensor = tch::Tensor::cat(&tensors, dim as i64);
|
||||
to_tensor(tensor)
|
||||
TchOps::cat(tensors, dim)
|
||||
}
|
||||
|
||||
fn relu<const D: usize>(tensor: TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
to_tensor(tensor.unary_ops(|mut tensor| tensor.relu_(), |tensor| tensor.relu()))
|
||||
}
|
||||
|
||||
fn bool_into_int<const D: usize>(
|
||||
tensor: <TchBackend<E> as Backend>::BoolTensorPrimitive<D>,
|
||||
) -> <<TchBackend<E> as Backend>::IntegerBackend as Backend>::TensorPrimitive<D> {
|
||||
let tensor = tensor.tensor.to_kind(TchKind::<i64>::new().kind());
|
||||
to_tensor(tensor)
|
||||
}
|
||||
}
|
||||
|
||||
fn index<const D1: usize, const D2: usize, E: tch::kind::Element + Copy>(
|
||||
tensor: &TchTensor<E, D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> TchTensor<E, D1> {
|
||||
let kind = tensor.kind;
|
||||
|
||||
let mut tensor = tensor.tensor.shallow_clone();
|
||||
|
||||
for (i, index) in indexes.iter().enumerate().take(D2) {
|
||||
let start = index.start as i64;
|
||||
let length = (index.end - index.start) as i64;
|
||||
tensor = tensor.narrow(i as i64, start, length);
|
||||
}
|
||||
let tensor = Arc::new(tensor);
|
||||
|
||||
TchTensor { kind, tensor }
|
||||
}
|
||||
|
|
|
@ -382,11 +382,11 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
}
|
||||
|
||||
fn index_assign<const D1: usize, const D2: usize>(
|
||||
_tensor: Self::Primitive<D1>,
|
||||
_indexes: [Range<usize>; D2],
|
||||
_value: Self::Primitive<D1>,
|
||||
tensor: Self::Primitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
value: Self::Primitive<D1>,
|
||||
) -> Self::Primitive<D1> {
|
||||
todo!("Index assigned is not yet implemented for bool tensor")
|
||||
B::bool_index_assign(tensor, indexes, value)
|
||||
}
|
||||
|
||||
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
|
||||
|
@ -408,31 +408,32 @@ impl<B: Backend> BasicOps<B> for Bool {
|
|||
data: Data<Self::Elem, D>,
|
||||
device: &B::Device,
|
||||
) -> Self::Primitive<D> {
|
||||
B::from_data_bool(data, device)
|
||||
B::bool_from_data(data, device)
|
||||
}
|
||||
|
||||
fn repeat<const D: usize>(
|
||||
_tensor: Self::Primitive<D>,
|
||||
_dim: usize,
|
||||
_times: usize,
|
||||
tensor: Self::Primitive<D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> Self::Primitive<D> {
|
||||
todo!("Repeat operation is not yet implemented for bool tensor");
|
||||
B::bool_repeat(tensor, dim, times)
|
||||
}
|
||||
|
||||
fn equal<const D: usize>(
|
||||
_lhs: Self::Primitive<D>,
|
||||
_rhs: Self::Primitive<D>,
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Primitive<D>,
|
||||
) -> Tensor<B, D, Bool> {
|
||||
todo!("Equal operation is not yet implemented for bool tensor");
|
||||
Tensor::new(B::bool_equal(lhs, rhs))
|
||||
}
|
||||
|
||||
fn equal_scalar<const D: usize>(
|
||||
_lhs: Self::Primitive<D>,
|
||||
_rhs: Self::Elem,
|
||||
lhs: Self::Primitive<D>,
|
||||
rhs: Self::Elem,
|
||||
) -> Tensor<B, D, Bool> {
|
||||
todo!("Equal scalar operation is not yet implemented for bool tensor");
|
||||
Tensor::new(B::bool_equal_elem(lhs, rhs))
|
||||
}
|
||||
fn cat<const D: usize>(_vectors: Vec<Self::Primitive<D>>, _dim: usize) -> Self::Primitive<D> {
|
||||
todo!("Cat vectors operation is not yet implemented for bool tensor");
|
||||
|
||||
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
|
||||
B::bool_cat(vectors, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -6,12 +6,12 @@ where
|
|||
{
|
||||
/// Create a boolean tensor from data.
|
||||
pub fn from_bool(data: Data<bool, D>) -> Self {
|
||||
Self::new(B::from_data_bool(data, &B::Device::default()))
|
||||
Self::new(B::bool_from_data(data, &B::Device::default()))
|
||||
}
|
||||
|
||||
/// Create a boolean tensor from data on the given device.
|
||||
pub fn from_bool_device(data: Data<bool, D>, device: &B::Device) -> Self {
|
||||
Self::new(B::from_data_bool(data, device))
|
||||
Self::new(B::bool_from_data(data, device))
|
||||
}
|
||||
|
||||
/// Convert the bool tensor into an int tensor.
|
||||
|
|
|
@ -5,6 +5,7 @@ use crate::tensor::Element;
|
|||
|
||||
pub trait Backend:
|
||||
TensorOps<Self>
|
||||
+ BoolTensorOps<Self>
|
||||
+ ModuleOps<Self>
|
||||
+ Clone
|
||||
+ Sized
|
||||
|
|
|
@ -0,0 +1,80 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::ops::Range;
|
||||
|
||||
use crate::{backend::Backend, tensor::Shape, Data};
|
||||
|
||||
pub trait BoolTensorOps<B: Backend> {
|
||||
/// Bool version of empty, see [tensor](crate::Tensor).
|
||||
fn bool_empty<const D: usize>(shape: Shape<D>, device: &B::Device)
|
||||
-> B::BoolTensorPrimitive<D>;
|
||||
fn bool_shape<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Shape<D>;
|
||||
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
|
||||
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D> {
|
||||
Self::bool_into_data(tensor.clone())
|
||||
}
|
||||
fn bool_from_data<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: &B::Device,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn bool_into_int<const D: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D>,
|
||||
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
|
||||
fn bool_device<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> B::Device;
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D>,
|
||||
device: &B::Device,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> B::BoolTensorPrimitive<D2>;
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> B::BoolTensorPrimitive<D1>;
|
||||
fn bool_index_assign<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
value: B::BoolTensorPrimitive<D1>,
|
||||
) -> B::BoolTensorPrimitive<D1>;
|
||||
fn bool_repeat<const D: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D>,
|
||||
dim: usize,
|
||||
times: usize,
|
||||
) -> B::BoolTensorPrimitive<D> {
|
||||
let mut shape = Self::bool_shape(&tensor);
|
||||
if shape.dims[dim] != 1 {
|
||||
panic!("Can only repeat dimension with dim=1");
|
||||
}
|
||||
shape.dims[dim] = times;
|
||||
|
||||
let mut i = 0;
|
||||
let indexes_select_all = [0; D].map(|_| {
|
||||
let start = 0;
|
||||
let end = shape.dims[i];
|
||||
i += 1;
|
||||
start..end
|
||||
});
|
||||
|
||||
let mut tensor_output = Self::bool_empty(shape, &Self::bool_device(&tensor));
|
||||
for i in 0..times {
|
||||
let mut indexes = indexes_select_all.clone();
|
||||
indexes[dim] = i..i + 1;
|
||||
tensor_output = Self::bool_index_assign(tensor_output, indexes, tensor.clone());
|
||||
}
|
||||
|
||||
tensor_output
|
||||
}
|
||||
fn bool_cat<const D: usize>(
|
||||
tensors: Vec<B::BoolTensorPrimitive<D>>,
|
||||
dim: usize,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn bool_equal<const D: usize>(
|
||||
lhs: B::BoolTensorPrimitive<D>,
|
||||
rhs: B::BoolTensorPrimitive<D>,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn bool_equal_elem<const D: usize>(
|
||||
lhs: B::BoolTensorPrimitive<D>,
|
||||
rhs: bool,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
}
|
|
@ -0,0 +1 @@
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
mod bool_tensor;
|
||||
mod int_tensor;
|
||||
mod modules;
|
||||
mod tensor;
|
||||
|
||||
pub use bool_tensor::*;
|
||||
pub use int_tensor::*;
|
||||
pub use modules::*;
|
||||
pub use tensor::*;
|
||||
|
|
|
@ -8,10 +8,6 @@ pub trait TensorOps<B: Backend> {
|
|||
data: Data<B::FloatElem, D>,
|
||||
device: &B::Device,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: &B::Device,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<B::FloatElem>,
|
||||
|
@ -26,25 +22,6 @@ pub trait TensorOps<B: Backend> {
|
|||
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Shape<D>;
|
||||
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
|
||||
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::FloatElem, D>;
|
||||
fn bool_shape<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Shape<D>;
|
||||
fn bool_to_data<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> Data<bool, D>;
|
||||
fn bool_into_data<const D: usize>(tensor: B::BoolTensorPrimitive<D>) -> Data<bool, D>;
|
||||
fn bool_into_int<const D: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D>,
|
||||
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
|
||||
fn bool_device<const D: usize>(tensor: &B::BoolTensorPrimitive<D>) -> B::Device;
|
||||
fn bool_to_device<const D: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D>,
|
||||
device: &B::Device,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn bool_reshape<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
shape: Shape<D2>,
|
||||
) -> B::BoolTensorPrimitive<D2>;
|
||||
fn bool_index<const D1: usize, const D2: usize>(
|
||||
tensor: B::BoolTensorPrimitive<D1>,
|
||||
indexes: [Range<usize>; D2],
|
||||
) -> B::BoolTensorPrimitive<D1>;
|
||||
fn device<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::Device;
|
||||
fn to_device<const D: usize>(
|
||||
tensor: B::TensorPrimitive<D>,
|
||||
|
|
Loading…
Reference in New Issue