mirror of https://github.com/tracel-ai/burn.git
fix powf bugs
This commit is contained in:
parent
07c69bd29a
commit
803ba5d92f
|
@ -447,8 +447,8 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
|||
TchTensor::binary_ops_tensor(
|
||||
tensor,
|
||||
exponent,
|
||||
|lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|
||||
|lhs, rhs| rhs.f_pow(lhs).unwrap(),
|
||||
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|
||||
)
|
||||
}
|
||||
|
|
|
@ -7,10 +7,10 @@ mod tests {
|
|||
fn should_support_powf_ops() {
|
||||
let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
||||
let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]);
|
||||
let tensor_pow = Tensor::<TestBackend, 2>::from_data(pow, &Default::default());
|
||||
let data_actual = tensor.powf(tensor_pow).into_data();
|
||||
let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 3125.0]]);
|
||||
let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue