mirror of https://github.com/tracel-ai/burn.git
Add `round`, `floor`, `ceil` for float tensor (#2372)
* Add round, floor, ceil into FloatTensorOps trait; Impl round, floor, ceil for tensor, tch, ndarray, candle; Add tests for autodiff * Add test for round, floor, ceil in burn tensor backend * Add test for round, floor, ceil in burn candle backend * Impl round, floor, ceil for burn fushion backend * Update burn book * Fix round gradient != 0 issue * Add tests for halfway cases * Use `round_ties_even` for `round` in ndarray backend * Impl round to even for candle * Add round, floor, ceil for burn router * Add round, floor, ceil for jit backend; Upgrade cubecl * Add round_ties_even for no-std in ndarray backend * Be explicit about what rounding strategy is used --------- Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
This commit is contained in:
parent
f3968cbeac
commit
6515a081aa
|
@ -513,7 +513,7 @@ dependencies = [
|
||||||
name = "burn-common"
|
name = "burn-common"
|
||||||
version = "0.15.0"
|
version = "0.15.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cubecl-common",
|
"cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d)",
|
||||||
"dashmap",
|
"dashmap",
|
||||||
"getrandom",
|
"getrandom",
|
||||||
"indicatif",
|
"indicatif",
|
||||||
|
@ -1452,7 +1452,7 @@ version = "0.2.0"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"cubecl-common",
|
"cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
|
||||||
"cubecl-macros",
|
"cubecl-macros",
|
||||||
"cubecl-runtime",
|
"cubecl-runtime",
|
||||||
"derive-new",
|
"derive-new",
|
||||||
|
@ -1483,7 +1483,7 @@ version = "0.2.0"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"cubecl-common",
|
"cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
|
||||||
"cubecl-core",
|
"cubecl-core",
|
||||||
"cubecl-cpp",
|
"cubecl-cpp",
|
||||||
"cubecl-runtime",
|
"cubecl-runtime",
|
||||||
|
@ -1534,7 +1534,7 @@ name = "cubecl-macros"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
||||||
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cubecl-common",
|
"cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
|
||||||
"darling",
|
"darling",
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"ident_case",
|
"ident_case",
|
||||||
|
@ -1566,7 +1566,7 @@ source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd01
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-channel",
|
"async-channel",
|
||||||
"cfg_aliases 0.2.1",
|
"cfg_aliases 0.2.1",
|
||||||
"cubecl-common",
|
"cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
|
||||||
"derive-new",
|
"derive-new",
|
||||||
"dirs 5.0.1",
|
"dirs 5.0.1",
|
||||||
"hashbrown 0.14.5",
|
"hashbrown 0.14.5",
|
||||||
|
@ -1601,7 +1601,7 @@ dependencies = [
|
||||||
"async-channel",
|
"async-channel",
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
"cfg_aliases 0.2.1",
|
"cfg_aliases 0.2.1",
|
||||||
"cubecl-common",
|
"cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
|
||||||
"cubecl-core",
|
"cubecl-core",
|
||||||
"cubecl-runtime",
|
"cubecl-runtime",
|
||||||
"cubecl-spirv",
|
"cubecl-spirv",
|
||||||
|
|
|
@ -250,9 +250,11 @@ Those operations are only available for `Float` tensors.
|
||||||
|
|
||||||
| Burn API | PyTorch Equivalent |
|
| Burn API | PyTorch Equivalent |
|
||||||
| -------------------------------------------- | ---------------------------------- |
|
| -------------------------------------------- | ---------------------------------- |
|
||||||
|
| `tensor.ceil()` | `tensor.ceil()` |
|
||||||
| `tensor.cos()` | `tensor.cos()` |
|
| `tensor.cos()` | `tensor.cos()` |
|
||||||
| `tensor.erf()` | `tensor.erf()` |
|
| `tensor.erf()` | `tensor.erf()` |
|
||||||
| `tensor.exp()` | `tensor.exp()` |
|
| `tensor.exp()` | `tensor.exp()` |
|
||||||
|
| `tensor.floor()` | `tensor.floor()` |
|
||||||
| `tensor.from_floats(floats, device)` | N/A |
|
| `tensor.from_floats(floats, device)` | N/A |
|
||||||
| `tensor.from_full_precision(tensor)` | N/A |
|
| `tensor.from_full_precision(tensor)` | N/A |
|
||||||
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
|
| `tensor.int()` | Similar to `tensor.to(torch.long)` |
|
||||||
|
@ -264,6 +266,7 @@ Those operations are only available for `Float` tensors.
|
||||||
| `tensor.random(shape, distribution, device)` | N/A |
|
| `tensor.random(shape, distribution, device)` | N/A |
|
||||||
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
|
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
|
||||||
| `tensor.recip()` | `tensor.reciprocal()` |
|
| `tensor.recip()` | `tensor.reciprocal()` |
|
||||||
|
| `tensor.round()` | `tensor.round()` |
|
||||||
| `tensor.sin()` | `tensor.sin()` |
|
| `tensor.sin()` | `tensor.sin()` |
|
||||||
| `tensor.sqrt()` | `tensor.sqrt()` |
|
| `tensor.sqrt()` | `tensor.sqrt()` |
|
||||||
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
|
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |
|
||||||
|
|
|
@ -1885,6 +1885,123 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Round;
|
||||||
|
retro_unary!(RetroRound, B::float_round);
|
||||||
|
|
||||||
|
impl<B: Backend> Backward<B, 1> for Round {
|
||||||
|
type State = (Shape, B::Device);
|
||||||
|
|
||||||
|
fn backward(
|
||||||
|
self,
|
||||||
|
ops: Ops<Self::State, 1>,
|
||||||
|
grads: &mut Gradients,
|
||||||
|
_checkpointer: &mut Checkpointer,
|
||||||
|
) {
|
||||||
|
let (shape, device) = ops.state;
|
||||||
|
unary::<B, _>(ops.parents, ops.node, grads, |_grad| {
|
||||||
|
B::float_zeros(shape, &device)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match Round
|
||||||
|
.prepare::<C>([tensor.node.clone()])
|
||||||
|
.memory_bound()
|
||||||
|
.retro_forward(RetroRound::<B>::new(tensor.node.id))
|
||||||
|
.parents([&tensor])
|
||||||
|
.stateful()
|
||||||
|
{
|
||||||
|
OpsKind::Tracked(preps) => preps.finish(
|
||||||
|
(
|
||||||
|
B::float_shape(&tensor.primitive),
|
||||||
|
B::float_device(&tensor.primitive),
|
||||||
|
),
|
||||||
|
B::float_round(tensor.primitive),
|
||||||
|
),
|
||||||
|
OpsKind::UnTracked(preps) => preps.finish(B::float_round(tensor.primitive)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Floor;
|
||||||
|
retro_unary!(RetroFloor, B::float_floor);
|
||||||
|
|
||||||
|
impl<B: Backend> Backward<B, 1> for Floor {
|
||||||
|
type State = (Shape, B::Device);
|
||||||
|
|
||||||
|
fn backward(
|
||||||
|
self,
|
||||||
|
ops: Ops<Self::State, 1>,
|
||||||
|
grads: &mut Gradients,
|
||||||
|
_checkpointer: &mut Checkpointer,
|
||||||
|
) {
|
||||||
|
let (shape, device) = ops.state;
|
||||||
|
unary::<B, _>(ops.parents, ops.node, grads, |_grad| {
|
||||||
|
B::float_zeros(shape, &device)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match Floor
|
||||||
|
.prepare::<C>([tensor.node.clone()])
|
||||||
|
.memory_bound()
|
||||||
|
.retro_forward(RetroFloor::<B>::new(tensor.node.id))
|
||||||
|
.parents([&tensor])
|
||||||
|
.stateful()
|
||||||
|
{
|
||||||
|
OpsKind::Tracked(preps) => preps.finish(
|
||||||
|
(
|
||||||
|
B::float_shape(&tensor.primitive),
|
||||||
|
B::float_device(&tensor.primitive),
|
||||||
|
),
|
||||||
|
B::float_floor(tensor.primitive),
|
||||||
|
),
|
||||||
|
OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct Ceil;
|
||||||
|
retro_unary!(RetroCeil, B::float_ceil);
|
||||||
|
|
||||||
|
impl<B: Backend> Backward<B, 1> for Ceil {
|
||||||
|
type State = (Shape, B::Device);
|
||||||
|
|
||||||
|
fn backward(
|
||||||
|
self,
|
||||||
|
ops: Ops<Self::State, 1>,
|
||||||
|
grads: &mut Gradients,
|
||||||
|
_checkpointer: &mut Checkpointer,
|
||||||
|
) {
|
||||||
|
let (shape, device) = ops.state;
|
||||||
|
unary::<B, _>(ops.parents, ops.node, grads, |_grad| {
|
||||||
|
B::float_zeros(shape, &device)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match Ceil
|
||||||
|
.prepare::<C>([tensor.node.clone()])
|
||||||
|
.memory_bound()
|
||||||
|
.retro_forward(RetroCeil::<B>::new(tensor.node.id))
|
||||||
|
.parents([&tensor])
|
||||||
|
.stateful()
|
||||||
|
{
|
||||||
|
OpsKind::Tracked(preps) => preps.finish(
|
||||||
|
(
|
||||||
|
B::float_shape(&tensor.primitive),
|
||||||
|
B::float_device(&tensor.primitive),
|
||||||
|
),
|
||||||
|
B::float_floor(tensor.primitive),
|
||||||
|
),
|
||||||
|
OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
struct Erf;
|
struct Erf;
|
||||||
|
|
|
@ -0,0 +1,25 @@
|
||||||
|
#[burn_tensor_testgen::testgen(ad_ceil)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::TensorData;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_ceil() {
|
||||||
|
let data = TensorData::from([
|
||||||
|
[-1.9751, 0.0714, 0.0643, 0.2406],
|
||||||
|
[-1.3172, 0.1252, -0.1119, -0.0127],
|
||||||
|
]);
|
||||||
|
let device = Default::default();
|
||||||
|
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||||
|
let tensor_2 = tensor_1.clone().ceil();
|
||||||
|
let grads = tensor_2.backward();
|
||||||
|
|
||||||
|
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||||
|
let shape = grad_1.to_data().shape;
|
||||||
|
|
||||||
|
grad_1.to_data().assert_eq(
|
||||||
|
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
#[burn_tensor_testgen::testgen(ad_floor)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::TensorData;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_floor() {
|
||||||
|
let data = TensorData::from([
|
||||||
|
[-1.9751, 0.0714, 0.0643, 0.2406],
|
||||||
|
[-1.3172, 0.1252, -0.1119, -0.0127],
|
||||||
|
]);
|
||||||
|
let device = Default::default();
|
||||||
|
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
|
||||||
|
let tensor_2 = tensor_1.clone().floor();
|
||||||
|
let grads = tensor_2.backward();
|
||||||
|
|
||||||
|
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||||
|
let shape = grad_1.to_data().shape;
|
||||||
|
|
||||||
|
grad_1.to_data().assert_eq(
|
||||||
|
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ mod backward;
|
||||||
mod bridge;
|
mod bridge;
|
||||||
mod broadcast;
|
mod broadcast;
|
||||||
mod cat;
|
mod cat;
|
||||||
|
mod ceil;
|
||||||
mod checkpoint;
|
mod checkpoint;
|
||||||
mod complex;
|
mod complex;
|
||||||
mod conv1d;
|
mod conv1d;
|
||||||
|
@ -27,6 +28,7 @@ mod erf;
|
||||||
mod exp;
|
mod exp;
|
||||||
mod expand;
|
mod expand;
|
||||||
mod flip;
|
mod flip;
|
||||||
|
mod floor;
|
||||||
mod gather_scatter;
|
mod gather_scatter;
|
||||||
mod gelu;
|
mod gelu;
|
||||||
mod gradients;
|
mod gradients;
|
||||||
|
@ -50,6 +52,7 @@ mod recip;
|
||||||
mod relu;
|
mod relu;
|
||||||
mod repeat_dim;
|
mod repeat_dim;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
|
mod round;
|
||||||
mod select;
|
mod select;
|
||||||
mod sigmoid;
|
mod sigmoid;
|
||||||
mod sign;
|
mod sign;
|
||||||
|
@ -127,6 +130,9 @@ macro_rules! testgen_all {
|
||||||
burn_autodiff::testgen_ad_abs!();
|
burn_autodiff::testgen_ad_abs!();
|
||||||
burn_autodiff::testgen_ad_sub!();
|
burn_autodiff::testgen_ad_sub!();
|
||||||
burn_autodiff::testgen_ad_tanh!();
|
burn_autodiff::testgen_ad_tanh!();
|
||||||
|
burn_autodiff::testgen_ad_round!();
|
||||||
|
burn_autodiff::testgen_ad_floor!();
|
||||||
|
burn_autodiff::testgen_ad_ceil!();
|
||||||
burn_autodiff::testgen_ad_sigmoid!();
|
burn_autodiff::testgen_ad_sigmoid!();
|
||||||
burn_autodiff::testgen_ad_log_sigmoid!();
|
burn_autodiff::testgen_ad_log_sigmoid!();
|
||||||
burn_autodiff::testgen_ad_transpose!();
|
burn_autodiff::testgen_ad_transpose!();
|
||||||
|
|
|
@ -0,0 +1,23 @@
|
||||||
|
#[burn_tensor_testgen::testgen(ad_round)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::TensorData;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_diff_round() {
|
||||||
|
let data = TensorData::from([
|
||||||
|
[-1.9751, 0.0714, 0.0643, 0.2406],
|
||||||
|
[-1.3172, 0.1252, -0.1119, -0.0127],
|
||||||
|
]);
|
||||||
|
let device = Default::default();
|
||||||
|
let tensor_1 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
|
||||||
|
let tensor_2 = tensor_1.clone().round();
|
||||||
|
let grads = tensor_2.backward();
|
||||||
|
|
||||||
|
let grad_1 = tensor_1.grad(&grads).unwrap();
|
||||||
|
grad_1.to_data().assert_eq(
|
||||||
|
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
|
||||||
|
false,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
|
@ -90,6 +90,9 @@ mod tests {
|
||||||
burn_tensor::testgen_argwhere_nonzero!();
|
burn_tensor::testgen_argwhere_nonzero!();
|
||||||
burn_tensor::testgen_sign!();
|
burn_tensor::testgen_sign!();
|
||||||
burn_tensor::testgen_nan!();
|
burn_tensor::testgen_nan!();
|
||||||
|
burn_tensor::testgen_round!();
|
||||||
|
burn_tensor::testgen_floor!();
|
||||||
|
burn_tensor::testgen_ceil!();
|
||||||
|
|
||||||
// TODO: https://github.com/tracel-ai/burn/issues/1237
|
// TODO: https://github.com/tracel-ai/burn/issues/1237
|
||||||
//
|
//
|
||||||
|
@ -165,4 +168,7 @@ mod tests {
|
||||||
burn_autodiff::testgen_ad_tanh!();
|
burn_autodiff::testgen_ad_tanh!();
|
||||||
burn_autodiff::testgen_ad_transpose!();
|
burn_autodiff::testgen_ad_transpose!();
|
||||||
burn_autodiff::testgen_ad_expand!();
|
burn_autodiff::testgen_ad_expand!();
|
||||||
|
burn_autodiff::testgen_ad_round!();
|
||||||
|
burn_autodiff::testgen_ad_floor!();
|
||||||
|
burn_autodiff::testgen_ad_ceil!();
|
||||||
}
|
}
|
||||||
|
|
|
@ -333,6 +333,36 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
|
||||||
CandleTensor::new(tensor.tensor.tanh().unwrap())
|
CandleTensor::new(tensor.tensor.tanh().unwrap())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
let inner = |tensor: FloatTensor<Self>| -> candle_core::Result<FloatTensor<Self>> {
|
||||||
|
// implements round_to_even for consistent behavior vs libtorch
|
||||||
|
// https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/runtime/register_ops_utils.h#L65-L67
|
||||||
|
|
||||||
|
let floor_a = tensor.tensor.floor()?;
|
||||||
|
let frac_part = tensor.tensor.sub(&floor_a)?;
|
||||||
|
|
||||||
|
let half = (candle_core::Tensor::ones_like(&tensor.tensor)? * 0.5)?;
|
||||||
|
let mask_half = frac_part.eq(&half)?;
|
||||||
|
let half_tensor = tensor.tensor.mul(&half)?;
|
||||||
|
let rounded_half = half_tensor.round()?;
|
||||||
|
let doubled =
|
||||||
|
rounded_half.mul(&(candle_core::Tensor::ones_like(&tensor.tensor)? * 2.0)?)?;
|
||||||
|
let standard_round = tensor.tensor.round()?;
|
||||||
|
Ok(CandleTensor::new(
|
||||||
|
mask_half.where_cond(&doubled, &standard_round)?,
|
||||||
|
))
|
||||||
|
};
|
||||||
|
inner(tensor).unwrap()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
CandleTensor::new(tensor.tensor.floor().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
CandleTensor::new(tensor.tensor.ceil().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
CandleTensor::new(tensor.tensor.erf().unwrap())
|
CandleTensor::new(tensor.tensor.erf().unwrap())
|
||||||
}
|
}
|
||||||
|
|
|
@ -2097,4 +2097,76 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
|
||||||
|
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
unary_float_ops!(RoundOps, B::float_round);
|
||||||
|
|
||||||
|
let stream = tensor.stream;
|
||||||
|
let out = tensor
|
||||||
|
.client
|
||||||
|
.tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
|
||||||
|
|
||||||
|
let desc = UnaryOperationDescription {
|
||||||
|
input: tensor.into_description(),
|
||||||
|
out: out.to_description_out(),
|
||||||
|
};
|
||||||
|
out.client.register(
|
||||||
|
vec![stream],
|
||||||
|
OperationDescription::Float(
|
||||||
|
FloatElem::<Self>::dtype(),
|
||||||
|
FloatOperationDescription::Round(desc.clone()),
|
||||||
|
),
|
||||||
|
RoundOps::<B>::new(desc),
|
||||||
|
);
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
unary_float_ops!(FloorOps, B::float_floor);
|
||||||
|
|
||||||
|
let stream = tensor.stream;
|
||||||
|
let out = tensor
|
||||||
|
.client
|
||||||
|
.tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
|
||||||
|
|
||||||
|
let desc = UnaryOperationDescription {
|
||||||
|
input: tensor.into_description(),
|
||||||
|
out: out.to_description_out(),
|
||||||
|
};
|
||||||
|
out.client.register(
|
||||||
|
vec![stream],
|
||||||
|
OperationDescription::Float(
|
||||||
|
FloatElem::<Self>::dtype(),
|
||||||
|
FloatOperationDescription::Floor(desc.clone()),
|
||||||
|
),
|
||||||
|
FloorOps::<B>::new(desc),
|
||||||
|
);
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
unary_float_ops!(CeilOps, B::float_ceil);
|
||||||
|
|
||||||
|
let stream = tensor.stream;
|
||||||
|
let out = tensor
|
||||||
|
.client
|
||||||
|
.tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
|
||||||
|
|
||||||
|
let desc = UnaryOperationDescription {
|
||||||
|
input: tensor.into_description(),
|
||||||
|
out: out.to_description_out(),
|
||||||
|
};
|
||||||
|
out.client.register(
|
||||||
|
vec![stream],
|
||||||
|
OperationDescription::Float(
|
||||||
|
FloatElem::<Self>::dtype(),
|
||||||
|
FloatOperationDescription::Ceil(desc.clone()),
|
||||||
|
),
|
||||||
|
CeilOps::<B>::new(desc),
|
||||||
|
);
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -520,6 +520,24 @@ impl RelativeOpsScalar<f32> for FloatOperationDescription {
|
||||||
out: desc.out.to_relative(converter),
|
out: desc.out.to_relative(converter),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
FloatOperationDescription::Round(desc) => {
|
||||||
|
FloatOperationDescription::Round(UnaryOperationDescription {
|
||||||
|
input: desc.input.to_relative(converter),
|
||||||
|
out: desc.out.to_relative(converter),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
FloatOperationDescription::Floor(desc) => {
|
||||||
|
FloatOperationDescription::Floor(UnaryOperationDescription {
|
||||||
|
input: desc.input.to_relative(converter),
|
||||||
|
out: desc.out.to_relative(converter),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
FloatOperationDescription::Ceil(desc) => {
|
||||||
|
FloatOperationDescription::Ceil(UnaryOperationDescription {
|
||||||
|
input: desc.input.to_relative(converter),
|
||||||
|
out: desc.out.to_relative(converter),
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -333,6 +333,36 @@ where
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
unary_op!(float(tensor) => |context, tensor| {
|
||||||
|
#[cube]
|
||||||
|
fn execute<C: Float>(input: Line<C>) -> Line<C> {
|
||||||
|
Line::round(input)
|
||||||
|
}
|
||||||
|
execute::expand::<C>(context, tensor)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
unary_op!(float(tensor) => |context, tensor| {
|
||||||
|
#[cube]
|
||||||
|
fn execute<C: Float>(input: Line<C>) -> Line<C> {
|
||||||
|
Line::floor(input)
|
||||||
|
}
|
||||||
|
execute::expand::<C>(context, tensor)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
unary_op!(float(tensor) => |context, tensor| {
|
||||||
|
#[cube]
|
||||||
|
fn execute<C: Float>(input: Line<C>) -> Line<C> {
|
||||||
|
Line::ceil(input)
|
||||||
|
}
|
||||||
|
execute::expand::<C>(context, tensor)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
unary_op!(float(tensor) => |context, tensor| {
|
unary_op!(float(tensor) => |context, tensor| {
|
||||||
#[cube]
|
#[cube]
|
||||||
|
|
|
@ -20,6 +20,22 @@ use num_traits::Float;
|
||||||
|
|
||||||
use libm::erf;
|
use libm::erf;
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn round_ties_even_wrapper(x: f64) -> f64 {
|
||||||
|
x.round_ties_even()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "std"))]
|
||||||
|
#[allow(dead_code)]
|
||||||
|
fn round_ties_even_wrapper(x: f64) -> f64 {
|
||||||
|
if (x - x.floor()) == 0.5 {
|
||||||
|
(x * 0.5).round() * 2.0
|
||||||
|
} else {
|
||||||
|
x.round()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
|
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
|
||||||
for NdArray<E, I, Q>
|
for NdArray<E, I, Q>
|
||||||
{
|
{
|
||||||
|
@ -351,6 +367,34 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorO
|
||||||
NdArrayTensor::new(array)
|
NdArrayTensor::new(array)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn float_round(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
|
||||||
|
let array = tensor
|
||||||
|
.array
|
||||||
|
// .mapv_into(|a| (a.to_f64()).round_ties_even().elem())
|
||||||
|
.mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
|
||||||
|
.into_shared();
|
||||||
|
|
||||||
|
NdArrayTensor::new(array)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_floor(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
|
||||||
|
let array = tensor
|
||||||
|
.array
|
||||||
|
.mapv_into(|a| (a.to_f64()).floor().elem())
|
||||||
|
.into_shared();
|
||||||
|
|
||||||
|
NdArrayTensor::new(array)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_ceil(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
|
||||||
|
let array = tensor
|
||||||
|
.array
|
||||||
|
.mapv_into(|a| (a.to_f64()).ceil().elem())
|
||||||
|
.into_shared();
|
||||||
|
|
||||||
|
NdArrayTensor::new(array)
|
||||||
|
}
|
||||||
|
|
||||||
fn float_erf(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
|
fn float_erf(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
|
||||||
let array = tensor
|
let array = tensor
|
||||||
.array
|
.array
|
||||||
|
|
|
@ -1076,6 +1076,60 @@ impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
|
||||||
out
|
out
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
let client = tensor.client.clone();
|
||||||
|
let dtype = tensor.dtype;
|
||||||
|
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);
|
||||||
|
|
||||||
|
let desc = UnaryOperationDescription {
|
||||||
|
input: tensor.into_description(),
|
||||||
|
out: out.to_description_out(),
|
||||||
|
};
|
||||||
|
|
||||||
|
client.register(OperationDescription::Float(
|
||||||
|
dtype,
|
||||||
|
FloatOperationDescription::Round(desc),
|
||||||
|
));
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
let client = tensor.client.clone();
|
||||||
|
let dtype = tensor.dtype;
|
||||||
|
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);
|
||||||
|
|
||||||
|
let desc = UnaryOperationDescription {
|
||||||
|
input: tensor.into_description(),
|
||||||
|
out: out.to_description_out(),
|
||||||
|
};
|
||||||
|
|
||||||
|
client.register(OperationDescription::Float(
|
||||||
|
dtype,
|
||||||
|
FloatOperationDescription::Floor(desc),
|
||||||
|
));
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
|
let client = tensor.client.clone();
|
||||||
|
let dtype = tensor.dtype;
|
||||||
|
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);
|
||||||
|
|
||||||
|
let desc = UnaryOperationDescription {
|
||||||
|
input: tensor.into_description(),
|
||||||
|
out: out.to_description_out(),
|
||||||
|
};
|
||||||
|
|
||||||
|
client.register(OperationDescription::Float(
|
||||||
|
dtype,
|
||||||
|
FloatOperationDescription::Ceil(desc),
|
||||||
|
));
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
||||||
|
|
||||||
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
|
||||||
let client = tensor.client.clone();
|
let client = tensor.client.clone();
|
||||||
let dtype = tensor.dtype;
|
let dtype = tensor.dtype;
|
||||||
|
|
|
@ -826,6 +826,15 @@ where
|
||||||
FloatOperationDescription::Tanh(desc) => {
|
FloatOperationDescription::Tanh(desc) => {
|
||||||
unary_float_ops!(handles, desc, B::float_tanh)
|
unary_float_ops!(handles, desc, B::float_tanh)
|
||||||
}
|
}
|
||||||
|
FloatOperationDescription::Round(desc) => {
|
||||||
|
unary_float_ops!(handles, desc, B::float_round)
|
||||||
|
}
|
||||||
|
FloatOperationDescription::Floor(desc) => {
|
||||||
|
unary_float_ops!(handles, desc, B::float_floor)
|
||||||
|
}
|
||||||
|
FloatOperationDescription::Ceil(desc) => {
|
||||||
|
unary_float_ops!(handles, desc, B::float_ceil)
|
||||||
|
}
|
||||||
FloatOperationDescription::IntoInt(desc) => {
|
FloatOperationDescription::IntoInt(desc) => {
|
||||||
let tensor = handles.get_float_tensor::<B>(&desc.input);
|
let tensor = handles.get_float_tensor::<B>(&desc.input);
|
||||||
|
|
||||||
|
|
|
@ -357,6 +357,18 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
|
||||||
tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
|
tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn float_round(tensor: TchTensor<E>) -> TchTensor<E> {
|
||||||
|
tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_floor(tensor: TchTensor<E>) -> TchTensor<E> {
|
||||||
|
tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn float_ceil(tensor: TchTensor<E>) -> TchTensor<E> {
|
||||||
|
tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
|
||||||
|
}
|
||||||
|
|
||||||
fn float_erf(tensor: TchTensor<E>) -> TchTensor<E> {
|
fn float_erf(tensor: TchTensor<E>) -> TchTensor<E> {
|
||||||
tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
|
tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
|
||||||
}
|
}
|
||||||
|
|
|
@ -59,6 +59,12 @@ pub enum FloatOperationDescription {
|
||||||
Sin(UnaryOperationDescription),
|
Sin(UnaryOperationDescription),
|
||||||
/// Operation corresponding to [tanh](crate::ops::FloatTensorOps::float_tanh).
|
/// Operation corresponding to [tanh](crate::ops::FloatTensorOps::float_tanh).
|
||||||
Tanh(UnaryOperationDescription),
|
Tanh(UnaryOperationDescription),
|
||||||
|
/// Operation corresponding to [round](crate::ops::FloatTensorOps::float_round).
|
||||||
|
Round(UnaryOperationDescription),
|
||||||
|
/// Operation corresponding to [floor](crate::ops::FloatTensorOps::float_floor).
|
||||||
|
Floor(UnaryOperationDescription),
|
||||||
|
/// Operation corresponding to [ceil](crate::ops::FloatTensorOps::float_ceil).
|
||||||
|
Ceil(UnaryOperationDescription),
|
||||||
/// Operation corresponding to [into_int](crate::ops::FloatTensorOps::float_into_int).
|
/// Operation corresponding to [into_int](crate::ops::FloatTensorOps::float_into_int).
|
||||||
IntoInt(UnaryOperationDescription),
|
IntoInt(UnaryOperationDescription),
|
||||||
/// Operation corresponding to [matmul](crate::ops::FloatTensorOps::float_matmul).
|
/// Operation corresponding to [matmul](crate::ops::FloatTensorOps::float_matmul).
|
||||||
|
@ -1454,6 +1460,9 @@ impl FloatOperationDescription {
|
||||||
FloatOperationDescription::Cos(desc) => vec![&desc.input, &desc.out],
|
FloatOperationDescription::Cos(desc) => vec![&desc.input, &desc.out],
|
||||||
FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out],
|
FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out],
|
||||||
FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out],
|
FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out],
|
||||||
|
FloatOperationDescription::Round(desc) => vec![&desc.input, &desc.out],
|
||||||
|
FloatOperationDescription::Floor(desc) => vec![&desc.input, &desc.out],
|
||||||
|
FloatOperationDescription::Ceil(desc) => vec![&desc.input, &desc.out],
|
||||||
FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
|
FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
|
||||||
FloatOperationDescription::Quantize(desc) => {
|
FloatOperationDescription::Quantize(desc) => {
|
||||||
if let Some(offset) = &desc.qparams.offset {
|
if let Some(offset) = &desc.qparams.offset {
|
||||||
|
|
|
@ -104,6 +104,30 @@ where
|
||||||
)))
|
)))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Applies element wise round operation.
|
||||||
|
///
|
||||||
|
/// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
|
||||||
|
/// strategy, with halfway cases rounded to the nearest integer value.
|
||||||
|
pub fn round(self) -> Self {
|
||||||
|
Self::new(TensorPrimitive::Float(B::float_round(
|
||||||
|
self.primitive.tensor(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies element wise floor operation.
|
||||||
|
pub fn floor(self) -> Self {
|
||||||
|
Self::new(TensorPrimitive::Float(B::float_floor(
|
||||||
|
self.primitive.tensor(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Applies element wise ceil operation.
|
||||||
|
pub fn ceil(self) -> Self {
|
||||||
|
Self::new(TensorPrimitive::Float(B::float_ceil(
|
||||||
|
self.primitive.tensor(),
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a tensor from floats (f32) on a given device.
|
/// Create a tensor from floats (f32) on a given device.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
|
|
|
@ -931,6 +931,42 @@ pub trait FloatTensorOps<B: Backend> {
|
||||||
/// A tensor with the same shape as `tensor` with tangent values.
|
/// A tensor with the same shape as `tensor` with tangent values.
|
||||||
fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
|
fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
|
||||||
|
|
||||||
|
/// Returns a new tensor with rounded values.
|
||||||
|
///
|
||||||
|
/// This function should implemented the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
|
||||||
|
/// strategy, with halfway cases rounded to the nearest integer value.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `tensor` - The tensor to be rounded.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A tensor with the same shape as `tensor` with rounded values.
|
||||||
|
fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
|
||||||
|
|
||||||
|
/// Returns a new tensor with floored values.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `tensor` - The tensor to be floored.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A tensor with the same shape as `tensor` with floored values.
|
||||||
|
fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
|
||||||
|
|
||||||
|
/// Returns a new tensor with ceiled values.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
///
|
||||||
|
/// * `tensor` - The tensor to be ceiled.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
///
|
||||||
|
/// A tensor with the same shape as `tensor` with ceiled values.
|
||||||
|
fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
|
||||||
|
|
||||||
/// Returns a new tensor with the error function values.
|
/// Returns a new tensor with the error function values.
|
||||||
///
|
///
|
||||||
/// # Arguments
|
/// # Arguments
|
||||||
|
|
|
@ -106,6 +106,9 @@ macro_rules! testgen_all {
|
||||||
burn_tensor::testgen_remainder!();
|
burn_tensor::testgen_remainder!();
|
||||||
burn_tensor::testgen_cartesian_grid!();
|
burn_tensor::testgen_cartesian_grid!();
|
||||||
burn_tensor::testgen_nan!();
|
burn_tensor::testgen_nan!();
|
||||||
|
burn_tensor::testgen_round!();
|
||||||
|
burn_tensor::testgen_floor!();
|
||||||
|
burn_tensor::testgen_ceil!();
|
||||||
|
|
||||||
// test stats
|
// test stats
|
||||||
burn_tensor::testgen_var!();
|
burn_tensor::testgen_var!();
|
||||||
|
|
|
@ -0,0 +1,16 @@
|
||||||
|
#[burn_tensor_testgen::testgen(ceil)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{Tensor, TensorData};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_ceil_ops() {
|
||||||
|
let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.ceil();
|
||||||
|
let expected = TensorData::from([[25., 88., 77.], [60., 44., 95.]]);
|
||||||
|
|
||||||
|
output.into_data().assert_approx_eq(&expected, 3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,16 @@
|
||||||
|
#[burn_tensor_testgen::testgen(floor)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{Tensor, TensorData};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_floor_ops() {
|
||||||
|
let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.floor();
|
||||||
|
let expected = TensorData::from([[24., 87., 76.], [59., 43., 94.]]);
|
||||||
|
|
||||||
|
output.into_data().assert_approx_eq(&expected, 3);
|
||||||
|
}
|
||||||
|
}
|
|
@ -11,6 +11,7 @@ mod bool;
|
||||||
mod cartesian_grid;
|
mod cartesian_grid;
|
||||||
mod cast;
|
mod cast;
|
||||||
mod cat;
|
mod cat;
|
||||||
|
mod ceil;
|
||||||
mod chunk;
|
mod chunk;
|
||||||
mod clamp;
|
mod clamp;
|
||||||
mod close;
|
mod close;
|
||||||
|
@ -22,6 +23,7 @@ mod exp;
|
||||||
mod expand;
|
mod expand;
|
||||||
mod flatten;
|
mod flatten;
|
||||||
mod flip;
|
mod flip;
|
||||||
|
mod floor;
|
||||||
mod full;
|
mod full;
|
||||||
mod gather_scatter;
|
mod gather_scatter;
|
||||||
mod init;
|
mod init;
|
||||||
|
@ -48,6 +50,7 @@ mod remainder;
|
||||||
mod repeat;
|
mod repeat;
|
||||||
mod repeat_dim;
|
mod repeat_dim;
|
||||||
mod reshape;
|
mod reshape;
|
||||||
|
mod round;
|
||||||
mod select;
|
mod select;
|
||||||
mod sign;
|
mod sign;
|
||||||
mod sin;
|
mod sin;
|
||||||
|
|
|
@ -0,0 +1,24 @@
|
||||||
|
#[burn_tensor_testgen::testgen(round)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use burn_tensor::{Tensor, TensorData};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_support_round_ops() {
|
||||||
|
let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);
|
||||||
|
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.round();
|
||||||
|
let expected = TensorData::from([[24., 88., 76.], [60., 44., 95.]]);
|
||||||
|
|
||||||
|
output.into_data().assert_approx_eq(&expected, 3);
|
||||||
|
|
||||||
|
let data = TensorData::from([1.5, 2.5, 3.5, 4.5, 5.5, 6.5]);
|
||||||
|
let tensor = Tensor::<TestBackend, 1>::from_data(data, &Default::default());
|
||||||
|
|
||||||
|
let output = tensor.round();
|
||||||
|
let expected = TensorData::from([2., 2., 4., 4., 6., 6.]);
|
||||||
|
|
||||||
|
output.into_data().assert_approx_eq(&expected, 3);
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue