Add `round`, `floor`, `ceil` for float tensor (#2372)

* Add round, floor, ceil into FloatTensorOps trait; Impl round, floor, ceil for tensor, tch, ndarray, candle; Add tests for autodiff

* Add test for round, floor, ceil in burn tensor backend

* Add test for round, floor, ceil in burn candle backend

* Impl round, floor, ceil for burn fushion backend

* Update burn book

* Fix round gradient != 0 issue

* Add tests for halfway cases

* Use `round_ties_even` for `round` in ndarray backend

* Impl round to even for candle

* Add round, floor, ceil for burn router

* Add round, floor, ceil for jit backend; Upgrade cubecl

* Add round_ties_even for no-std in ndarray backend

* Be explicit about what rounding strategy is used

---------

Co-authored-by: Guillaume Lagrange <lagrange.guillaume.1@gmail.com>
This commit is contained in:
med1844 2024-10-21 09:18:43 -07:00 committed by GitHub
parent f3968cbeac
commit 6515a081aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 611 additions and 6 deletions

12
Cargo.lock generated
View File

@ -513,7 +513,7 @@ dependencies = [
name = "burn-common" name = "burn-common"
version = "0.15.0" version = "0.15.0"
dependencies = [ dependencies = [
"cubecl-common", "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=7854b343975c990dd8bb1b4b68b3bc9bda488c1d)",
"dashmap", "dashmap",
"getrandom", "getrandom",
"indicatif", "indicatif",
@ -1452,7 +1452,7 @@ version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c" source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
"cubecl-macros", "cubecl-macros",
"cubecl-runtime", "cubecl-runtime",
"derive-new", "derive-new",
@ -1483,7 +1483,7 @@ version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c" source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"cubecl-common", "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
"cubecl-core", "cubecl-core",
"cubecl-cpp", "cubecl-cpp",
"cubecl-runtime", "cubecl-runtime",
@ -1534,7 +1534,7 @@ name = "cubecl-macros"
version = "0.2.0" version = "0.2.0"
source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c" source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd011907b4e662388c#e3fdc96ec2d68dcdde8135bd011907b4e662388c"
dependencies = [ dependencies = [
"cubecl-common", "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
"darling", "darling",
"derive-new", "derive-new",
"ident_case", "ident_case",
@ -1566,7 +1566,7 @@ source = "git+https://github.com/tracel-ai/cubecl?rev=e3fdc96ec2d68dcdde8135bd01
dependencies = [ dependencies = [
"async-channel", "async-channel",
"cfg_aliases 0.2.1", "cfg_aliases 0.2.1",
"cubecl-common", "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
"derive-new", "derive-new",
"dirs 5.0.1", "dirs 5.0.1",
"hashbrown 0.14.5", "hashbrown 0.14.5",
@ -1601,7 +1601,7 @@ dependencies = [
"async-channel", "async-channel",
"bytemuck", "bytemuck",
"cfg_aliases 0.2.1", "cfg_aliases 0.2.1",
"cubecl-common", "cubecl-common 0.2.0 (git+https://github.com/tracel-ai/cubecl?rev=547c3fe249deae7d0efe13a47d368c34ce9a736c)",
"cubecl-core", "cubecl-core",
"cubecl-runtime", "cubecl-runtime",
"cubecl-spirv", "cubecl-spirv",

View File

@ -250,9 +250,11 @@ Those operations are only available for `Float` tensors.
| Burn API | PyTorch Equivalent | | Burn API | PyTorch Equivalent |
| -------------------------------------------- | ---------------------------------- | | -------------------------------------------- | ---------------------------------- |
| `tensor.ceil()` | `tensor.ceil()` |
| `tensor.cos()` | `tensor.cos()` | | `tensor.cos()` | `tensor.cos()` |
| `tensor.erf()` | `tensor.erf()` | | `tensor.erf()` | `tensor.erf()` |
| `tensor.exp()` | `tensor.exp()` | | `tensor.exp()` | `tensor.exp()` |
| `tensor.floor()` | `tensor.floor()` |
| `tensor.from_floats(floats, device)` | N/A | | `tensor.from_floats(floats, device)` | N/A |
| `tensor.from_full_precision(tensor)` | N/A | | `tensor.from_full_precision(tensor)` | N/A |
| `tensor.int()` | Similar to `tensor.to(torch.long)` | | `tensor.int()` | Similar to `tensor.to(torch.long)` |
@ -264,6 +266,7 @@ Those operations are only available for `Float` tensors.
| `tensor.random(shape, distribution, device)` | N/A | | `tensor.random(shape, distribution, device)` | N/A |
| `tensor.random_like(distribution)` | `torch.rand_like()` only uniform | | `tensor.random_like(distribution)` | `torch.rand_like()` only uniform |
| `tensor.recip()` | `tensor.reciprocal()` | | `tensor.recip()` | `tensor.reciprocal()` |
| `tensor.round()` | `tensor.round()` |
| `tensor.sin()` | `tensor.sin()` | | `tensor.sin()` | `tensor.sin()` |
| `tensor.sqrt()` | `tensor.sqrt()` | | `tensor.sqrt()` | `tensor.sqrt()` |
| `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` | | `tensor.swap_dims(dim1, dim2)` | `tensor.transpose(dim1, dim2)` |

View File

@ -1885,6 +1885,123 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
} }
} }
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct Round;
retro_unary!(RetroRound, B::float_round);
impl<B: Backend> Backward<B, 1> for Round {
type State = (Shape, B::Device);
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let (shape, device) = ops.state;
unary::<B, _>(ops.parents, ops.node, grads, |_grad| {
B::float_zeros(shape, &device)
})
}
}
match Round
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroRound::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(preps) => preps.finish(
(
B::float_shape(&tensor.primitive),
B::float_device(&tensor.primitive),
),
B::float_round(tensor.primitive),
),
OpsKind::UnTracked(preps) => preps.finish(B::float_round(tensor.primitive)),
}
}
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct Floor;
retro_unary!(RetroFloor, B::float_floor);
impl<B: Backend> Backward<B, 1> for Floor {
type State = (Shape, B::Device);
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let (shape, device) = ops.state;
unary::<B, _>(ops.parents, ops.node, grads, |_grad| {
B::float_zeros(shape, &device)
})
}
}
match Floor
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroFloor::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(preps) => preps.finish(
(
B::float_shape(&tensor.primitive),
B::float_device(&tensor.primitive),
),
B::float_floor(tensor.primitive),
),
OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)),
}
}
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)]
struct Ceil;
retro_unary!(RetroCeil, B::float_ceil);
impl<B: Backend> Backward<B, 1> for Ceil {
type State = (Shape, B::Device);
fn backward(
self,
ops: Ops<Self::State, 1>,
grads: &mut Gradients,
_checkpointer: &mut Checkpointer,
) {
let (shape, device) = ops.state;
unary::<B, _>(ops.parents, ops.node, grads, |_grad| {
B::float_zeros(shape, &device)
})
}
}
match Ceil
.prepare::<C>([tensor.node.clone()])
.memory_bound()
.retro_forward(RetroCeil::<B>::new(tensor.node.id))
.parents([&tensor])
.stateful()
{
OpsKind::Tracked(preps) => preps.finish(
(
B::float_shape(&tensor.primitive),
B::float_device(&tensor.primitive),
),
B::float_floor(tensor.primitive),
),
OpsKind::UnTracked(preps) => preps.finish(B::float_floor(tensor.primitive)),
}
}
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> { fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
#[derive(Debug)] #[derive(Debug)]
struct Erf; struct Erf;

View File

@ -0,0 +1,25 @@
#[burn_tensor_testgen::testgen(ad_ceil)]
mod tests {
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_ceil() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = tensor_1.clone().ceil();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let shape = grad_1.to_data().shape;
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}
}

View File

@ -0,0 +1,25 @@
#[burn_tensor_testgen::testgen(ad_floor)]
mod tests {
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_floor() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data, &device).require_grad();
let tensor_2 = tensor_1.clone().floor();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
let shape = grad_1.to_data().shape;
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}
}

View File

@ -11,6 +11,7 @@ mod backward;
mod bridge; mod bridge;
mod broadcast; mod broadcast;
mod cat; mod cat;
mod ceil;
mod checkpoint; mod checkpoint;
mod complex; mod complex;
mod conv1d; mod conv1d;
@ -27,6 +28,7 @@ mod erf;
mod exp; mod exp;
mod expand; mod expand;
mod flip; mod flip;
mod floor;
mod gather_scatter; mod gather_scatter;
mod gelu; mod gelu;
mod gradients; mod gradients;
@ -50,6 +52,7 @@ mod recip;
mod relu; mod relu;
mod repeat_dim; mod repeat_dim;
mod reshape; mod reshape;
mod round;
mod select; mod select;
mod sigmoid; mod sigmoid;
mod sign; mod sign;
@ -127,6 +130,9 @@ macro_rules! testgen_all {
burn_autodiff::testgen_ad_abs!(); burn_autodiff::testgen_ad_abs!();
burn_autodiff::testgen_ad_sub!(); burn_autodiff::testgen_ad_sub!();
burn_autodiff::testgen_ad_tanh!(); burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_round!();
burn_autodiff::testgen_ad_floor!();
burn_autodiff::testgen_ad_ceil!();
burn_autodiff::testgen_ad_sigmoid!(); burn_autodiff::testgen_ad_sigmoid!();
burn_autodiff::testgen_ad_log_sigmoid!(); burn_autodiff::testgen_ad_log_sigmoid!();
burn_autodiff::testgen_ad_transpose!(); burn_autodiff::testgen_ad_transpose!();

View File

@ -0,0 +1,23 @@
#[burn_tensor_testgen::testgen(ad_round)]
mod tests {
use super::*;
use burn_tensor::TensorData;
#[test]
fn should_diff_round() {
let data = TensorData::from([
[-1.9751, 0.0714, 0.0643, 0.2406],
[-1.3172, 0.1252, -0.1119, -0.0127],
]);
let device = Default::default();
let tensor_1 = TestAutodiffTensor::<2>::from_data(data.clone(), &device).require_grad();
let tensor_2 = tensor_1.clone().round();
let grads = tensor_2.backward();
let grad_1 = tensor_1.grad(&grads).unwrap();
grad_1.to_data().assert_eq(
&TensorData::from([[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
false,
);
}
}

View File

@ -90,6 +90,9 @@ mod tests {
burn_tensor::testgen_argwhere_nonzero!(); burn_tensor::testgen_argwhere_nonzero!();
burn_tensor::testgen_sign!(); burn_tensor::testgen_sign!();
burn_tensor::testgen_nan!(); burn_tensor::testgen_nan!();
burn_tensor::testgen_round!();
burn_tensor::testgen_floor!();
burn_tensor::testgen_ceil!();
// TODO: https://github.com/tracel-ai/burn/issues/1237 // TODO: https://github.com/tracel-ai/burn/issues/1237
// //
@ -165,4 +168,7 @@ mod tests {
burn_autodiff::testgen_ad_tanh!(); burn_autodiff::testgen_ad_tanh!();
burn_autodiff::testgen_ad_transpose!(); burn_autodiff::testgen_ad_transpose!();
burn_autodiff::testgen_ad_expand!(); burn_autodiff::testgen_ad_expand!();
burn_autodiff::testgen_ad_round!();
burn_autodiff::testgen_ad_floor!();
burn_autodiff::testgen_ad_ceil!();
} }

View File

@ -333,6 +333,36 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
CandleTensor::new(tensor.tensor.tanh().unwrap()) CandleTensor::new(tensor.tensor.tanh().unwrap())
} }
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let inner = |tensor: FloatTensor<Self>| -> candle_core::Result<FloatTensor<Self>> {
// implements round_to_even for consistent behavior vs libtorch
// https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/runtime/register_ops_utils.h#L65-L67
let floor_a = tensor.tensor.floor()?;
let frac_part = tensor.tensor.sub(&floor_a)?;
let half = (candle_core::Tensor::ones_like(&tensor.tensor)? * 0.5)?;
let mask_half = frac_part.eq(&half)?;
let half_tensor = tensor.tensor.mul(&half)?;
let rounded_half = half_tensor.round()?;
let doubled =
rounded_half.mul(&(candle_core::Tensor::ones_like(&tensor.tensor)? * 2.0)?)?;
let standard_round = tensor.tensor.round()?;
Ok(CandleTensor::new(
mask_half.where_cond(&doubled, &standard_round)?,
))
};
inner(tensor).unwrap()
}
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.floor().unwrap())
}
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.ceil().unwrap())
}
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> { fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
CandleTensor::new(tensor.tensor.erf().unwrap()) CandleTensor::new(tensor.tensor.erf().unwrap())
} }

View File

@ -2097,4 +2097,76 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
out out
} }
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float_ops!(RoundOps, B::float_round);
let stream = tensor.stream;
let out = tensor
.client
.tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Round(desc.clone()),
),
RoundOps::<B>::new(desc),
);
out
}
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float_ops!(FloorOps, B::float_floor);
let stream = tensor.stream;
let out = tensor
.client
.tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Floor(desc.clone()),
),
FloorOps::<B>::new(desc),
);
out
}
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_float_ops!(CeilOps, B::float_ceil);
let stream = tensor.stream;
let out = tensor
.client
.tensor_uninitialized(tensor.shape.clone(), B::FloatElem::dtype());
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
out.client.register(
vec![stream],
OperationDescription::Float(
FloatElem::<Self>::dtype(),
FloatOperationDescription::Ceil(desc.clone()),
),
CeilOps::<B>::new(desc),
);
out
}
} }

View File

@ -520,6 +520,24 @@ impl RelativeOpsScalar<f32> for FloatOperationDescription {
out: desc.out.to_relative(converter), out: desc.out.to_relative(converter),
}) })
} }
FloatOperationDescription::Round(desc) => {
FloatOperationDescription::Round(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
FloatOperationDescription::Floor(desc) => {
FloatOperationDescription::Floor(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
FloatOperationDescription::Ceil(desc) => {
FloatOperationDescription::Ceil(UnaryOperationDescription {
input: desc.input.to_relative(converter),
out: desc.out.to_relative(converter),
})
}
} }
} }
} }

View File

@ -333,6 +333,36 @@ where
}) })
} }
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_op!(float(tensor) => |context, tensor| {
#[cube]
fn execute<C: Float>(input: Line<C>) -> Line<C> {
Line::round(input)
}
execute::expand::<C>(context, tensor)
})
}
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_op!(float(tensor) => |context, tensor| {
#[cube]
fn execute<C: Float>(input: Line<C>) -> Line<C> {
Line::floor(input)
}
execute::expand::<C>(context, tensor)
})
}
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_op!(float(tensor) => |context, tensor| {
#[cube]
fn execute<C: Float>(input: Line<C>) -> Line<C> {
Line::ceil(input)
}
execute::expand::<C>(context, tensor)
})
}
fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> { fn float_erf(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
unary_op!(float(tensor) => |context, tensor| { unary_op!(float(tensor) => |context, tensor| {
#[cube] #[cube]

View File

@ -20,6 +20,22 @@ use num_traits::Float;
use libm::erf; use libm::erf;
#[cfg(feature = "std")]
#[allow(dead_code)]
fn round_ties_even_wrapper(x: f64) -> f64 {
x.round_ties_even()
}
#[cfg(not(feature = "std"))]
#[allow(dead_code)]
fn round_ties_even_wrapper(x: f64) -> f64 {
if (x - x.floor()) == 0.5 {
(x * 0.5).round() * 2.0
} else {
x.round()
}
}
impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self> impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorOps<Self>
for NdArray<E, I, Q> for NdArray<E, I, Q>
{ {
@ -351,6 +367,34 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorO
NdArrayTensor::new(array) NdArrayTensor::new(array)
} }
fn float_round(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
let array = tensor
.array
// .mapv_into(|a| (a.to_f64()).round_ties_even().elem())
.mapv_into(|a| round_ties_even_wrapper(a.to_f64()).elem())
.into_shared();
NdArrayTensor::new(array)
}
fn float_floor(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
let array = tensor
.array
.mapv_into(|a| (a.to_f64()).floor().elem())
.into_shared();
NdArrayTensor::new(array)
}
fn float_ceil(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
let array = tensor
.array
.mapv_into(|a| (a.to_f64()).ceil().elem())
.into_shared();
NdArrayTensor::new(array)
}
fn float_erf(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> { fn float_erf(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
let array = tensor let array = tensor
.array .array

View File

@ -1076,6 +1076,60 @@ impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
out out
} }
fn float_round(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let client = tensor.client.clone();
let dtype = tensor.dtype;
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::Float(
dtype,
FloatOperationDescription::Round(desc),
));
out
}
fn float_floor(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let client = tensor.client.clone();
let dtype = tensor.dtype;
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::Float(
dtype,
FloatOperationDescription::Floor(desc),
));
out
}
fn float_ceil(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let client = tensor.client.clone();
let dtype = tensor.dtype;
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);
let desc = UnaryOperationDescription {
input: tensor.into_description(),
out: out.to_description_out(),
};
client.register(OperationDescription::Float(
dtype,
FloatOperationDescription::Ceil(desc),
));
out
}
fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> { fn float_recip(tensor: FloatTensor<Self>) -> FloatTensor<Self> {
let client = tensor.client.clone(); let client = tensor.client.clone();
let dtype = tensor.dtype; let dtype = tensor.dtype;

View File

@ -826,6 +826,15 @@ where
FloatOperationDescription::Tanh(desc) => { FloatOperationDescription::Tanh(desc) => {
unary_float_ops!(handles, desc, B::float_tanh) unary_float_ops!(handles, desc, B::float_tanh)
} }
FloatOperationDescription::Round(desc) => {
unary_float_ops!(handles, desc, B::float_round)
}
FloatOperationDescription::Floor(desc) => {
unary_float_ops!(handles, desc, B::float_floor)
}
FloatOperationDescription::Ceil(desc) => {
unary_float_ops!(handles, desc, B::float_ceil)
}
FloatOperationDescription::IntoInt(desc) => { FloatOperationDescription::IntoInt(desc) => {
let tensor = handles.get_float_tensor::<B>(&desc.input); let tensor = handles.get_float_tensor::<B>(&desc.input);

View File

@ -357,6 +357,18 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh()) tensor.unary_ops(|mut tensor| tensor.tanh_(), |tensor| tensor.tanh())
} }
fn float_round(tensor: TchTensor<E>) -> TchTensor<E> {
tensor.unary_ops(|mut tensor| tensor.round_(), |tensor| tensor.round())
}
fn float_floor(tensor: TchTensor<E>) -> TchTensor<E> {
tensor.unary_ops(|mut tensor| tensor.floor_(), |tensor| tensor.floor())
}
fn float_ceil(tensor: TchTensor<E>) -> TchTensor<E> {
tensor.unary_ops(|mut tensor| tensor.ceil_(), |tensor| tensor.ceil())
}
fn float_erf(tensor: TchTensor<E>) -> TchTensor<E> { fn float_erf(tensor: TchTensor<E>) -> TchTensor<E> {
tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf()) tensor.unary_ops(|mut tensor| tensor.erf_(), |tensor| tensor.erf())
} }

View File

@ -59,6 +59,12 @@ pub enum FloatOperationDescription {
Sin(UnaryOperationDescription), Sin(UnaryOperationDescription),
/// Operation corresponding to [tanh](crate::ops::FloatTensorOps::float_tanh). /// Operation corresponding to [tanh](crate::ops::FloatTensorOps::float_tanh).
Tanh(UnaryOperationDescription), Tanh(UnaryOperationDescription),
/// Operation corresponding to [round](crate::ops::FloatTensorOps::float_round).
Round(UnaryOperationDescription),
/// Operation corresponding to [floor](crate::ops::FloatTensorOps::float_floor).
Floor(UnaryOperationDescription),
/// Operation corresponding to [ceil](crate::ops::FloatTensorOps::float_ceil).
Ceil(UnaryOperationDescription),
/// Operation corresponding to [into_int](crate::ops::FloatTensorOps::float_into_int). /// Operation corresponding to [into_int](crate::ops::FloatTensorOps::float_into_int).
IntoInt(UnaryOperationDescription), IntoInt(UnaryOperationDescription),
/// Operation corresponding to [matmul](crate::ops::FloatTensorOps::float_matmul). /// Operation corresponding to [matmul](crate::ops::FloatTensorOps::float_matmul).
@ -1454,6 +1460,9 @@ impl FloatOperationDescription {
FloatOperationDescription::Cos(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::Cos(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::Sin(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::Tanh(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Round(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Floor(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Ceil(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out], FloatOperationDescription::IntoInt(desc) => vec![&desc.input, &desc.out],
FloatOperationDescription::Quantize(desc) => { FloatOperationDescription::Quantize(desc) => {
if let Some(offset) = &desc.qparams.offset { if let Some(offset) = &desc.qparams.offset {

View File

@ -104,6 +104,30 @@ where
))) )))
} }
/// Applies element wise round operation.
///
/// This function implements the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
/// strategy, with halfway cases rounded to the nearest integer value.
pub fn round(self) -> Self {
Self::new(TensorPrimitive::Float(B::float_round(
self.primitive.tensor(),
)))
}
/// Applies element wise floor operation.
pub fn floor(self) -> Self {
Self::new(TensorPrimitive::Float(B::float_floor(
self.primitive.tensor(),
)))
}
/// Applies element wise ceil operation.
pub fn ceil(self) -> Self {
Self::new(TensorPrimitive::Float(B::float_ceil(
self.primitive.tensor(),
)))
}
/// Create a tensor from floats (f32) on a given device. /// Create a tensor from floats (f32) on a given device.
/// ///
/// # Example /// # Example

View File

@ -931,6 +931,42 @@ pub trait FloatTensorOps<B: Backend> {
/// A tensor with the same shape as `tensor` with tangent values. /// A tensor with the same shape as `tensor` with tangent values.
fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>; fn float_tanh(tensor: FloatTensor<B>) -> FloatTensor<B>;
/// Returns a new tensor with rounded values.
///
/// This function should implemented the [round half to even](https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even)
/// strategy, with halfway cases rounded to the nearest integer value.
///
/// # Arguments
///
/// * `tensor` - The tensor to be rounded.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with rounded values.
fn float_round(tensor: FloatTensor<B>) -> FloatTensor<B>;
/// Returns a new tensor with floored values.
///
/// # Arguments
///
/// * `tensor` - The tensor to be floored.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with floored values.
fn float_floor(tensor: FloatTensor<B>) -> FloatTensor<B>;
/// Returns a new tensor with ceiled values.
///
/// # Arguments
///
/// * `tensor` - The tensor to be ceiled.
///
/// # Returns
///
/// A tensor with the same shape as `tensor` with ceiled values.
fn float_ceil(tensor: FloatTensor<B>) -> FloatTensor<B>;
/// Returns a new tensor with the error function values. /// Returns a new tensor with the error function values.
/// ///
/// # Arguments /// # Arguments

View File

@ -106,6 +106,9 @@ macro_rules! testgen_all {
burn_tensor::testgen_remainder!(); burn_tensor::testgen_remainder!();
burn_tensor::testgen_cartesian_grid!(); burn_tensor::testgen_cartesian_grid!();
burn_tensor::testgen_nan!(); burn_tensor::testgen_nan!();
burn_tensor::testgen_round!();
burn_tensor::testgen_floor!();
burn_tensor::testgen_ceil!();
// test stats // test stats
burn_tensor::testgen_var!(); burn_tensor::testgen_var!();

View File

@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(ceil)]
mod tests {
use super::*;
use burn_tensor::{Tensor, TensorData};
#[test]
fn should_support_ceil_ops() {
let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let output = tensor.ceil();
let expected = TensorData::from([[25., 88., 77.], [60., 44., 95.]]);
output.into_data().assert_approx_eq(&expected, 3);
}
}

View File

@ -0,0 +1,16 @@
#[burn_tensor_testgen::testgen(floor)]
mod tests {
use super::*;
use burn_tensor::{Tensor, TensorData};
#[test]
fn should_support_floor_ops() {
let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let output = tensor.floor();
let expected = TensorData::from([[24., 87., 76.], [59., 43., 94.]]);
output.into_data().assert_approx_eq(&expected, 3);
}
}

View File

@ -11,6 +11,7 @@ mod bool;
mod cartesian_grid; mod cartesian_grid;
mod cast; mod cast;
mod cat; mod cat;
mod ceil;
mod chunk; mod chunk;
mod clamp; mod clamp;
mod close; mod close;
@ -22,6 +23,7 @@ mod exp;
mod expand; mod expand;
mod flatten; mod flatten;
mod flip; mod flip;
mod floor;
mod full; mod full;
mod gather_scatter; mod gather_scatter;
mod init; mod init;
@ -48,6 +50,7 @@ mod remainder;
mod repeat; mod repeat;
mod repeat_dim; mod repeat_dim;
mod reshape; mod reshape;
mod round;
mod select; mod select;
mod sign; mod sign;
mod sin; mod sin;

View File

@ -0,0 +1,24 @@
#[burn_tensor_testgen::testgen(round)]
mod tests {
use super::*;
use burn_tensor::{Tensor, TensorData};
#[test]
fn should_support_round_ops() {
let data = TensorData::from([[24.0423, 87.9478, 76.1838], [59.6929, 43.8169, 94.8826]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data, &Default::default());
let output = tensor.round();
let expected = TensorData::from([[24., 88., 76.], [60., 44., 95.]]);
output.into_data().assert_approx_eq(&expected, 3);
let data = TensorData::from([1.5, 2.5, 3.5, 4.5, 5.5, 6.5]);
let tensor = Tensor::<TestBackend, 1>::from_data(data, &Default::default());
let output = tensor.round();
let expected = TensorData::from([2., 2., 4., 4., 6., 6.]);
output.into_data().assert_approx_eq(&expected, 3);
}
}