refactor: backends (#124)

This commit is contained in:
Nathaniel Simard 2022-12-02 19:28:34 -05:00 committed by GitHub
parent eee90a5c9e
commit b99b23e1a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 157 additions and 172 deletions

View File

@ -1,7 +1,6 @@
use crate::graph::grad::Grads;
use crate::tensor::ADTensor;
use burn_tensor::backend::{ADBackend, Backend};
use burn_tensor::{Data, Distribution, Shape};
#[derive(Clone, Copy, Debug, Default)]
pub struct ADBackendDecorator<B> {
@ -17,41 +16,10 @@ impl<B: Backend> Backend for ADBackendDecorator<B> {
type TensorPrimitive<const D: usize> = ADTensor<D, B>;
type BoolTensorPrimitive<const D: usize> = B::BoolTensorPrimitive<D>;
fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
let tensor = B::from_data(data, device);
ADTensor::from_tensor(tensor)
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: Self::Device,
) -> Self::BoolTensorPrimitive<D> {
B::from_data_bool(data, device)
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
Self::from_inner(B::random(shape, distribution, device))
}
fn ad_enabled() -> bool {
true
}
fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_inner(B::zeros(shape, device))
}
fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_inner(B::ones(shape, device))
}
fn name() -> String {
format!("autodiff<{}>", B::name())
}

View File

@ -6,7 +6,7 @@ use crate::ops::unary_ops_wrapper_explicit;
use crate::tensor::ADTensor;
use crate::ADBackendDecorator;
use burn_tensor::backend::Backend;
use burn_tensor::{ops::*, Data, ElementConversion, Shape, Tensor};
use burn_tensor::{ops::*, Data, Distribution, ElementConversion, Shape, Tensor};
use std::ops::Range;
use std::sync::Arc;
@ -19,6 +19,34 @@ impl<B: Backend, const D: usize> std::ops::Add<ADTensor<D, B>> for ADTensor<D, B
}
impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
fn from_data<const D: usize>(data: Data<B::Elem, D>, device: B::Device) -> ADTensor<D, B> {
let tensor = B::from_data(data, device);
ADTensor::from_tensor(tensor)
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: B::Device,
) -> B::BoolTensorPrimitive<D> {
B::from_data_bool(data, device)
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<B::Elem>,
device: B::Device,
) -> ADTensor<D, B> {
ADTensor::from_tensor(B::random(shape, distribution, device))
}
fn zeros<const D: usize>(shape: Shape<D>, device: B::Device) -> ADTensor<D, B> {
ADTensor::from_tensor(B::zeros(shape, device))
}
fn ones<const D: usize>(shape: Shape<D>, device: B::Device) -> ADTensor<D, B> {
ADTensor::from_tensor(B::ones(shape, device))
}
fn shape<const D: usize>(
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
) -> &Shape<D> {

View File

@ -1,13 +1,11 @@
use super::element::NdArrayElement;
use super::NdArrayTensor;
use burn_tensor::backend::Backend;
use burn_tensor::Data;
use burn_tensor::{Distribution, Shape};
use rand::rngs::StdRng;
use rand::SeedableRng;
use std::sync::Mutex;
static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
#[derive(Clone, Copy, Debug)]
pub enum NdArrayDevice {
@ -34,39 +32,10 @@ impl<E: NdArrayElement> Backend for NdArrayBackend<E> {
type TensorPrimitive<const D: usize> = NdArrayTensor<E, D>;
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
_device: Self::Device,
) -> NdArrayTensor<E, D> {
NdArrayTensor::from_data(data)
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
_device: Self::Device,
) -> Self::BoolTensorPrimitive<D> {
NdArrayTensor::from_data(data)
}
fn ad_enabled() -> bool {
false
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
let mut seed = SEED.lock().unwrap();
let mut rng: StdRng = match seed.as_ref() {
Some(rng) => rng.clone(),
None => StdRng::from_entropy(),
};
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
*seed = Some(rng);
tensor
}
fn name() -> String {
"ndarray".to_string()
}

View File

@ -10,10 +10,8 @@ extern crate blas_src;
mod backend;
mod element;
mod module_ops;
mod ops;
mod tensor;
mod tensor_ops;
pub use backend::*;
pub(crate) use tensor::*;

View File

@ -1 +1,3 @@
mod creation;
mod module;
mod tensor;

View File

@ -1,4 +1,4 @@
use super::{element::NdArrayElement, NdArrayBackend, NdArrayTensor};
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use burn_tensor::{ops::*, Shape};
use std::ops::Add;

View File

@ -1,8 +1,14 @@
use super::{element::NdArrayElement, BatchMatrix, NdArrayBackend, NdArrayTensor};
use crate::{to_nd_array_tensor, NdArrayDevice};
use std::cmp::Ordering;
use std::ops::Range;
use crate::tensor::BatchMatrix;
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
use crate::{to_nd_array_tensor, NdArrayDevice, SEED};
use burn_tensor::Distribution;
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
use ndarray::{Axis, Dim, IxDyn, SliceInfoElem};
use std::{cmp::Ordering, ops::Range};
use rand::rngs::StdRng;
use rand::SeedableRng;
macro_rules! keepdim {
(
@ -30,6 +36,32 @@ macro_rules! keepdim {
}
impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn from_data<const D: usize>(data: Data<E, D>, _device: NdArrayDevice) -> NdArrayTensor<E, D> {
NdArrayTensor::from_data(data)
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
_device: NdArrayDevice,
) -> NdArrayTensor<bool, D> {
NdArrayTensor::from_data(data)
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<E>,
device: NdArrayDevice,
) -> NdArrayTensor<E, D> {
let mut seed = SEED.lock().unwrap();
let mut rng: StdRng = match seed.as_ref() {
Some(rng) => rng.clone(),
None => StdRng::from_entropy(),
};
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
*seed = Some(rng);
tensor
}
fn shape<const D: usize>(
tensor: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
) -> &Shape<D> {

View File

@ -1,7 +1,6 @@
use super::element::TchElement;
use super::TchTensor;
use burn_tensor::backend::Backend;
use burn_tensor::{Data, Distribution, Shape};
#[derive(Clone, Copy, Debug)]
/// The device struct when using the `tch` backend.
@ -51,75 +50,10 @@ impl<E: TchElement> Backend for TchBackend<E> {
type TensorPrimitive<const D: usize> = TchTensor<E, D>;
type BoolTensorPrimitive<const D: usize> = TchTensor<bool, D>;
fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
device: Self::Device,
) -> TchTensor<E, D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor::from_data(data, device)
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: Self::Device,
) -> Self::BoolTensorPrimitive<D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor::from_data(data, device)
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D> {
match distribution {
Distribution::Standard => {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.normal_(0.0, 1.0);
tensor
}
Distribution::Bernoulli(prob) => {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.f_bernoulli_float_(prob).unwrap();
tensor
}
Distribution::Uniform(from, to) => {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor
.tensor
.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap());
tensor
}
Distribution::Normal(mean, std) => {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.normal(mean, std);
tensor
}
}
}
fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.zero_();
tensor
}
fn seed(seed: u64) {
tch::manual_seed(seed as i64);
}
fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
tensor.tensor = tensor.tensor.ones_like();
tensor
}
fn ad_enabled() -> bool {
false
}

View File

@ -1,13 +1,10 @@
mod backend;
mod element;
mod module_ops;
mod ops;
mod tensor;
mod tensor_ops;
pub use backend::*;
pub use tensor::*;
pub use tensor_ops::*;
#[cfg(test)]
mod tests {

View File

@ -1 +1,3 @@
mod creation;
mod module;
mod tensor;

View File

@ -1,4 +1,4 @@
use super::{element::TchElement, TchBackend, TchTensor};
use crate::{element::TchElement, TchBackend, TchTensor};
use burn_tensor::{ops::ModuleOps, Shape};
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {

View File

@ -1,8 +1,70 @@
use super::{element::TchElement, TchBackend, TchDevice, TchKind, TchShape, TchTensor};
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
use crate::{element::TchElement, TchBackend, TchDevice, TchKind, TchShape, TchTensor};
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Shape};
use std::ops::{Add, Div, Mul, Range, Sub};
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
fn from_data<const D: usize>(data: Data<E, D>, device: TchDevice) -> TchTensor<E, D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor::from_data(data, device)
}
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: TchDevice,
) -> TchTensor<bool, D> {
let device = match device {
TchDevice::Cpu => tch::Device::Cpu,
TchDevice::Cuda(num) => tch::Device::Cuda(num),
};
TchTensor::from_data(data, device)
}
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<E>,
device: TchDevice,
) -> TchTensor<E, D> {
match distribution {
Distribution::Standard => {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
tensor.tensor = tensor.tensor.normal_(0.0, 1.0);
tensor
}
Distribution::Bernoulli(prob) => {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
tensor.tensor = tensor.tensor.f_bernoulli_float_(prob).unwrap();
tensor
}
Distribution::Uniform(from, to) => {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
tensor.tensor = tensor
.tensor
.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap());
tensor
}
Distribution::Normal(mean, std) => {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
tensor.tensor = tensor.tensor.normal(mean, std);
tensor
}
}
}
fn zeros<const D: usize>(shape: Shape<D>, device: TchDevice) -> TchTensor<E, D> {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
tensor.tensor = tensor.tensor.zero_();
tensor
}
fn ones<const D: usize>(shape: Shape<D>, device: TchDevice) -> TchTensor<E, D> {
let mut tensor = TchTensor::<E, D>::empty(shape, device);
tensor.tensor = tensor.tensor.ones_like();
tensor
}
fn shape<const D: usize>(tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>) -> &Shape<D> {
&tensor.shape
}

View File

@ -1,8 +1,6 @@
use super::Gradients;
use crate::ops::*;
use crate::tensor::Element;
use crate::tensor::{Data, Distribution, Shape};
use super::Gradients;
pub trait Backend:
TensorOps<Self>
@ -38,33 +36,9 @@ pub trait Backend:
+ std::fmt::Debug
+ From<<Self::IntegerBackend as Backend>::BoolTensorPrimitive<D>>;
fn from_data<const D: usize>(
data: Data<Self::Elem, D>,
device: Self::Device,
) -> Self::TensorPrimitive<D>;
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: Self::Device,
) -> Self::BoolTensorPrimitive<D>;
fn ad_enabled() -> bool;
fn name() -> String;
fn seed(seed: u64);
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<Self::Elem>,
device: Self::Device,
) -> Self::TensorPrimitive<D>;
fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_data(Data::zeros(shape), device)
}
fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
Self::from_data(Data::ones(shape), device)
}
}
pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =

View File

@ -1,4 +1,4 @@
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion};
use std::ops::Range;
pub trait ModuleOps<B: Backend> {
@ -14,6 +14,25 @@ pub trait ModuleOps<B: Backend> {
}
pub trait TensorOps<B: Backend> {
fn from_data<const D: usize>(
data: Data<B::Elem, D>,
device: B::Device,
) -> B::TensorPrimitive<D>;
fn from_data_bool<const D: usize>(
data: Data<bool, D>,
device: B::Device,
) -> B::BoolTensorPrimitive<D>;
fn random<const D: usize>(
shape: Shape<D>,
distribution: Distribution<B::Elem>,
device: B::Device,
) -> B::TensorPrimitive<D>;
fn zeros<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D> {
Self::from_data(Data::zeros(shape), device)
}
fn ones<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D> {
Self::from_data(Data::ones(shape), device)
}
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> &Shape<D>;
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::Elem, D>;
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::Elem, D>;
@ -43,7 +62,7 @@ pub trait TensorOps<B: Backend> {
.map(|i| (i as i64).to_elem())
.collect::<Vec<<B::IntegerBackend as Backend>::Elem>>();
let data = Data::new(value, shape);
<B::IntegerBackend as Backend>::from_data(data, device)
<B::IntegerBackend as TensorOps<B::IntegerBackend>>::from_data(data, device)
}
fn empty<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D>;
fn repeat<const D: usize>(