mirror of https://github.com/tracel-ai/burn.git
Expose element traits (#700)
This commit is contained in:
parent
03ee987cda
commit
7f558bdc46
|
@ -3,13 +3,15 @@ use libm::{exp, fabs, log, log1p, pow, sqrt};
|
|||
use libm::{expf, fabsf, log1pf, logf, powf, sqrtf};
|
||||
use ndarray::LinalgScalar;
|
||||
|
||||
pub(crate) trait FloatNdArrayElement: NdArrayElement + LinalgScalar
|
||||
/// A float element for ndarray backend.
|
||||
pub trait FloatNdArrayElement: NdArrayElement + LinalgScalar
|
||||
where
|
||||
Self: Sized,
|
||||
{
|
||||
}
|
||||
|
||||
pub(crate) trait NdArrayElement:
|
||||
/// A general element for ndarray backend.
|
||||
pub trait NdArrayElement:
|
||||
Element
|
||||
+ ndarray::LinalgScalar
|
||||
+ ndarray::ScalarOperand
|
||||
|
@ -21,7 +23,8 @@ pub(crate) trait NdArrayElement:
|
|||
{
|
||||
}
|
||||
|
||||
pub(crate) trait ExpElement {
|
||||
/// A element for ndarray backend that supports exp ops.
|
||||
pub trait ExpElement {
|
||||
fn exp_elem(self) -> Self;
|
||||
fn log_elem(self) -> Self;
|
||||
fn log1p_elem(self) -> Self;
|
||||
|
|
|
@ -21,6 +21,7 @@ mod sharing;
|
|||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use element::FloatNdArrayElement;
|
||||
pub(crate) use sharing::*;
|
||||
pub(crate) use tensor::*;
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use burn_tensor::Element;
|
||||
use half::{bf16, f16};
|
||||
|
||||
/// The element type for the tch backend.
|
||||
pub trait TchElement: Element + tch::kind::Element {}
|
||||
|
||||
impl TchElement for f64 {}
|
||||
|
|
|
@ -9,6 +9,7 @@ mod ops;
|
|||
mod tensor;
|
||||
|
||||
pub use backend::*;
|
||||
pub use element::*;
|
||||
pub use tensor::*;
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use burn_tensor::Element;
|
||||
|
||||
/// The base element trait for the wgou backend.
|
||||
pub trait WgpuElement: core::fmt::Debug + Send + Sync + 'static + Clone
|
||||
where
|
||||
Self: Sized,
|
||||
|
@ -9,8 +10,10 @@ where
|
|||
fn from_bytes(bytes: &[u8]) -> &[Self];
|
||||
}
|
||||
|
||||
/// The float element type for the wgpu backend.
|
||||
pub trait FloatElement: WgpuElement + Element {}
|
||||
|
||||
/// The int element type for the wgpu backend.
|
||||
pub trait IntElement: WgpuElement + Element {}
|
||||
|
||||
impl WgpuElement for u32 {
|
||||
|
|
|
@ -13,11 +13,13 @@ pub mod benchmark;
|
|||
pub mod kernel;
|
||||
|
||||
pub(crate) mod context;
|
||||
pub(crate) mod element;
|
||||
pub(crate) mod pool;
|
||||
pub(crate) mod tensor;
|
||||
pub(crate) mod tune;
|
||||
|
||||
mod element;
|
||||
pub use element::{FloatElement, IntElement};
|
||||
|
||||
mod device;
|
||||
pub use device::*;
|
||||
|
||||
|
|
Loading…
Reference in New Issue