Add the pow operator. (#1583)
* Add the pow operator. * Support the pow operation in onnx.
This commit is contained in:
parent
88618255cb
commit
e6d86b0819
|
@ -2578,11 +2578,21 @@ impl Tensor {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns log(sum(exp(tensor), dim)).
|
/// Returns log(sum(exp(tensor), dim)).
|
||||||
pub fn logsumexp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
pub fn log_sum_exp<D: Dims>(&self, sum_dims: D) -> Result<Self> {
|
||||||
let exp = self.exp()?;
|
let exp = self.exp()?;
|
||||||
let sum = exp.sum(sum_dims)?;
|
let sum = exp.sum(sum_dims)?;
|
||||||
sum.log()
|
sum.log()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Pointwise pow operation.
|
||||||
|
pub fn pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||||
|
rhs.mul(&self.log()?)?.exp()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Broadcasting version of `pow`.
|
||||||
|
pub fn broadcast_pow(&self, rhs: &Tensor) -> Result<Self> {
|
||||||
|
rhs.broadcast_mul(&self.log()?)?.exp()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
macro_rules! bin_trait {
|
macro_rules! bin_trait {
|
||||||
|
|
|
@ -1245,11 +1245,23 @@ fn assert_close(a: &Tensor, b: &Tensor, epsilon: f64) -> Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn logsumexp() -> Result<()> {
|
fn log_sum_exp() -> Result<()> {
|
||||||
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
let input = Tensor::new(&[[1f64, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||||
let output = input.logsumexp(D::Minus1)?;
|
let output = input.log_sum_exp(D::Minus1)?;
|
||||||
// The expectations obtained from pytorch.
|
// The expectations obtained from pytorch.
|
||||||
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
let expected = Tensor::new(&[3.4076, 6.4076], &Device::Cpu)?;
|
||||||
assert_close(&output, &expected, 0.00001)?;
|
assert_close(&output, &expected, 0.00001)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pow() -> Result<()> {
|
||||||
|
let lhs = Tensor::new(&[[1f32, 2., 3.], [4., 5., 6.]], &Device::Cpu)?;
|
||||||
|
let rhs = (&lhs - 2.)?;
|
||||||
|
let res = lhs.pow(&rhs)?;
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec2_round(&res, 4)?,
|
||||||
|
[[1.0, 1.0, 3.0], [16.0, 125.0, 1296.0001]]
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
|
@ -254,6 +254,12 @@ pub fn simple_eval(
|
||||||
let output = input0.broadcast_div(input1)?;
|
let output = input0.broadcast_div(input1)?;
|
||||||
values.insert(node.output[0].clone(), output);
|
values.insert(node.output[0].clone(), output);
|
||||||
}
|
}
|
||||||
|
"Pow" => {
|
||||||
|
let input0 = get(&node.input[0])?;
|
||||||
|
let input1 = get(&node.input[1])?;
|
||||||
|
let output = input0.broadcast_pow(input1)?;
|
||||||
|
values.insert(node.output[0].clone(), output);
|
||||||
|
}
|
||||||
"Equal" => {
|
"Equal" => {
|
||||||
let input0 = get(&node.input[0])?;
|
let input0 = get(&node.input[0])?;
|
||||||
let input1 = get(&node.input[1])?;
|
let input1 = get(&node.input[1])?;
|
||||||
|
|
Loading…
Reference in New Issue