mirror of https://github.com/tracel-ai/burn.git
Remainder operator (#1726)
* Adds remainder ops implementation for Tensor. * Adds test for % operator.
This commit is contained in:
parent
99e1ba4864
commit
fba1e27e0c
|
@ -2868,6 +2868,20 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
impl<E, const D: usize, B, K> core::ops::Rem<E> for Tensor<B, D, K>
|
||||
where
|
||||
E: ElementConversion,
|
||||
B: Backend,
|
||||
K: Numeric<B>,
|
||||
K::Elem: Element,
|
||||
{
|
||||
type Output = Self;
|
||||
|
||||
fn rem(self, other: E) -> Self {
|
||||
Tensor::remainder_scalar(self, other)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, const D: usize, K> core::ops::Mul<Tensor<B, D, K>> for Tensor<B, D, K>
|
||||
where
|
||||
B: Backend,
|
||||
|
|
|
@ -95,4 +95,17 @@ mod tests {
|
|||
let data_expected = Data::from([9.0, 1.0]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_support_remainder_op() {
|
||||
let data = Data::from([-3.0, -2.0, -1.0, 1.0, 2.0, 3.0]);
|
||||
let device = Default::default();
|
||||
let tensor = Tensor::<TestBackend, 1>::from_data(data, &device);
|
||||
|
||||
let output = tensor % 2.0;
|
||||
|
||||
let data_actual = output.into_data();
|
||||
let data_expected = Data::from([1.0, 0.0, 1.0, 1.0, 0.0, 1.0]);
|
||||
data_expected.assert_approx_eq(&data_actual, 3);
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue