mirror of https://github.com/tracel-ai/burn.git
fix powf bugs
This commit is contained in:
parent
07c69bd29a
commit
803ba5d92f
|
@ -447,8 +447,8 @@ impl<E: tch::kind::Element + Copy + Default> TchOps<E> {
|
||||||
TchTensor::binary_ops_tensor(
|
TchTensor::binary_ops_tensor(
|
||||||
tensor,
|
tensor,
|
||||||
exponent,
|
exponent,
|
||||||
|
|lhs, rhs| lhs.f_pow_tensor_(rhs).unwrap(),
|
||||||
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|
||||||
|lhs, rhs| rhs.f_pow(lhs).unwrap(),
|
|
||||||
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|
|lhs, rhs| lhs.f_pow(rhs).unwrap(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,10 +7,10 @@ mod tests {
|
||||||
fn should_support_powf_ops() {
|
fn should_support_powf_ops() {
|
||||||
let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
let data = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||||
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
||||||
let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
let pow = Data::from([[1.0, 1.0, 2.0], [3.0, 4.0, 2.0]]);
|
||||||
let tensor_pow = Tensor::<TestBackend, 2>::from_data(pow, &Default::default());
|
let tensor_pow = Tensor::<TestBackend, 2>::from_data(pow, &Default::default());
|
||||||
let data_actual = tensor.powf(tensor_pow).into_data();
|
let data_actual = tensor.powf(tensor_pow).into_data();
|
||||||
let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 3125.0]]);
|
let data_expected = Data::from([[1.0, 1.0, 4.0], [27.0, 256.0, 25.0]]);
|
||||||
data_expected.assert_approx_eq(&data_actual, 3);
|
data_expected.assert_approx_eq(&data_actual, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue