mirror of https://github.com/tracel-ai/burn.git
Fix mask_where broadcasted input (#2381)
This commit is contained in:
parent
296c526551
commit
604dbae57d
|
@ -134,3 +134,22 @@ pub fn expand<E: CandleElement>(tensor: CandleTensor<E>, shape: Shape) -> Candle
|
||||||
pub fn sign<E: CandleElement>(tensor: CandleTensor<E>) -> CandleTensor<E> {
|
pub fn sign<E: CandleElement>(tensor: CandleTensor<E>) -> CandleTensor<E> {
|
||||||
CandleTensor::new(tensor.tensor.sign().unwrap())
|
CandleTensor::new(tensor.tensor.sign().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn mask_where_broadcasted<E: CandleElement>(
|
||||||
|
tensor: CandleTensor<E>,
|
||||||
|
mask: CandleTensor<u8>,
|
||||||
|
value: CandleTensor<E>,
|
||||||
|
) -> CandleTensor<E> {
|
||||||
|
let shape = tensor
|
||||||
|
.tensor
|
||||||
|
.shape()
|
||||||
|
.broadcast_shape_binary_op(mask.tensor.shape(), "where_cond")
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let mut tensor = tensor.tensor;
|
||||||
|
if shape != *tensor.shape() {
|
||||||
|
tensor = tensor.broadcast_as(shape).unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
CandleTensor::new(mask.tensor.where_cond(&value.tensor, &tensor).unwrap())
|
||||||
|
}
|
||||||
|
|
|
@ -60,11 +60,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
|
||||||
mask: BoolTensor<Self>,
|
mask: BoolTensor<Self>,
|
||||||
source: IntTensor<Self>,
|
source: IntTensor<Self>,
|
||||||
) -> IntTensor<Self> {
|
) -> IntTensor<Self> {
|
||||||
CandleTensor::new(
|
super::base::mask_where_broadcasted(tensor, mask, source)
|
||||||
mask.tensor
|
|
||||||
.where_cond(&source.tensor, &tensor.tensor)
|
|
||||||
.unwrap(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn int_mask_fill(
|
fn int_mask_fill(
|
||||||
|
|
|
@ -203,11 +203,7 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
|
||||||
mask: BoolTensor<Self>,
|
mask: BoolTensor<Self>,
|
||||||
value: FloatTensor<Self>,
|
value: FloatTensor<Self>,
|
||||||
) -> FloatTensor<Self> {
|
) -> FloatTensor<Self> {
|
||||||
CandleTensor::new(
|
super::base::mask_where_broadcasted(tensor, mask, value)
|
||||||
mask.tensor
|
|
||||||
.where_cond(&value.tensor, &tensor.tensor)
|
|
||||||
.unwrap(),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn float_mask_fill(
|
fn float_mask_fill(
|
||||||
|
|
|
@ -941,7 +941,7 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
||||||
let stream_1 = tensor.stream;
|
let stream_1 = tensor.stream;
|
||||||
let stream_2 = mask.stream;
|
let stream_2 = mask.stream;
|
||||||
let stream_3 = value.stream;
|
let stream_3 = value.stream;
|
||||||
let shape: Vec<usize> = tensor.shape.clone();
|
let shape = binary_ops_shape(&tensor.shape, &mask.shape);
|
||||||
let out = tensor
|
let out = tensor
|
||||||
.client
|
.client
|
||||||
.tensor_uninitialized(shape, B::FloatElem::dtype());
|
.tensor_uninitialized(shape, B::FloatElem::dtype());
|
||||||
|
|
|
@ -217,7 +217,7 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
||||||
let stream_1 = tensor.stream;
|
let stream_1 = tensor.stream;
|
||||||
let stream_2 = mask.stream;
|
let stream_2 = mask.stream;
|
||||||
let stream_3 = value.stream;
|
let stream_3 = value.stream;
|
||||||
let shape: Vec<usize> = tensor.shape.clone();
|
let shape = binary_ops_shape(&tensor.shape, &mask.shape);
|
||||||
let out = tensor
|
let out = tensor
|
||||||
.client
|
.client
|
||||||
.tensor_uninitialized(shape, B::IntElem::dtype());
|
.tensor_uninitialized(shape, B::IntElem::dtype());
|
||||||
|
|
|
@ -22,6 +22,60 @@ mod tests {
|
||||||
output.into_data().assert_eq(&expected, false);
|
output.into_data().assert_eq(&expected, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_mask_where_broadcast_int() {
|
||||||
|
let device = Default::default();
|
||||||
|
// When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times
|
||||||
|
let tensor = Tensor::<TestBackend, 1, Int>::arange(2..6, &device).reshape([1, 2, 2]);
|
||||||
|
let mask = Tensor::<TestBackend, 3, Bool>::from_bool(
|
||||||
|
TensorData::from([
|
||||||
|
[[true, false], [false, true]],
|
||||||
|
[[false, true], [true, false]],
|
||||||
|
[[false, false], [false, false]],
|
||||||
|
[[true, true], [true, true]],
|
||||||
|
]),
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
let value = Tensor::<TestBackend, 3, Int>::ones([4, 2, 2], &device);
|
||||||
|
|
||||||
|
let output = tensor.mask_where(mask, value);
|
||||||
|
let expected = TensorData::from([
|
||||||
|
[[1, 3], [4, 1]],
|
||||||
|
[[2, 1], [1, 5]],
|
||||||
|
[[2, 3], [4, 5]],
|
||||||
|
[[1, 1], [1, 1]],
|
||||||
|
]);
|
||||||
|
|
||||||
|
output.into_data().assert_eq(&expected, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_mask_where_broadcast() {
|
||||||
|
let device = Default::default();
|
||||||
|
// When broadcasted, the input [[2, 3], [4, 5]] is repeated 4 times
|
||||||
|
let tensor = Tensor::<TestBackend, 1, Int>::arange(2..6, &device).reshape([1, 2, 2]);
|
||||||
|
let mask = Tensor::<TestBackend, 3, Bool>::from_bool(
|
||||||
|
TensorData::from([
|
||||||
|
[[true, false], [false, true]],
|
||||||
|
[[false, true], [true, false]],
|
||||||
|
[[false, false], [false, false]],
|
||||||
|
[[true, true], [true, true]],
|
||||||
|
]),
|
||||||
|
&device,
|
||||||
|
);
|
||||||
|
let value = Tensor::<TestBackend, 3>::ones([4, 2, 2], &device);
|
||||||
|
|
||||||
|
let output = tensor.float().mask_where(mask, value);
|
||||||
|
let expected = TensorData::from([
|
||||||
|
[[1., 3.], [4., 1.]],
|
||||||
|
[[2., 1.], [1., 5.]],
|
||||||
|
[[2., 3.], [4., 5.]],
|
||||||
|
[[1., 1.], [1., 1.]],
|
||||||
|
]);
|
||||||
|
|
||||||
|
output.into_data().assert_eq(&expected, false);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_handle_mask_where_nans() {
|
fn should_handle_mask_where_nans() {
|
||||||
let device = Default::default();
|
let device = Default::default();
|
||||||
|
|
Loading…
Reference in New Issue