From ef01a4ed3f76854a2cb26113a1122d2613a13082 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Sat, 12 Nov 2022 12:06:53 -0500 Subject: [PATCH] refactor: pow ops (#98) --- .../src/tensor/backend/autodiff/ops/erf.rs | 2 +- .../src/tensor/backend/autodiff/ops/mod.rs | 1 - .../src/tensor/backend/autodiff/ops/pow.rs | 61 ------------------- .../src/tensor/backend/autodiff/ops/tensor.rs | 31 ++++++++++ burn-tensor/src/tensor/backend/base.rs | 1 - .../src/tensor/backend/ndarray/ops/mod.rs | 1 - .../src/tensor/backend/ndarray/ops/pow.rs | 16 ----- .../src/tensor/backend/ndarray/tensor_ops.rs | 7 +++ burn-tensor/src/tensor/backend/tch/ops/mod.rs | 1 - burn-tensor/src/tensor/backend/tch/ops/pow.rs | 21 ------- .../src/tensor/backend/tch/tensor_ops.rs | 4 ++ burn-tensor/src/tensor/base.rs | 2 +- burn-tensor/src/tensor/ops/base.rs | 5 +- burn-tensor/tests/tensor/grad/mod.rs | 1 + burn-tensor/tests/tensor/grad/pow.rs | 25 ++++++++ 15 files changed, 71 insertions(+), 108 deletions(-) delete mode 100644 burn-tensor/src/tensor/backend/autodiff/ops/pow.rs delete mode 100644 burn-tensor/src/tensor/backend/ndarray/ops/pow.rs delete mode 100644 burn-tensor/src/tensor/backend/tch/ops/pow.rs create mode 100644 burn-tensor/tests/tensor/grad/pow.rs diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/erf.rs b/burn-tensor/src/tensor/backend/autodiff/ops/erf.rs index 28dcc3d62..767780c00 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/erf.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/erf.rs @@ -12,7 +12,7 @@ register_ops!( name ADTensorErfOps, partial |state: &UnaryOpsNodeState, B::TensorPrimitive>|{ let value = state.input.value(); - let exponent = B::neg(&value.powf(2.0.to_elem())); + let exponent = B::neg(&B::powf(&value, 2.0)); let numerator = B::mul_scalar(&B::exp(&exponent), &2.0.to_elem()); let denominator = std::f64::consts::PI.sqrt().to_elem(); let value = B::div_scalar(&numerator, &denominator); diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs index 3990997ed..73f147821 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/mod.rs @@ -3,7 +3,6 @@ mod cat; mod creation; mod erf; mod module; -mod pow; mod tensor; mod macros; diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/pow.rs b/burn-tensor/src/tensor/backend/autodiff/ops/pow.rs deleted file mode 100644 index 2c800bf8d..000000000 --- a/burn-tensor/src/tensor/backend/autodiff/ops/pow.rs +++ /dev/null @@ -1,61 +0,0 @@ -use crate::tensor::backend::Backend; -use crate::ElementConversion; -use crate::{ - execute_ops, - graph::ops::{UnaryOps, UnaryOpsNodeState}, - register_ops, - tensor::{backend::autodiff::ADTensor, ops::*}, -}; - -register_ops!( - ops UnaryOps, - name ADTensorPowOps state f32, - partial | - value: &f32, - state: &UnaryOpsNodeState, B::TensorPrimitive> - | { - let value = B::mul_scalar(&state.input - .value() - .powf(value - 1.0) - , &value.clone().to_elem()); - B::mul(&state.output.grad(), &value) - }, -); - -impl TensorOpsPow for ADTensor { - fn powf(&self, value: f32) -> Self { - execute_ops!( - input self.node.clone(), - out TensorOpsPow::powf(&self.tensor(), value), - ops ADTensorPowOps::::new(value), - ) - } -} - -#[cfg(test)] -mod tests { - use crate::tensor::{backend::autodiff::helper::TestADTensor, Data}; - - #[test] - fn should_diff_powf() { - let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); - let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); - - let tensor_1 = TestADTensor::from_data(data_1); - let tensor_2 = TestADTensor::from_data(data_2); - - let tensor_3 = tensor_1.matmul(&tensor_2.powf(0.4)); - let tensor_4 = tensor_3.matmul(&tensor_2); - let grads = tensor_4.backward(); - - let grad_1 = tensor_1.grad(&grads).unwrap(); - let grad_2 = tensor_2.grad(&grads).unwrap(); - - grad_1 - .to_data() - .assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3); - grad_2 - .to_data() - .assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3); - } -} diff --git a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs index 4befe64f9..4104e0c88 100644 --- a/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs +++ b/burn-tensor/src/tensor/backend/autodiff/ops/tensor.rs @@ -983,4 +983,35 @@ impl TensorOps> for ADBackendDecorator { unary_ops_wrapper(tensor.node.clone(), output, ops) } + + fn powf( + tensor: & as Backend>::TensorPrimitive, + value: f32, + ) -> as Backend>::TensorPrimitive { + #[derive(new, Debug)] + struct Backward { + value: f32, + _b: B, + } + + impl UnaryOps, B::TensorPrimitive> + for Backward + { + fn partial( + &self, + state: &UnaryOpsNodeState, B::TensorPrimitive>, + ) -> B::TensorPrimitive { + let value = B::mul_scalar( + &B::powf(&state.input.value(), self.value - 1.0), + &self.value.clone().to_elem(), + ); + B::mul(&state.output.grad(), &value) + } + } + + let output = B::powf(tensor.tensor_ref(), value); + let ops = Backward::::new(value, B::default()); + + unary_ops_wrapper(tensor.node.clone(), output, ops) + } } diff --git a/burn-tensor/src/tensor/backend/base.rs b/burn-tensor/src/tensor/backend/base.rs index 780fff2ab..4dd3b11ff 100644 --- a/burn-tensor/src/tensor/backend/base.rs +++ b/burn-tensor/src/tensor/backend/base.rs @@ -25,7 +25,6 @@ pub trait Backend: + Ones> + TensorOpsCat + TensorOpsErf - + TensorOpsPow + ReLU + Clone + Send diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs index c6b0581ad..1020a84f8 100644 --- a/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/ndarray/ops/mod.rs @@ -1,4 +1,3 @@ mod cat; mod creation; mod erf; -mod pow; diff --git a/burn-tensor/src/tensor/backend/ndarray/ops/pow.rs b/burn-tensor/src/tensor/backend/ndarray/ops/pow.rs deleted file mode 100644 index c3a179ec6..000000000 --- a/burn-tensor/src/tensor/backend/ndarray/ops/pow.rs +++ /dev/null @@ -1,16 +0,0 @@ -use crate::{ - tensor::{backend::ndarray::NdArrayTensor, ops::*}, - NdArrayElement, -}; - -impl TensorOpsPow for NdArrayTensor -where - E: NdArrayElement, -{ - fn powf(&self, value: f32) -> Self { - let array = self.array.mapv(|a| a.pow_elem(value)).into_shared(); - let shape = self.shape; - - Self { array, shape } - } -} diff --git a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs index 3416cb131..2e30f9c9b 100644 --- a/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/ndarray/tensor_ops.rs @@ -467,6 +467,13 @@ impl TensorOps> for NdArrayBackend { NdArrayTensor { array, shape } } + + fn powf(tensor: &NdArrayTensor, value: f32) -> NdArrayTensor { + let array = tensor.array.mapv(|a| a.pow_elem(value)).into_shared(); + let shape = tensor.shape; + + NdArrayTensor { array, shape } + } } fn to_slice_args( diff --git a/burn-tensor/src/tensor/backend/tch/ops/mod.rs b/burn-tensor/src/tensor/backend/tch/ops/mod.rs index c6b0581ad..1020a84f8 100644 --- a/burn-tensor/src/tensor/backend/tch/ops/mod.rs +++ b/burn-tensor/src/tensor/backend/tch/ops/mod.rs @@ -1,4 +1,3 @@ mod cat; mod creation; mod erf; -mod pow; diff --git a/burn-tensor/src/tensor/backend/tch/ops/pow.rs b/burn-tensor/src/tensor/backend/tch/ops/pow.rs deleted file mode 100644 index 05acc5aaa..000000000 --- a/burn-tensor/src/tensor/backend/tch/ops/pow.rs +++ /dev/null @@ -1,21 +0,0 @@ -use crate::{ - tensor::{backend::tch::TchTensor, ops::*}, - TchElement, -}; - -impl TensorOpsPow for TchTensor -where - E: TchElement, -{ - fn powf(&self, value: f32) -> Self { - let tensor = self.tensor.pow_tensor_scalar(value as f64); - let kind = self.kind; - let shape = self.shape; - - Self { - tensor, - shape, - kind, - } - } -} diff --git a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs index ca363d43c..ac6becc9c 100644 --- a/burn-tensor/src/tensor/backend/tch/tensor_ops.rs +++ b/burn-tensor/src/tensor/backend/tch/tensor_ops.rs @@ -374,6 +374,10 @@ impl TensorOps> for TchBackend { fn log(tensor: &TchTensor) -> TchTensor { to_tensor(tensor.tensor.log()) } + + fn powf(tensor: &TchTensor, value: f32) -> TchTensor { + to_tensor(tensor.tensor.pow_tensor_scalar(value as f64)) + } } fn to_tensor(tensor: tch::Tensor) -> TchTensor { diff --git a/burn-tensor/src/tensor/base.rs b/burn-tensor/src/tensor/base.rs index 9b7bf1a2b..f92ea458c 100644 --- a/burn-tensor/src/tensor/base.rs +++ b/burn-tensor/src/tensor/base.rs @@ -82,7 +82,7 @@ where /// /// `y = x^a` pub fn powf(&self, value: f32) -> Self { - Self::new(self.value.powf(value)) + Self::new(B::powf(&self.value, value)) } /// Returns the shape of the current tensor. diff --git a/burn-tensor/src/tensor/ops/base.rs b/burn-tensor/src/tensor/ops/base.rs index 35b32dc4f..835fcfe59 100644 --- a/burn-tensor/src/tensor/ops/base.rs +++ b/burn-tensor/src/tensor/ops/base.rs @@ -194,16 +194,13 @@ pub trait TensorOps { ) -> ::TensorPrimitive; fn exp(tensor: &B::TensorPrimitive) -> B::TensorPrimitive; fn log(tensor: &B::TensorPrimitive) -> B::TensorPrimitive; + fn powf(tensor: &B::TensorPrimitive, value: f32) -> B::TensorPrimitive; } pub trait TensorOpsCat { fn cat(tensors: Vec<&Self>, dim: usize) -> Self; } -pub trait TensorOpsPow { - fn powf(&self, value: f32) -> Self; -} - pub trait TensorOpsErf { fn erf(&self) -> Self; } diff --git a/burn-tensor/tests/tensor/grad/mod.rs b/burn-tensor/tests/tensor/grad/mod.rs index 55abdc779..47c127623 100644 --- a/burn-tensor/tests/tensor/grad/mod.rs +++ b/burn-tensor/tests/tensor/grad/mod.rs @@ -9,6 +9,7 @@ mod mask; mod matmul; mod mul; mod neg; +mod pow; mod reshape; mod softmax; mod sub; diff --git a/burn-tensor/tests/tensor/grad/pow.rs b/burn-tensor/tests/tensor/grad/pow.rs new file mode 100644 index 000000000..e2a0d3662 --- /dev/null +++ b/burn-tensor/tests/tensor/grad/pow.rs @@ -0,0 +1,25 @@ +use crate::tensor::TestADTensor; +use burn_tensor::Data; + +#[test] +fn should_diff_powf() { + let data_1 = Data::::from([[0.0, 1.0], [3.0, 4.0]]); + let data_2 = Data::::from([[6.0, 7.0], [9.0, 10.0]]); + + let tensor_1 = TestADTensor::from_data(data_1); + let tensor_2 = TestADTensor::from_data(data_2); + + let tensor_3 = tensor_1.matmul(&tensor_2.powf(0.4)); + let tensor_4 = tensor_3.matmul(&tensor_2); + let grads = tensor_4.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let grad_2 = tensor_2.grad(&grads).unwrap(); + + grad_1 + .to_data() + .assert_approx_eq(&Data::from([[68.0, 79.0328], [68.0, 79.0328]]), 3); + grad_2 + .to_data() + .assert_approx_eq(&Data::from([[23.5081, 25.2779], [26.0502, 28.6383]]), 3); +}