fix powf bugs

This commit is contained in:
louisfd 2024-01-31 11:06:25 -05:00
parent 07c69bd29a
commit 803ba5d92f
2 changed files with 3 additions and 3 deletions

View File

@ -447,8 +447,8 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
TchTensor::binary_ops_tensor( TchTensor::binary_ops_tensor(
tensor, tensor,
exponent, exponent,
|lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(),
|lhs, rhs| lhs.f_pow(rhs).unwrap(), |lhs, rhs| lhs.f_pow(rhs).unwrap(),
|lhs, rhs| rhs.f_pow(lhs).unwrap(),
|lhs, rhs| lhs.f_pow(rhs).unwrap(), |lhs, rhs| lhs.f_pow(rhs).unwrap(),
) )
} }

View File

@ -7,10 +7,10 @@ mod tests {
fn should_support_powf_ops() { fn should_support_powf_ops() {
let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default()); let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]); let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]);
let tensor_pow = Tensor::<TestBackend, 2>::from_data(pow, &Default::default()); let tensor_pow = Tensor::<TestBackend, 2>::from_data(pow, &Default::default());
let data_actual = tensor.powf(tensor_pow).into_data(); let data_actual = tensor.powf(tensor_pow).into_data();
let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 3125.0]]); let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]);
data_expected.assert_approx_eq(&data_actual, 3); data_expected.assert_approx_eq(&data_actual, 3);
} }