mirror of https://github.com/tracel-ai/burn.git
Refactor/zeros ones elems (#102)
This commit is contained in:
parent
1a45368878
commit
23677b8e89
|
@ -15,7 +15,7 @@ impl Forward2BackwardGraphConverter {
|
|||
state: HashMap::new(),
|
||||
}
|
||||
}
|
||||
pub fn from<T: Clone + 'static + Zeros<T>>(
|
||||
pub fn from<T: Clone + 'static + Zeros>(
|
||||
&mut self,
|
||||
node: &ForwardNodeRef<T>,
|
||||
) -> BackwardNodeRef<T> {
|
||||
|
|
|
@ -21,7 +21,7 @@ impl Gradients {
|
|||
|
||||
pub fn register<T>(&mut self, node: &BackwardNode<T>)
|
||||
where
|
||||
T: Zeros<T> + Clone + Add<Output = T>,
|
||||
T: Zeros + Clone + Add<Output = T>,
|
||||
T: std::fmt::Debug + 'static + Send + Sync,
|
||||
{
|
||||
let grad = node.state.grad();
|
||||
|
@ -37,7 +37,7 @@ impl Gradients {
|
|||
|
||||
pub fn from<T>(node: &BackwardNode<T>) -> Self
|
||||
where
|
||||
T: Zeros<T> + Clone + Add<Output = T>,
|
||||
T: Zeros + Clone + Add<Output = T>,
|
||||
T: std::fmt::Debug + 'static + Send + Sync,
|
||||
{
|
||||
let mut grads = Self::empty();
|
||||
|
|
|
@ -19,7 +19,7 @@ pub struct BackwardNode<Out> {
|
|||
}
|
||||
pub type BackwardNodeRef<Out> = Arc<BackwardNode<Out>>;
|
||||
|
||||
impl<Out: Clone + Zeros<Out>> BackwardNode<Out> {
|
||||
impl<Out: Clone + Zeros> BackwardNode<Out> {
|
||||
pub fn from_node(
|
||||
node: &ForwardNodeRef<Out>,
|
||||
converter: &mut Forward2BackwardGraphConverter,
|
||||
|
@ -35,7 +35,7 @@ impl<Out: Clone + Zeros<Out>> BackwardNode<Out> {
|
|||
|
||||
impl<Out> BackwardNode<Out>
|
||||
where
|
||||
Out: Zeros<Out> + Ones<Out> + Clone + Add<Output = Out>,
|
||||
Out: Zeros + Ones + Clone + Add<Output = Out>,
|
||||
Out: std::fmt::Debug + 'static + Send + Sync,
|
||||
{
|
||||
pub fn backward(&mut self) -> Gradients {
|
||||
|
@ -74,7 +74,7 @@ where
|
|||
|
||||
impl<T> RecordedOpsParent for BackwardNode<T>
|
||||
where
|
||||
T: Zeros<T> + Clone + Add<Output = T>,
|
||||
T: Zeros + Clone + Add<Output = T>,
|
||||
T: std::fmt::Debug + 'static + Send + Sync,
|
||||
{
|
||||
fn backward_step(&self) {
|
||||
|
|
|
@ -23,7 +23,7 @@ pub struct BackwardNodeState<Out> {
|
|||
pub grad: RefCell<Out>,
|
||||
}
|
||||
|
||||
impl<Out: Zeros<Out>> BackwardNodeState<Out> {
|
||||
impl<Out: Zeros> BackwardNodeState<Out> {
|
||||
pub fn new(value: Out) -> Self {
|
||||
let grad = value.zeros();
|
||||
let grad = RefCell::new(grad);
|
||||
|
@ -42,7 +42,7 @@ where
|
|||
|
||||
impl<Out> BackwardNodeState<Out>
|
||||
where
|
||||
Out: Zeros<Out> + Clone + Add<Output = Out>,
|
||||
Out: Zeros + Clone + Add<Output = Out>,
|
||||
Out: std::fmt::Debug,
|
||||
{
|
||||
pub fn grad(&self) -> Out {
|
||||
|
|
|
@ -32,9 +32,9 @@ pub struct BackwardBinaryRecordedOps<Lhs, Rhs, Ops> {
|
|||
|
||||
impl<Lhs, Rhs, Out, Ops> ForwardRecordedOps<Out> for ForwardBinaryRecordedOps<Lhs, Rhs, Ops>
|
||||
where
|
||||
Lhs: Clone + Zeros<Lhs> + Add<Output = Lhs> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Rhs: Clone + Zeros<Rhs> + Add<Output = Rhs> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
Lhs: Clone + Zeros + Add<Output = Lhs> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Rhs: Clone + Zeros + Add<Output = Rhs> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
Ops: BinaryOps<Lhs, Rhs, Out> + std::fmt::Debug + 'static + Send + Sync,
|
||||
{
|
||||
fn to_backward(
|
||||
|
@ -51,9 +51,9 @@ where
|
|||
|
||||
impl<Lhs, Rhs, Out, Ops> BackwardRecordedOps<Out> for BackwardBinaryRecordedOps<Lhs, Rhs, Ops>
|
||||
where
|
||||
Lhs: Clone + Zeros<Lhs> + Add<Output = Lhs> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Rhs: Clone + Zeros<Rhs> + Add<Output = Rhs> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
Lhs: Clone + Zeros + Add<Output = Lhs> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Rhs: Clone + Zeros + Add<Output = Rhs> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
Ops: BinaryOps<Lhs, Rhs, Out> + std::fmt::Debug + 'static,
|
||||
{
|
||||
fn backward_step(&self, state: &BackwardNodeState<Out>) {
|
||||
|
|
|
@ -10,7 +10,7 @@ pub struct InitRecordedOps {}
|
|||
|
||||
impl<Out> BackwardRecordedOps<Out> for InitRecordedOps
|
||||
where
|
||||
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
{
|
||||
fn backward_step(&self, _: &BackwardNodeState<Out>) {}
|
||||
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
|
||||
|
@ -20,7 +20,7 @@ where
|
|||
|
||||
impl<Out> ForwardRecordedOps<Out> for InitRecordedOps
|
||||
where
|
||||
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
{
|
||||
fn to_backward(
|
||||
&self,
|
||||
|
|
|
@ -26,8 +26,8 @@ pub struct BackwareUnaryRecordedOps<In, Ops> {
|
|||
|
||||
impl<In, Out, Ops> ForwardRecordedOps<Out> for ForwardUnaryRecordedOps<In, Ops>
|
||||
where
|
||||
In: Clone + Zeros<In> + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
In: Clone + Zeros + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
Ops: UnaryOps<In, Out> + std::fmt::Debug + 'static,
|
||||
{
|
||||
fn to_backward(
|
||||
|
@ -43,8 +43,8 @@ where
|
|||
|
||||
impl<In, Out, Ops> BackwardRecordedOps<Out> for BackwareUnaryRecordedOps<In, Ops>
|
||||
where
|
||||
In: Clone + Zeros<In> + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
In: Clone + Zeros + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
|
||||
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
|
||||
Ops: UnaryOps<In, Out> + std::fmt::Debug + 'static,
|
||||
{
|
||||
fn backward_step(&self, state: &BackwardNodeState<Out>) {
|
||||
|
|
|
@ -3,13 +3,13 @@ use crate::tensor::{
|
|||
ops::*,
|
||||
};
|
||||
|
||||
impl<B: Backend, const D: usize> Zeros<Self> for ADTensor<D, B> {
|
||||
impl<B: Backend, const D: usize> Zeros for ADTensor<D, B> {
|
||||
fn zeros(&self) -> Self {
|
||||
ADTensor::from_tensor(self.tensor().zeros())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> Ones<Self> for ADTensor<D, B> {
|
||||
impl<B: Backend, const D: usize> Ones for ADTensor<D, B> {
|
||||
fn ones(&self) -> Self {
|
||||
ADTensor::from_tensor(self.tensor().ones())
|
||||
}
|
||||
|
|
|
@ -20,8 +20,8 @@ pub trait Backend:
|
|||
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
|
||||
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
|
||||
type TensorPrimitive<const D: usize>: std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
|
||||
+ Zeros<Self::TensorPrimitive<D>>
|
||||
+ Ones<Self::TensorPrimitive<D>>
|
||||
+ Zeros
|
||||
+ Ones
|
||||
+ Clone
|
||||
+ Send
|
||||
+ Sync
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use super::element::NdArrayElement;
|
||||
use super::NdArrayTensor;
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::Data;
|
||||
use crate::tensor::{backend::Backend, NdArrayElement};
|
||||
use crate::{Distribution, Shape};
|
||||
use rand::rngs::StdRng;
|
||||
use rand::SeedableRng;
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
use crate::Element;
|
||||
|
||||
pub(crate) trait NdArrayElement:
|
||||
Element + ndarray::LinalgScalar + ndarray::ScalarOperand + ExpElement + num_traits::FromPrimitive
|
||||
{
|
||||
}
|
||||
|
||||
pub(crate) trait ExpElement {
|
||||
fn exp_elem(self) -> Self;
|
||||
fn log_elem(self) -> Self;
|
||||
fn pow_elem(self, value: f32) -> Self;
|
||||
}
|
||||
|
||||
macro_rules! impl_exp_elem {
|
||||
($elem:ident) => {
|
||||
impl ExpElement for $elem {
|
||||
fn exp_elem(self) -> Self {
|
||||
$elem::exp(self)
|
||||
}
|
||||
fn log_elem(self) -> Self {
|
||||
$elem::ln(self)
|
||||
}
|
||||
fn pow_elem(self, value: f32) -> Self {
|
||||
$elem::powf(self, value.into())
|
||||
}
|
||||
}
|
||||
};
|
||||
($elem:ident, $tmp:ident) => {
|
||||
impl ExpElement for $elem {
|
||||
fn exp_elem(self) -> Self {
|
||||
let tmp = $tmp::exp(self as $tmp);
|
||||
tmp as $elem
|
||||
}
|
||||
fn log_elem(self) -> Self {
|
||||
let tmp = $tmp::ln(self as $tmp);
|
||||
tmp as $elem
|
||||
}
|
||||
fn pow_elem(self, value: f32) -> Self {
|
||||
let tmp = $tmp::powf(self as $tmp, value as $tmp);
|
||||
tmp as $elem
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl NdArrayElement for f64 {}
|
||||
impl_exp_elem!(f64);
|
||||
|
||||
impl NdArrayElement for f32 {}
|
||||
impl_exp_elem!(f32);
|
||||
|
||||
impl NdArrayElement for i64 {}
|
||||
impl_exp_elem!(i64, f64);
|
||||
|
||||
impl NdArrayElement for i32 {}
|
||||
impl_exp_elem!(i32, f32);
|
||||
|
||||
impl NdArrayElement for i16 {}
|
||||
impl_exp_elem!(i16, f32);
|
||||
|
||||
impl NdArrayElement for u8 {}
|
||||
impl_exp_elem!(u8, f32);
|
|
@ -1,4 +1,5 @@
|
|||
mod backend;
|
||||
mod element;
|
||||
mod module_ops;
|
||||
mod ops;
|
||||
mod shape;
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use super::{element::NdArrayElement, NdArrayBackend, NdArrayTensor};
|
||||
use crate::{ops::*, Shape};
|
||||
use std::ops::Add;
|
||||
|
||||
use super::{NdArrayBackend, NdArrayTensor};
|
||||
use crate::{ops::*, NdArrayElement, Shape};
|
||||
|
||||
impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
||||
fn embedding(
|
||||
weights: &NdArrayTensor<E, 2>,
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*, Data};
|
||||
|
||||
impl<P, const D: usize> Zeros<NdArrayTensor<P, D>> for NdArrayTensor<P, D>
|
||||
impl<P, const D: usize> Zeros for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Default + Clone + Zeros<P> + std::fmt::Debug,
|
||||
P: Default + Clone + Zeros + std::fmt::Debug,
|
||||
{
|
||||
fn zeros(&self) -> NdArrayTensor<P, D> {
|
||||
let data = Data::<P, D>::zeros(self.shape);
|
||||
|
@ -10,9 +10,9 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<P, const D: usize> Ones<NdArrayTensor<P, D>> for NdArrayTensor<P, D>
|
||||
impl<P, const D: usize> Ones for NdArrayTensor<P, D>
|
||||
where
|
||||
P: Default + Clone + Ones<P> + std::fmt::Debug,
|
||||
P: Default + Clone + Ones + std::fmt::Debug,
|
||||
{
|
||||
fn ones(&self) -> NdArrayTensor<P, D> {
|
||||
let data = Data::<P, D>::ones(self.shape);
|
||||
|
|
|
@ -1,8 +1,7 @@
|
|||
use super::NdArrayBackend;
|
||||
use super::{element::NdArrayElement, NdArrayBackend};
|
||||
use crate::{
|
||||
ops::TensorOps,
|
||||
tensor::{Data, Shape},
|
||||
NdArrayElement,
|
||||
};
|
||||
use ndarray::{s, ArcArray, Array, Axis, Dim, Ix2, Ix3, IxDyn};
|
||||
|
||||
|
@ -22,9 +21,9 @@ impl<E: NdArrayElement, const D: usize> std::ops::Add for NdArrayTensor<E, D> {
|
|||
|
||||
#[cfg(test)]
|
||||
mod utils {
|
||||
use crate::{backend::NdArrayBackend, ops::TensorOps, NdArrayElement};
|
||||
|
||||
use super::*;
|
||||
use crate::{backend::NdArrayBackend, ops::TensorOps};
|
||||
|
||||
impl<E, const D: usize> NdArrayTensor<E, D>
|
||||
where
|
||||
E: Default + Clone,
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use super::{BatchMatrix, NdArrayBackend, NdArrayTensor};
|
||||
use super::{element::NdArrayElement, BatchMatrix, NdArrayBackend, NdArrayTensor};
|
||||
use crate::{
|
||||
backend::{Backend, NdArrayDevice},
|
||||
ops::TensorOps,
|
||||
to_nd_array_tensor, Data, ElementConversion, NdArrayElement, Shape,
|
||||
to_nd_array_tensor, Data, ElementConversion, Shape,
|
||||
};
|
||||
use ndarray::{Axis, Dim, IxDyn, SliceInfoElem};
|
||||
use std::{cmp::Ordering, ops::Range};
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
use super::element::TchElement;
|
||||
use super::TchTensor;
|
||||
use crate::tensor::{backend::Backend, TchElement};
|
||||
use crate::tensor::backend::Backend;
|
||||
use crate::tensor::{Data, Distribution, Shape};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
use crate::ops::{Ones, Zeros};
|
||||
use crate::{
|
||||
make_element, Distribution, Element, ElementConversion, ElementPrecision, ElementRandom,
|
||||
ElementValue, Precision,
|
||||
};
|
||||
use half::f16;
|
||||
use num_traits::ToPrimitive;
|
||||
use rand::rngs::StdRng;
|
||||
|
||||
pub(crate) trait TchElement: Element + tch::kind::Element {}
|
||||
|
||||
impl TchElement for f64 {}
|
||||
impl TchElement for f32 {}
|
||||
impl TchElement for f16 {}
|
||||
|
||||
impl TchElement for i64 {}
|
||||
impl TchElement for i32 {}
|
||||
impl TchElement for i16 {}
|
||||
|
||||
impl TchElement for u8 {}
|
||||
|
||||
make_element!(
|
||||
ty f16 Precision::Half,
|
||||
zero <f16 as num_traits::Zero>::zero(),
|
||||
one <f16 as num_traits::One>::one(),
|
||||
convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()),
|
||||
random |distribution: Distribution<f16>, rng: &mut StdRng| {
|
||||
let distribution: Distribution<f32> = distribution.convert();
|
||||
let sample = distribution.sampler(rng).sample();
|
||||
f16::from_elem(sample)
|
||||
}
|
||||
);
|
|
@ -1,4 +1,5 @@
|
|||
mod backend;
|
||||
mod element;
|
||||
mod module_ops;
|
||||
mod ops;
|
||||
mod tensor;
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{TchBackend, TchTensor};
|
||||
use crate::{ops::ModuleOps, Shape, TchElement};
|
||||
use super::{element::TchElement, TchBackend, TchTensor};
|
||||
use crate::{ops::ModuleOps, Shape};
|
||||
|
||||
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
|
||||
fn embedding(weights: &TchTensor<E, 2>, indexes: &TchTensor<i64, 2>) -> TchTensor<E, 3> {
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::tensor::{backend::tch::TchTensor, ops::*};
|
||||
|
||||
impl<P, const D: usize> Zeros<TchTensor<P, D>> for TchTensor<P, D>
|
||||
impl<P, const D: usize> Zeros for TchTensor<P, D>
|
||||
where
|
||||
P: tch::kind::Element,
|
||||
{
|
||||
|
@ -17,7 +17,7 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<P, const D: usize> Ones<TchTensor<P, D>> for TchTensor<P, D>
|
||||
impl<P, const D: usize> Ones for TchTensor<P, D>
|
||||
where
|
||||
P: tch::kind::Element,
|
||||
{
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use super::element::TchElement;
|
||||
use crate::{
|
||||
backend::{TchBackend, TchDevice},
|
||||
ops::TensorOps,
|
||||
tensor::{Data, Shape},
|
||||
TchElement,
|
||||
};
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
|
@ -99,7 +99,7 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
|
|||
#[cfg(test)]
|
||||
mod utils {
|
||||
use super::*;
|
||||
use crate::{backend::TchBackend, ops::TensorOps, TchElement};
|
||||
use crate::{backend::TchBackend, ops::TensorOps};
|
||||
|
||||
impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
|
||||
pub(crate) fn into_data(self) -> Data<P, D>
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{TchBackend, TchDevice, TchKind, TchShape, TchTensor};
|
||||
use crate::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape, TchElement};
|
||||
use super::{element::TchElement, TchBackend, TchDevice, TchKind, TchShape, TchTensor};
|
||||
use crate::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
|
||||
use std::ops::{Add, Div, Mul, Range, Sub};
|
||||
|
||||
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
||||
|
|
|
@ -147,7 +147,7 @@ impl<P: Element, const D: usize> Data<P, D> {
|
|||
}
|
||||
impl<P: std::fmt::Debug, const D: usize> Data<P, D>
|
||||
where
|
||||
P: Zeros<P> + Default,
|
||||
P: Zeros + Default,
|
||||
{
|
||||
pub fn zeros(shape: Shape<D>) -> Data<P, D> {
|
||||
let elem = P::default();
|
||||
|
@ -167,7 +167,7 @@ where
|
|||
|
||||
impl<P: std::fmt::Debug, const D: usize> Data<P, D>
|
||||
where
|
||||
P: Ones<P> + Default,
|
||||
P: Ones + Default,
|
||||
{
|
||||
pub fn ones(shape: Shape<D>) -> Data<P, D> {
|
||||
let elem = P::default();
|
||||
|
|
|
@ -1,43 +1,35 @@
|
|||
use crate::{tensor::ops::*, Distribution};
|
||||
use half::f16;
|
||||
use num_traits::ToPrimitive;
|
||||
use rand::prelude::StdRng;
|
||||
|
||||
pub trait Element:
|
||||
Zeros<Self>
|
||||
Zeros
|
||||
+ ToPrimitive
|
||||
+ ElementRandom<Self>
|
||||
+ ElementRandom
|
||||
+ ElementConversion
|
||||
+ ElementPrecision
|
||||
+ ElementValue
|
||||
+ Ones<Self>
|
||||
+ Ones
|
||||
+ std::ops::Mul<Self, Output = Self>
|
||||
+ std::fmt::Debug
|
||||
+ Default
|
||||
+ 'static
|
||||
+ Send
|
||||
+ Sync
|
||||
+ Copy
|
||||
+ std::cmp::PartialOrd<Self>
|
||||
+ 'static
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
pub(crate) trait TchElement: Element + tch::kind::Element {}
|
||||
|
||||
pub(crate) trait ExpElement {
|
||||
fn exp_elem(self) -> Self;
|
||||
fn log_elem(self) -> Self;
|
||||
fn pow_elem(self, value: f32) -> Self;
|
||||
}
|
||||
|
||||
pub trait ElementConversion {
|
||||
fn from_elem<E: ToPrimitive>(elem: E) -> Self;
|
||||
fn to_elem<E: Element>(&self) -> E;
|
||||
}
|
||||
|
||||
pub trait ElementRandom<T> {
|
||||
fn random(distribution: Distribution<T>, rng: &mut StdRng) -> T;
|
||||
pub trait ElementRandom {
|
||||
fn random(distribution: Distribution<Self>, rng: &mut StdRng) -> Self
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
|
||||
pub trait ElementValue {
|
||||
|
@ -60,30 +52,31 @@ pub trait ElementPrecision {
|
|||
fn precision() -> Precision;
|
||||
}
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
pub(crate) trait NdArrayElement:
|
||||
Element + ndarray::LinalgScalar + ndarray::ScalarOperand + ExpElement + num_traits::FromPrimitive
|
||||
{
|
||||
}
|
||||
|
||||
macro_rules! ad_items {
|
||||
#[macro_export]
|
||||
macro_rules! make_element {
|
||||
(
|
||||
ty $float:ident $precision:expr,
|
||||
ty $type:ident $precision:expr,
|
||||
zero $zero:expr,
|
||||
one $one:expr,
|
||||
convert $convert:expr,
|
||||
random $random:expr
|
||||
|
||||
) => {
|
||||
impl Element for $float {}
|
||||
impl Element for $type {}
|
||||
|
||||
impl Zeros<$float> for $float {
|
||||
fn zeros(&self) -> $float {
|
||||
impl Zeros for $type {
|
||||
fn zeros(&self) -> $type {
|
||||
$zero
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementConversion for $float {
|
||||
impl Ones for $type {
|
||||
fn ones(&self) -> $type {
|
||||
$one
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementConversion for $type {
|
||||
fn from_elem<E: ToPrimitive>(elem: E) -> Self {
|
||||
$convert(&elem)
|
||||
}
|
||||
|
@ -92,7 +85,7 @@ macro_rules! ad_items {
|
|||
}
|
||||
}
|
||||
|
||||
impl ElementValue for $float {
|
||||
impl ElementValue for $type {
|
||||
fn inf() -> Self {
|
||||
Self::from_elem(f64::INFINITY)
|
||||
}
|
||||
|
@ -110,30 +103,25 @@ macro_rules! ad_items {
|
|||
}
|
||||
}
|
||||
|
||||
impl ElementPrecision for $float {
|
||||
impl ElementPrecision for $type {
|
||||
fn precision() -> Precision {
|
||||
$precision
|
||||
}
|
||||
}
|
||||
|
||||
impl ElementRandom<$float> for $float {
|
||||
fn random(distribution: Distribution<$float>, rng: &mut StdRng) -> $float {
|
||||
impl ElementRandom for $type {
|
||||
fn random(distribution: Distribution<Self>, rng: &mut StdRng) -> Self {
|
||||
$random(distribution, rng)
|
||||
}
|
||||
}
|
||||
|
||||
impl Ones<$float> for $float {
|
||||
fn ones(&self) -> $float {
|
||||
$one
|
||||
}
|
||||
}
|
||||
};
|
||||
};
|
||||
(
|
||||
float $float:ident $precision:expr,
|
||||
convert $convert:expr,
|
||||
random $random:expr
|
||||
) => {
|
||||
ad_items!(
|
||||
make_element!(
|
||||
ty $float $precision,
|
||||
zero 0.0,
|
||||
one 1.0,
|
||||
|
@ -146,7 +134,7 @@ macro_rules! ad_items {
|
|||
convert $convert:expr,
|
||||
random $random:expr
|
||||
) => {
|
||||
ad_items!(
|
||||
make_element!(
|
||||
ty $int $precision,
|
||||
zero 0,
|
||||
one 1,
|
||||
|
@ -156,122 +144,41 @@ macro_rules! ad_items {
|
|||
};
|
||||
}
|
||||
|
||||
ad_items!(
|
||||
make_element!(
|
||||
float f64 Precision::Double,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(),
|
||||
random |distribution: Distribution<f64>, rng: &mut StdRng| distribution.sampler(rng).sample()
|
||||
);
|
||||
|
||||
ad_items!(
|
||||
ty f16 Precision::Half,
|
||||
zero <f16 as num_traits::Zero>::zero(),
|
||||
one <f16 as num_traits::One>::one(),
|
||||
convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()),
|
||||
random |distribution: Distribution<f16>, rng: &mut StdRng| {
|
||||
let distribution: Distribution<f32> = distribution.convert();
|
||||
let sample = distribution.sampler(rng).sample();
|
||||
f16::from_elem(sample)
|
||||
}
|
||||
);
|
||||
ad_items!(
|
||||
make_element!(
|
||||
float f32 Precision::Full,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_f32().unwrap(),
|
||||
random |distribution: Distribution<f32>, rng: &mut StdRng| distribution.sampler(rng).sample()
|
||||
);
|
||||
|
||||
ad_items!(
|
||||
make_element!(
|
||||
int i64 Precision::Double,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_i64().unwrap(),
|
||||
random |distribution: Distribution<i64>, rng: &mut StdRng| distribution.sampler(rng).sample()
|
||||
);
|
||||
ad_items!(
|
||||
make_element!(
|
||||
int i32 Precision::Full,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_i32().unwrap(),
|
||||
random |distribution: Distribution<i32>, rng: &mut StdRng| distribution.sampler(rng).sample()
|
||||
);
|
||||
ad_items!(
|
||||
make_element!(
|
||||
int i16 Precision::Half,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_i16().unwrap(),
|
||||
random |distribution: Distribution<i16>, rng: &mut StdRng| distribution.sampler(rng).sample()
|
||||
);
|
||||
ad_items!(
|
||||
make_element!(
|
||||
int i8 Precision::Other,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_i8().unwrap(),
|
||||
random |distribution: Distribution<i8>, rng: &mut StdRng| distribution.sampler(rng).sample()
|
||||
);
|
||||
|
||||
ad_items!(
|
||||
make_element!(
|
||||
int u8 Precision::Other,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(),
|
||||
random |distribution: Distribution<u8>, rng: &mut StdRng| distribution.sampler(rng).sample()
|
||||
);
|
||||
|
||||
#[cfg(feature = "tch")]
|
||||
mod tch_elem {
|
||||
use super::*;
|
||||
|
||||
impl TchElement for f64 {}
|
||||
impl TchElement for f32 {}
|
||||
impl TchElement for f16 {}
|
||||
|
||||
impl TchElement for i64 {}
|
||||
impl TchElement for i32 {}
|
||||
impl TchElement for i16 {}
|
||||
|
||||
impl TchElement for u8 {}
|
||||
}
|
||||
|
||||
#[cfg(feature = "ndarray")]
|
||||
mod ndarray_elem {
|
||||
use super::*;
|
||||
|
||||
macro_rules! impl_exp_elem {
|
||||
($elem:ident) => {
|
||||
impl ExpElement for $elem {
|
||||
fn exp_elem(self) -> Self {
|
||||
$elem::exp(self)
|
||||
}
|
||||
fn log_elem(self) -> Self {
|
||||
$elem::ln(self)
|
||||
}
|
||||
fn pow_elem(self, value: f32) -> Self {
|
||||
$elem::powf(self, value.into())
|
||||
}
|
||||
}
|
||||
};
|
||||
($elem:ident, $tmp:ident) => {
|
||||
impl ExpElement for $elem {
|
||||
fn exp_elem(self) -> Self {
|
||||
let tmp = $tmp::exp(self as $tmp);
|
||||
tmp as $elem
|
||||
}
|
||||
fn log_elem(self) -> Self {
|
||||
let tmp = $tmp::ln(self as $tmp);
|
||||
tmp as $elem
|
||||
}
|
||||
fn pow_elem(self, value: f32) -> Self {
|
||||
let tmp = $tmp::powf(self as $tmp, value as $tmp);
|
||||
tmp as $elem
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl NdArrayElement for f64 {}
|
||||
impl_exp_elem!(f64);
|
||||
|
||||
impl NdArrayElement for f32 {}
|
||||
impl_exp_elem!(f32);
|
||||
|
||||
impl NdArrayElement for i64 {}
|
||||
impl_exp_elem!(i64, f64);
|
||||
|
||||
impl NdArrayElement for i32 {}
|
||||
impl_exp_elem!(i32, f32);
|
||||
|
||||
impl NdArrayElement for i16 {}
|
||||
impl_exp_elem!(i16, f32);
|
||||
|
||||
impl NdArrayElement for u8 {}
|
||||
impl_exp_elem!(u8, f32);
|
||||
}
|
||||
|
|
|
@ -200,10 +200,10 @@ pub trait TensorOps<B: Backend> {
|
|||
fn relu<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
}
|
||||
|
||||
pub trait Zeros<T> {
|
||||
fn zeros(&self) -> T;
|
||||
pub trait Zeros {
|
||||
fn zeros(&self) -> Self;
|
||||
}
|
||||
|
||||
pub trait Ones<T> {
|
||||
fn ones(&self) -> T;
|
||||
pub trait Ones {
|
||||
fn ones(&self) -> Self;
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue