diff --git a/burn-autodiff/src/backend.rs b/burn-autodiff/src/backend.rs index fb2f4e450..567bd7396 100644 --- a/burn-autodiff/src/backend.rs +++ b/burn-autodiff/src/backend.rs @@ -1,7 +1,6 @@ use crate::graph::grad::Grads; use crate::tensor::ADTensor; use burn_tensor::backend::{ADBackend, Backend}; -use burn_tensor::{Data, Distribution, Shape}; #[derive(Clone, Copy, Debug, Default)] pub struct ADBackendDecorator { @@ -17,41 +16,10 @@ impl Backend for ADBackendDecorator { type TensorPrimitive = ADTensor; type BoolTensorPrimitive = B::BoolTensorPrimitive; - fn from_data( - data: Data, - device: Self::Device, - ) -> Self::TensorPrimitive { - let tensor = B::from_data(data, device); - ADTensor::from_tensor(tensor) - } - - fn from_data_bool( - data: Data, - device: Self::Device, - ) -> Self::BoolTensorPrimitive { - B::from_data_bool(data, device) - } - - fn random( - shape: Shape, - distribution: Distribution, - device: Self::Device, - ) -> Self::TensorPrimitive { - Self::from_inner(B::random(shape, distribution, device)) - } - fn ad_enabled() -> bool { true } - fn zeros(shape: Shape, device: Self::Device) -> Self::TensorPrimitive { - Self::from_inner(B::zeros(shape, device)) - } - - fn ones(shape: Shape, device: Self::Device) -> Self::TensorPrimitive { - Self::from_inner(B::ones(shape, device)) - } - fn name() -> String { format!("autodiff<{}>", B::name()) } diff --git a/burn-autodiff/src/ops/tensor.rs b/burn-autodiff/src/ops/tensor.rs index 5bf5c8321..466a31642 100644 --- a/burn-autodiff/src/ops/tensor.rs +++ b/burn-autodiff/src/ops/tensor.rs @@ -6,7 +6,7 @@ use crate::ops::unary_ops_wrapper_explicit; use crate::tensor::ADTensor; use crate::ADBackendDecorator; use burn_tensor::backend::Backend; -use burn_tensor::{ops::*, Data, ElementConversion, Shape, Tensor}; +use burn_tensor::{ops::*, Data, Distribution, ElementConversion, Shape, Tensor}; use std::ops::Range; use std::sync::Arc; @@ -19,6 +19,34 @@ impl std::ops::Add> for ADTensor TensorOps> for ADBackendDecorator { + fn from_data(data: Data, device: B::Device) -> ADTensor { + let tensor = B::from_data(data, device); + ADTensor::from_tensor(tensor) + } + + fn from_data_bool( + data: Data, + device: B::Device, + ) -> B::BoolTensorPrimitive { + B::from_data_bool(data, device) + } + + fn random( + shape: Shape, + distribution: Distribution, + device: B::Device, + ) -> ADTensor { + ADTensor::from_tensor(B::random(shape, distribution, device)) + } + + fn zeros(shape: Shape, device: B::Device) -> ADTensor { + ADTensor::from_tensor(B::zeros(shape, device)) + } + + fn ones(shape: Shape, device: B::Device) -> ADTensor { + ADTensor::from_tensor(B::ones(shape, device)) + } + fn shape( tensor: & as Backend>::TensorPrimitive, ) -> &Shape { diff --git a/burn-ndarray/src/backend.rs b/burn-ndarray/src/backend.rs index e1b80af7e..c7b0bc1dd 100644 --- a/burn-ndarray/src/backend.rs +++ b/burn-ndarray/src/backend.rs @@ -1,13 +1,11 @@ use super::element::NdArrayElement; use super::NdArrayTensor; use burn_tensor::backend::Backend; -use burn_tensor::Data; -use burn_tensor::{Distribution, Shape}; use rand::rngs::StdRng; use rand::SeedableRng; use std::sync::Mutex; -static SEED: Mutex> = Mutex::new(None); +pub(crate) static SEED: Mutex> = Mutex::new(None); #[derive(Clone, Copy, Debug)] pub enum NdArrayDevice { @@ -34,39 +32,10 @@ impl Backend for NdArrayBackend { type TensorPrimitive = NdArrayTensor; type BoolTensorPrimitive = NdArrayTensor; - fn from_data( - data: Data, - _device: Self::Device, - ) -> NdArrayTensor { - NdArrayTensor::from_data(data) - } - - fn from_data_bool( - data: Data, - _device: Self::Device, - ) -> Self::BoolTensorPrimitive { - NdArrayTensor::from_data(data) - } - fn ad_enabled() -> bool { false } - fn random( - shape: Shape, - distribution: Distribution, - device: Self::Device, - ) -> Self::TensorPrimitive { - let mut seed = SEED.lock().unwrap(); - let mut rng: StdRng = match seed.as_ref() { - Some(rng) => rng.clone(), - None => StdRng::from_entropy(), - }; - let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device); - *seed = Some(rng); - tensor - } - fn name() -> String { "ndarray".to_string() } diff --git a/burn-ndarray/src/lib.rs b/burn-ndarray/src/lib.rs index ce51ef3f7..7f81d75a5 100644 --- a/burn-ndarray/src/lib.rs +++ b/burn-ndarray/src/lib.rs @@ -10,10 +10,8 @@ extern crate blas_src; mod backend; mod element; -mod module_ops; mod ops; mod tensor; -mod tensor_ops; pub use backend::*; pub(crate) use tensor::*; diff --git a/burn-ndarray/src/ops/mod.rs b/burn-ndarray/src/ops/mod.rs index 7a7356db9..66111d0fd 100644 --- a/burn-ndarray/src/ops/mod.rs +++ b/burn-ndarray/src/ops/mod.rs @@ -1 +1,3 @@ mod creation; +mod module; +mod tensor; diff --git a/burn-ndarray/src/module_ops.rs b/burn-ndarray/src/ops/module.rs similarity index 96% rename from burn-ndarray/src/module_ops.rs rename to burn-ndarray/src/ops/module.rs index ae3ea862a..5af05c63e 100644 --- a/burn-ndarray/src/module_ops.rs +++ b/burn-ndarray/src/ops/module.rs @@ -1,4 +1,4 @@ -use super::{element::NdArrayElement, NdArrayBackend, NdArrayTensor}; +use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend}; use burn_tensor::{ops::*, Shape}; use std::ops::Add; diff --git a/burn-ndarray/src/tensor_ops.rs b/burn-ndarray/src/ops/tensor.rs similarity index 94% rename from burn-ndarray/src/tensor_ops.rs rename to burn-ndarray/src/ops/tensor.rs index 664e8f58c..592b94ab2 100644 --- a/burn-ndarray/src/tensor_ops.rs +++ b/burn-ndarray/src/ops/tensor.rs @@ -1,8 +1,14 @@ -use super::{element::NdArrayElement, BatchMatrix, NdArrayBackend, NdArrayTensor}; -use crate::{to_nd_array_tensor, NdArrayDevice}; +use std::cmp::Ordering; +use std::ops::Range; + +use crate::tensor::BatchMatrix; +use crate::{element::NdArrayElement, tensor::NdArrayTensor, NdArrayBackend}; +use crate::{to_nd_array_tensor, NdArrayDevice, SEED}; +use burn_tensor::Distribution; use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape}; use ndarray::{Axis, Dim, IxDyn, SliceInfoElem}; -use std::{cmp::Ordering, ops::Range}; +use rand::rngs::StdRng; +use rand::SeedableRng; macro_rules! keepdim { ( @@ -30,6 +36,32 @@ macro_rules! keepdim { } impl TensorOps> for NdArrayBackend { + fn from_data(data: Data, _device: NdArrayDevice) -> NdArrayTensor { + NdArrayTensor::from_data(data) + } + + fn from_data_bool( + data: Data, + _device: NdArrayDevice, + ) -> NdArrayTensor { + NdArrayTensor::from_data(data) + } + + fn random( + shape: Shape, + distribution: Distribution, + device: NdArrayDevice, + ) -> NdArrayTensor { + let mut seed = SEED.lock().unwrap(); + let mut rng: StdRng = match seed.as_ref() { + Some(rng) => rng.clone(), + None => StdRng::from_entropy(), + }; + let tensor = Self::from_data(Data::random(shape, distribution, &mut rng), device); + *seed = Some(rng); + tensor + } + fn shape( tensor: & as Backend>::TensorPrimitive, ) -> &Shape { diff --git a/burn-tch/src/backend.rs b/burn-tch/src/backend.rs index 7db5ce17f..382e607f7 100644 --- a/burn-tch/src/backend.rs +++ b/burn-tch/src/backend.rs @@ -1,7 +1,6 @@ use super::element::TchElement; use super::TchTensor; use burn_tensor::backend::Backend; -use burn_tensor::{Data, Distribution, Shape}; #[derive(Clone, Copy, Debug)] /// The device struct when using the `tch` backend. @@ -51,75 +50,10 @@ impl Backend for TchBackend { type TensorPrimitive = TchTensor; type BoolTensorPrimitive = TchTensor; - fn from_data( - data: Data, - device: Self::Device, - ) -> TchTensor { - let device = match device { - TchDevice::Cpu => tch::Device::Cpu, - TchDevice::Cuda(num) => tch::Device::Cuda(num), - }; - TchTensor::from_data(data, device) - } - - fn from_data_bool( - data: Data, - device: Self::Device, - ) -> Self::BoolTensorPrimitive { - let device = match device { - TchDevice::Cpu => tch::Device::Cpu, - TchDevice::Cuda(num) => tch::Device::Cuda(num), - }; - TchTensor::from_data(data, device) - } - - fn random( - shape: Shape, - distribution: Distribution, - device: Self::Device, - ) -> Self::TensorPrimitive { - match distribution { - Distribution::Standard => { - let mut tensor = TchTensor::::empty(shape, device); - tensor.tensor = tensor.tensor.normal_(0.0, 1.0); - tensor - } - Distribution::Bernoulli(prob) => { - let mut tensor = TchTensor::::empty(shape, device); - tensor.tensor = tensor.tensor.f_bernoulli_float_(prob).unwrap(); - tensor - } - Distribution::Uniform(from, to) => { - let mut tensor = TchTensor::::empty(shape, device); - tensor.tensor = tensor - .tensor - .uniform_(from.to_f64().unwrap(), to.to_f64().unwrap()); - tensor - } - Distribution::Normal(mean, std) => { - let mut tensor = TchTensor::::empty(shape, device); - tensor.tensor = tensor.tensor.normal(mean, std); - tensor - } - } - } - - fn zeros(shape: Shape, device: Self::Device) -> Self::TensorPrimitive { - let mut tensor = TchTensor::::empty(shape, device); - tensor.tensor = tensor.tensor.zero_(); - tensor - } - fn seed(seed: u64) { tch::manual_seed(seed as i64); } - fn ones(shape: Shape, device: Self::Device) -> Self::TensorPrimitive { - let mut tensor = TchTensor::::empty(shape, device); - tensor.tensor = tensor.tensor.ones_like(); - tensor - } - fn ad_enabled() -> bool { false } diff --git a/burn-tch/src/lib.rs b/burn-tch/src/lib.rs index 0c0ee1c01..4e8e4452a 100644 --- a/burn-tch/src/lib.rs +++ b/burn-tch/src/lib.rs @@ -1,13 +1,10 @@ mod backend; mod element; -mod module_ops; mod ops; mod tensor; -mod tensor_ops; pub use backend::*; pub use tensor::*; -pub use tensor_ops::*; #[cfg(test)] mod tests { diff --git a/burn-tch/src/ops/mod.rs b/burn-tch/src/ops/mod.rs index 7a7356db9..66111d0fd 100644 --- a/burn-tch/src/ops/mod.rs +++ b/burn-tch/src/ops/mod.rs @@ -1 +1,3 @@ mod creation; +mod module; +mod tensor; diff --git a/burn-tch/src/module_ops.rs b/burn-tch/src/ops/module.rs similarity index 94% rename from burn-tch/src/module_ops.rs rename to burn-tch/src/ops/module.rs index feac825be..07139e3e3 100644 --- a/burn-tch/src/module_ops.rs +++ b/burn-tch/src/ops/module.rs @@ -1,4 +1,4 @@ -use super::{element::TchElement, TchBackend, TchTensor}; +use crate::{element::TchElement, TchBackend, TchTensor}; use burn_tensor::{ops::ModuleOps, Shape}; impl ModuleOps> for TchBackend { diff --git a/burn-tch/src/tensor_ops.rs b/burn-tch/src/ops/tensor.rs similarity index 84% rename from burn-tch/src/tensor_ops.rs rename to burn-tch/src/ops/tensor.rs index 653b0f55d..2c6336efd 100644 --- a/burn-tch/src/tensor_ops.rs +++ b/burn-tch/src/ops/tensor.rs @@ -1,8 +1,70 @@ -use super::{element::TchElement, TchBackend, TchDevice, TchKind, TchShape, TchTensor}; -use burn_tensor::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape}; +use crate::{element::TchElement, TchBackend, TchDevice, TchKind, TchShape, TchTensor}; +use burn_tensor::{backend::Backend, ops::TensorOps, Data, Distribution, ElementConversion, Shape}; use std::ops::{Add, Div, Mul, Range, Sub}; impl TensorOps> for TchBackend { + fn from_data(data: Data, device: TchDevice) -> TchTensor { + let device = match device { + TchDevice::Cpu => tch::Device::Cpu, + TchDevice::Cuda(num) => tch::Device::Cuda(num), + }; + TchTensor::from_data(data, device) + } + + fn from_data_bool( + data: Data, + device: TchDevice, + ) -> TchTensor { + let device = match device { + TchDevice::Cpu => tch::Device::Cpu, + TchDevice::Cuda(num) => tch::Device::Cuda(num), + }; + TchTensor::from_data(data, device) + } + + fn random( + shape: Shape, + distribution: Distribution, + device: TchDevice, + ) -> TchTensor { + match distribution { + Distribution::Standard => { + let mut tensor = TchTensor::::empty(shape, device); + tensor.tensor = tensor.tensor.normal_(0.0, 1.0); + tensor + } + Distribution::Bernoulli(prob) => { + let mut tensor = TchTensor::::empty(shape, device); + tensor.tensor = tensor.tensor.f_bernoulli_float_(prob).unwrap(); + tensor + } + Distribution::Uniform(from, to) => { + let mut tensor = TchTensor::::empty(shape, device); + tensor.tensor = tensor + .tensor + .uniform_(from.to_f64().unwrap(), to.to_f64().unwrap()); + tensor + } + Distribution::Normal(mean, std) => { + let mut tensor = TchTensor::::empty(shape, device); + tensor.tensor = tensor.tensor.normal(mean, std); + tensor + } + } + } + + fn zeros(shape: Shape, device: TchDevice) -> TchTensor { + let mut tensor = TchTensor::::empty(shape, device); + tensor.tensor = tensor.tensor.zero_(); + tensor + } + + fn ones(shape: Shape, device: TchDevice) -> TchTensor { + let mut tensor = TchTensor::::empty(shape, device); + tensor.tensor = tensor.tensor.ones_like(); + tensor + } + fn shape(tensor: & as Backend>::TensorPrimitive) -> &Shape { &tensor.shape } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index abc1700c5..c76038a70 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -1,8 +1,6 @@ +use super::Gradients; use crate::ops::*; use crate::tensor::Element; -use crate::tensor::{Data, Distribution, Shape}; - -use super::Gradients; pub trait Backend: TensorOps @@ -38,33 +36,9 @@ pub trait Backend: + std::fmt::Debug + From<::BoolTensorPrimitive>; - fn from_data( - data: Data, - device: Self::Device, - ) -> Self::TensorPrimitive; - - fn from_data_bool( - data: Data, - device: Self::Device, - ) -> Self::BoolTensorPrimitive; - fn ad_enabled() -> bool; fn name() -> String; fn seed(seed: u64); - - fn random( - shape: Shape, - distribution: Distribution, - device: Self::Device, - ) -> Self::TensorPrimitive; - - fn zeros(shape: Shape, device: Self::Device) -> Self::TensorPrimitive { - Self::from_data(Data::zeros(shape), device) - } - - fn ones(shape: Shape, device: Self::Device) -> Self::TensorPrimitive { - Self::from_data(Data::ones(shape), device) - } } pub(crate) type ADBackendTensorPrimitive = diff --git a/burn-tensor/src/tensor/ops/base.rs b/burn-tensor/src/tensor/ops/base.rs index 8179d7f6e..5f865c133 100644 --- a/burn-tensor/src/tensor/ops/base.rs +++ b/burn-tensor/src/tensor/ops/base.rs @@ -1,4 +1,4 @@ -use crate::{backend::Backend, tensor::Shape, Data, ElementConversion}; +use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion}; use std::ops::Range; pub trait ModuleOps { @@ -14,6 +14,25 @@ pub trait ModuleOps { } pub trait TensorOps { + fn from_data( + data: Data, + device: B::Device, + ) -> B::TensorPrimitive; + fn from_data_bool( + data: Data, + device: B::Device, + ) -> B::BoolTensorPrimitive; + fn random( + shape: Shape, + distribution: Distribution, + device: B::Device, + ) -> B::TensorPrimitive; + fn zeros(shape: Shape, device: B::Device) -> B::TensorPrimitive { + Self::from_data(Data::zeros(shape), device) + } + fn ones(shape: Shape, device: B::Device) -> B::TensorPrimitive { + Self::from_data(Data::ones(shape), device) + } fn shape(tensor: &B::TensorPrimitive) -> &Shape; fn to_data(tensor: &B::TensorPrimitive) -> Data; fn into_data(tensor: B::TensorPrimitive) -> Data; @@ -43,7 +62,7 @@ pub trait TensorOps { .map(|i| (i as i64).to_elem()) .collect::::Elem>>(); let data = Data::new(value, shape); - ::from_data(data, device) + >::from_data(data, device) } fn empty(shape: Shape, device: B::Device) -> B::TensorPrimitive; fn repeat(