From 525244062f4d1a74568db1f158033c6c8ae05f40 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 14 Jun 2024 09:02:38 -0400 Subject: [PATCH] Implement `Element` for `bool` (#1878) * Element already implements One * Add element module * Add our own traits for Zero, One and ToPrimitive to support bool Element * Fix typo * Add basic tests for ToPrimitive with expected values * The most important change of all * Remove One + Zero identities * Move zero/one outside mapv + refactor ToPrimitive -> ToElement trait * Add num-traits to NOTICES.md --- NOTICES.md | 33 + crates/burn-jit/src/kernel/pool/max_pool2d.rs | 6 +- .../src/kernel/reduce/naive/argmax.rs | 3 +- .../src/kernel/reduce/naive/argmin.rs | 3 +- .../src/kernel/reduce/shared/argmax.rs | 2 +- .../src/kernel/reduce/shared/argmin.rs | 2 +- crates/burn-ndarray/src/element.rs | 2 - crates/burn-ndarray/src/ops/base.rs | 12 +- crates/burn-ndarray/src/ops/macros.rs | 3 +- crates/burn-ndarray/src/ops/tensor.rs | 10 +- crates/burn-tensor/src/tensor/api/base.rs | 3 +- crates/burn-tensor/src/tensor/api/numeric.rs | 3 +- crates/burn-tensor/src/tensor/data.rs | 12 - .../tensor/{element.rs => element/base.rs} | 42 +- crates/burn-tensor/src/tensor/element/cast.rs | 579 ++++++++++++++++++ crates/burn-tensor/src/tensor/element/mod.rs | 6 + .../burn-tensor/src/tensor/ops/int_tensor.rs | 7 +- crates/burn-tensor/src/tensor/ops/tensor.rs | 4 +- 18 files changed, 670 insertions(+), 62 deletions(-) rename crates/burn-tensor/src/tensor/{element.rs => element/base.rs} (84%) create mode 100644 crates/burn-tensor/src/tensor/element/cast.rs create mode 100644 crates/burn-tensor/src/tensor/element/mod.rs diff --git a/NOTICES.md b/NOTICES.md index 8848c7a40..77b97d65a 100644 --- a/NOTICES.md +++ b/NOTICES.md @@ -303,3 +303,36 @@ SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +## num-traits + +**Source:** https://github.com/rust-num/num-traits/blob/master/src/cast.rs + +MIT License + +Copyright (c) 2014 The Rust Project Developers + +Permission is hereby granted, free of charge, to any +person obtaining a copy of this software and associated +documentation files (the "Software"), to deal in the +Software without restriction, including without +limitation the rights to use, copy, modify, merge, +publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software +is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice +shall be included in all copies or substantial portions +of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF +ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED +TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A +PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT +SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION +OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR +IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d.rs b/crates/burn-jit/src/kernel/pool/max_pool2d.rs index cae5c3b42..7e7d5f0bc 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d.rs @@ -21,8 +21,7 @@ impl PoolStrategy for MaxPool { fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator { let max_val = scope.create_local(item); - let max_initial = - Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem()); + let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem()); cpa!(scope, max_val = max_initial); max_val } @@ -68,8 +67,7 @@ impl PoolStrategy for MaxPoolWithIndices { fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator { let max_val = scope.create_local(item); - let max_initial = - Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem()); + let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem()); cpa!(scope, max_val = max_initial); let max_index = scope.create_local(Elem::UInt); (max_val, max_index) diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs b/crates/burn-jit/src/kernel/reduce/naive/argmax.rs index c3847bcc8..e3d3a3c68 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/argmax.rs @@ -16,8 +16,7 @@ impl ReduceDimNaive for Argmax { ) -> Self::Accumulator { let index = scope.create_local(Elem::UInt); let max = scope.create_local(input_item); - let max_initial = - Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem()); + let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem()); cpa!(scope, max = max_initial); (max, index) diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs b/crates/burn-jit/src/kernel/reduce/naive/argmin.rs index 632f76550..fe5b2d13f 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/argmin.rs @@ -17,8 +17,7 @@ impl ReduceDimNaive for Argmin { ) -> Self::Accumulator { let index = scope.create_local(Elem::UInt); let min = scope.create_local(input_item); - let min_initial = - Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem()); + let min_initial = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem()); cpa!(scope, min = min_initial); (min, index) diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs index 54c1594c7..977be5aeb 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs @@ -18,7 +18,7 @@ impl ReduceDimShared for Argmax { let value_shared_memory = scope.create_shared(input_item, shared_memory_size); let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); - let max = Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem()); + let max = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem()); cpa!(scope, value_shared_memory[write_position] = max); (value_shared_memory, index_shared_memory) } diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs index 198a89201..aa11511bb 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs @@ -19,7 +19,7 @@ impl ReduceDimShared for Argmin { let value_shared_memory = scope.create_shared(input_item, shared_memory_size); let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); - let min = Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem()); + let min = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem()); cpa!(scope, value_shared_memory[write_position] = min); (value_shared_memory, index_shared_memory) } diff --git a/crates/burn-ndarray/src/element.rs b/crates/burn-ndarray/src/element.rs index 56fd1624b..cb4ad14e2 100644 --- a/crates/burn-ndarray/src/element.rs +++ b/crates/burn-ndarray/src/element.rs @@ -1,6 +1,5 @@ use burn_tensor::Element; use ndarray::LinalgScalar; -use num_traits::One; use num_traits::Signed; #[cfg(not(feature = "std"))] @@ -20,7 +19,6 @@ where /// A general element for ndarray backend. pub trait NdArrayElement: Element - + One + ndarray::LinalgScalar + ndarray::ScalarOperand + ExpElement diff --git a/crates/burn-ndarray/src/ops/base.rs b/crates/burn-ndarray/src/ops/base.rs index 1d4ce975f..8c72e61cb 100644 --- a/crates/burn-ndarray/src/ops/base.rs +++ b/crates/burn-ndarray/src/ops/base.rs @@ -563,16 +563,18 @@ where where E: Signed, { + let zero = 0.elem(); + let one = 1.elem::(); NdArrayTensor::new( tensor .array .mapv(|x| { - if x > E::zero() { - E::one() - } else if x < E::zero() { - -E::one() + if x > zero { + one + } else if x < zero { + -one } else { - E::zero() + zero } }) .into_shared(), diff --git a/crates/burn-ndarray/src/ops/macros.rs b/crates/burn-ndarray/src/ops/macros.rs index a16e2dc21..9ec9fbf2b 100644 --- a/crates/burn-ndarray/src/ops/macros.rs +++ b/crates/burn-ndarray/src/ops/macros.rs @@ -34,6 +34,7 @@ macro_rules! keepdim { }}; } +use burn_tensor::ElementConversion; pub(crate) use keepdim; use ndarray::Axis; @@ -63,7 +64,7 @@ pub(crate) fn prod_dim( ) -> NdArrayTensor { let array = tensor .array - .fold_axis(Axis(dim), E::one(), |acc, &x| acc.mul(x.elem())) + .fold_axis(Axis(dim), 1.elem::(), |acc, &x| acc.mul(x.elem())) .into_shared(); NdArrayTensor { array } diff --git a/crates/burn-ndarray/src/ops/tensor.rs b/crates/burn-ndarray/src/ops/tensor.rs index ce638ad66..954b464fb 100644 --- a/crates/burn-ndarray/src/ops/tensor.rs +++ b/crates/burn-ndarray/src/ops/tensor.rs @@ -406,7 +406,7 @@ impl FloatTensorOps for NdArray { fn float_cos(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor .array - .mapv_into(|a| (a.to_f64().unwrap()).cos().elem()) + .mapv_into(|a| (a.to_f64()).cos().elem()) .into_shared(); NdArrayTensor::new(array) @@ -415,7 +415,7 @@ impl FloatTensorOps for NdArray { fn float_sin(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor .array - .mapv_into(|a| (a.to_f64().unwrap()).sin().elem()) + .mapv_into(|a| (a.to_f64()).sin().elem()) .into_shared(); NdArrayTensor::new(array) @@ -424,7 +424,7 @@ impl FloatTensorOps for NdArray { fn float_tanh(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor .array - .mapv_into(|a| (a.to_f64().unwrap()).tanh().elem()) + .mapv_into(|a| (a.to_f64()).tanh().elem()) .into_shared(); NdArrayTensor::new(array) @@ -433,7 +433,7 @@ impl FloatTensorOps for NdArray { fn float_erf(tensor: NdArrayTensor) -> NdArrayTensor { let array = tensor .array - .mapv_into(|a| erf(a.to_f64().unwrap()).elem()) + .mapv_into(|a| erf(a.to_f64()).elem()) .into_shared(); NdArrayTensor::new(array) @@ -473,7 +473,7 @@ impl FloatTensorOps for NdArray { lhs: NdArrayTensor, rhs: NdArrayTensor, ) -> NdArrayTensor { - NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| a.powf_elem(b.to_f32().unwrap())) + NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| a.powf_elem(b.to_f32())) } fn float_permute( diff --git a/crates/burn-tensor/src/tensor/api/base.rs b/crates/burn-tensor/src/tensor/api/base.rs index e3a9de7bc..11b2fff2e 100644 --- a/crates/burn-tensor/src/tensor/api/base.rs +++ b/crates/burn-tensor/src/tensor/api/base.rs @@ -20,6 +20,7 @@ use serde::{Serialize, Serializer}; use crate::check::TensorCheck; use crate::tensor::api::chunk::chunk; use crate::tensor::api::narrow::narrow; +use crate::Element; use crate::{backend::Backend, check, Bool, Data, DataSerialize, Float, Int, Shape, TensorKind}; /// A tensor with a given backend, shape and data type. @@ -1213,7 +1214,7 @@ impl core::ops::BitXor for Tensor { /// This is an internal trait, use the public API provided by [tensor struct](Tensor). pub trait BasicOps: TensorKind { /// The type of the tensor elements. - type Elem: 'static + Copy; + type Elem: Element; /// Creates an empty tensor with the given shape. /// diff --git a/crates/burn-tensor/src/tensor/api/numeric.rs b/crates/burn-tensor/src/tensor/api/numeric.rs index 705d3b007..7bbaa363e 100644 --- a/crates/burn-tensor/src/tensor/api/numeric.rs +++ b/crates/burn-tensor/src/tensor/api/numeric.rs @@ -6,7 +6,6 @@ use crate::{ backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element, ElementConversion, Float, Int, Shape, Tensor, TensorKind, }; -use num_traits::Zero; impl Tensor where @@ -656,7 +655,7 @@ where /// /// A boolean tensor with the same shape as the input tensor. pub fn bool(self) -> Tensor { - K::not_equal_elem::(self.primitive, K::Elem::zero()) + K::not_equal_elem::(self.primitive, 0.elem()) } /// Create a random tensor of the given shape on the given device where each element is diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index 4b3d60aea..8f870a02a 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -192,18 +192,6 @@ impl DataSerialize { } } -impl Data { - /// Converts the data to a different element type. - pub fn convert(self) -> Data { - let value: Vec = self.value.into_iter().map(|a| (a as i64).elem()).collect(); - - Data { - value, - shape: self.shape, - } - } -} - impl Data { /// Populates the data with random values. pub fn random(shape: Shape, distribution: Distribution, rng: &mut R) -> Self { diff --git a/crates/burn-tensor/src/tensor/element.rs b/crates/burn-tensor/src/tensor/element/base.rs similarity index 84% rename from crates/burn-tensor/src/tensor/element.rs rename to crates/burn-tensor/src/tensor/element/base.rs index 7a2c99a9b..7e15822db 100644 --- a/crates/burn-tensor/src/tensor/element.rs +++ b/crates/burn-tensor/src/tensor/element/base.rs @@ -1,16 +1,13 @@ use core::cmp::Ordering; -use crate::Distribution; +use crate::{cast::ToElement, Distribution}; use half::{bf16, f16}; -use num_traits::{identities::Zero, One, ToPrimitive}; use rand::RngCore; use serde::{Deserialize, Serialize}; /// Element trait for tensor. pub trait Element: - ToPrimitive - + Zero - + One + ToElement + ElementRandom + ElementConversion + ElementPrecision @@ -38,7 +35,7 @@ pub trait ElementConversion { /// # Returns /// /// The converted element. - fn from_elem(elem: E) -> Self; + fn from_elem(elem: E) -> Self; /// Converts and returns the converted element. fn elem(self) -> E; @@ -105,7 +102,7 @@ macro_rules! make_element { } impl ElementConversion for $type { - fn from_elem(elem: E) -> Self { + fn from_elem(elem: E) -> Self { #[allow(clippy::redundant_closure_call)] $convert(&elem) } @@ -140,7 +137,7 @@ macro_rules! make_element { make_element!( ty f64 Precision::Double, - convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(), + convert |elem: &dyn ToElement| elem.to_f64(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &f64, b: &f64| a.total_cmp(b), dtype DType::F64 @@ -148,7 +145,7 @@ make_element!( make_element!( ty f32 Precision::Full, - convert |elem: &dyn ToPrimitive| elem.to_f32().unwrap(), + convert |elem: &dyn ToElement| elem.to_f32(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &f32, b: &f32| a.total_cmp(b), dtype DType::F32 @@ -156,7 +153,7 @@ make_element!( make_element!( ty i64 Precision::Double, - convert |elem: &dyn ToPrimitive| elem.to_i64().unwrap(), + convert |elem: &dyn ToElement| elem.to_i64(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &i64, b: &i64| Ord::cmp(a, b), dtype DType::I64 @@ -164,7 +161,7 @@ make_element!( make_element!( ty i32 Precision::Full, - convert |elem: &dyn ToPrimitive| elem.to_i32().unwrap(), + convert |elem: &dyn ToElement| elem.to_i32(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &i32, b: &i32| Ord::cmp(a, b), dtype DType::I32 @@ -172,7 +169,7 @@ make_element!( make_element!( ty u32 Precision::Full, - convert |elem: &dyn ToPrimitive| elem.to_u32().unwrap(), + convert |elem: &dyn ToElement| elem.to_u32(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &u32, b: &u32| Ord::cmp(a, b), dtype DType::U32 @@ -180,7 +177,7 @@ make_element!( make_element!( ty i16 Precision::Half, - convert |elem: &dyn ToPrimitive| elem.to_i16().unwrap(), + convert |elem: &dyn ToElement| elem.to_i16(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &i16, b: &i16| Ord::cmp(a, b), dtype DType::I16 @@ -188,7 +185,7 @@ make_element!( make_element!( ty i8 Precision::Other, - convert |elem: &dyn ToPrimitive| elem.to_i8().unwrap(), + convert |elem: &dyn ToElement| elem.to_i8(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &i8, b: &i8| Ord::cmp(a, b), dtype DType::I8 @@ -196,7 +193,7 @@ make_element!( make_element!( ty u8 Precision::Other, - convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(), + convert |elem: &dyn ToElement| elem.to_u8(), random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(), cmp |a: &u8, b: &u8| Ord::cmp(a, b), dtype DType::U8 @@ -204,7 +201,7 @@ make_element!( make_element!( ty f16 Precision::Half, - convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()), + convert |elem: &dyn ToElement| f16::from_f32(elem.to_f32()), random |distribution: Distribution, rng: &mut R| { let sample: f32 = distribution.sampler(rng).sample(); f16::from_elem(sample) @@ -214,7 +211,7 @@ make_element!( ); make_element!( ty bf16 Precision::Half, - convert |elem: &dyn ToPrimitive| bf16::from_f32(elem.to_f32().unwrap()), + convert |elem: &dyn ToElement| bf16::from_f32(elem.to_f32()), random |distribution: Distribution, rng: &mut R| { let sample: f32 = distribution.sampler(rng).sample(); bf16::from_elem(sample) @@ -223,6 +220,17 @@ make_element!( dtype DType::BF16 ); +make_element!( + ty bool Precision::Other, + convert |elem: &dyn ToElement| elem.to_u8() != 0, + random |distribution: Distribution, rng: &mut R| { + let sample: u8 = distribution.sampler(rng).sample(); + bool::from_elem(sample) + }, + cmp |a: &bool, b: &bool| Ord::cmp(a, b), + dtype DType::Bool +); + #[allow(missing_docs)] #[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)] pub enum DType { diff --git a/crates/burn-tensor/src/tensor/element/cast.rs b/crates/burn-tensor/src/tensor/element/cast.rs new file mode 100644 index 000000000..6cb687c3e --- /dev/null +++ b/crates/burn-tensor/src/tensor/element/cast.rs @@ -0,0 +1,579 @@ +use core::mem::size_of; + +use half::{bf16, f16}; + +/// A generic trait for converting a value to a number. +/// Adapted from [num_traits::ToPrimitive] to support [bool]. +/// +/// A value can be represented by the target type when it lies within +/// the range of scalars supported by the target type. +/// For example, a negative integer cannot be represented by an unsigned +/// integer type, and an `i64` with a very high magnitude might not be +/// convertible to an `i32`. +/// On the other hand, conversions with possible precision loss or truncation +/// are admitted, like an `f32` with a decimal part to an integer type, or +/// even a large `f64` saturating to `f32` infinity. +/// +/// The methods *panic* when the value cannot be represented by the target type. +pub trait ToElement { + /// Converts the value of `self` to an `isize`. + #[inline] + fn to_isize(&self) -> isize { + ToElement::to_isize(&self.to_i64()) + } + + /// Converts the value of `self` to an `i8`. + #[inline] + fn to_i8(&self) -> i8 { + ToElement::to_i8(&self.to_i64()) + } + + /// Converts the value of `self` to an `i16`. + #[inline] + fn to_i16(&self) -> i16 { + ToElement::to_i16(&self.to_i64()) + } + + /// Converts the value of `self` to an `i32`. + #[inline] + fn to_i32(&self) -> i32 { + ToElement::to_i32(&self.to_i64()) + } + + /// Converts the value of `self` to an `i64`. + fn to_i64(&self) -> i64; + + /// Converts the value of `self` to an `i128`. + /// + /// The default implementation converts through `to_i64()`. Types implementing + /// this trait should override this method if they can represent a greater range. + #[inline] + fn to_i128(&self) -> i128 { + i128::from(self.to_i64()) + } + + /// Converts the value of `self` to a `usize`. + #[inline] + fn to_usize(&self) -> usize { + ToElement::to_usize(&self.to_u64()) + } + + /// Converts the value of `self` to a `u8`. + #[inline] + fn to_u8(&self) -> u8 { + ToElement::to_u8(&self.to_u64()) + } + + /// Converts the value of `self` to a `u16`. + #[inline] + fn to_u16(&self) -> u16 { + ToElement::to_u16(&self.to_u64()) + } + + /// Converts the value of `self` to a `u32`. + #[inline] + fn to_u32(&self) -> u32 { + ToElement::to_u32(&self.to_u64()) + } + + /// Converts the value of `self` to a `u64`. + fn to_u64(&self) -> u64; + + /// Converts the value of `self` to a `u128`. + /// + /// The default implementation converts through `to_u64()`. Types implementing + /// this trait should override this method if they can represent a greater range. + #[inline] + fn to_u128(&self) -> u128 { + u128::from(self.to_u64()) + } + + /// Converts the value of `self` to an `f32`. Overflows may map to positive + /// or negative infinity. + #[inline] + fn to_f32(&self) -> f32 { + ToElement::to_f32(&self.to_f64()) + } + + /// Converts the value of `self` to an `f64`. Overflows may map to positive + /// or negative infinity. + /// + /// The default implementation tries to convert through `to_i64()`, and + /// failing that through `to_u64()`. Types implementing this trait should + /// override this method if they can represent a greater range. + #[inline] + fn to_f64(&self) -> f64 { + ToElement::to_f64(&self.to_u64()) + } +} + +macro_rules! impl_to_element_int_to_int { + ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( + #[inline] + $(#[$cfg])* + fn $method(&self) -> $DstT { + let min = $DstT::MIN as $SrcT; + let max = $DstT::MAX as $SrcT; + if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) { + *self as $DstT + } else { + panic!("Element cannot be represented in the target type") + } + } + )*} +} + +macro_rules! impl_to_element_int_to_uint { + ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( + #[inline] + $(#[$cfg])* + fn $method(&self) -> $DstT { + let max = $DstT::MAX as $SrcT; + if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) { + *self as $DstT + } else { + panic!("Element cannot be represented in the target type") + } + } + )*} +} + +macro_rules! impl_to_element_int { + ($T:ident) => { + impl ToElement for $T { + impl_to_element_int_to_int! { $T: + fn to_isize -> isize; + fn to_i8 -> i8; + fn to_i16 -> i16; + fn to_i32 -> i32; + fn to_i64 -> i64; + fn to_i128 -> i128; + } + + impl_to_element_int_to_uint! { $T: + fn to_usize -> usize; + fn to_u8 -> u8; + fn to_u16 -> u16; + fn to_u32 -> u32; + fn to_u64 -> u64; + fn to_u128 -> u128; + } + + #[inline] + fn to_f32(&self) -> f32 { + *self as f32 + } + #[inline] + fn to_f64(&self) -> f64 { + *self as f64 + } + } + }; +} + +impl_to_element_int!(isize); +impl_to_element_int!(i8); +impl_to_element_int!(i16); +impl_to_element_int!(i32); +impl_to_element_int!(i64); +impl_to_element_int!(i128); + +macro_rules! impl_to_element_uint_to_int { + ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( + #[inline] + $(#[$cfg])* + fn $method(&self) -> $DstT { + let max = $DstT::MAX as $SrcT; + if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max { + *self as $DstT + } else { + panic!("Element cannot be represented in the target type") + } + } + )*} +} + +macro_rules! impl_to_element_uint_to_uint { + ($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$( + #[inline] + $(#[$cfg])* + fn $method(&self) -> $DstT { + let max = $DstT::MAX as $SrcT; + if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max { + *self as $DstT + } else { + panic!("Element cannot be represented in the target type") + } + } + )*} +} + +macro_rules! impl_to_element_uint { + ($T:ident) => { + impl ToElement for $T { + impl_to_element_uint_to_int! { $T: + fn to_isize -> isize; + fn to_i8 -> i8; + fn to_i16 -> i16; + fn to_i32 -> i32; + fn to_i64 -> i64; + fn to_i128 -> i128; + } + + impl_to_element_uint_to_uint! { $T: + fn to_usize -> usize; + fn to_u8 -> u8; + fn to_u16 -> u16; + fn to_u32 -> u32; + fn to_u64 -> u64; + fn to_u128 -> u128; + } + + #[inline] + fn to_f32(&self) -> f32 { + *self as f32 + } + #[inline] + fn to_f64(&self) -> f64 { + *self as f64 + } + } + }; +} + +impl_to_element_uint!(usize); +impl_to_element_uint!(u8); +impl_to_element_uint!(u16); +impl_to_element_uint!(u32); +impl_to_element_uint!(u64); +impl_to_element_uint!(u128); + +macro_rules! impl_to_element_float_to_float { + ($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$( + #[inline] + fn $method(&self) -> $DstT { + // We can safely cast all values, whether NaN, +-inf, or finite. + // Finite values that are reducing size may saturate to +-inf. + *self as $DstT + } + )*} +} + +macro_rules! float_to_int_unchecked { + // SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating. + // We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`. + ($float:expr => $int:ty) => { + unsafe { $float.to_int_unchecked::<$int>() } + }; +} + +macro_rules! impl_to_element_float_to_signed_int { + ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$( + #[inline] + $(#[$cfg])* + fn $method(&self) -> $i { + // Float as int truncates toward zero, so we want to allow values + // in the exclusive range `(MIN-1, MAX+1)`. + if size_of::<$f>() > size_of::<$i>() { + // With a larger size, we can represent the range exactly. + const MIN_M1: $f = $i::MIN as $f - 1.0; + const MAX_P1: $f = $i::MAX as $f + 1.0; + if *self > MIN_M1 && *self < MAX_P1 { + return float_to_int_unchecked!(*self => $i); + } + } else { + // We can't represent `MIN-1` exactly, but there's no fractional part + // at this magnitude, so we can just use a `MIN` inclusive boundary. + const MIN: $f = $i::MIN as $f; + // We can't represent `MAX` exactly, but it will round up to exactly + // `MAX+1` (a power of two) when we cast it. + const MAX_P1: $f = $i::MAX as $f; + if *self >= MIN && *self < MAX_P1 { + return float_to_int_unchecked!(*self => $i); + } + } + panic!("Float cannot be represented in the target signed int type") + } + )*} +} + +macro_rules! impl_to_element_float_to_unsigned_int { + ($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$( + #[inline] + $(#[$cfg])* + fn $method(&self) -> $u { + // Float as int truncates toward zero, so we want to allow values + // in the exclusive range `(-1, MAX+1)`. + if size_of::<$f>() > size_of::<$u>() { + // With a larger size, we can represent the range exactly. + const MAX_P1: $f = $u::MAX as $f + 1.0; + if *self > -1.0 && *self < MAX_P1 { + return float_to_int_unchecked!(*self => $u); + } + } else { + // We can't represent `MAX` exactly, but it will round up to exactly + // `MAX+1` (a power of two) when we cast it. + // (`u128::MAX as f32` is infinity, but this is still ok.) + const MAX_P1: $f = $u::MAX as $f; + if *self > -1.0 && *self < MAX_P1 { + return float_to_int_unchecked!(*self => $u); + } + } + panic!("Float cannot be represented in the target unsigned int type") + } + )*} +} + +macro_rules! impl_to_element_float { + ($T:ident) => { + impl ToElement for $T { + impl_to_element_float_to_signed_int! { $T: + fn to_isize -> isize; + fn to_i8 -> i8; + fn to_i16 -> i16; + fn to_i32 -> i32; + fn to_i64 -> i64; + fn to_i128 -> i128; + } + + impl_to_element_float_to_unsigned_int! { $T: + fn to_usize -> usize; + fn to_u8 -> u8; + fn to_u16 -> u16; + fn to_u32 -> u32; + fn to_u64 -> u64; + fn to_u128 -> u128; + } + + impl_to_element_float_to_float! { $T: + fn to_f32 -> f32; + fn to_f64 -> f64; + } + } + }; +} + +impl_to_element_float!(f32); +impl_to_element_float!(f64); + +impl ToElement for f16 { + #[inline] + fn to_i64(&self) -> i64 { + Self::to_f32(*self).to_i64() + } + #[inline] + fn to_u64(&self) -> u64 { + Self::to_f32(*self).to_u64() + } + #[inline] + fn to_i8(&self) -> i8 { + Self::to_f32(*self).to_i8() + } + #[inline] + fn to_u8(&self) -> u8 { + Self::to_f32(*self).to_u8() + } + #[inline] + fn to_i16(&self) -> i16 { + Self::to_f32(*self).to_i16() + } + #[inline] + fn to_u16(&self) -> u16 { + Self::to_f32(*self).to_u16() + } + #[inline] + fn to_i32(&self) -> i32 { + Self::to_f32(*self).to_i32() + } + #[inline] + fn to_u32(&self) -> u32 { + Self::to_f32(*self).to_u32() + } + #[inline] + fn to_f32(&self) -> f32 { + Self::to_f32(*self) + } + #[inline] + fn to_f64(&self) -> f64 { + Self::to_f64(*self) + } +} + +impl ToElement for bf16 { + #[inline] + fn to_i64(&self) -> i64 { + Self::to_f32(*self).to_i64() + } + #[inline] + fn to_u64(&self) -> u64 { + Self::to_f32(*self).to_u64() + } + #[inline] + fn to_i8(&self) -> i8 { + Self::to_f32(*self).to_i8() + } + #[inline] + fn to_u8(&self) -> u8 { + Self::to_f32(*self).to_u8() + } + #[inline] + fn to_i16(&self) -> i16 { + Self::to_f32(*self).to_i16() + } + #[inline] + fn to_u16(&self) -> u16 { + Self::to_f32(*self).to_u16() + } + #[inline] + fn to_i32(&self) -> i32 { + Self::to_f32(*self).to_i32() + } + #[inline] + fn to_u32(&self) -> u32 { + Self::to_f32(*self).to_u32() + } + #[inline] + fn to_f32(&self) -> f32 { + Self::to_f32(*self) + } + #[inline] + fn to_f64(&self) -> f64 { + Self::to_f64(*self) + } +} + +impl ToElement for bool { + #[inline] + fn to_i64(&self) -> i64 { + *self as i64 + } + #[inline] + fn to_u64(&self) -> u64 { + *self as u64 + } + #[inline] + fn to_i8(&self) -> i8 { + *self as i8 + } + #[inline] + fn to_u8(&self) -> u8 { + *self as u8 + } + #[inline] + fn to_i16(&self) -> i16 { + *self as i16 + } + #[inline] + fn to_u16(&self) -> u16 { + *self as u16 + } + #[inline] + fn to_i32(&self) -> i32 { + *self as i32 + } + #[inline] + fn to_u32(&self) -> u32 { + *self as u32 + } + #[inline] + fn to_f32(&self) -> f32 { + self.to_u8() as f32 + } + #[inline] + fn to_f64(&self) -> f64 { + self.to_u8() as f64 + } +} + +mod tests { + #[allow(unused_imports)] + use super::*; + + #[test] + fn to_element_float() { + let f32_toolarge = 1e39f64; + assert_eq!(f32_toolarge.to_f32(), f32::INFINITY); + assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY); + assert_eq!((f32::MAX as f64).to_f32(), f32::MAX); + assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX); + assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY); + assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY); + assert!((f64::NAN).to_f32().is_nan()); + } + + #[test] + #[should_panic] + fn to_element_signed_to_u8_underflow() { + let _x = (-1i8).to_u8(); + } + + #[test] + #[should_panic] + fn to_element_signed_to_u16_underflow() { + let _x = (-1i8).to_u16(); + } + + #[test] + #[should_panic] + fn to_element_signed_to_u32_underflow() { + let _x = (-1i8).to_u32(); + } + + #[test] + #[should_panic] + fn to_element_signed_to_u64_underflow() { + let _x = (-1i8).to_u64(); + } + + #[test] + #[should_panic] + fn to_element_signed_to_u128_underflow() { + let _x = (-1i8).to_u128(); + } + + #[test] + #[should_panic] + fn to_element_signed_to_usize_underflow() { + let _x = (-1i8).to_usize(); + } + + #[test] + #[should_panic] + fn to_element_unsigned_to_u8_overflow() { + let _x = 256.to_u8(); + } + + #[test] + #[should_panic] + fn to_element_unsigned_to_u16_overflow() { + let _x = 65_536.to_u16(); + } + + #[test] + #[should_panic] + fn to_element_unsigned_to_u32_overflow() { + let _x = 4_294_967_296u64.to_u32(); + } + + #[test] + #[should_panic] + fn to_element_unsigned_to_u64_overflow() { + let _x = 18_446_744_073_709_551_616u128.to_u64(); + } + + #[test] + fn to_element_int_to_float() { + assert_eq!((-1).to_f32(), -1.0); + assert_eq!((-1).to_f64(), -1.0); + assert_eq!(255.to_f32(), 255.0); + assert_eq!(65_535.to_f64(), 65_535.0); + } + + #[test] + fn to_element_float_to_int() { + assert_eq!((-1.0).to_i8(), -1); + assert_eq!(1.0.to_u8(), 1); + assert_eq!(1.8.to_u16(), 1); + assert_eq!(123.456.to_u32(), 123); + } +} diff --git a/crates/burn-tensor/src/tensor/element/mod.rs b/crates/burn-tensor/src/tensor/element/mod.rs new file mode 100644 index 000000000..609966ec8 --- /dev/null +++ b/crates/burn-tensor/src/tensor/element/mod.rs @@ -0,0 +1,6 @@ +mod base; + +/// Tensor element casting. +pub mod cast; + +pub use base::*; diff --git a/crates/burn-tensor/src/tensor/ops/int_tensor.rs b/crates/burn-tensor/src/tensor/ops/int_tensor.rs index be792316a..bf2bf1975 100644 --- a/crates/burn-tensor/src/tensor/ops/int_tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/int_tensor.rs @@ -1,13 +1,13 @@ use super::cat::cat_with_slice_assign; use super::repeat::repeat_with_slice_assign; use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; +use crate::tensor::cast::ToElement; use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int}; use crate::{cartesian_grid, Tensor}; use crate::{tensor::api::chunk, tensor::api::narrow}; use alloc::vec::Vec; use burn_common::reader::Reader; use core::ops::Range; -use num_traits::ToPrimitive; #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use crate::{argsort, sort, sort_with_indices}; @@ -536,10 +536,7 @@ pub trait IntTensorOps { /// /// The elements of `lhs` raised to the value of `rhs`. fn int_powi_scalar(lhs: IntTensor, rhs: IntElem) -> IntTensor { - B::float_into_int(B::float_powf_scalar( - B::int_into_float(lhs), - rhs.to_f32().unwrap(), - )) + B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs.to_f32())) } /// Element-wise power with a floatTensor. diff --git a/crates/burn-tensor/src/tensor/ops/tensor.rs b/crates/burn-tensor/src/tensor/ops/tensor.rs index 27b28b307..8dfdcfb93 100644 --- a/crates/burn-tensor/src/tensor/ops/tensor.rs +++ b/crates/burn-tensor/src/tensor/ops/tensor.rs @@ -2,13 +2,13 @@ use super::cat::cat_with_slice_assign; use super::repeat::repeat_with_slice_assign; use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor}; use crate::backend::BackendBridge; +use crate::tensor::cast::ToElement; use crate::Tensor; use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Float}; use crate::{tensor::api::chunk, tensor::api::narrow}; use alloc::vec::Vec; use burn_common::reader::Reader; use core::ops::Range; -use num_traits::ToPrimitive; #[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))] use crate::{argsort, sort, sort_with_indices}; @@ -1005,7 +1005,7 @@ pub trait FloatTensorOps { lhs: FloatTensor, rhs: IntElem, ) -> FloatTensor { - Self::float_powf_scalar(lhs, rhs.to_f32().unwrap()) + Self::float_powf_scalar(lhs, rhs.to_f32()) } /// Returns a new tensor with values raised to the power of float `value`.