diff --git a/Cargo.lock b/Cargo.lock index 7ae6c0475..2c545aed8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -513,7 +513,7 @@ dependencies = [ name = "burn-common" version = "0.15.0" dependencies = [ - "cubecl-common", + "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d)", "dashmap", "getrandom", "indicatif", @@ -1452,7 +1452,7 @@ version = "0.2.0" source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c" dependencies = [ "bytemuck", - "cubecl-common", + "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)", "cubecl-macros", "cubecl-runtime", "derive-new", @@ -1483,7 +1483,7 @@ version = "0.2.0" source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c" dependencies = [ "bytemuck", - "cubecl-common", + "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)", "cubecl-core", "cubecl-cpp", "cubecl-runtime", @@ -1534,7 +1534,7 @@ name = "cubecl-macros" version = "0.2.0" source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c" dependencies = [ - "cubecl-common", + "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)", "darling", "derive-new", "ident_case", @@ -1566,7 +1566,7 @@ source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd01 dependencies = [ "async-channel", "cfg_aliases 0.2.1", - "cubecl-common", + "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)", "derive-new", "dirs 5.0.1", "hashbrown 0.14.5", @@ -1601,7 +1601,7 @@ dependencies = [ "async-channel", "bytemuck", "cfg_aliases 0.2.1", - "cubecl-common", + "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)", "cubecl-core", "cubecl-runtime", "cubecl-spirv", diff --git a/burn-book/src/building-blocks/tensor.md b/burn-book/src/building-blocks/tensor.md index ad9a76ae7..5b81a9010 100644 --- a/burn-book/src/building-blocks/tensor.md +++ b/burn-book/src/building-blocks/tensor.md @@ -250,9 +250,11 @@ Those operations are only available for `Float` tensors. | Burn API | PyTorch Equivalent | | -------------------------------------------- | ---------------------------------- | +| `tensor.ceil()` | `tensor.ceil()` | | `tensor.cos()` | `tensor.cos()` | | `tensor.erf()` | `tensor.erf()` | | `tensor.exp()` | `tensor.exp()` | +| `tensor.floor()` | `tensor.floor()` | | `tensor.from_floats(floats, device)` | N/A | | `tensor.from_full_precision(tensor)` | N/A | | `tensor.int()` | Similar to `tensor.to(torch.long)` | @@ -264,6 +266,7 @@ Those operations are only available for `Float` tensors. | `tensor.random(shape, distribution, device)` | N/A | | `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | | `tensor.recip()` | `tensor.reciprocal()` | +| `tensor.round()` | `tensor.round()` | | `tensor.sin()` | `tensor.sin()` | | `tensor.sqrt()` | `tensor.sqrt()` | | `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | diff --git a/crates/burn-autodiff/src/ops/tensor.rs b/crates/burn-autodiff/src/ops/tensor.rs index 979aa3113..7b3ceb33e 100644 --- a/crates/burn-autodiff/src/ops/tensor.rs +++ b/crates/burn-autodiff/src/ops/tensor.rs @@ -1885,6 +1885,123 @@ impl FloatTensorOps for Autodiff } } + fn float_round(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Round; + retro_unary!(RetroRound, B::float_round); + + impl Backward for Round { + type State = (Shape, B::Device); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + let (shape, device) = ops.state; + unary::(ops.parents, ops.node, grads, |_grad| { + B::float_zeros(shape, &device) + }) + } + } + + match Round + .prepare::([tensor.node.clone()]) + .memory_bound() + .retro_forward(RetroRound::::new(tensor.node.id)) + .parents([&tensor]) + .stateful() + { + OpsKind::Tracked(preps) => preps.finish( + ( + B::float_shape(&tensor.primitive), + B::float_device(&tensor.primitive), + ), + B::float_round(tensor.primitive), + ), + OpsKind::UnTracked(preps) => preps.finish(B::float_round(tensor.primitive)), + } + } + + fn float_floor(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Floor; + retro_unary!(RetroFloor, B::float_floor); + + impl Backward for Floor { + type State = (Shape, B::Device); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + let (shape, device) = ops.state; + unary::(ops.parents, ops.node, grads, |_grad| { + B::float_zeros(shape, &device) + }) + } + } + + match Floor + .prepare::([tensor.node.clone()]) + .memory_bound() + .retro_forward(RetroFloor::::new(tensor.node.id)) + .parents([&tensor]) + .stateful() + { + OpsKind::Tracked(preps) => preps.finish( + ( + B::float_shape(&tensor.primitive), + B::float_device(&tensor.primitive), + ), + B::float_floor(tensor.primitive), + ), + OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)), + } + } + + fn float_ceil(tensor: FloatTensor) -> FloatTensor { + #[derive(Debug)] + struct Ceil; + retro_unary!(RetroCeil, B::float_ceil); + + impl Backward for Ceil { + type State = (Shape, B::Device); + + fn backward( + self, + ops: Ops, + grads: &mut Gradients, + _checkpointer: &mut Checkpointer, + ) { + let (shape, device) = ops.state; + unary::(ops.parents, ops.node, grads, |_grad| { + B::float_zeros(shape, &device) + }) + } + } + + match Ceil + .prepare::([tensor.node.clone()]) + .memory_bound() + .retro_forward(RetroCeil::::new(tensor.node.id)) + .parents([&tensor]) + .stateful() + { + OpsKind::Tracked(preps) => preps.finish( + ( + B::float_shape(&tensor.primitive), + B::float_device(&tensor.primitive), + ), + B::float_floor(tensor.primitive), + ), + OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)), + } + } + fn float_erf(tensor: FloatTensor) -> FloatTensor { #[derive(Debug)] struct Erf; diff --git a/crates/burn-autodiff/src/tests/ceil.rs b/crates/burn-autodiff/src/tests/ceil.rs new file mode 100644 index 000000000..6b9883c6e --- /dev/null +++ b/crates/burn-autodiff/src/tests/ceil.rs @@ -0,0 +1,25 @@ +#[burn_tensor_testgen::testgen(ad_ceil)] +mod tests { + use super::*; + use burn_tensor::TensorData; + + #[test] + fn should_diff_ceil() { + let data = TensorData::from([ + [-1.9751, 0.0714, 0.0643, 0.2406], + [-1.3172, 0.1252, -0.1119, -0.0127], + ]); + let device = Default::default(); + let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad(); + let tensor_2 = tensor_1.clone().ceil(); + let grads = tensor_2.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let shape = grad_1.to_data().shape; + + grad_1.to_data().assert_eq( + &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), + false, + ); + } +} diff --git a/crates/burn-autodiff/src/tests/floor.rs b/crates/burn-autodiff/src/tests/floor.rs new file mode 100644 index 000000000..ca417a63f --- /dev/null +++ b/crates/burn-autodiff/src/tests/floor.rs @@ -0,0 +1,25 @@ +#[burn_tensor_testgen::testgen(ad_floor)] +mod tests { + use super::*; + use burn_tensor::TensorData; + + #[test] + fn should_diff_floor() { + let data = TensorData::from([ + [-1.9751, 0.0714, 0.0643, 0.2406], + [-1.3172, 0.1252, -0.1119, -0.0127], + ]); + let device = Default::default(); + let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad(); + let tensor_2 = tensor_1.clone().floor(); + let grads = tensor_2.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + let shape = grad_1.to_data().shape; + + grad_1.to_data().assert_eq( + &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), + false, + ); + } +} diff --git a/crates/burn-autodiff/src/tests/mod.rs b/crates/burn-autodiff/src/tests/mod.rs index 84a0eee27..6fdafc1be 100644 --- a/crates/burn-autodiff/src/tests/mod.rs +++ b/crates/burn-autodiff/src/tests/mod.rs @@ -11,6 +11,7 @@ mod backward; mod bridge; mod broadcast; mod cat; +mod ceil; mod checkpoint; mod complex; mod conv1d; @@ -27,6 +28,7 @@ mod erf; mod exp; mod expand; mod flip; +mod floor; mod gather_scatter; mod gelu; mod gradients; @@ -50,6 +52,7 @@ mod recip; mod relu; mod repeat_dim; mod reshape; +mod round; mod select; mod sigmoid; mod sign; @@ -127,6 +130,9 @@ macro_rules! testgen_all { burn_autodiff::testgen_ad_abs!(); burn_autodiff::testgen_ad_sub!(); burn_autodiff::testgen_ad_tanh!(); + burn_autodiff::testgen_ad_round!(); + burn_autodiff::testgen_ad_floor!(); + burn_autodiff::testgen_ad_ceil!(); burn_autodiff::testgen_ad_sigmoid!(); burn_autodiff::testgen_ad_log_sigmoid!(); burn_autodiff::testgen_ad_transpose!(); diff --git a/crates/burn-autodiff/src/tests/round.rs b/crates/burn-autodiff/src/tests/round.rs new file mode 100644 index 000000000..338a35c3d --- /dev/null +++ b/crates/burn-autodiff/src/tests/round.rs @@ -0,0 +1,23 @@ +#[burn_tensor_testgen::testgen(ad_round)] +mod tests { + use super::*; + use burn_tensor::TensorData; + + #[test] + fn should_diff_round() { + let data = TensorData::from([ + [-1.9751, 0.0714, 0.0643, 0.2406], + [-1.3172, 0.1252, -0.1119, -0.0127], + ]); + let device = Default::default(); + let tensor_1 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad(); + let tensor_2 = tensor_1.clone().round(); + let grads = tensor_2.backward(); + + let grad_1 = tensor_1.grad(&grads).unwrap(); + grad_1.to_data().assert_eq( + &TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]), + false, + ); + } +} diff --git a/crates/burn-candle/src/lib.rs b/crates/burn-candle/src/lib.rs index d82440efa..8f0d496de 100644 --- a/crates/burn-candle/src/lib.rs +++ b/crates/burn-candle/src/lib.rs @@ -90,6 +90,9 @@ mod tests { burn_tensor::testgen_argwhere_nonzero!(); burn_tensor::testgen_sign!(); burn_tensor::testgen_nan!(); + burn_tensor::testgen_round!(); + burn_tensor::testgen_floor!(); + burn_tensor::testgen_ceil!(); // TODO: https://github.com/tracel-ai/burn/issues/1237 // @@ -165,4 +168,7 @@ mod tests { burn_autodiff::testgen_ad_tanh!(); burn_autodiff::testgen_ad_transpose!(); burn_autodiff::testgen_ad_expand!(); + burn_autodiff::testgen_ad_round!(); + burn_autodiff::testgen_ad_floor!(); + burn_autodiff::testgen_ad_ceil!(); } diff --git a/crates/burn-candle/src/ops/tensor.rs b/crates/burn-candle/src/ops/tensor.rs index 1471922ec..278bf1787 100644 --- a/crates/burn-candle/src/ops/tensor.rs +++ b/crates/burn-candle/src/ops/tensor.rs @@ -333,6 +333,36 @@ impl FloatTensorOps for Candle CandleTensor::new(tensor.tensor.tanh().unwrap()) } + fn float_round(tensor: FloatTensor) -> FloatTensor { + let inner = |tensor: FloatTensor| -> candle_core::Result> { + // implements round_to_even for consistent behavior vs libtorch + // https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/runtime/register_ops_utils.h#L65-L67 + + let floor_a = tensor.tensor.floor()?; + let frac_part = tensor.tensor.sub(&floor_a)?; + + let half = (candle_core::Tensor::ones_like(&tensor.tensor)? * 0.5)?; + let mask_half = frac_part.eq(&half)?; + let half_tensor = tensor.tensor.mul(&half)?; + let rounded_half = half_tensor.round()?; + let doubled = + rounded_half.mul(&(candle_core::Tensor::ones_like(&tensor.tensor)? * 2.0)?)?; + let standard_round = tensor.tensor.round()?; + Ok(CandleTensor::new( + mask_half.where_cond(&doubled, &standard_round)?, + )) + }; + inner(tensor).unwrap() + } + + fn float_floor(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.floor().unwrap()) + } + + fn float_ceil(tensor: FloatTensor) -> FloatTensor { + CandleTensor::new(tensor.tensor.ceil().unwrap()) + } + fn float_erf(tensor: FloatTensor) -> FloatTensor { CandleTensor::new(tensor.tensor.erf().unwrap()) } diff --git a/crates/burn-fusion/src/ops/float.rs b/crates/burn-fusion/src/ops/float.rs index 1be7cd884..4bcb2e411 100644 --- a/crates/burn-fusion/src/ops/float.rs +++ b/crates/burn-fusion/src/ops/float.rs @@ -2097,4 +2097,76 @@ impl FloatTensorOps for Fusion { out } + + fn float_round(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(RoundOps, B::float_round); + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Round(desc.clone()), + ), + RoundOps::::new(desc), + ); + + out + } + + fn float_floor(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(FloorOps, B::float_floor); + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Floor(desc.clone()), + ), + FloorOps::::new(desc), + ); + + out + } + + fn float_ceil(tensor: FloatTensor) -> FloatTensor { + unary_float_ops!(CeilOps, B::float_ceil); + + let stream = tensor.stream; + let out = tensor + .client + .tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype()); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + out.client.register( + vec![stream], + OperationDescription::Float( + FloatElem::::dtype(), + FloatOperationDescription::Ceil(desc.clone()), + ), + CeilOps::::new(desc), + ); + + out + } } diff --git a/crates/burn-fusion/src/stream/context.rs b/crates/burn-fusion/src/stream/context.rs index 28293c926..b6b8886db 100644 --- a/crates/burn-fusion/src/stream/context.rs +++ b/crates/burn-fusion/src/stream/context.rs @@ -520,6 +520,24 @@ impl RelativeOpsScalar for FloatOperationDescription { out: desc.out.to_relative(converter), }) } + FloatOperationDescription::Round(desc) => { + FloatOperationDescription::Round(UnaryOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + FloatOperationDescription::Floor(desc) => { + FloatOperationDescription::Floor(UnaryOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } + FloatOperationDescription::Ceil(desc) => { + FloatOperationDescription::Ceil(UnaryOperationDescription { + input: desc.input.to_relative(converter), + out: desc.out.to_relative(converter), + }) + } } } } diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index e7fa38f5a..5e3f2104c 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -333,6 +333,36 @@ where }) } + fn float_round(tensor: FloatTensor) -> FloatTensor { + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::round(input) + } + execute::expand::(context, tensor) + }) + } + + fn float_floor(tensor: FloatTensor) -> FloatTensor { + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::floor(input) + } + execute::expand::(context, tensor) + }) + } + + fn float_ceil(tensor: FloatTensor) -> FloatTensor { + unary_op!(float(tensor) => |context, tensor| { + #[cube] + fn execute(input: Line) -> Line { + Line::ceil(input) + } + execute::expand::(context, tensor) + }) + } + fn float_erf(tensor: FloatTensor) -> FloatTensor { unary_op!(float(tensor) => |context, tensor| { #[cube] diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index 49d038d7b..866e522f5 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -20,6 +20,22 @@ use num_traits::Float; use libm::erf; +#[cfg(feature = "std")] +#[allow(dead_code)] +fn round_ties_even_wrapper(x: f64) -> f64 { + x.round_ties_even() +} + +#[cfg(not(feature = "std"))] +#[allow(dead_code)] +fn round_ties_even_wrapper(x: f64) -> f64 { + if (x - x.floor()) == 0.5 { + (x * 0.5).round() * 2.0 + } else { + x.round() + } +} + impl FloatTensorOps for NdArray { @@ -351,6 +367,34 @@ impl FloatTensorO NdArrayTensor::new(array) } + fn float_round(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + // .mapv_into(|a| (a.to_f64()).round_ties_even().elem()) + .mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem()) + .into_shared(); + + NdArrayTensor::new(array) + } + + fn float_floor(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| (a.to_f64()).floor().elem()) + .into_shared(); + + NdArrayTensor::new(array) + } + + fn float_ceil(tensor: NdArrayTensor) -> NdArrayTensor { + let array = tensor + .array + .mapv_into(|a| (a.to_f64()).ceil().elem()) + .into_shared(); + + NdArrayTensor::new(array) + } + fn float_erf(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor .array diff --git a/crates/burn-router/src/ops/op_float.rs b/crates/burn-router/src/ops/op_float.rs index 1859e525d..76107a104 100644 --- a/crates/burn-router/src/ops/op_float.rs +++ b/crates/burn-router/src/ops/op_float.rs @@ -1076,6 +1076,60 @@ impl FloatTensorOps for BackendRouter { out } + fn float_round(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Round(desc), + )); + + out + } + + fn float_floor(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Floor(desc), + )); + + out + } + + fn float_ceil(tensor: FloatTensor) -> FloatTensor { + let client = tensor.client.clone(); + let dtype = tensor.dtype; + let out = client.register_empty_tensor(tensor.shape.clone(), dtype); + + let desc = UnaryOperationDescription { + input: tensor.into_description(), + out: out.to_description_out(), + }; + + client.register(OperationDescription::Float( + dtype, + FloatOperationDescription::Ceil(desc), + )); + + out + } + fn float_recip(tensor: FloatTensor) -> FloatTensor { let client = tensor.client.clone(); let dtype = tensor.dtype; diff --git a/crates/burn-router/src/runner.rs b/crates/burn-router/src/runner.rs index 87aa5b358..2f408155b 100644 --- a/crates/burn-router/src/runner.rs +++ b/crates/burn-router/src/runner.rs @@ -826,6 +826,15 @@ where FloatOperationDescription::Tanh(desc) => { unary_float_ops!(handles, desc, B::float_tanh) } + FloatOperationDescription::Round(desc) => { + unary_float_ops!(handles, desc, B::float_round) + } + FloatOperationDescription::Floor(desc) => { + unary_float_ops!(handles, desc, B::float_floor) + } + FloatOperationDescription::Ceil(desc) => { + unary_float_ops!(handles, desc, B::float_ceil) + } FloatOperationDescription::IntoInt(desc) => { let tensor = handles.get_float_tensor::(&desc.input); diff --git a/crates/burn-tch/src/ops/tensor.rs b/crates/burn-tch/src/ops/tensor.rs index ab5af6303..656b4182c 100644 --- a/crates/burn-tch/src/ops/tensor.rs +++ b/crates/burn-tch/src/ops/tensor.rs @@ -357,6 +357,18 @@ impl FloatTensorOps for LibTorch { tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh()) } + fn float_round(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round()) + } + + fn float_floor(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor()) + } + + fn float_ceil(tensor: TchTensor) -> TchTensor { + tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil()) + } + fn float_erf(tensor: TchTensor) -> TchTensor { tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf()) } diff --git a/crates/burn-tensor/src/repr/operation.rs b/crates/burn-tensor/src/repr/operation.rs index 9102b4ac7..7304134de 100644 --- a/crates/burn-tensor/src/repr/operation.rs +++ b/crates/burn-tensor/src/repr/operation.rs @@ -59,6 +59,12 @@ pub enum FloatOperationDescription { Sin(UnaryOperationDescription), /// Operation corresponding to [tanh](crate::ops::FloatTensorOps::float_tanh). Tanh(UnaryOperationDescription), + /// Operation corresponding to [round](crate::ops::FloatTensorOps::float_round). + Round(UnaryOperationDescription), + /// Operation corresponding to [floor](crate::ops::FloatTensorOps::float_floor). + Floor(UnaryOperationDescription), + /// Operation corresponding to [ceil](crate::ops::FloatTensorOps::float_ceil). + Ceil(UnaryOperationDescription), /// Operation corresponding to [into_int](crate::ops::FloatTensorOps::float_into_int). IntoInt(UnaryOperationDescription), /// Operation corresponding to [matmul](crate::ops::FloatTensorOps::float_matmul). @@ -1454,6 +1460,9 @@ impl FloatOperationDescription { FloatOperationDescription::Cos(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out], + FloatOperationDescription::Round(desc) => vec![&desc.input, &desc.out], + FloatOperationDescription::Floor(desc) => vec![&desc.input, &desc.out], + FloatOperationDescription::Ceil(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::Quantize(desc) => { if let Some(offset) = &desc.qparams.offset { diff --git a/crates/burn-tensor/src/tensor/api/float.rs b/crates/burn-tensor/src/tensor/api/float.rs index 7d6d4c20d..1ba8506bd 100644 --- a/crates/burn-tensor/src/tensor/api/float.rs +++ b/crates/burn-tensor/src/tensor/api/float.rs @@ -104,6 +104,30 @@ where ))) } + /// Applies element wise round operation. + /// + /// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even) + /// strategy, with halfway cases rounded to the nearest integer value. + pub fn round(self) -> Self { + Self::new(TensorPrimitive::Float(B::float_round( + self.primitive.tensor(), + ))) + } + + /// Applies element wise floor operation. + pub fn floor(self) -> Self { + Self::new(TensorPrimitive::Float(B::float_floor( + self.primitive.tensor(), + ))) + } + + /// Applies element wise ceil operation. + pub fn ceil(self) -> Self { + Self::new(TensorPrimitive::Float(B::float_ceil( + self.primitive.tensor(), + ))) + } + /// Create a tensor from floats (f32) on a given device. /// /// # Example diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index d33fb694e..8b4f4b241 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -931,6 +931,42 @@ pub trait FloatTensorOps { /// A tensor with the same shape as `tensor` with tangent values. fn float_tanh(tensor: FloatTensor) -> FloatTensor; + /// Returns a new tensor with rounded values. + /// + /// This function should implemented the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even) + /// strategy, with halfway cases rounded to the nearest integer value. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to be rounded. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with rounded values. + fn float_round(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with floored values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to be floored. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with floored values. + fn float_floor(tensor: FloatTensor) -> FloatTensor; + + /// Returns a new tensor with ceiled values. + /// + /// # Arguments + /// + /// * `tensor` - The tensor to be ceiled. + /// + /// # Returns + /// + /// A tensor with the same shape as `tensor` with ceiled values. + fn float_ceil(tensor: FloatTensor) -> FloatTensor; + /// Returns a new tensor with the error function values. /// /// # Arguments diff --git a/crates/burn-tensor/src/tests/mod.rs b/crates/burn-tensor/src/tests/mod.rs index e5a0acc80..a4834049d 100644 --- a/crates/burn-tensor/src/tests/mod.rs +++ b/crates/burn-tensor/src/tests/mod.rs @@ -106,6 +106,9 @@ macro_rules! testgen_all { burn_tensor::testgen_remainder!(); burn_tensor::testgen_cartesian_grid!(); burn_tensor::testgen_nan!(); + burn_tensor::testgen_round!(); + burn_tensor::testgen_floor!(); + burn_tensor::testgen_ceil!(); // test stats burn_tensor::testgen_var!(); diff --git a/crates/burn-tensor/src/tests/ops/ceil.rs b/crates/burn-tensor/src/tests/ops/ceil.rs new file mode 100644 index 000000000..7d5e53455 --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/ceil.rs @@ -0,0 +1,16 @@ +#[burn_tensor_testgen::testgen(ceil)] +mod tests { + use super::*; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_ceil_ops() { + let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); + let tensor = Tensor::::from_data(data, &Default::default()); + + let output = tensor.ceil(); + let expected = TensorData::from([[25., 88., 77.], [60., 44., 95.]]); + + output.into_data().assert_approx_eq(&expected, 3); + } +} diff --git a/crates/burn-tensor/src/tests/ops/floor.rs b/crates/burn-tensor/src/tests/ops/floor.rs new file mode 100644 index 000000000..5913244e6 --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/floor.rs @@ -0,0 +1,16 @@ +#[burn_tensor_testgen::testgen(floor)] +mod tests { + use super::*; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_floor_ops() { + let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); + let tensor = Tensor::::from_data(data, &Default::default()); + + let output = tensor.floor(); + let expected = TensorData::from([[24., 87., 76.], [59., 43., 94.]]); + + output.into_data().assert_approx_eq(&expected, 3); + } +} diff --git a/crates/burn-tensor/src/tests/ops/mod.rs b/crates/burn-tensor/src/tests/ops/mod.rs index 533ca7e71..01bf61225 100644 --- a/crates/burn-tensor/src/tests/ops/mod.rs +++ b/crates/burn-tensor/src/tests/ops/mod.rs @@ -11,6 +11,7 @@ mod bool; mod cartesian_grid; mod cast; mod cat; +mod ceil; mod chunk; mod clamp; mod close; @@ -22,6 +23,7 @@ mod exp; mod expand; mod flatten; mod flip; +mod floor; mod full; mod gather_scatter; mod init; @@ -48,6 +50,7 @@ mod remainder; mod repeat; mod repeat_dim; mod reshape; +mod round; mod select; mod sign; mod sin; diff --git a/crates/burn-tensor/src/tests/ops/round.rs b/crates/burn-tensor/src/tests/ops/round.rs new file mode 100644 index 000000000..01e108951 --- /dev/null +++ b/crates/burn-tensor/src/tests/ops/round.rs @@ -0,0 +1,24 @@ +#[burn_tensor_testgen::testgen(round)] +mod tests { + use super::*; + use burn_tensor::{Tensor, TensorData}; + + #[test] + fn should_support_round_ops() { + let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]); + let tensor = Tensor::::from_data(data, &Default::default()); + + let output = tensor.round(); + let expected = TensorData::from([[24., 88., 76.], [60., 44., 95.]]); + + output.into_data().assert_approx_eq(&expected, 3); + + let data = TensorData::from([1.5, 2.5, 3.5, 4.5, 5.5, 6.5]); + let tensor = Tensor::::from_data(data, &Default::default()); + + let output = tensor.round(); + let expected = TensorData::from([2., 2., 4., 4., 6., 6.]); + + output.into_data().assert_approx_eq(&expected, 3); + } +}