Add the pow operator. (#1583)

* Add the pow operator.

* Support the pow operation in onnx.
This commit is contained in:
Laurent Mazare 2024-01-13 20:24:06 +01:00 committed by GitHub
parent 88618255cb
commit e6d86b0819
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 31 additions and 3 deletions

View File

@ -2578,11 +2578,21 @@ impl Tensor {
} }
/// Returns log(sum(exp(tensor), dim)). /// Returns log(sum(exp(tensor), dim)).
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> { pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
let exp = self.exp()?; let exp = self.exp()?;
let sum = exp.sum(sum_dims)?; let sum = exp.sum(sum_dims)?;
sum.log() sum.log()
} }
/// Pointwise pow operation.
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.mul(&self.log()?)?.exp()
}
/// Broadcasting version of `pow`.
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
rhs.broadcast_mul(&self.log()?)?.exp()
}
} }
macro_rules! bin_trait { macro_rules! bin_trait {

View File

@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
} }
#[test] #[test]
fn logsumexp() -> Result<()> { fn log_sum_exp() -> Result<()> {
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?; let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let output = input.logsumexp(D::Minus1)?; let output = input.log_sum_exp(D::Minus1)?;
// The expectations obtained from pytorch. // The expectations obtained from pytorch.
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?; let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
assert_close(&output, &expected, 0.00001)?; assert_close(&output, &expected, 0.00001)?;
Ok(()) Ok(())
} }
#[test]
fn pow() -> Result<()> {
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
let rhs = (&lhs - 2.)?;
let res = lhs.pow(&rhs)?;
assert_eq!(
test_utils::to_vec2_round(&res, 4)?,
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
);
Ok(())
}

View File

@ -254,6 +254,12 @@ pub fn simple_eval(
let output = input0.broadcast_div(input1)?; let output = input0.broadcast_div(input1)?;
values.insert(node.output[0].clone(), output); values.insert(node.output[0].clone(), output);
} }
"Pow" => {
let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?;
let output = input0.broadcast_pow(input1)?;
values.insert(node.output[0].clone(), output);
}
"Equal" => { "Equal" => {
let input0 = get(&node.input[0])?; let input0 = get(&node.input[0])?;
let input1 = get(&node.input[1])?; let input1 = get(&node.input[1])?;