mirror of https://github.com/tracel-ai/burn.git
Change ndarray mask_where implementation to correctly deal with NaNs (#2272)
* Change ndarray mask_where implementation to correctly deal with NaNs * Add test
This commit is contained in:
parent
2fbad48f64
commit
6f0e61aa4f
|
@ -409,17 +409,14 @@ where
|
||||||
mask: NdArrayTensor<bool, D>,
|
mask: NdArrayTensor<bool, D>,
|
||||||
source: NdArrayTensor<E, D>,
|
source: NdArrayTensor<E, D>,
|
||||||
) -> NdArrayTensor<E, D> {
|
) -> NdArrayTensor<E, D> {
|
||||||
let mask_mul_4tensor = mask.array.mapv(|x| match x {
|
let tensor = tensor.array.broadcast(mask.array.dim()).unwrap();
|
||||||
true => 0.elem(),
|
let source = source.array.broadcast(mask.array.dim()).unwrap();
|
||||||
false => 1.elem(),
|
let output = Zip::from(&tensor)
|
||||||
});
|
.and(&mask.array)
|
||||||
let mask_mul_4source = mask.array.mapv(|x| match x {
|
.and(&source)
|
||||||
true => 1.elem(),
|
.map_collect(|&x, &mask_val, &y| if mask_val { y } else { x })
|
||||||
false => 0.elem(),
|
.into_shared();
|
||||||
});
|
NdArrayTensor::new(output)
|
||||||
let array = (tensor.array * mask_mul_4tensor) + (source.array * mask_mul_4source);
|
|
||||||
|
|
||||||
NdArrayTensor::new(array)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn mask_fill<const D: usize>(
|
pub fn mask_fill<const D: usize>(
|
||||||
|
|
|
@ -22,6 +22,40 @@ mod tests {
|
||||||
output.into_data().assert_eq(&expected, false);
|
output.into_data().assert_eq(&expected, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_handle_mask_where_nans() {
|
||||||
|
let device = Default::default();
|
||||||
|
let tensor = TestTensor::from_data(
|
||||||
|
[
|
||||||
|
[f32::NAN, f32::NAN, f32::NAN],
|
||||||
|
[f32::NAN, f32::NAN, f32::NAN],
|
||||||
|
[f32::NAN, f32::NAN, f32::NAN],
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
let mask = Tensor::<TestBackend, 2, Bool>::from_bool(
|
||||||
|
TensorData::from([
|
||||||
|
[true, true, true],
|
||||||
|
[true, true, false],
|
||||||
|
[false, false, false],
|
||||||
|
]),
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
let value = Tensor::<TestBackend, 2>::from_data(
|
||||||
|
TensorData::from([[0.9, 0.8, 0.7], [0.6, 0.5, 0.4], [0.3, 0.2, 0.1]]),
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
|
||||||
|
let output = tensor.mask_where(mask, value);
|
||||||
|
let expected = TensorData::from([
|
||||||
|
[0.9, 0.8, 0.7],
|
||||||
|
[0.6, 0.5, f32::NAN],
|
||||||
|
[f32::NAN, f32::NAN, f32::NAN],
|
||||||
|
]);
|
||||||
|
|
||||||
|
output.into_data().assert_eq(&expected, false);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_support_mask_fill_ops() {
|
fn should_support_mask_fill_ops() {
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
|
|
Loading…
Reference in New Issue