mirror of https://github.com/tracel-ai/burn.git
refactor: pow ops (#98)
This commit is contained in:
parent
8c050c2904
commit
ef01a4ed3f
|
@ -12,7 +12,7 @@ register_ops!(
|
|||
name ADTensorErfOps,
|
||||
partial |state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>|{
|
||||
let value = state.input.value();
|
||||
let exponent = B::neg(&value.powf(2.0.to_elem()));
|
||||
let exponent = B::neg(&B::powf(&value, 2.0));
|
||||
let numerator = B::mul_scalar(&B::exp(&exponent), &2.0.to_elem());
|
||||
let denominator = std::f64::consts::PI.sqrt().to_elem();
|
||||
let value = B::div_scalar(&numerator, &denominator);
|
||||
|
|
|
@ -3,7 +3,6 @@ mod cat;
|
|||
mod creation;
|
||||
mod erf;
|
||||
mod module;
|
||||
mod pow;
|
||||
mod tensor;
|
||||
|
||||
mod macros;
|
||||
|
|
|
@ -1,61 +0,0 @@
|
|||
use crate::tensor::backend::Backend;
|
||||
use crate::ElementConversion;
|
||||
use crate::{
|
||||
execute_ops,
|
||||
graph::ops::{UnaryOps, UnaryOpsNodeState},
|
||||
register_ops,
|
||||
tensor::{backend::autodiff::ADTensor, ops::*},
|
||||
};
|
||||
|
||||
register_ops!(
|
||||
ops UnaryOps,
|
||||
name ADTensorPowOps state f32,
|
||||
partial |
|
||||
value: &f32,
|
||||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
|
||||
| {
|
||||
let value = B::mul_scalar(&state.input
|
||||
.value()
|
||||
.powf(value - 1.0)
|
||||
, &value.clone().to_elem());
|
||||
B::mul(&state.output.grad(), &value)
|
||||
},
|
||||
);
|
||||
|
||||
impl<B: Backend, const D: usize> TensorOpsPow<B::Elem, D> for ADTensor<D, B> {
|
||||
fn powf(&self, value: f32) -> Self {
|
||||
execute_ops!(
|
||||
input self.node.clone(),
|
||||
out TensorOpsPow::powf(&self.tensor(), value),
|
||||
ops ADTensorPowOps::<B, D>::new(value),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::tensor::{backend::autodiff::helper::TestADTensor, Data};
|
||||
|
||||
#[test]
|
||||
fn should_diff_powf() {
|
||||
let data_1 = Data::<f64, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = Data::<f64, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.powf(0.4));
|
||||
let tensor_4 = tensor_3.matmul(&tensor_2);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3);
|
||||
}
|
||||
}
|
|
@ -983,4 +983,35 @@ impl<B: Backend> TensorOps<ADBackendDecorator<B>> for ADBackendDecorator<B> {
|
|||
|
||||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
|
||||
fn powf<const D: usize>(
|
||||
tensor: &<ADBackendDecorator<B> as Backend>::TensorPrimitive<D>,
|
||||
value: f32,
|
||||
) -> <ADBackendDecorator<B> as Backend>::TensorPrimitive<D> {
|
||||
#[derive(new, Debug)]
|
||||
struct Backward<B: Backend, const D: usize> {
|
||||
value: f32,
|
||||
_b: B,
|
||||
}
|
||||
|
||||
impl<B: Backend, const D: usize> UnaryOps<B::TensorPrimitive<D>, B::TensorPrimitive<D>>
|
||||
for Backward<B, D>
|
||||
{
|
||||
fn partial(
|
||||
&self,
|
||||
state: &UnaryOpsNodeState<B::TensorPrimitive<D>, B::TensorPrimitive<D>>,
|
||||
) -> B::TensorPrimitive<D> {
|
||||
let value = B::mul_scalar(
|
||||
&B::powf(&state.input.value(), self.value - 1.0),
|
||||
&self.value.clone().to_elem(),
|
||||
);
|
||||
B::mul(&state.output.grad(), &value)
|
||||
}
|
||||
}
|
||||
|
||||
let output = B::powf(tensor.tensor_ref(), value);
|
||||
let ops = Backward::<B, D>::new(value, B::default());
|
||||
|
||||
unary_ops_wrapper(tensor.node.clone(), output, ops)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -25,7 +25,6 @@ pub trait Backend:
|
|||
+ Ones<Self::TensorPrimitive<D>>
|
||||
+ TensorOpsCat<Self::Elem, D>
|
||||
+ TensorOpsErf<Self::Elem, D>
|
||||
+ TensorOpsPow<Self::Elem, D>
|
||||
+ ReLU<Self::Elem, D>
|
||||
+ Clone
|
||||
+ Send
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
mod cat;
|
||||
mod creation;
|
||||
mod erf;
|
||||
mod pow;
|
||||
|
|
|
@ -1,16 +0,0 @@
|
|||
use crate::{
|
||||
tensor::{backend::ndarray::NdArrayTensor, ops::*},
|
||||
NdArrayElement,
|
||||
};
|
||||
|
||||
impl<E, const D: usize> TensorOpsPow<E, D> for NdArrayTensor<E, D>
|
||||
where
|
||||
E: NdArrayElement,
|
||||
{
|
||||
fn powf(&self, value: f32) -> Self {
|
||||
let array = self.array.mapv(|a| a.pow_elem(value)).into_shared();
|
||||
let shape = self.shape;
|
||||
|
||||
Self { array, shape }
|
||||
}
|
||||
}
|
|
@ -467,6 +467,13 @@ impl<E: NdArrayElement> TensorOps<NdArrayBackend<E>> for NdArrayBackend<E> {
|
|||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
|
||||
fn powf<const D: usize>(tensor: &NdArrayTensor<E, D>, value: f32) -> NdArrayTensor<E, D> {
|
||||
let array = tensor.array.mapv(|a| a.pow_elem(value)).into_shared();
|
||||
let shape = tensor.shape;
|
||||
|
||||
NdArrayTensor { array, shape }
|
||||
}
|
||||
}
|
||||
|
||||
fn to_slice_args<const D1: usize, const D2: usize>(
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
mod cat;
|
||||
mod creation;
|
||||
mod erf;
|
||||
mod pow;
|
||||
|
|
|
@ -1,21 +0,0 @@
|
|||
use crate::{
|
||||
tensor::{backend::tch::TchTensor, ops::*},
|
||||
TchElement,
|
||||
};
|
||||
|
||||
impl<E, const D: usize> TensorOpsPow<E, D> for TchTensor<E, D>
|
||||
where
|
||||
E: TchElement,
|
||||
{
|
||||
fn powf(&self, value: f32) -> Self {
|
||||
let tensor = self.tensor.pow_tensor_scalar(value as f64);
|
||||
let kind = self.kind;
|
||||
let shape = self.shape;
|
||||
|
||||
Self {
|
||||
tensor,
|
||||
shape,
|
||||
kind,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -374,6 +374,10 @@ impl<E: TchElement> TensorOps<TchBackend<E>> for TchBackend<E> {
|
|||
fn log<const D: usize>(tensor: &TchTensor<E, D>) -> TchTensor<E, D> {
|
||||
to_tensor(tensor.tensor.log())
|
||||
}
|
||||
|
||||
fn powf<const D: usize>(tensor: &TchTensor<E, D>, value: f32) -> TchTensor<E, D> {
|
||||
to_tensor(tensor.tensor.pow_tensor_scalar(value as f64))
|
||||
}
|
||||
}
|
||||
|
||||
fn to_tensor<const D: usize, E: TchElement>(tensor: tch::Tensor) -> TchTensor<E, D> {
|
||||
|
|
|
@ -82,7 +82,7 @@ where
|
|||
///
|
||||
/// `y = x^a`
|
||||
pub fn powf(&self, value: f32) -> Self {
|
||||
Self::new(self.value.powf(value))
|
||||
Self::new(B::powf(&self.value, value))
|
||||
}
|
||||
|
||||
/// Returns the shape of the current tensor.
|
||||
|
|
|
@ -194,16 +194,13 @@ pub trait TensorOps<B: Backend> {
|
|||
) -> <B::IntegerBackend as Backend>::TensorPrimitive<D>;
|
||||
fn exp<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn log<const D: usize>(tensor: &B::TensorPrimitive<D>) -> B::TensorPrimitive<D>;
|
||||
fn powf<const D: usize>(tensor: &B::TensorPrimitive<D>, value: f32) -> B::TensorPrimitive<D>;
|
||||
}
|
||||
|
||||
pub trait TensorOpsCat<E, const D: usize> {
|
||||
fn cat(tensors: Vec<&Self>, dim: usize) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsPow<E, const D: usize> {
|
||||
fn powf(&self, value: f32) -> Self;
|
||||
}
|
||||
|
||||
pub trait TensorOpsErf<E, const D: usize> {
|
||||
fn erf(&self) -> Self;
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ mod mask;
|
|||
mod matmul;
|
||||
mod mul;
|
||||
mod neg;
|
||||
mod pow;
|
||||
mod reshape;
|
||||
mod softmax;
|
||||
mod sub;
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
use crate::tensor::TestADTensor;
|
||||
use burn_tensor::Data;
|
||||
|
||||
#[test]
|
||||
fn should_diff_powf() {
|
||||
let data_1 = Data::<f32, 2>::from([[0.0, 1.0], [3.0, 4.0]]);
|
||||
let data_2 = Data::<f32, 2>::from([[6.0, 7.0], [9.0, 10.0]]);
|
||||
|
||||
let tensor_1 = TestADTensor::from_data(data_1);
|
||||
let tensor_2 = TestADTensor::from_data(data_2);
|
||||
|
||||
let tensor_3 = tensor_1.matmul(&tensor_2.powf(0.4));
|
||||
let tensor_4 = tensor_3.matmul(&tensor_2);
|
||||
let grads = tensor_4.backward();
|
||||
|
||||
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||
let grad_2 = tensor_2.grad(&grads).unwrap();
|
||||
|
||||
grad_1
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3);
|
||||
grad_2
|
||||
.to_data()
|
||||
.assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3);
|
||||
}
|
Loading…
Reference in New Issue