mirror of https://github.com/tracel-ai/burn.git
Make powf broadcastable (#2398)
This commit is contained in:
parent
6515a081aa
commit
1df0704ff3
|
@ -480,12 +480,13 @@ where
|
|||
rhs: NdArrayTensor<OtherE>,
|
||||
var_name: impl FnMut(&E, &OtherE) -> E,
|
||||
) -> NdArrayTensor<E> {
|
||||
NdArrayTensor::new(
|
||||
Zip::from(lhs.array.view())
|
||||
.and(rhs.array.view())
|
||||
.map_collect(var_name)
|
||||
.into_shared(),
|
||||
)
|
||||
let lhs = lhs
|
||||
.array
|
||||
.broadcast(rhs.array.dim())
|
||||
.unwrap_or(lhs.array.view());
|
||||
let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
|
||||
|
||||
NdArrayTensor::new(Zip::from(lhs).and(rhs).map_collect(var_name).into_shared())
|
||||
}
|
||||
|
||||
pub(crate) fn elementwise_op_scalar(
|
||||
|
|
|
@ -54,4 +54,21 @@ mod tests {
|
|||
|
||||
output.into_data().assert_approx_eq(&expected, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_powf_broadcasted() {
|
||||
let device = Default::default();
|
||||
let tensor_1 = Tensor::<TestBackend, 1>::from_floats([2.0, 3.0, 4.0], &device);
|
||||
let tensor_2 = Tensor::from_floats([1.0], &device);
|
||||
|
||||
// Broadcast rhs
|
||||
let output = tensor_1.clone().powf(tensor_2.clone());
|
||||
output.into_data().assert_approx_eq(&tensor_1.to_data(), 3);
|
||||
|
||||
// Broadcast lhs
|
||||
let output = tensor_2.powf(tensor_1);
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq(&TensorData::from([1.0, 1.0, 1.0]), 3);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue