diff --git a/burn-tch/src/ops/base.rs b/burn-tch/src/ops/base.rs index a61e0bcd5..67f8312c4 100644 --- a/burn-tch/src/ops/base.rs +++ b/burn-tch/src/ops/base.rs @@ -447,8 +447,8 @@ impl TchOps { TchTensor::binary_ops_tensor( tensor, exponent, + |lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(), |lhs, rhs| lhs.f_pow(rhs).unwrap(), - |lhs, rhs| rhs.f_pow(lhs).unwrap(), |lhs, rhs| lhs.f_pow(rhs).unwrap(), ) } diff --git a/burn-tensor/src/tests/ops/powf.rs b/burn-tensor/src/tests/ops/powf.rs index 4868aec74..90c2c1b9f 100644 --- a/burn-tensor/src/tests/ops/powf.rs +++ b/burn-tensor/src/tests/ops/powf.rs @@ -7,10 +7,10 @@ mod tests { fn should_support_powf_ops() { let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let tensor = Tensor::::from_data(data, &Default::default()); - let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); + let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]); let tensor_pow = Tensor::::from_data(pow, &Default::default()); let data_actual = tensor.powf(tensor_pow).into_data(); - let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 3125.0]]); + let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]); data_expected.assert_approx_eq(&data_actual, 3); }