mirror of https://github.com/tracel-ai/burn.git
Make powf broadcastable (#2398)
This commit is contained in:
parent
6515a081aa
commit
1df0704ff3
|
@ -480,12 +480,13 @@ where
|
||||||
rhs: NdArrayTensor<OtherE>,
|
rhs: NdArrayTensor<OtherE>,
|
||||||
var_name: impl FnMut(&E, &OtherE) -> E,
|
var_name: impl FnMut(&E, &OtherE) -> E,
|
||||||
) -> NdArrayTensor<E> {
|
) -> NdArrayTensor<E> {
|
||||||
NdArrayTensor::new(
|
let lhs = lhs
|
||||||
Zip::from(lhs.array.view())
|
.array
|
||||||
.and(rhs.array.view())
|
.broadcast(rhs.array.dim())
|
||||||
.map_collect(var_name)
|
.unwrap_or(lhs.array.view());
|
||||||
.into_shared(),
|
let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
|
||||||
)
|
|
||||||
|
NdArrayTensor::new(Zip::from(lhs).and(rhs).map_collect(var_name).into_shared())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn elementwise_op_scalar(
|
pub(crate) fn elementwise_op_scalar(
|
||||||
|
|
|
@ -54,4 +54,21 @@ mod tests {
|
||||||
|
|
||||||
output.into_data().assert_approx_eq(&expected, 3);
|
output.into_data().assert_approx_eq(&expected, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_powf_broadcasted() {
|
||||||
|
let device = Default::default();
|
||||||
|
let tensor_1 = Tensor::<TestBackend, 1>::from_floats([2.0, 3.0, 4.0], &device);
|
||||||
|
let tensor_2 = Tensor::from_floats([1.0], &device);
|
||||||
|
|
||||||
|
// Broadcast rhs
|
||||||
|
let output = tensor_1.clone().powf(tensor_2.clone());
|
||||||
|
output.into_data().assert_approx_eq(&tensor_1.to_data(), 3);
|
||||||
|
|
||||||
|
// Broadcast lhs
|
||||||
|
let output = tensor_2.powf(tensor_1);
|
||||||
|
output
|
||||||
|
.into_data()
|
||||||
|
.assert_approx_eq(&TensorData::from([1.0, 1.0, 1.0]), 3);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue