Refactor/zeros ones elems (#102)

This commit is contained in:
Nathaniel Simard 2022-11-14 18:41:26 -05:00 committed by GitHub
parent 1a45368878
commit 23677b8e89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
26 changed files with 184 additions and 181 deletions

View File

@ -15,7 +15,7 @@ impl Forward2BackwardGraphConverter {
state: HashMap::new(),
}
}
pub fn from<T: Clone + 'static + Zeros<T>>(
pub fn from<T: Clone + 'static + Zeros>(
&mut self,
node: &ForwardNodeRef<T>,
) -> BackwardNodeRef<T> {

View File

@ -21,7 +21,7 @@ impl Gradients {
pub fn register<T>(&mut self, node: &BackwardNode<T>)
where
T: Zeros<T> + Clone + Add<Output = T>,
T: Zeros + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static + Send + Sync,
{
let grad = node.state.grad();
@ -37,7 +37,7 @@ impl Gradients {
pub fn from<T>(node: &BackwardNode<T>) -> Self
where
T: Zeros<T> + Clone + Add<Output = T>,
T: Zeros + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static + Send + Sync,
{
let mut grads = Self::empty();

View File

@ -19,7 +19,7 @@ pub struct BackwardNode<Out> {
}
pub type BackwardNodeRef<Out> = Arc<BackwardNode<Out>>;
impl<Out: Clone + Zeros<Out>> BackwardNode<Out> {
impl<Out: Clone + Zeros> BackwardNode<Out> {
pub fn from_node(
node: &ForwardNodeRef<Out>,
converter: &mut Forward2BackwardGraphConverter,
@ -35,7 +35,7 @@ impl<Out: Clone + Zeros<Out>> BackwardNode<Out> {
impl<Out> BackwardNode<Out>
where
Out: Zeros<Out> + Ones<Out> + Clone + Add<Output = Out>,
Out: Zeros + Ones + Clone + Add<Output = Out>,
Out: std::fmt::Debug + 'static + Send + Sync,
{
pub fn backward(&mut self) -> Gradients {
@ -74,7 +74,7 @@ where
impl<T> RecordedOpsParent for BackwardNode<T>
where
T: Zeros<T> + Clone + Add<Output = T>,
T: Zeros + Clone + Add<Output = T>,
T: std::fmt::Debug + 'static + Send + Sync,
{
fn backward_step(&self) {

View File

@ -23,7 +23,7 @@ pub struct BackwardNodeState<Out> {
pub grad: RefCell<Out>,
}
impl<Out: Zeros<Out>> BackwardNodeState<Out> {
impl<Out: Zeros> BackwardNodeState<Out> {
pub fn new(value: Out) -> Self {
let grad = value.zeros();
let grad = RefCell::new(grad);
@ -42,7 +42,7 @@ where
impl<Out> BackwardNodeState<Out>
where
Out: Zeros<Out> + Clone + Add<Output = Out>,
Out: Zeros + Clone + Add<Output = Out>,
Out: std::fmt::Debug,
{
pub fn grad(&self) -> Out {

View File

@ -32,9 +32,9 @@ pub struct BackwardBinaryRecordedOps<Lhs, Rhs, Ops> {
impl<Lhs, Rhs, Out, Ops> ForwardRecordedOps<Out> for ForwardBinaryRecordedOps<Lhs, Rhs, Ops>
where
Lhs: Clone + Zeros<Lhs> + Add<Output = Lhs> + std::fmt::Debug + 'static + Send + Sync,
Rhs: Clone + Zeros<Rhs> + Add<Output = Rhs> + std::fmt::Debug + 'static + Send + Sync,
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
Lhs: Clone + Zeros + Add<Output = Lhs> + std::fmt::Debug + 'static + Send + Sync,
Rhs: Clone + Zeros + Add<Output = Rhs> + std::fmt::Debug + 'static + Send + Sync,
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
Ops: BinaryOps<Lhs, Rhs, Out> + std::fmt::Debug + 'static + Send + Sync,
{
fn to_backward(
@ -51,9 +51,9 @@ where
impl<Lhs, Rhs, Out, Ops> BackwardRecordedOps<Out> for BackwardBinaryRecordedOps<Lhs, Rhs, Ops>
where
Lhs: Clone + Zeros<Lhs> + Add<Output = Lhs> + std::fmt::Debug + 'static + Send + Sync,
Rhs: Clone + Zeros<Rhs> + Add<Output = Rhs> + std::fmt::Debug + 'static + Send + Sync,
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
Lhs: Clone + Zeros + Add<Output = Lhs> + std::fmt::Debug + 'static + Send + Sync,
Rhs: Clone + Zeros + Add<Output = Rhs> + std::fmt::Debug + 'static + Send + Sync,
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
Ops: BinaryOps<Lhs, Rhs, Out> + std::fmt::Debug + 'static,
{
fn backward_step(&self, state: &BackwardNodeState<Out>) {

View File

@ -10,7 +10,7 @@ pub struct InitRecordedOps {}
impl<Out> BackwardRecordedOps<Out> for InitRecordedOps
where
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
{
fn backward_step(&self, _: &BackwardNodeState<Out>) {}
fn backward_parents(&self) -> Vec<RecordedOpsParentRef> {
@ -20,7 +20,7 @@ where
impl<Out> ForwardRecordedOps<Out> for InitRecordedOps
where
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
{
fn to_backward(
&self,

View File

@ -26,8 +26,8 @@ pub struct BackwareUnaryRecordedOps<In, Ops> {
impl<In, Out, Ops> ForwardRecordedOps<Out> for ForwardUnaryRecordedOps<In, Ops>
where
In: Clone + Zeros<In> + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
In: Clone + Zeros + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
Ops: UnaryOps<In, Out> + std::fmt::Debug + 'static,
{
fn to_backward(
@ -43,8 +43,8 @@ where
impl<In, Out, Ops> BackwardRecordedOps<Out> for BackwareUnaryRecordedOps<In, Ops>
where
In: Clone + Zeros<In> + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
Out: Clone + Zeros<Out> + Add<Output = Out> + std::fmt::Debug + 'static,
In: Clone + Zeros + Add<Output = In> + std::fmt::Debug + 'static + Send + Sync,
Out: Clone + Zeros + Add<Output = Out> + std::fmt::Debug + 'static,
Ops: UnaryOps<In, Out> + std::fmt::Debug + 'static,
{
fn backward_step(&self, state: &BackwardNodeState<Out>) {

View File

@ -3,13 +3,13 @@ use crate::tensor::{
ops::*,
};
impl<B: Backend, const D: usize> Zeros<Self> for ADTensor<D, B> {
impl<B: Backend, const D: usize> Zeros for ADTensor<D, B> {
fn zeros(&self) -> Self {
ADTensor::from_tensor(self.tensor().zeros())
}
}
impl<B: Backend, const D: usize> Ones<Self> for ADTensor<D, B> {
impl<B: Backend, const D: usize> Ones for ADTensor<D, B> {
fn ones(&self) -> Self {
ADTensor::from_tensor(self.tensor().ones())
}

View File

@ -20,8 +20,8 @@ pub trait Backend:
type FullPrecisionBackend: Backend<Elem = Self::FullPrecisionElem, Device = Self::Device>;
type IntegerBackend: Backend<Elem = i64, Device = Self::Device>;
type TensorPrimitive<const D: usize>: std::ops::Add<Self::TensorPrimitive<D>, Output = Self::TensorPrimitive<D>>
+ Zeros<Self::TensorPrimitive<D>>
+ Ones<Self::TensorPrimitive<D>>
+ Zeros
+ Ones
+ Clone
+ Send
+ Sync

View File

@ -1,6 +1,7 @@
use super::element::NdArrayElement;
use super::NdArrayTensor;
use crate::tensor::backend::Backend;
use crate::tensor::Data;
use crate::tensor::{backend::Backend, NdArrayElement};
use crate::{Distribution, Shape};
use rand::rngs::StdRng;
use rand::SeedableRng;

View File

@ -0,0 +1,62 @@
use crate::Element;
pub(crate) trait NdArrayElement:
Element + ndarray::LinalgScalar + ndarray::ScalarOperand + ExpElement + num_traits::FromPrimitive
{
}
pub(crate) trait ExpElement {
fn exp_elem(self) -> Self;
fn log_elem(self) -> Self;
fn pow_elem(self, value: f32) -> Self;
}
macro_rules! impl_exp_elem {
($elem:ident) => {
impl ExpElement for $elem {
fn exp_elem(self) -> Self {
$elem::exp(self)
}
fn log_elem(self) -> Self {
$elem::ln(self)
}
fn pow_elem(self, value: f32) -> Self {
$elem::powf(self, value.into())
}
}
};
($elem:ident, $tmp:ident) => {
impl ExpElement for $elem {
fn exp_elem(self) -> Self {
let tmp = $tmp::exp(self as $tmp);
tmp as $elem
}
fn log_elem(self) -> Self {
let tmp = $tmp::ln(self as $tmp);
tmp as $elem
}
fn pow_elem(self, value: f32) -> Self {
let tmp = $tmp::powf(self as $tmp, value as $tmp);
tmp as $elem
}
}
};
}
impl NdArrayElement for f64 {}
impl_exp_elem!(f64);
impl NdArrayElement for f32 {}
impl_exp_elem!(f32);
impl NdArrayElement for i64 {}
impl_exp_elem!(i64, f64);
impl NdArrayElement for i32 {}
impl_exp_elem!(i32, f32);
impl NdArrayElement for i16 {}
impl_exp_elem!(i16, f32);
impl NdArrayElement for u8 {}
impl_exp_elem!(u8, f32);

View File

@ -1,4 +1,5 @@
mod backend;
mod element;
mod module_ops;
mod ops;
mod shape;

View File

@ -1,8 +1,7 @@
use super::{element::NdArrayElement, NdArrayBackend, NdArrayTensor};
use crate::{ops::*, Shape};
use std::ops::Add;
use super::{NdArrayBackend, NdArrayTensor};
use crate::{ops::*, NdArrayElement, Shape};
impl<E: NdArrayElement> ModuleOps<NdArrayBackend<E>> for NdArrayBackend<E> {
fn embedding(
weights: &NdArrayTensor<E, 2>,

View File

@ -1,8 +1,8 @@
use crate::tensor::{backend::ndarray::NdArrayTensor, ops::*, Data};
impl<P, const D: usize> Zeros<NdArrayTensor<P, D>> for NdArrayTensor<P, D>
impl<P, const D: usize> Zeros for NdArrayTensor<P, D>
where
P: Default + Clone + Zeros<P> + std::fmt::Debug,
P: Default + Clone + Zeros + std::fmt::Debug,
{
fn zeros(&self) -> NdArrayTensor<P, D> {
let data = Data::<P, D>::zeros(self.shape);
@ -10,9 +10,9 @@ where
}
}
impl<P, const D: usize> Ones<NdArrayTensor<P, D>> for NdArrayTensor<P, D>
impl<P, const D: usize> Ones for NdArrayTensor<P, D>
where
P: Default + Clone + Ones<P> + std::fmt::Debug,
P: Default + Clone + Ones + std::fmt::Debug,
{
fn ones(&self) -> NdArrayTensor<P, D> {
let data = Data::<P, D>::ones(self.shape);

View File

@ -1,8 +1,7 @@
use super::NdArrayBackend;
use super::{element::NdArrayElement, NdArrayBackend};
use crate::{
ops::TensorOps,
tensor::{Data, Shape},
NdArrayElement,
};
use ndarray::{s, ArcArray, Array, Axis, Dim, Ix2, Ix3, IxDyn};
@ -22,9 +21,9 @@ impl<E: NdArrayElement, const D: usize> std::ops::Add for NdArrayTensor<E, D> {
#[cfg(test)]
mod utils {
use crate::{backend::NdArrayBackend, ops::TensorOps, NdArrayElement};
use super::*;
use crate::{backend::NdArrayBackend, ops::TensorOps};
impl<E, const D: usize> NdArrayTensor<E, D>
where
E: Default + Clone,

View File

@ -1,8 +1,8 @@
use super::{BatchMatrix, NdArrayBackend, NdArrayTensor};
use super::{element::NdArrayElement, BatchMatrix, NdArrayBackend, NdArrayTensor};
use crate::{
backend::{Backend, NdArrayDevice},
ops::TensorOps,
to_nd_array_tensor, Data, ElementConversion, NdArrayElement, Shape,
to_nd_array_tensor, Data, ElementConversion, Shape,
};
use ndarray::{Axis, Dim, IxDyn, SliceInfoElem};
use std::{cmp::Ordering, ops::Range};

View File

@ -1,5 +1,6 @@
use super::element::TchElement;
use super::TchTensor;
use crate::tensor::{backend::Backend, TchElement};
use crate::tensor::backend::Backend;
use crate::tensor::{Data, Distribution, Shape};
#[derive(Clone, Copy, Debug)]

View File

@ -0,0 +1,32 @@
use crate::ops::{Ones, Zeros};
use crate::{
make_element, Distribution, Element, ElementConversion, ElementPrecision, ElementRandom,
ElementValue, Precision,
};
use half::f16;
use num_traits::ToPrimitive;
use rand::rngs::StdRng;
pub(crate) trait TchElement: Element + tch::kind::Element {}
impl TchElement for f64 {}
impl TchElement for f32 {}
impl TchElement for f16 {}
impl TchElement for i64 {}
impl TchElement for i32 {}
impl TchElement for i16 {}
impl TchElement for u8 {}
make_element!(
ty f16 Precision::Half,
zero <f16 as num_traits::Zero>::zero(),
one <f16 as num_traits::One>::one(),
convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()),
random |distribution: Distribution<f16>, rng: &mut StdRng| {
let distribution: Distribution<f32> = distribution.convert();
let sample = distribution.sampler(rng).sample();
f16::from_elem(sample)
}
);

View File

@ -1,4 +1,5 @@
mod backend;
mod element;
mod module_ops;
mod ops;
mod tensor;

View File

@ -1,5 +1,5 @@
use super::{TchBackend, TchTensor};
use crate::{ops::ModuleOps, Shape, TchElement};
use super::{element::TchElement, TchBackend, TchTensor};
use crate::{ops::ModuleOps, Shape};
impl<E: TchElement> ModuleOps<TchBackend<E>> for TchBackend<E> {
fn embedding(weights: &TchTensor<E, 2>, indexes: &TchTensor<i64, 2>) -> TchTensor<E, 3> {

View File

@ -1,6 +1,6 @@
use crate::tensor::{backend::tch::TchTensor, ops::*};
impl<P, const D: usize> Zeros<TchTensor<P, D>> for TchTensor<P, D>
impl<P, const D: usize> Zeros for TchTensor<P, D>
where
P: tch::kind::Element,
{
@ -17,7 +17,7 @@ where
}
}
impl<P, const D: usize> Ones<TchTensor<P, D>> for TchTensor<P, D>
impl<P, const D: usize> Ones for TchTensor<P, D>
where
P: tch::kind::Element,
{

View File

@ -1,8 +1,8 @@
use super::element::TchElement;
use crate::{
backend::{TchBackend, TchDevice},
ops::TensorOps,
tensor::{Data, Shape},
TchElement,
};
lazy_static::lazy_static! {
@ -99,7 +99,7 @@ impl<P: tch::kind::Element + Default, const D: usize> TchTensor<P, D> {
#[cfg(test)]
mod utils {
use super::*;
use crate::{backend::TchBackend, ops::TensorOps, TchElement};
use crate::{backend::TchBackend, ops::TensorOps};
impl<P: tch::kind::Element, const D: usize> TchTensor<P, D> {
pub(crate) fn into_data(self) -> Data<P, D>

View File

@ -1,5 +1,5 @@
use super::{TchBackend, TchDevice, TchKind, TchShape, TchTensor};
use crate::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape, TchElement};
use super::{element::TchElement, TchBackend, TchDevice, TchKind, TchShape, TchTensor};
use crate::{backend::Backend, ops::TensorOps, Data, ElementConversion, Shape};
use std::ops::{Add, Div, Mul, Range, Sub};
impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {

View File

@ -147,7 +147,7 @@ impl<P: Element, const D: usize> Data<P, D> {
}
impl<P: std::fmt::Debug, const D: usize> Data<P, D>
where
P: Zeros<P> + Default,
P: Zeros + Default,
{
pub fn zeros(shape: Shape<D>) -> Data<P, D> {
let elem = P::default();
@ -167,7 +167,7 @@ where
impl<P: std::fmt::Debug, const D: usize> Data<P, D>
where
P: Ones<P> + Default,
P: Ones + Default,
{
pub fn ones(shape: Shape<D>) -> Data<P, D> {
let elem = P::default();

View File

@ -1,43 +1,35 @@
use crate::{tensor::ops::*, Distribution};
use half::f16;
use num_traits::ToPrimitive;
use rand::prelude::StdRng;
pub trait Element:
Zeros<Self>
Zeros
+ ToPrimitive
+ ElementRandom<Self>
+ ElementRandom
+ ElementConversion
+ ElementPrecision
+ ElementValue
+ Ones<Self>
+ Ones
+ std::ops::Mul<Self, Output = Self>
+ std::fmt::Debug
+ Default
+ 'static
+ Send
+ Sync
+ Copy
+ std::cmp::PartialOrd<Self>
+ 'static
{
}
#[cfg(feature = "tch")]
pub(crate) trait TchElement: Element + tch::kind::Element {}
pub(crate) trait ExpElement {
fn exp_elem(self) -> Self;
fn log_elem(self) -> Self;
fn pow_elem(self, value: f32) -> Self;
}
pub trait ElementConversion {
fn from_elem<E: ToPrimitive>(elem: E) -> Self;
fn to_elem<E: Element>(&self) -> E;
}
pub trait ElementRandom<T> {
fn random(distribution: Distribution<T>, rng: &mut StdRng) -> T;
pub trait ElementRandom {
fn random(distribution: Distribution<Self>, rng: &mut StdRng) -> Self
where
Self: Sized;
}
pub trait ElementValue {
@ -60,30 +52,31 @@ pub trait ElementPrecision {
fn precision() -> Precision;
}
#[cfg(feature = "ndarray")]
pub(crate) trait NdArrayElement:
Element + ndarray::LinalgScalar + ndarray::ScalarOperand + ExpElement + num_traits::FromPrimitive
{
}
macro_rules! ad_items {
#[macro_export]
macro_rules! make_element {
(
ty $float:ident $precision:expr,
ty $type:ident $precision:expr,
zero $zero:expr,
one $one:expr,
convert $convert:expr,
random $random:expr
) => {
impl Element for $float {}
impl Element for $type {}
impl Zeros<$float> for $float {
fn zeros(&self) -> $float {
impl Zeros for $type {
fn zeros(&self) -> $type {
$zero
}
}
impl ElementConversion for $float {
impl Ones for $type {
fn ones(&self) -> $type {
$one
}
}
impl ElementConversion for $type {
fn from_elem<E: ToPrimitive>(elem: E) -> Self {
$convert(&elem)
}
@ -92,7 +85,7 @@ macro_rules! ad_items {
}
}
impl ElementValue for $float {
impl ElementValue for $type {
fn inf() -> Self {
Self::from_elem(f64::INFINITY)
}
@ -110,30 +103,25 @@ macro_rules! ad_items {
}
}
impl ElementPrecision for $float {
impl ElementPrecision for $type {
fn precision() -> Precision {
$precision
}
}
impl ElementRandom<$float> for $float {
fn random(distribution: Distribution<$float>, rng: &mut StdRng) -> $float {
impl ElementRandom for $type {
fn random(distribution: Distribution<Self>, rng: &mut StdRng) -> Self {
$random(distribution, rng)
}
}
impl Ones<$float> for $float {
fn ones(&self) -> $float {
$one
}
}
};
};
(
float $float:ident $precision:expr,
convert $convert:expr,
random $random:expr
) => {
ad_items!(
make_element!(
ty $float $precision,
zero 0.0,
one 1.0,
@ -146,7 +134,7 @@ macro_rules! ad_items {
convert $convert:expr,
random $random:expr
) => {
ad_items!(
make_element!(
ty $int $precision,
zero 0,
one 1,
@ -156,122 +144,41 @@ macro_rules! ad_items {
};
}
ad_items!(
make_element!(
float f64 Precision::Double,
convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(),
random |distribution: Distribution<f64>, rng: &mut StdRng| distribution.sampler(rng).sample()
);
ad_items!(
ty f16 Precision::Half,
zero <f16 as num_traits::Zero>::zero(),
one <f16 as num_traits::One>::one(),
convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()),
random |distribution: Distribution<f16>, rng: &mut StdRng| {
let distribution: Distribution<f32> = distribution.convert();
let sample = distribution.sampler(rng).sample();
f16::from_elem(sample)
}
);
ad_items!(
make_element!(
float f32 Precision::Full,
convert |elem: &dyn ToPrimitive| elem.to_f32().unwrap(),
random |distribution: Distribution<f32>, rng: &mut StdRng| distribution.sampler(rng).sample()
);
ad_items!(
make_element!(
int i64 Precision::Double,
convert |elem: &dyn ToPrimitive| elem.to_i64().unwrap(),
random |distribution: Distribution<i64>, rng: &mut StdRng| distribution.sampler(rng).sample()
);
ad_items!(
make_element!(
int i32 Precision::Full,
convert |elem: &dyn ToPrimitive| elem.to_i32().unwrap(),
random |distribution: Distribution<i32>, rng: &mut StdRng| distribution.sampler(rng).sample()
);
ad_items!(
make_element!(
int i16 Precision::Half,
convert |elem: &dyn ToPrimitive| elem.to_i16().unwrap(),
random |distribution: Distribution<i16>, rng: &mut StdRng| distribution.sampler(rng).sample()
);
ad_items!(
make_element!(
int i8 Precision::Other,
convert |elem: &dyn ToPrimitive| elem.to_i8().unwrap(),
random |distribution: Distribution<i8>, rng: &mut StdRng| distribution.sampler(rng).sample()
);
ad_items!(
make_element!(
int u8 Precision::Other,
convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(),
random |distribution: Distribution<u8>, rng: &mut StdRng| distribution.sampler(rng).sample()
);
#[cfg(feature = "tch")]
mod tch_elem {
use super::*;
impl TchElement for f64 {}
impl TchElement for f32 {}
impl TchElement for f16 {}
impl TchElement for i64 {}
impl TchElement for i32 {}
impl TchElement for i16 {}
impl TchElement for u8 {}
}
#[cfg(feature = "ndarray")]
mod ndarray_elem {
use super::*;
macro_rules! impl_exp_elem {
($elem:ident) => {
impl ExpElement for $elem {
fn exp_elem(self) -> Self {
$elem::exp(self)
}
fn log_elem(self) -> Self {
$elem::ln(self)
}
fn pow_elem(self, value: f32) -> Self {
$elem::powf(self, value.into())
}
}
};
($elem:ident, $tmp:ident) => {
impl ExpElement for $elem {
fn exp_elem(self) -> Self {
let tmp = $tmp::exp(self as $tmp);
tmp as $elem
}
fn log_elem(self) -> Self {
let tmp = $tmp::ln(self as $tmp);
tmp as $elem
}
fn pow_elem(self, value: f32) -> Self {
let tmp = $tmp::powf(self as $tmp, value as $tmp);
tmp as $elem
}
}
};
}
impl NdArrayElement for f64 {}
impl_exp_elem!(f64);
impl NdArrayElement for f32 {}
impl_exp_elem!(f32);
impl NdArrayElement for i64 {}
impl_exp_elem!(i64, f64);
impl NdArrayElement for i32 {}
impl_exp_elem!(i32, f32);
impl NdArrayElement for i16 {}
impl_exp_elem!(i16, f32);
impl NdArrayElement for u8 {}
impl_exp_elem!(u8, f32);
}

View File

@ -200,10 +200,10 @@ pub trait TensorOps<B: Backend> {
fn relu<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
}
pub trait Zeros<T> {
fn zeros(&self) -> T;
pub trait Zeros {
fn zeros(&self) -> Self;
}
pub trait Ones<T> {
fn ones(&self) -> T;
pub trait Ones {
fn ones(&self) -> Self;
}