Implement `Element` for `bool` (#1878)

* Element already implements One

* Add element module

* Add our own traits for Zero, One and ToPrimitive to support bool Element

* Fix typo

* Add basic tests for ToPrimitive with expected values

* The most important change of all

* Remove One + Zero identities

* Move zero/one outside mapv + refactor ToPrimitive -> ToElement trait

* Add num-traits to NOTICES.md
This commit is contained in:
Guillaume Lagrange 2024-06-14 09:02:38 -04:00 committed by GitHub
parent b71c300638
commit 525244062f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 670 additions and 62 deletions

View File

@ -303,3 +303,36 @@ SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
## num-traits
**Source:** https://github.com/rust-num/num-traits/blob/master/src/cast.rs
MIT License
Copyright (c) 2014 The Rust Project Developers
Permission is hereby granted, free of charge, to any
person obtaining a copy of this software and associated
documentation files (the "Software"), to deal in the
Software without restriction, including without
limitation the rights to use, copy, modify, merge,
publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software
is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice
shall be included in all copies or substantial portions
of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.

View File

@ -21,8 +21,7 @@ impl<E: JitElement> PoolStrategy for MaxPool<E> {
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
let max_val = scope.create_local(item);
let max_initial =
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
cpa!(scope, max_val = max_initial);
max_val
}
@ -68,8 +67,7 @@ impl<E: JitElement> PoolStrategy for MaxPoolWithIndices<E> {
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
let max_val = scope.create_local(item);
let max_initial =
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
cpa!(scope, max_val = max_initial);
let max_index = scope.create_local(Elem::UInt);
(max_val, max_index)

View File

@ -16,8 +16,7 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmax {
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let max = scope.create_local(input_item);
let max_initial =
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
cpa!(scope, max = max_initial);
(max, index)

View File

@ -17,8 +17,7 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmin {
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(input_item);
let min_initial =
Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
let min_initial = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
cpa!(scope, min = min_initial);
(min, index)

View File

@ -18,7 +18,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
let max = Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
let max = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
cpa!(scope, value_shared_memory[write_position] = max);
(value_shared_memory, index_shared_memory)
}

View File

@ -19,7 +19,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
let min = Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
let min = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
cpa!(scope, value_shared_memory[write_position] = min);
(value_shared_memory, index_shared_memory)
}

View File

@ -1,6 +1,5 @@
use burn_tensor::Element;
use ndarray::LinalgScalar;
use num_traits::One;
use num_traits::Signed;
#[cfg(not(feature = "std"))]
@ -20,7 +19,6 @@ where
/// A general element for ndarray backend.
pub trait NdArrayElement:
Element
+ One
+ ndarray::LinalgScalar
+ ndarray::ScalarOperand
+ ExpElement

View File

@ -563,16 +563,18 @@ where
where
E: Signed,
{
let zero = 0.elem();
let one = 1.elem::<E>();
NdArrayTensor::new(
tensor
.array
.mapv(|x| {
if x > E::zero() {
E::one()
} else if x < E::zero() {
-E::one()
if x > zero {
one
} else if x < zero {
-one
} else {
E::zero()
zero
}
})
.into_shared(),

View File

@ -34,6 +34,7 @@ macro_rules! keepdim {
}};
}
use burn_tensor::ElementConversion;
pub(crate) use keepdim;
use ndarray::Axis;
@ -63,7 +64,7 @@ pub(crate) fn prod_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
) -> NdArrayTensor<E, D2> {
let array = tensor
.array
.fold_axis(Axis(dim), E::one(), |acc, &x| acc.mul(x.elem()))
.fold_axis(Axis(dim), 1.elem::<E>(), |acc, &x| acc.mul(x.elem()))
.into_shared();
NdArrayTensor { array }

View File

@ -406,7 +406,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_cos<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| (a.to_f64().unwrap()).cos().elem())
.mapv_into(|a| (a.to_f64()).cos().elem())
.into_shared();
NdArrayTensor::new(array)
@ -415,7 +415,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_sin<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| (a.to_f64().unwrap()).sin().elem())
.mapv_into(|a| (a.to_f64()).sin().elem())
.into_shared();
NdArrayTensor::new(array)
@ -424,7 +424,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_tanh<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| (a.to_f64().unwrap()).tanh().elem())
.mapv_into(|a| (a.to_f64()).tanh().elem())
.into_shared();
NdArrayTensor::new(array)
@ -433,7 +433,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
fn float_erf<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
let array = tensor
.array
.mapv_into(|a| erf(a.to_f64().unwrap()).elem())
.mapv_into(|a| erf(a.to_f64()).elem())
.into_shared();
NdArrayTensor::new(array)
@ -473,7 +473,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
lhs: NdArrayTensor<E, D>,
rhs: NdArrayTensor<E, D>,
) -> NdArrayTensor<E, D> {
NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| a.powf_elem(b.to_f32().unwrap()))
NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| a.powf_elem(b.to_f32()))
}
fn float_permute<const D: usize>(

View File

@ -20,6 +20,7 @@ use serde::{Serialize, Serializer};
use crate::check::TensorCheck;
use crate::tensor::api::chunk::chunk;
use crate::tensor::api::narrow::narrow;
use crate::Element;
use crate::{backend::Backend, check, Bool, Data, DataSerialize, Float, Int, Shape, TensorKind};
/// A tensor with a given backend, shape and data type.
@ -1213,7 +1214,7 @@ impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
pub trait BasicOps<B: Backend>: TensorKind<B> {
/// The type of the tensor elements.
type Elem: 'static + Copy;
type Elem: Element;
/// Creates an empty tensor with the given shape.
///

View File

@ -6,7 +6,6 @@ use crate::{
backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element,
ElementConversion, Float, Int, Shape, Tensor, TensorKind,
};
use num_traits::Zero;
impl<B, const D: usize, K> Tensor<B, D, K>
where
@ -656,7 +655,7 @@ where
///
/// A boolean tensor with the same shape as the input tensor.
pub fn bool(self) -> Tensor<B, D, Bool> {
K::not_equal_elem::<D>(self.primitive, K::Elem::zero())
K::not_equal_elem::<D>(self.primitive, 0.elem())
}
/// Create a random tensor of the given shape on the given device where each element is

View File

@ -192,18 +192,6 @@ impl<E: Element> DataSerialize<E> {
}
}
impl<const D: usize> Data<bool, D> {
/// Converts the data to a different element type.
pub fn convert<E: Element>(self) -> Data<E, D> {
let value: Vec<E> = self.value.into_iter().map(|a| (a as i64).elem()).collect();
Data {
value,
shape: self.shape,
}
}
}
impl<E: Element, const D: usize> Data<E, D> {
/// Populates the data with random values.
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution, rng: &mut R) -> Self {

View File

@ -1,16 +1,13 @@
use core::cmp::Ordering;
use crate::Distribution;
use crate::{cast::ToElement, Distribution};
use half::{bf16, f16};
use num_traits::{identities::Zero, One, ToPrimitive};
use rand::RngCore;
use serde::{Deserialize, Serialize};
/// Element trait for tensor.
pub trait Element:
ToPrimitive
+ Zero
+ One
ToElement
+ ElementRandom
+ ElementConversion
+ ElementPrecision
@ -38,7 +35,7 @@ pub trait ElementConversion {
/// # Returns
///
/// The converted element.
fn from_elem<E: ToPrimitive>(elem: E) -> Self;
fn from_elem<E: ToElement>(elem: E) -> Self;
/// Converts and returns the converted element.
fn elem<E: Element>(self) -> E;
@ -105,7 +102,7 @@ macro_rules! make_element {
}
impl ElementConversion for $type {
fn from_elem<E: ToPrimitive>(elem: E) -> Self {
fn from_elem<E: ToElement>(elem: E) -> Self {
#[allow(clippy::redundant_closure_call)]
$convert(&elem)
}
@ -140,7 +137,7 @@ macro_rules! make_element {
make_element!(
ty f64 Precision::Double,
convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(),
convert |elem: &dyn ToElement| elem.to_f64(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &f64, b: &f64| a.total_cmp(b),
dtype DType::F64
@ -148,7 +145,7 @@ make_element!(
make_element!(
ty f32 Precision::Full,
convert |elem: &dyn ToPrimitive| elem.to_f32().unwrap(),
convert |elem: &dyn ToElement| elem.to_f32(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &f32, b: &f32| a.total_cmp(b),
dtype DType::F32
@ -156,7 +153,7 @@ make_element!(
make_element!(
ty i64 Precision::Double,
convert |elem: &dyn ToPrimitive| elem.to_i64().unwrap(),
convert |elem: &dyn ToElement| elem.to_i64(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i64, b: &i64| Ord::cmp(a, b),
dtype DType::I64
@ -164,7 +161,7 @@ make_element!(
make_element!(
ty i32 Precision::Full,
convert |elem: &dyn ToPrimitive| elem.to_i32().unwrap(),
convert |elem: &dyn ToElement| elem.to_i32(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i32, b: &i32| Ord::cmp(a, b),
dtype DType::I32
@ -172,7 +169,7 @@ make_element!(
make_element!(
ty u32 Precision::Full,
convert |elem: &dyn ToPrimitive| elem.to_u32().unwrap(),
convert |elem: &dyn ToElement| elem.to_u32(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &u32, b: &u32| Ord::cmp(a, b),
dtype DType::U32
@ -180,7 +177,7 @@ make_element!(
make_element!(
ty i16 Precision::Half,
convert |elem: &dyn ToPrimitive| elem.to_i16().unwrap(),
convert |elem: &dyn ToElement| elem.to_i16(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i16, b: &i16| Ord::cmp(a, b),
dtype DType::I16
@ -188,7 +185,7 @@ make_element!(
make_element!(
ty i8 Precision::Other,
convert |elem: &dyn ToPrimitive| elem.to_i8().unwrap(),
convert |elem: &dyn ToElement| elem.to_i8(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &i8, b: &i8| Ord::cmp(a, b),
dtype DType::I8
@ -196,7 +193,7 @@ make_element!(
make_element!(
ty u8 Precision::Other,
convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(),
convert |elem: &dyn ToElement| elem.to_u8(),
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
cmp |a: &u8, b: &u8| Ord::cmp(a, b),
dtype DType::U8
@ -204,7 +201,7 @@ make_element!(
make_element!(
ty f16 Precision::Half,
convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()),
convert |elem: &dyn ToElement| f16::from_f32(elem.to_f32()),
random |distribution: Distribution, rng: &mut R| {
let sample: f32 = distribution.sampler(rng).sample();
f16::from_elem(sample)
@ -214,7 +211,7 @@ make_element!(
);
make_element!(
ty bf16 Precision::Half,
convert |elem: &dyn ToPrimitive| bf16::from_f32(elem.to_f32().unwrap()),
convert |elem: &dyn ToElement| bf16::from_f32(elem.to_f32()),
random |distribution: Distribution, rng: &mut R| {
let sample: f32 = distribution.sampler(rng).sample();
bf16::from_elem(sample)
@ -223,6 +220,17 @@ make_element!(
dtype DType::BF16
);
make_element!(
ty bool Precision::Other,
convert |elem: &dyn ToElement| elem.to_u8() != 0,
random |distribution: Distribution, rng: &mut R| {
let sample: u8 = distribution.sampler(rng).sample();
bool::from_elem(sample)
},
cmp |a: &bool, b: &bool| Ord::cmp(a, b),
dtype DType::Bool
);
#[allow(missing_docs)]
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum DType {

View File

@ -0,0 +1,579 @@
use core::mem::size_of;
use half::{bf16, f16};
/// A generic trait for converting a value to a number.
/// Adapted from [num_traits::ToPrimitive] to support [bool].
///
/// A value can be represented by the target type when it lies within
/// the range of scalars supported by the target type.
/// For example, a negative integer cannot be represented by an unsigned
/// integer type, and an `i64` with a very high magnitude might not be
/// convertible to an `i32`.
/// On the other hand, conversions with possible precision loss or truncation
/// are admitted, like an `f32` with a decimal part to an integer type, or
/// even a large `f64` saturating to `f32` infinity.
///
/// The methods *panic* when the value cannot be represented by the target type.
pub trait ToElement {
/// Converts the value of `self` to an `isize`.
#[inline]
fn to_isize(&self) -> isize {
ToElement::to_isize(&self.to_i64())
}
/// Converts the value of `self` to an `i8`.
#[inline]
fn to_i8(&self) -> i8 {
ToElement::to_i8(&self.to_i64())
}
/// Converts the value of `self` to an `i16`.
#[inline]
fn to_i16(&self) -> i16 {
ToElement::to_i16(&self.to_i64())
}
/// Converts the value of `self` to an `i32`.
#[inline]
fn to_i32(&self) -> i32 {
ToElement::to_i32(&self.to_i64())
}
/// Converts the value of `self` to an `i64`.
fn to_i64(&self) -> i64;
/// Converts the value of `self` to an `i128`.
///
/// The default implementation converts through `to_i64()`. Types implementing
/// this trait should override this method if they can represent a greater range.
#[inline]
fn to_i128(&self) -> i128 {
i128::from(self.to_i64())
}
/// Converts the value of `self` to a `usize`.
#[inline]
fn to_usize(&self) -> usize {
ToElement::to_usize(&self.to_u64())
}
/// Converts the value of `self` to a `u8`.
#[inline]
fn to_u8(&self) -> u8 {
ToElement::to_u8(&self.to_u64())
}
/// Converts the value of `self` to a `u16`.
#[inline]
fn to_u16(&self) -> u16 {
ToElement::to_u16(&self.to_u64())
}
/// Converts the value of `self` to a `u32`.
#[inline]
fn to_u32(&self) -> u32 {
ToElement::to_u32(&self.to_u64())
}
/// Converts the value of `self` to a `u64`.
fn to_u64(&self) -> u64;
/// Converts the value of `self` to a `u128`.
///
/// The default implementation converts through `to_u64()`. Types implementing
/// this trait should override this method if they can represent a greater range.
#[inline]
fn to_u128(&self) -> u128 {
u128::from(self.to_u64())
}
/// Converts the value of `self` to an `f32`. Overflows may map to positive
/// or negative infinity.
#[inline]
fn to_f32(&self) -> f32 {
ToElement::to_f32(&self.to_f64())
}
/// Converts the value of `self` to an `f64`. Overflows may map to positive
/// or negative infinity.
///
/// The default implementation tries to convert through `to_i64()`, and
/// failing that through `to_u64()`. Types implementing this trait should
/// override this method if they can represent a greater range.
#[inline]
fn to_f64(&self) -> f64 {
ToElement::to_f64(&self.to_u64())
}
}
macro_rules! impl_to_element_int_to_int {
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $DstT {
let min = $DstT::MIN as $SrcT;
let max = $DstT::MAX as $SrcT;
if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
*self as $DstT
} else {
panic!("Element cannot be represented in the target type")
}
}
)*}
}
macro_rules! impl_to_element_int_to_uint {
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $DstT {
let max = $DstT::MAX as $SrcT;
if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
*self as $DstT
} else {
panic!("Element cannot be represented in the target type")
}
}
)*}
}
macro_rules! impl_to_element_int {
($T:ident) => {
impl ToElement for $T {
impl_to_element_int_to_int! { $T:
fn to_isize -> isize;
fn to_i8 -> i8;
fn to_i16 -> i16;
fn to_i32 -> i32;
fn to_i64 -> i64;
fn to_i128 -> i128;
}
impl_to_element_int_to_uint! { $T:
fn to_usize -> usize;
fn to_u8 -> u8;
fn to_u16 -> u16;
fn to_u32 -> u32;
fn to_u64 -> u64;
fn to_u128 -> u128;
}
#[inline]
fn to_f32(&self) -> f32 {
*self as f32
}
#[inline]
fn to_f64(&self) -> f64 {
*self as f64
}
}
};
}
impl_to_element_int!(isize);
impl_to_element_int!(i8);
impl_to_element_int!(i16);
impl_to_element_int!(i32);
impl_to_element_int!(i64);
impl_to_element_int!(i128);
macro_rules! impl_to_element_uint_to_int {
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $DstT {
let max = $DstT::MAX as $SrcT;
if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
*self as $DstT
} else {
panic!("Element cannot be represented in the target type")
}
}
)*}
}
macro_rules! impl_to_element_uint_to_uint {
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $DstT {
let max = $DstT::MAX as $SrcT;
if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
*self as $DstT
} else {
panic!("Element cannot be represented in the target type")
}
}
)*}
}
macro_rules! impl_to_element_uint {
($T:ident) => {
impl ToElement for $T {
impl_to_element_uint_to_int! { $T:
fn to_isize -> isize;
fn to_i8 -> i8;
fn to_i16 -> i16;
fn to_i32 -> i32;
fn to_i64 -> i64;
fn to_i128 -> i128;
}
impl_to_element_uint_to_uint! { $T:
fn to_usize -> usize;
fn to_u8 -> u8;
fn to_u16 -> u16;
fn to_u32 -> u32;
fn to_u64 -> u64;
fn to_u128 -> u128;
}
#[inline]
fn to_f32(&self) -> f32 {
*self as f32
}
#[inline]
fn to_f64(&self) -> f64 {
*self as f64
}
}
};
}
impl_to_element_uint!(usize);
impl_to_element_uint!(u8);
impl_to_element_uint!(u16);
impl_to_element_uint!(u32);
impl_to_element_uint!(u64);
impl_to_element_uint!(u128);
macro_rules! impl_to_element_float_to_float {
($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(
#[inline]
fn $method(&self) -> $DstT {
// We can safely cast all values, whether NaN, +-inf, or finite.
// Finite values that are reducing size may saturate to +-inf.
*self as $DstT
}
)*}
}
macro_rules! float_to_int_unchecked {
// SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating.
// We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`.
($float:expr => $int:ty) => {
unsafe { $float.to_int_unchecked::<$int>() }
};
}
macro_rules! impl_to_element_float_to_signed_int {
($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $i {
// Float as int truncates toward zero, so we want to allow values
// in the exclusive range `(MIN-1, MAX+1)`.
if size_of::<$f>() > size_of::<$i>() {
// With a larger size, we can represent the range exactly.
const MIN_M1: $f = $i::MIN as $f - 1.0;
const MAX_P1: $f = $i::MAX as $f + 1.0;
if *self > MIN_M1 && *self < MAX_P1 {
return float_to_int_unchecked!(*self => $i);
}
} else {
// We can't represent `MIN-1` exactly, but there's no fractional part
// at this magnitude, so we can just use a `MIN` inclusive boundary.
const MIN: $f = $i::MIN as $f;
// We can't represent `MAX` exactly, but it will round up to exactly
// `MAX+1` (a power of two) when we cast it.
const MAX_P1: $f = $i::MAX as $f;
if *self >= MIN && *self < MAX_P1 {
return float_to_int_unchecked!(*self => $i);
}
}
panic!("Float cannot be represented in the target signed int type")
}
)*}
}
macro_rules! impl_to_element_float_to_unsigned_int {
($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(
#[inline]
$(#[$cfg])*
fn $method(&self) -> $u {
// Float as int truncates toward zero, so we want to allow values
// in the exclusive range `(-1, MAX+1)`.
if size_of::<$f>() > size_of::<$u>() {
// With a larger size, we can represent the range exactly.
const MAX_P1: $f = $u::MAX as $f + 1.0;
if *self > -1.0 && *self < MAX_P1 {
return float_to_int_unchecked!(*self => $u);
}
} else {
// We can't represent `MAX` exactly, but it will round up to exactly
// `MAX+1` (a power of two) when we cast it.
// (`u128::MAX as f32` is infinity, but this is still ok.)
const MAX_P1: $f = $u::MAX as $f;
if *self > -1.0 && *self < MAX_P1 {
return float_to_int_unchecked!(*self => $u);
}
}
panic!("Float cannot be represented in the target unsigned int type")
}
)*}
}
macro_rules! impl_to_element_float {
($T:ident) => {
impl ToElement for $T {
impl_to_element_float_to_signed_int! { $T:
fn to_isize -> isize;
fn to_i8 -> i8;
fn to_i16 -> i16;
fn to_i32 -> i32;
fn to_i64 -> i64;
fn to_i128 -> i128;
}
impl_to_element_float_to_unsigned_int! { $T:
fn to_usize -> usize;
fn to_u8 -> u8;
fn to_u16 -> u16;
fn to_u32 -> u32;
fn to_u64 -> u64;
fn to_u128 -> u128;
}
impl_to_element_float_to_float! { $T:
fn to_f32 -> f32;
fn to_f64 -> f64;
}
}
};
}
impl_to_element_float!(f32);
impl_to_element_float!(f64);
impl ToElement for f16 {
#[inline]
fn to_i64(&self) -> i64 {
Self::to_f32(*self).to_i64()
}
#[inline]
fn to_u64(&self) -> u64 {
Self::to_f32(*self).to_u64()
}
#[inline]
fn to_i8(&self) -> i8 {
Self::to_f32(*self).to_i8()
}
#[inline]
fn to_u8(&self) -> u8 {
Self::to_f32(*self).to_u8()
}
#[inline]
fn to_i16(&self) -> i16 {
Self::to_f32(*self).to_i16()
}
#[inline]
fn to_u16(&self) -> u16 {
Self::to_f32(*self).to_u16()
}
#[inline]
fn to_i32(&self) -> i32 {
Self::to_f32(*self).to_i32()
}
#[inline]
fn to_u32(&self) -> u32 {
Self::to_f32(*self).to_u32()
}
#[inline]
fn to_f32(&self) -> f32 {
Self::to_f32(*self)
}
#[inline]
fn to_f64(&self) -> f64 {
Self::to_f64(*self)
}
}
impl ToElement for bf16 {
#[inline]
fn to_i64(&self) -> i64 {
Self::to_f32(*self).to_i64()
}
#[inline]
fn to_u64(&self) -> u64 {
Self::to_f32(*self).to_u64()
}
#[inline]
fn to_i8(&self) -> i8 {
Self::to_f32(*self).to_i8()
}
#[inline]
fn to_u8(&self) -> u8 {
Self::to_f32(*self).to_u8()
}
#[inline]
fn to_i16(&self) -> i16 {
Self::to_f32(*self).to_i16()
}
#[inline]
fn to_u16(&self) -> u16 {
Self::to_f32(*self).to_u16()
}
#[inline]
fn to_i32(&self) -> i32 {
Self::to_f32(*self).to_i32()
}
#[inline]
fn to_u32(&self) -> u32 {
Self::to_f32(*self).to_u32()
}
#[inline]
fn to_f32(&self) -> f32 {
Self::to_f32(*self)
}
#[inline]
fn to_f64(&self) -> f64 {
Self::to_f64(*self)
}
}
impl ToElement for bool {
#[inline]
fn to_i64(&self) -> i64 {
*self as i64
}
#[inline]
fn to_u64(&self) -> u64 {
*self as u64
}
#[inline]
fn to_i8(&self) -> i8 {
*self as i8
}
#[inline]
fn to_u8(&self) -> u8 {
*self as u8
}
#[inline]
fn to_i16(&self) -> i16 {
*self as i16
}
#[inline]
fn to_u16(&self) -> u16 {
*self as u16
}
#[inline]
fn to_i32(&self) -> i32 {
*self as i32
}
#[inline]
fn to_u32(&self) -> u32 {
*self as u32
}
#[inline]
fn to_f32(&self) -> f32 {
self.to_u8() as f32
}
#[inline]
fn to_f64(&self) -> f64 {
self.to_u8() as f64
}
}
mod tests {
#[allow(unused_imports)]
use super::*;
#[test]
fn to_element_float() {
let f32_toolarge = 1e39f64;
assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);
assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);
assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);
assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);
assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);
assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);
assert!((f64::NAN).to_f32().is_nan());
}
#[test]
#[should_panic]
fn to_element_signed_to_u8_underflow() {
let _x = (-1i8).to_u8();
}
#[test]
#[should_panic]
fn to_element_signed_to_u16_underflow() {
let _x = (-1i8).to_u16();
}
#[test]
#[should_panic]
fn to_element_signed_to_u32_underflow() {
let _x = (-1i8).to_u32();
}
#[test]
#[should_panic]
fn to_element_signed_to_u64_underflow() {
let _x = (-1i8).to_u64();
}
#[test]
#[should_panic]
fn to_element_signed_to_u128_underflow() {
let _x = (-1i8).to_u128();
}
#[test]
#[should_panic]
fn to_element_signed_to_usize_underflow() {
let _x = (-1i8).to_usize();
}
#[test]
#[should_panic]
fn to_element_unsigned_to_u8_overflow() {
let _x = 256.to_u8();
}
#[test]
#[should_panic]
fn to_element_unsigned_to_u16_overflow() {
let _x = 65_536.to_u16();
}
#[test]
#[should_panic]
fn to_element_unsigned_to_u32_overflow() {
let _x = 4_294_967_296u64.to_u32();
}
#[test]
#[should_panic]
fn to_element_unsigned_to_u64_overflow() {
let _x = 18_446_744_073_709_551_616u128.to_u64();
}
#[test]
fn to_element_int_to_float() {
assert_eq!((-1).to_f32(), -1.0);
assert_eq!((-1).to_f64(), -1.0);
assert_eq!(255.to_f32(), 255.0);
assert_eq!(65_535.to_f64(), 65_535.0);
}
#[test]
fn to_element_float_to_int() {
assert_eq!((-1.0).to_i8(), -1);
assert_eq!(1.0.to_u8(), 1);
assert_eq!(1.8.to_u16(), 1);
assert_eq!(123.456.to_u32(), 123);
}
}

View File

@ -0,0 +1,6 @@
mod base;
/// Tensor element casting.
pub mod cast;
pub use base::*;

View File

@ -1,13 +1,13 @@
use super::cat::cat_with_slice_assign;
use super::repeat::repeat_with_slice_assign;
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
use crate::tensor::cast::ToElement;
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
use crate::{cartesian_grid, Tensor};
use crate::{tensor::api::chunk, tensor::api::narrow};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
use num_traits::ToPrimitive;
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use crate::{argsort, sort, sort_with_indices};
@ -536,10 +536,7 @@ pub trait IntTensorOps<B: Backend> {
///
/// The elements of `lhs` raised to the value of `rhs`.
fn int_powi_scalar<const D: usize>(lhs: IntTensor<B, D>, rhs: IntElem<B>) -> IntTensor<B, D> {
B::float_into_int(B::float_powf_scalar(
B::int_into_float(lhs),
rhs.to_f32().unwrap(),
))
B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs.to_f32()))
}
/// Element-wise power with a floatTensor.

View File

@ -2,13 +2,13 @@ use super::cat::cat_with_slice_assign;
use super::repeat::repeat_with_slice_assign;
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
use crate::backend::BackendBridge;
use crate::tensor::cast::ToElement;
use crate::Tensor;
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Float};
use crate::{tensor::api::chunk, tensor::api::narrow};
use alloc::vec::Vec;
use burn_common::reader::Reader;
use core::ops::Range;
use num_traits::ToPrimitive;
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
use crate::{argsort, sort, sort_with_indices};
@ -1005,7 +1005,7 @@ pub trait FloatTensorOps<B: Backend> {
lhs: FloatTensor<B, D>,
rhs: IntElem<B>,
) -> FloatTensor<B, D> {
Self::float_powf_scalar(lhs, rhs.to_f32().unwrap())
Self::float_powf_scalar(lhs, rhs.to_f32())
}
/// Returns a new tensor with values raised to the power of float `value`.