diff --git a/burn-ndarray/src/element.rs b/burn-ndarray/src/element.rs index d08a7f170..be08c7622 100644 --- a/burn-ndarray/src/element.rs +++ b/burn-ndarray/src/element.rs @@ -3,13 +3,15 @@ use libm::{exp, fabs, log, log1p, pow, sqrt}; use libm::{expf, fabsf, log1pf, logf, powf, sqrtf}; use ndarray::LinalgScalar; -pub(crate) trait FloatNdArrayElement: NdArrayElement + LinalgScalar +/// A float element for ndarray backend. +pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar where Self: Sized, { } -pub(crate) trait NdArrayElement: +/// A general element for ndarray backend. +pub trait NdArrayElement: Element + ndarray::LinalgScalar + ndarray::ScalarOperand @@ -21,7 +23,8 @@ pub(crate) trait NdArrayElement: { } -pub(crate) trait ExpElement { +/// A element for ndarray backend that supports exp ops. +pub trait ExpElement { fn exp_elem(self) -> Self; fn log_elem(self) -> Self; fn log1p_elem(self) -> Self; diff --git a/burn-ndarray/src/lib.rs b/burn-ndarray/src/lib.rs index 20ffad339..3492f93e3 100644 --- a/burn-ndarray/src/lib.rs +++ b/burn-ndarray/src/lib.rs @@ -21,6 +21,7 @@ mod sharing; mod tensor; pub use backend::*; +pub use element::FloatNdArrayElement; pub(crate) use sharing::*; pub(crate) use tensor::*; diff --git a/burn-tch/src/element.rs b/burn-tch/src/element.rs index efc13ab82..8317321c0 100644 --- a/burn-tch/src/element.rs +++ b/burn-tch/src/element.rs @@ -1,6 +1,7 @@ use burn_tensor::Element; use half::{bf16, f16}; +/// The element type for the tch backend. pub trait TchElement: Element + tch::kind::Element {} impl TchElement for f64 {} diff --git a/burn-tch/src/lib.rs b/burn-tch/src/lib.rs index 6d01154a7..008f12e0f 100644 --- a/burn-tch/src/lib.rs +++ b/burn-tch/src/lib.rs @@ -9,6 +9,7 @@ mod ops; mod tensor; pub use backend::*; +pub use element::*; pub use tensor::*; #[cfg(test)] diff --git a/burn-wgpu/src/element.rs b/burn-wgpu/src/element.rs index 816f399af..a1bf0446e 100644 --- a/burn-wgpu/src/element.rs +++ b/burn-wgpu/src/element.rs @@ -1,5 +1,6 @@ use burn_tensor::Element; +/// The base element trait for the wgou backend. pub trait WgpuElement: core::fmt::Debug + Send + Sync + 'static + Clone where Self: Sized, @@ -9,8 +10,10 @@ where fn from_bytes(bytes: &[u8]) -> &[Self]; } +/// The float element type for the wgpu backend. pub trait FloatElement: WgpuElement + Element {} +/// The int element type for the wgpu backend. pub trait IntElement: WgpuElement + Element {} impl WgpuElement for u32 { diff --git a/burn-wgpu/src/lib.rs b/burn-wgpu/src/lib.rs index 4f0475b2e..a32941b01 100644 --- a/burn-wgpu/src/lib.rs +++ b/burn-wgpu/src/lib.rs @@ -13,11 +13,13 @@ pub mod benchmark; pub mod kernel; pub(crate) mod context; -pub(crate) mod element; pub(crate) mod pool; pub(crate) mod tensor; pub(crate) mod tune; +mod element; +pub use element::{FloatElement, IntElement}; + mod device; pub use device::*;