mirror of https://github.com/tracel-ai/burn.git
refactor: backends (#124)
This commit is contained in:
parent
eee90a5c9e
commit
b99b23e1a7
|
@ -1,7 +1,6 @@
|
|||
use crate::graph::grad::Grads;
|
||||
use crate::tensor::ADTensor;
|
||||
use burn_tensor::backend::{ADBackend, Backend};
|
||||
use burn_tensor::{Data, Distribution, Shape};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Default)]
|
||||
pub struct ADBackendDecorator<B> {
|
||||
|
@ -17,41 +16,10 @@ impl<B: Backend> Backend for ADBackendDecorator<B> {
|
|||
type TensorPrimitive<const D: usize> = ADTensor<D, B>;
|
||||
type BoolTensorPrimitive<const D: usize> = B::BoolTensorPrimitive<D>;
|
||||
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<Self::Elem, D>,
|
||||
device: Self::Device,
|
||||
) -> Self::TensorPrimitive<D> {
|
||||
let tensor = B::from_data(data, device);
|
||||
ADTensor::from_tensor(tensor)
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: Self::Device,
|
||||
) -> Self::BoolTensorPrimitive<D> {
|
||||
B::from_data_bool(data, device)
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<Self::Elem>,
|
||||
device: Self::Device,
|
||||
) -> Self::TensorPrimitive<D> {
|
||||
Self::from_inner(B::random(shape, distribution, device))
|
||||
}
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
|
||||
Self::from_inner(B::zeros(shape, device))
|
||||
}
|
||||
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
|
||||
Self::from_inner(B::ones(shape, device))
|
||||
}
|
||||
|
||||
fn name() -> String {
|
||||
format!("autodiff<{}>", B::name())
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::ops::unary_ops_wrapper_explicit;
|
|||
use crate::tensor::ADTensor;
|
||||
use crate::ADBackendDecorator;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{ops::*, Data, ElementConversion, Shape, Tensor};
|
||||
use burn_tensor::{ops::*, Data, Distribution, ElementConversion, Shape, Tensor};
|
||||
use std::ops::Range;
|
||||
use std::sync::Arc;
|
||||
|
||||
|
@ -19,6 +19,34 @@ impl<B: Backend, const D: usize> std::ops::Add<ADTensor<D, B>> for ADTensor<D, B
|
|||
}
|
||||
|
||||
impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
||||
fn from_data<const D: usize>(data: Data<B::Elem, D>, device: B::Device) -> ADTensor<D, B> {
|
||||
let tensor = B::from_data(data, device);
|
||||
ADTensor::from_tensor(tensor)
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: B::Device,
|
||||
) -> B::BoolTensorPrimitive<D> {
|
||||
B::from_data_bool(data, device)
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<B::Elem>,
|
||||
device: B::Device,
|
||||
) -> ADTensor<D, B> {
|
||||
ADTensor::from_tensor(B::random(shape, distribution, device))
|
||||
}
|
||||
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: B::Device) -> ADTensor<D, B> {
|
||||
ADTensor::from_tensor(B::zeros(shape, device))
|
||||
}
|
||||
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: B::Device) -> ADTensor<D, B> {
|
||||
ADTensor::from_tensor(B::ones(shape, device))
|
||||
}
|
||||
|
||||
fn shape<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
) -> &Shape<D> {
|
||||
|
|
|
@ -1,13 +1,11 @@
|
|||
use super::element::NdArrayElement;
|
||||
use super::NdArrayTensor;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::Data;
|
||||
use burn_tensor::{Distribution, Shape};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
use std::sync::Mutex;
|
||||
|
||||
static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||
pub(crate) static SEED: Mutex<Option<StdRng>> = Mutex::new(None);
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub enum NdArrayDevice {
|
||||
|
@ -34,39 +32,10 @@ impl<E: NdArrayElement> Backend for NdArrayBackend<E> {
|
|||
type TensorPrimitive<const D: usize> = NdArrayTensor<E, D>;
|
||||
type BoolTensorPrimitive<const D: usize> = NdArrayTensor<bool, D>;
|
||||
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<Self::Elem, D>,
|
||||
_device: Self::Device,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
_device: Self::Device,
|
||||
) -> Self::BoolTensorPrimitive<D> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<Self::Elem>,
|
||||
device: Self::Device,
|
||||
) -> Self::TensorPrimitive<D> {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng: StdRng = match seed.as_ref() {
|
||||
Some(rng) => rng.clone(),
|
||||
None => StdRng::from_entropy(),
|
||||
};
|
||||
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
|
||||
*seed = Some(rng);
|
||||
tensor
|
||||
}
|
||||
|
||||
fn name() -> String {
|
||||
"ndarray".to_string()
|
||||
}
|
||||
|
|
|
@ -10,10 +10,8 @@ extern crate blas_src;
|
|||
|
||||
mod backend;
|
||||
mod element;
|
||||
mod module_ops;
|
||||
mod ops;
|
||||
mod tensor;
|
||||
mod tensor_ops;
|
||||
|
||||
pub use backend::*;
|
||||
pub(crate) use tensor::*;
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
mod creation;
|
||||
mod module;
|
||||
mod tensor;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{element::NdArrayElement, NdArrayBackend, NdArrayTensor};
|
||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||
use burn_tensor::{ops::*, Shape};
|
||||
use std::ops::Add;
|
||||
|
|
@ -1,8 +1,14 @@
|
|||
use super::{element::NdArrayElement, BatchMatrix, NdArrayBackend, NdArrayTensor};
|
||||
use crate::{to_nd_array_tensor, NdArrayDevice};
|
||||
use std::cmp::Ordering;
|
||||
use std::ops::Range;
|
||||
|
||||
use crate::tensor::BatchMatrix;
|
||||
use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend};
|
||||
use crate::{to_nd_array_tensor, NdArrayDevice, SEED};
|
||||
use burn_tensor::Distribution;
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
|
||||
use ndarray::{Axis, Dim, IxDyn, SliceInfoElem};
|
||||
use std::{cmp::Ordering, ops::Range};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
|
||||
macro_rules! keepdim {
|
||||
(
|
||||
|
@ -30,6 +36,32 @@ macro_rules! keepdim {
|
|||
}
|
||||
|
||||
impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||
fn from_data<const D: usize>(data: Data<E, D>, _device: NdArrayDevice) -> NdArrayTensor<E, D> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
_device: NdArrayDevice,
|
||||
) -> NdArrayTensor<bool, D> {
|
||||
NdArrayTensor::from_data(data)
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<E>,
|
||||
device: NdArrayDevice,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
let mut seed = SEED.lock().unwrap();
|
||||
let mut rng: StdRng = match seed.as_ref() {
|
||||
Some(rng) => rng.clone(),
|
||||
None => StdRng::from_entropy(),
|
||||
};
|
||||
let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device);
|
||||
*seed = Some(rng);
|
||||
tensor
|
||||
}
|
||||
|
||||
fn shape<const D: usize>(
|
||||
tensor: &<NdArrayBackend<E> as Backend>::TensorPrimitive<D>,
|
||||
) -> &Shape<D> {
|
|
@ -1,7 +1,6 @@
|
|||
use super::element::TchElement;
|
||||
use super::TchTensor;
|
||||
use burn_tensor::backend::Backend;
|
||||
use burn_tensor::{Data, Distribution, Shape};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
/// The device struct when using the `tch` backend.
|
||||
|
@ -51,75 +50,10 @@ impl<E: TchElement> Backend for TchBackend<E> {
|
|||
type TensorPrimitive<const D: usize> = TchTensor<E, D>;
|
||||
type BoolTensorPrimitive<const D: usize> = TchTensor<bool, D>;
|
||||
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<Self::Elem, D>,
|
||||
device: Self::Device,
|
||||
) -> TchTensor<E, D> {
|
||||
let device = match device {
|
||||
TchDevice::Cpu => tch::Device::Cpu,
|
||||
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
||||
};
|
||||
TchTensor::from_data(data, device)
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: Self::Device,
|
||||
) -> Self::BoolTensorPrimitive<D> {
|
||||
let device = match device {
|
||||
TchDevice::Cpu => tch::Device::Cpu,
|
||||
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
||||
};
|
||||
TchTensor::from_data(data, device)
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<Self::Elem>,
|
||||
device: Self::Device,
|
||||
) -> Self::TensorPrimitive<D> {
|
||||
match distribution {
|
||||
Distribution::Standard => {
|
||||
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.normal_(0.0, 1.0);
|
||||
tensor
|
||||
}
|
||||
Distribution::Bernoulli(prob) => {
|
||||
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.f_bernoulli_float_(prob).unwrap();
|
||||
tensor
|
||||
}
|
||||
Distribution::Uniform(from, to) => {
|
||||
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
|
||||
tensor.tensor = tensor
|
||||
.tensor
|
||||
.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap());
|
||||
tensor
|
||||
}
|
||||
Distribution::Normal(mean, std) => {
|
||||
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.normal(mean, std);
|
||||
tensor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
|
||||
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.zero_();
|
||||
tensor
|
||||
}
|
||||
|
||||
fn seed(seed: u64) {
|
||||
tch::manual_seed(seed as i64);
|
||||
}
|
||||
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
|
||||
let mut tensor = TchTensor::<Self::Elem, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.ones_like();
|
||||
tensor
|
||||
}
|
||||
|
||||
fn ad_enabled() -> bool {
|
||||
false
|
||||
}
|
||||
|
|
|
@ -1,13 +1,10 @@
|
|||
mod backend;
|
||||
mod element;
|
||||
mod module_ops;
|
||||
mod ops;
|
||||
mod tensor;
|
||||
mod tensor_ops;
|
||||
|
||||
pub use backend::*;
|
||||
pub use tensor::*;
|
||||
pub use tensor_ops::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
|
|
@ -1 +1,3 @@
|
|||
mod creation;
|
||||
mod module;
|
||||
mod tensor;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{element::TchElement, TchBackend, TchTensor};
|
||||
use crate::{element::TchElement, TchBackend, TchTensor};
|
||||
use burn_tensor::{ops::ModuleOps, Shape};
|
||||
|
||||
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
|
@ -1,8 +1,70 @@
|
|||
use super::{element::TchElement, TchBackend, TchDevice, TchKind, TchShape, TchTensor};
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
|
||||
use crate::{element::TchElement, TchBackend, TchDevice, TchKind, TchShape, TchTensor};
|
||||
use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Shape};
|
||||
use std::ops::{Add, Div, Mul, Range, Sub};
|
||||
|
||||
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||
fn from_data<const D: usize>(data: Data<E, D>, device: TchDevice) -> TchTensor<E, D> {
|
||||
let device = match device {
|
||||
TchDevice::Cpu => tch::Device::Cpu,
|
||||
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
||||
};
|
||||
TchTensor::from_data(data, device)
|
||||
}
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: TchDevice,
|
||||
) -> TchTensor<bool, D> {
|
||||
let device = match device {
|
||||
TchDevice::Cpu => tch::Device::Cpu,
|
||||
TchDevice::Cuda(num) => tch::Device::Cuda(num),
|
||||
};
|
||||
TchTensor::from_data(data, device)
|
||||
}
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<E>,
|
||||
device: TchDevice,
|
||||
) -> TchTensor<E, D> {
|
||||
match distribution {
|
||||
Distribution::Standard => {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.normal_(0.0, 1.0);
|
||||
tensor
|
||||
}
|
||||
Distribution::Bernoulli(prob) => {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.f_bernoulli_float_(prob).unwrap();
|
||||
tensor
|
||||
}
|
||||
Distribution::Uniform(from, to) => {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
tensor.tensor = tensor
|
||||
.tensor
|
||||
.uniform_(from.to_f64().unwrap(), to.to_f64().unwrap());
|
||||
tensor
|
||||
}
|
||||
Distribution::Normal(mean, std) => {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.normal(mean, std);
|
||||
tensor
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: TchDevice) -> TchTensor<E, D> {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.zero_();
|
||||
tensor
|
||||
}
|
||||
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: TchDevice) -> TchTensor<E, D> {
|
||||
let mut tensor = TchTensor::<E, D>::empty(shape, device);
|
||||
tensor.tensor = tensor.tensor.ones_like();
|
||||
tensor
|
||||
}
|
||||
|
||||
fn shape<const D: usize>(tensor: &<TchBackend<E> as Backend>::TensorPrimitive<D>) -> &Shape<D> {
|
||||
&tensor.shape
|
||||
}
|
|
@ -1,8 +1,6 @@
|
|||
use super::Gradients;
|
||||
use crate::ops::*;
|
||||
use crate::tensor::Element;
|
||||
use crate::tensor::{Data, Distribution, Shape};
|
||||
|
||||
use super::Gradients;
|
||||
|
||||
pub trait Backend:
|
||||
TensorOps<Self>
|
||||
|
@ -38,33 +36,9 @@ pub trait Backend:
|
|||
+ std::fmt::Debug
|
||||
+ From<<Self::IntegerBackend as Backend>::BoolTensorPrimitive<D>>;
|
||||
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<Self::Elem, D>,
|
||||
device: Self::Device,
|
||||
) -> Self::TensorPrimitive<D>;
|
||||
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: Self::Device,
|
||||
) -> Self::BoolTensorPrimitive<D>;
|
||||
|
||||
fn ad_enabled() -> bool;
|
||||
fn name() -> String;
|
||||
fn seed(seed: u64);
|
||||
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<Self::Elem>,
|
||||
device: Self::Device,
|
||||
) -> Self::TensorPrimitive<D>;
|
||||
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
|
||||
Self::from_data(Data::zeros(shape), device)
|
||||
}
|
||||
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: Self::Device) -> Self::TensorPrimitive<D> {
|
||||
Self::from_data(Data::ones(shape), device)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) type ADBackendTensorPrimitive<const D: usize, B> =
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::{backend::Backend, tensor::Shape, Data, ElementConversion};
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion};
|
||||
use std::ops::Range;
|
||||
|
||||
pub trait ModuleOps<B: Backend> {
|
||||
|
@ -14,6 +14,25 @@ pub trait ModuleOps<B: Backend> {
|
|||
}
|
||||
|
||||
pub trait TensorOps<B: Backend> {
|
||||
fn from_data<const D: usize>(
|
||||
data: Data<B::Elem, D>,
|
||||
device: B::Device,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn from_data_bool<const D: usize>(
|
||||
data: Data<bool, D>,
|
||||
device: B::Device,
|
||||
) -> B::BoolTensorPrimitive<D>;
|
||||
fn random<const D: usize>(
|
||||
shape: Shape<D>,
|
||||
distribution: Distribution<B::Elem>,
|
||||
device: B::Device,
|
||||
) -> B::TensorPrimitive<D>;
|
||||
fn zeros<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D> {
|
||||
Self::from_data(Data::zeros(shape), device)
|
||||
}
|
||||
fn ones<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D> {
|
||||
Self::from_data(Data::ones(shape), device)
|
||||
}
|
||||
fn shape<const D: usize>(tensor: &B::TensorPrimitive<D>) -> &Shape<D>;
|
||||
fn to_data<const D: usize>(tensor: &B::TensorPrimitive<D>) -> Data<B::Elem, D>;
|
||||
fn into_data<const D: usize>(tensor: B::TensorPrimitive<D>) -> Data<B::Elem, D>;
|
||||
|
@ -43,7 +62,7 @@ pub trait TensorOps<B: Backend> {
|
|||
.map(|i| (i as i64).to_elem())
|
||||
.collect::<Vec<<B::IntegerBackend as Backend>::Elem>>();
|
||||
let data = Data::new(value, shape);
|
||||
<B::IntegerBackend as Backend>::from_data(data, device)
|
||||
<B::IntegerBackend as TensorOps<B::IntegerBackend>>::from_data(data, device)
|
||||
}
|
||||
fn empty<const D: usize>(shape: Shape<D>, device: B::Device) -> B::TensorPrimitive<D>;
|
||||
fn repeat<const D: usize>(
|
||||
|
|
Loading…
Reference in New Issue