From e6d86b081980196745e5f0b0eda8ce5334c0ff67 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 13 Jan 2024 20:24:06 +0100 Subject: [PATCH] Add the pow operator. (#1583) * Add the pow operator. * Support the pow operation in onnx. --- candle-core/src/tensor.rs | 12 +++++++++++- candle-core/tests/tensor_tests.rs | 16 ++++++++++++++-- candle-onnx/src/eval.rs | 6 ++++++ 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/candle-core/src/tensor.rs b/candle-core/src/tensor.rs index 54f9fa2b..3100c6e8 100644 --- a/candle-core/src/tensor.rs +++ b/candle-core/src/tensor.rs @@ -2578,11 +2578,21 @@ impl Tensor { } /// Returns log(sum(exp(tensor), dim)). - pub fn logsumexp(&self, sum_dims: D) -> Result { + pub fn log_sum_exp(&self, sum_dims: D) -> Result { let exp = self.exp()?; let sum = exp.sum(sum_dims)?; sum.log() } + + /// Pointwise pow operation. + pub fn pow(&self, rhs: &Tensor) -> Result { + rhs.mul(&self.log()?)?.exp() + } + + /// Broadcasting version of `pow`. + pub fn broadcast_pow(&self, rhs: &Tensor) -> Result { + rhs.broadcast_mul(&self.log()?)?.exp() + } } macro_rules! bin_trait { diff --git a/candle-core/tests/tensor_tests.rs b/candle-core/tests/tensor_tests.rs index e83fb55b..33bab1b6 100644 --- a/candle-core/tests/tensor_tests.rs +++ b/candle-core/tests/tensor_tests.rs @@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> { } #[test] -fn logsumexp() -> Result<()> { +fn log_sum_exp() -> Result<()> { let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; - let output = input.logsumexp(D::Minus1)?; + let output = input.log_sum_exp(D::Minus1)?; // The expectations obtained from pytorch. let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; assert_close(&output, &expected, 0.00001)?; Ok(()) } + +#[test] +fn pow() -> Result<()> { + let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; + let rhs = (&lhs - 2.)?; + let res = lhs.pow(&rhs)?; + assert_eq!( + test_utils::to_vec2_round(&res, 4)?, + [[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]] + ); + Ok(()) +} diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 684776c2..c0ad8668 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -254,6 +254,12 @@ pub fn simple_eval( let output = input0.broadcast_div(input1)?; values.insert(node.output[0].clone(), output); } + "Pow" => { + let input0 = get(&node.input[0])?; + let input1 = get(&node.input[1])?; + let output = input0.broadcast_pow(input1)?; + values.insert(node.output[0].clone(), output); + } "Equal" => { let input0 = get(&node.input[0])?; let input1 = get(&node.input[1])?;