diff --git a/crates/burn-candle/src/ops/base.rs b/crates/burn-candle/src/ops/base.rs index 4d2a4e00e..7c6d867c2 100644 --- a/crates/burn-candle/src/ops/base.rs +++ b/crates/burn-candle/src/ops/base.rs @@ -134,3 +134,22 @@ pub fn expand(tensor: CandleTensor, shape: Shape) -> Candle pub fn sign(tensor: CandleTensor) -> CandleTensor { CandleTensor::new(tensor.tensor.sign().unwrap()) } + +pub fn mask_where_broadcasted( + tensor: CandleTensor, + mask: CandleTensor, + value: CandleTensor, +) -> CandleTensor { + let shape = tensor + .tensor + .shape() + .broadcast_shape_binary_op(mask.tensor.shape(), "where_cond") + .unwrap(); + + let mut tensor = tensor.tensor; + if shape != *tensor.shape() { + tensor = tensor.broadcast_as(shape).unwrap(); + } + + CandleTensor::new(mask.tensor.where_cond(&value.tensor, &tensor).unwrap()) +} diff --git a/crates/burn-candle/src/ops/int_tensor.rs b/crates/burn-candle/src/ops/int_tensor.rs index 5d8df491d..7a57b7350 100644 --- a/crates/burn-candle/src/ops/int_tensor.rs +++ b/crates/burn-candle/src/ops/int_tensor.rs @@ -60,11 +60,7 @@ impl IntTensorOps for Candle, source: IntTensor, ) -> IntTensor { - CandleTensor::new( - mask.tensor - .where_cond(&source.tensor, &tensor.tensor) - .unwrap(), - ) + super::base::mask_where_broadcasted(tensor, mask, source) } fn int_mask_fill( diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index 22fa97dd7..1471922ec 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -203,11 +203,7 @@ impl FloatTensorOps for Candle mask: BoolTensor, value: FloatTensor, ) -> FloatTensor { - CandleTensor::new( - mask.tensor - .where_cond(&value.tensor, &tensor.tensor) - .unwrap(), - ) + super::base::mask_where_broadcasted(tensor, mask, value) } fn float_mask_fill( diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 21dfde2cb..0351830d2 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -941,7 +941,7 @@ impl FloatTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = mask.stream; let stream_3 = value.stream; - let shape: Vec = tensor.shape.clone(); + let shape = binary_ops_shape(&tensor.shape, &mask.shape); let out = tensor .client .tensor_uninitialized(shape, B::FloatElem::dtype()); diff --git a/crates/burn-fusion/src/ops/int.rs b/crates/burn-fusion/src/ops/int.rs index 082070025..48352e2db 100644 --- a/crates/burn-fusion/src/ops/int.rs +++ b/crates/burn-fusion/src/ops/int.rs @@ -217,7 +217,7 @@ impl IntTensorOps for Fusion { let stream_1 = tensor.stream; let stream_2 = mask.stream; let stream_3 = value.stream; - let shape: Vec = tensor.shape.clone(); + let shape = binary_ops_shape(&tensor.shape, &mask.shape); let out = tensor .client .tensor_uninitialized(shape, B::IntElem::dtype()); diff --git a/crates/burn-tensor/src/tests/ops/mask.rs b/crates/burn-tensor/src/tests/ops/mask.rs index 677e70960..64d88c617 100644 --- a/crates/burn-tensor/src/tests/ops/mask.rs +++ b/crates/burn-tensor/src/tests/ops/mask.rs @@ -22,6 +22,60 @@ mod tests { output.into_data().assert_eq(&expected, false); } + #[test] + fn should_support_mask_where_broadcast_int() { + let device = Default::default(); + // When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times + let tensor = Tensor::::arange(2..6, &device).reshape([1, 2, 2]); + let mask = Tensor::::from_bool( + TensorData::from([ + [[true, false], [false, true]], + [[false, true], [true, false]], + [[false, false], [false, false]], + [[true, true], [true, true]], + ]), + &device, + ); + let value = Tensor::::ones([4, 2, 2], &device); + + let output = tensor.mask_where(mask, value); + let expected = TensorData::from([ + [[1, 3], [4, 1]], + [[2, 1], [1, 5]], + [[2, 3], [4, 5]], + [[1, 1], [1, 1]], + ]); + + output.into_data().assert_eq(&expected, false); + } + + #[test] + fn should_support_mask_where_broadcast() { + let device = Default::default(); + // When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times + let tensor = Tensor::::arange(2..6, &device).reshape([1, 2, 2]); + let mask = Tensor::::from_bool( + TensorData::from([ + [[true, false], [false, true]], + [[false, true], [true, false]], + [[false, false], [false, false]], + [[true, true], [true, true]], + ]), + &device, + ); + let value = Tensor::::ones([4, 2, 2], &device); + + let output = tensor.float().mask_where(mask, value); + let expected = TensorData::from([ + [[1., 3.], [4., 1.]], + [[2., 1.], [1., 5.]], + [[2., 3.], [4., 5.]], + [[1., 1.], [1., 1.]], + ]); + + output.into_data().assert_eq(&expected, false); + } + #[test] fn should_handle_mask_where_nans() { let device = Default::default();