mirror of https://github.com/tracel-ai/burn.git
Implement `Element` for `bool` (#1878)
* Element already implements One * Add element module * Add our own traits for Zero, One and ToPrimitive to support bool Element * Fix typo * Add basic tests for ToPrimitive with expected values * The most important change of all * Remove One + Zero identities * Move zero/one outside mapv + refactor ToPrimitive -> ToElement trait * Add num-traits to NOTICES.md
This commit is contained in:
parent
b71c300638
commit
525244062f
33
NOTICES.md
33
NOTICES.md
|
@ -303,3 +303,36 @@ SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
|
|||
FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
|
||||
ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
||||
## num-traits
|
||||
|
||||
**Source:** https://github.com/rust-num/num-traits/blob/master/src/cast.rs
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2014 The Rust Project Developers
|
||||
|
||||
Permission is hereby granted, free of charge, to any
|
||||
person obtaining a copy of this software and associated
|
||||
documentation files (the "Software"), to deal in the
|
||||
Software without restriction, including without
|
||||
limitation the rights to use, copy, modify, merge,
|
||||
publish, distribute, sublicense, and/or sell copies of
|
||||
the Software, and to permit persons to whom the Software
|
||||
is furnished to do so, subject to the following
|
||||
conditions:
|
||||
|
||||
The above copyright notice and this permission notice
|
||||
shall be included in all copies or substantial portions
|
||||
of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
|
||||
ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
|
||||
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
|
||||
PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
|
||||
SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
|
||||
CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
||||
OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
|
||||
IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
|
|
|
@ -21,8 +21,7 @@ impl<E: JitElement> PoolStrategy for MaxPool<E> {
|
|||
|
||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
||||
let max_val = scope.create_local(item);
|
||||
let max_initial =
|
||||
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
|
||||
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
|
||||
cpa!(scope, max_val = max_initial);
|
||||
max_val
|
||||
}
|
||||
|
@ -68,8 +67,7 @@ impl<E: JitElement> PoolStrategy for MaxPoolWithIndices<E> {
|
|||
|
||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
||||
let max_val = scope.create_local(item);
|
||||
let max_initial =
|
||||
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), item.elem());
|
||||
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
|
||||
cpa!(scope, max_val = max_initial);
|
||||
let max_index = scope.create_local(Elem::UInt);
|
||||
(max_val, max_index)
|
||||
|
|
|
@ -16,8 +16,7 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmax {
|
|||
) -> Self::Accumulator {
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let max = scope.create_local(input_item);
|
||||
let max_initial =
|
||||
Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
|
||||
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
|
||||
cpa!(scope, max = max_initial);
|
||||
|
||||
(max, index)
|
||||
|
|
|
@ -17,8 +17,7 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmin {
|
|||
) -> Self::Accumulator {
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let min = scope.create_local(input_item);
|
||||
let min_initial =
|
||||
Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
|
||||
let min_initial = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
|
||||
cpa!(scope, min = min_initial);
|
||||
|
||||
(min, index)
|
||||
|
|
|
@ -18,7 +18,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
|
|||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||
|
||||
let max = Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), input_item.elem());
|
||||
let max = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
|
||||
cpa!(scope, value_shared_memory[write_position] = max);
|
||||
(value_shared_memory, index_shared_memory)
|
||||
}
|
||||
|
|
|
@ -19,7 +19,7 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
|
|||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||
|
||||
let min = Variable::ConstantScalar(E::maximum_value().to_f64().unwrap(), input_item.elem());
|
||||
let min = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
|
||||
cpa!(scope, value_shared_memory[write_position] = min);
|
||||
(value_shared_memory, index_shared_memory)
|
||||
}
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use burn_tensor::Element;
|
||||
use ndarray::LinalgScalar;
|
||||
use num_traits::One;
|
||||
use num_traits::Signed;
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
|
@ -20,7 +19,6 @@ where
|
|||
/// A general element for ndarray backend.
|
||||
pub trait NdArrayElement:
|
||||
Element
|
||||
+ One
|
||||
+ ndarray::LinalgScalar
|
||||
+ ndarray::ScalarOperand
|
||||
+ ExpElement
|
||||
|
|
|
@ -563,16 +563,18 @@ where
|
|||
where
|
||||
E: Signed,
|
||||
{
|
||||
let zero = 0.elem();
|
||||
let one = 1.elem::<E>();
|
||||
NdArrayTensor::new(
|
||||
tensor
|
||||
.array
|
||||
.mapv(|x| {
|
||||
if x > E::zero() {
|
||||
E::one()
|
||||
} else if x < E::zero() {
|
||||
-E::one()
|
||||
if x > zero {
|
||||
one
|
||||
} else if x < zero {
|
||||
-one
|
||||
} else {
|
||||
E::zero()
|
||||
zero
|
||||
}
|
||||
})
|
||||
.into_shared(),
|
||||
|
|
|
@ -34,6 +34,7 @@ macro_rules! keepdim {
|
|||
}};
|
||||
}
|
||||
|
||||
use burn_tensor::ElementConversion;
|
||||
pub(crate) use keepdim;
|
||||
use ndarray::Axis;
|
||||
|
||||
|
@ -63,7 +64,7 @@ pub(crate) fn prod_dim<E: NdArrayElement, const D1: usize, const D2: usize>(
|
|||
) -> NdArrayTensor<E, D2> {
|
||||
let array = tensor
|
||||
.array
|
||||
.fold_axis(Axis(dim), E::one(), |acc, &x| acc.mul(x.elem()))
|
||||
.fold_axis(Axis(dim), 1.elem::<E>(), |acc, &x| acc.mul(x.elem()))
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor { array }
|
||||
|
|
|
@ -406,7 +406,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
|||
fn float_cos<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor
|
||||
.array
|
||||
.mapv_into(|a| (a.to_f64().unwrap()).cos().elem())
|
||||
.mapv_into(|a| (a.to_f64()).cos().elem())
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor::new(array)
|
||||
|
@ -415,7 +415,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
|||
fn float_sin<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor
|
||||
.array
|
||||
.mapv_into(|a| (a.to_f64().unwrap()).sin().elem())
|
||||
.mapv_into(|a| (a.to_f64()).sin().elem())
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor::new(array)
|
||||
|
@ -424,7 +424,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
|||
fn float_tanh<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor
|
||||
.array
|
||||
.mapv_into(|a| (a.to_f64().unwrap()).tanh().elem())
|
||||
.mapv_into(|a| (a.to_f64()).tanh().elem())
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor::new(array)
|
||||
|
@ -433,7 +433,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
|||
fn float_erf<const D: usize>(tensor: NdArrayTensor<E, D>) -> NdArrayTensor<E, D> {
|
||||
let array = tensor
|
||||
.array
|
||||
.mapv_into(|a| erf(a.to_f64().unwrap()).elem())
|
||||
.mapv_into(|a| erf(a.to_f64()).elem())
|
||||
.into_shared();
|
||||
|
||||
NdArrayTensor::new(array)
|
||||
|
@ -473,7 +473,7 @@ impl<E: FloatNdArrayElement> FloatTensorOps<Self> for NdArray<E> {
|
|||
lhs: NdArrayTensor<E, D>,
|
||||
rhs: NdArrayTensor<E, D>,
|
||||
) -> NdArrayTensor<E, D> {
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| a.powf_elem(b.to_f32().unwrap()))
|
||||
NdArrayMathOps::elementwise_op(lhs, rhs, |a, b| a.powf_elem(b.to_f32()))
|
||||
}
|
||||
|
||||
fn float_permute<const D: usize>(
|
||||
|
|
|
@ -20,6 +20,7 @@ use serde::{Serialize, Serializer};
|
|||
use crate::check::TensorCheck;
|
||||
use crate::tensor::api::chunk::chunk;
|
||||
use crate::tensor::api::narrow::narrow;
|
||||
use crate::Element;
|
||||
use crate::{backend::Backend, check, Bool, Data, DataSerialize, Float, Int, Shape, TensorKind};
|
||||
|
||||
/// A tensor with a given backend, shape and data type.
|
||||
|
@ -1213,7 +1214,7 @@ impl<B: Backend, const D: usize> core::ops::BitXor<T> for Tensor<B, D> {
|
|||
/// This is an internal trait, use the public API provided by [tensor struct](Tensor).
|
||||
pub trait BasicOps<B: Backend>: TensorKind<B> {
|
||||
/// The type of the tensor elements.
|
||||
type Elem: 'static + Copy;
|
||||
type Elem: Element;
|
||||
|
||||
/// Creates an empty tensor with the given shape.
|
||||
///
|
||||
|
|
|
@ -6,7 +6,6 @@ use crate::{
|
|||
backend::Backend, check, check::TensorCheck, BasicOps, Bool, Distribution, Element,
|
||||
ElementConversion, Float, Int, Shape, Tensor, TensorKind,
|
||||
};
|
||||
use num_traits::Zero;
|
||||
|
||||
impl<B, const D: usize, K> Tensor<B, D, K>
|
||||
where
|
||||
|
@ -656,7 +655,7 @@ where
|
|||
///
|
||||
/// A boolean tensor with the same shape as the input tensor.
|
||||
pub fn bool(self) -> Tensor<B, D, Bool> {
|
||||
K::not_equal_elem::<D>(self.primitive, K::Elem::zero())
|
||||
K::not_equal_elem::<D>(self.primitive, 0.elem())
|
||||
}
|
||||
|
||||
/// Create a random tensor of the given shape on the given device where each element is
|
||||
|
|
|
@ -192,18 +192,6 @@ impl<E: Element> DataSerialize<E> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<const D: usize> Data<bool, D> {
|
||||
/// Converts the data to a different element type.
|
||||
pub fn convert<E: Element>(self) -> Data<E, D> {
|
||||
let value: Vec<E> = self.value.into_iter().map(|a| (a as i64).elem()).collect();
|
||||
|
||||
Data {
|
||||
value,
|
||||
shape: self.shape,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<E: Element, const D: usize> Data<E, D> {
|
||||
/// Populates the data with random values.
|
||||
pub fn random<R: RngCore>(shape: Shape<D>, distribution: Distribution, rng: &mut R) -> Self {
|
||||
|
|
|
@ -1,16 +1,13 @@
|
|||
use core::cmp::Ordering;
|
||||
|
||||
use crate::Distribution;
|
||||
use crate::{cast::ToElement, Distribution};
|
||||
use half::{bf16, f16};
|
||||
use num_traits::{identities::Zero, One, ToPrimitive};
|
||||
use rand::RngCore;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Element trait for tensor.
|
||||
pub trait Element:
|
||||
ToPrimitive
|
||||
+ Zero
|
||||
+ One
|
||||
ToElement
|
||||
+ ElementRandom
|
||||
+ ElementConversion
|
||||
+ ElementPrecision
|
||||
|
@ -38,7 +35,7 @@ pub trait ElementConversion {
|
|||
/// # Returns
|
||||
///
|
||||
/// The converted element.
|
||||
fn from_elem<E: ToPrimitive>(elem: E) -> Self;
|
||||
fn from_elem<E: ToElement>(elem: E) -> Self;
|
||||
|
||||
/// Converts and returns the converted element.
|
||||
fn elem<E: Element>(self) -> E;
|
||||
|
@ -105,7 +102,7 @@ macro_rules! make_element {
|
|||
}
|
||||
|
||||
impl ElementConversion for $type {
|
||||
fn from_elem<E: ToPrimitive>(elem: E) -> Self {
|
||||
fn from_elem<E: ToElement>(elem: E) -> Self {
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
$convert(&elem)
|
||||
}
|
||||
|
@ -140,7 +137,7 @@ macro_rules! make_element {
|
|||
|
||||
make_element!(
|
||||
ty f64 Precision::Double,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_f64().unwrap(),
|
||||
convert |elem: &dyn ToElement| elem.to_f64(),
|
||||
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
|
||||
cmp |a: &f64, b: &f64| a.total_cmp(b),
|
||||
dtype DType::F64
|
||||
|
@ -148,7 +145,7 @@ make_element!(
|
|||
|
||||
make_element!(
|
||||
ty f32 Precision::Full,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_f32().unwrap(),
|
||||
convert |elem: &dyn ToElement| elem.to_f32(),
|
||||
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
|
||||
cmp |a: &f32, b: &f32| a.total_cmp(b),
|
||||
dtype DType::F32
|
||||
|
@ -156,7 +153,7 @@ make_element!(
|
|||
|
||||
make_element!(
|
||||
ty i64 Precision::Double,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_i64().unwrap(),
|
||||
convert |elem: &dyn ToElement| elem.to_i64(),
|
||||
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
|
||||
cmp |a: &i64, b: &i64| Ord::cmp(a, b),
|
||||
dtype DType::I64
|
||||
|
@ -164,7 +161,7 @@ make_element!(
|
|||
|
||||
make_element!(
|
||||
ty i32 Precision::Full,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_i32().unwrap(),
|
||||
convert |elem: &dyn ToElement| elem.to_i32(),
|
||||
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
|
||||
cmp |a: &i32, b: &i32| Ord::cmp(a, b),
|
||||
dtype DType::I32
|
||||
|
@ -172,7 +169,7 @@ make_element!(
|
|||
|
||||
make_element!(
|
||||
ty u32 Precision::Full,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_u32().unwrap(),
|
||||
convert |elem: &dyn ToElement| elem.to_u32(),
|
||||
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
|
||||
cmp |a: &u32, b: &u32| Ord::cmp(a, b),
|
||||
dtype DType::U32
|
||||
|
@ -180,7 +177,7 @@ make_element!(
|
|||
|
||||
make_element!(
|
||||
ty i16 Precision::Half,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_i16().unwrap(),
|
||||
convert |elem: &dyn ToElement| elem.to_i16(),
|
||||
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
|
||||
cmp |a: &i16, b: &i16| Ord::cmp(a, b),
|
||||
dtype DType::I16
|
||||
|
@ -188,7 +185,7 @@ make_element!(
|
|||
|
||||
make_element!(
|
||||
ty i8 Precision::Other,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_i8().unwrap(),
|
||||
convert |elem: &dyn ToElement| elem.to_i8(),
|
||||
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
|
||||
cmp |a: &i8, b: &i8| Ord::cmp(a, b),
|
||||
dtype DType::I8
|
||||
|
@ -196,7 +193,7 @@ make_element!(
|
|||
|
||||
make_element!(
|
||||
ty u8 Precision::Other,
|
||||
convert |elem: &dyn ToPrimitive| elem.to_u8().unwrap(),
|
||||
convert |elem: &dyn ToElement| elem.to_u8(),
|
||||
random |distribution: Distribution, rng: &mut R| distribution.sampler(rng).sample(),
|
||||
cmp |a: &u8, b: &u8| Ord::cmp(a, b),
|
||||
dtype DType::U8
|
||||
|
@ -204,7 +201,7 @@ make_element!(
|
|||
|
||||
make_element!(
|
||||
ty f16 Precision::Half,
|
||||
convert |elem: &dyn ToPrimitive| f16::from_f32(elem.to_f32().unwrap()),
|
||||
convert |elem: &dyn ToElement| f16::from_f32(elem.to_f32()),
|
||||
random |distribution: Distribution, rng: &mut R| {
|
||||
let sample: f32 = distribution.sampler(rng).sample();
|
||||
f16::from_elem(sample)
|
||||
|
@ -214,7 +211,7 @@ make_element!(
|
|||
);
|
||||
make_element!(
|
||||
ty bf16 Precision::Half,
|
||||
convert |elem: &dyn ToPrimitive| bf16::from_f32(elem.to_f32().unwrap()),
|
||||
convert |elem: &dyn ToElement| bf16::from_f32(elem.to_f32()),
|
||||
random |distribution: Distribution, rng: &mut R| {
|
||||
let sample: f32 = distribution.sampler(rng).sample();
|
||||
bf16::from_elem(sample)
|
||||
|
@ -223,6 +220,17 @@ make_element!(
|
|||
dtype DType::BF16
|
||||
);
|
||||
|
||||
make_element!(
|
||||
ty bool Precision::Other,
|
||||
convert |elem: &dyn ToElement| elem.to_u8() != 0,
|
||||
random |distribution: Distribution, rng: &mut R| {
|
||||
let sample: u8 = distribution.sampler(rng).sample();
|
||||
bool::from_elem(sample)
|
||||
},
|
||||
cmp |a: &bool, b: &bool| Ord::cmp(a, b),
|
||||
dtype DType::Bool
|
||||
);
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub enum DType {
|
|
@ -0,0 +1,579 @@
|
|||
use core::mem::size_of;
|
||||
|
||||
use half::{bf16, f16};
|
||||
|
||||
/// A generic trait for converting a value to a number.
|
||||
/// Adapted from [num_traits::ToPrimitive] to support [bool].
|
||||
///
|
||||
/// A value can be represented by the target type when it lies within
|
||||
/// the range of scalars supported by the target type.
|
||||
/// For example, a negative integer cannot be represented by an unsigned
|
||||
/// integer type, and an `i64` with a very high magnitude might not be
|
||||
/// convertible to an `i32`.
|
||||
/// On the other hand, conversions with possible precision loss or truncation
|
||||
/// are admitted, like an `f32` with a decimal part to an integer type, or
|
||||
/// even a large `f64` saturating to `f32` infinity.
|
||||
///
|
||||
/// The methods *panic* when the value cannot be represented by the target type.
|
||||
pub trait ToElement {
|
||||
/// Converts the value of `self` to an `isize`.
|
||||
#[inline]
|
||||
fn to_isize(&self) -> isize {
|
||||
ToElement::to_isize(&self.to_i64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to an `i8`.
|
||||
#[inline]
|
||||
fn to_i8(&self) -> i8 {
|
||||
ToElement::to_i8(&self.to_i64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to an `i16`.
|
||||
#[inline]
|
||||
fn to_i16(&self) -> i16 {
|
||||
ToElement::to_i16(&self.to_i64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to an `i32`.
|
||||
#[inline]
|
||||
fn to_i32(&self) -> i32 {
|
||||
ToElement::to_i32(&self.to_i64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to an `i64`.
|
||||
fn to_i64(&self) -> i64;
|
||||
|
||||
/// Converts the value of `self` to an `i128`.
|
||||
///
|
||||
/// The default implementation converts through `to_i64()`. Types implementing
|
||||
/// this trait should override this method if they can represent a greater range.
|
||||
#[inline]
|
||||
fn to_i128(&self) -> i128 {
|
||||
i128::from(self.to_i64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to a `usize`.
|
||||
#[inline]
|
||||
fn to_usize(&self) -> usize {
|
||||
ToElement::to_usize(&self.to_u64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to a `u8`.
|
||||
#[inline]
|
||||
fn to_u8(&self) -> u8 {
|
||||
ToElement::to_u8(&self.to_u64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to a `u16`.
|
||||
#[inline]
|
||||
fn to_u16(&self) -> u16 {
|
||||
ToElement::to_u16(&self.to_u64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to a `u32`.
|
||||
#[inline]
|
||||
fn to_u32(&self) -> u32 {
|
||||
ToElement::to_u32(&self.to_u64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to a `u64`.
|
||||
fn to_u64(&self) -> u64;
|
||||
|
||||
/// Converts the value of `self` to a `u128`.
|
||||
///
|
||||
/// The default implementation converts through `to_u64()`. Types implementing
|
||||
/// this trait should override this method if they can represent a greater range.
|
||||
#[inline]
|
||||
fn to_u128(&self) -> u128 {
|
||||
u128::from(self.to_u64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to an `f32`. Overflows may map to positive
|
||||
/// or negative infinity.
|
||||
#[inline]
|
||||
fn to_f32(&self) -> f32 {
|
||||
ToElement::to_f32(&self.to_f64())
|
||||
}
|
||||
|
||||
/// Converts the value of `self` to an `f64`. Overflows may map to positive
|
||||
/// or negative infinity.
|
||||
///
|
||||
/// The default implementation tries to convert through `to_i64()`, and
|
||||
/// failing that through `to_u64()`. Types implementing this trait should
|
||||
/// override this method if they can represent a greater range.
|
||||
#[inline]
|
||||
fn to_f64(&self) -> f64 {
|
||||
ToElement::to_f64(&self.to_u64())
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! impl_to_element_int_to_int {
|
||||
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
|
||||
#[inline]
|
||||
$(#[$cfg])*
|
||||
fn $method(&self) -> $DstT {
|
||||
let min = $DstT::MIN as $SrcT;
|
||||
let max = $DstT::MAX as $SrcT;
|
||||
if size_of::<$SrcT>() <= size_of::<$DstT>() || (min <= *self && *self <= max) {
|
||||
*self as $DstT
|
||||
} else {
|
||||
panic!("Element cannot be represented in the target type")
|
||||
}
|
||||
}
|
||||
)*}
|
||||
}
|
||||
|
||||
macro_rules! impl_to_element_int_to_uint {
|
||||
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
|
||||
#[inline]
|
||||
$(#[$cfg])*
|
||||
fn $method(&self) -> $DstT {
|
||||
let max = $DstT::MAX as $SrcT;
|
||||
if 0 <= *self && (size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max) {
|
||||
*self as $DstT
|
||||
} else {
|
||||
panic!("Element cannot be represented in the target type")
|
||||
}
|
||||
}
|
||||
)*}
|
||||
}
|
||||
|
||||
macro_rules! impl_to_element_int {
|
||||
($T:ident) => {
|
||||
impl ToElement for $T {
|
||||
impl_to_element_int_to_int! { $T:
|
||||
fn to_isize -> isize;
|
||||
fn to_i8 -> i8;
|
||||
fn to_i16 -> i16;
|
||||
fn to_i32 -> i32;
|
||||
fn to_i64 -> i64;
|
||||
fn to_i128 -> i128;
|
||||
}
|
||||
|
||||
impl_to_element_int_to_uint! { $T:
|
||||
fn to_usize -> usize;
|
||||
fn to_u8 -> u8;
|
||||
fn to_u16 -> u16;
|
||||
fn to_u32 -> u32;
|
||||
fn to_u64 -> u64;
|
||||
fn to_u128 -> u128;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn to_f32(&self) -> f32 {
|
||||
*self as f32
|
||||
}
|
||||
#[inline]
|
||||
fn to_f64(&self) -> f64 {
|
||||
*self as f64
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_to_element_int!(isize);
|
||||
impl_to_element_int!(i8);
|
||||
impl_to_element_int!(i16);
|
||||
impl_to_element_int!(i32);
|
||||
impl_to_element_int!(i64);
|
||||
impl_to_element_int!(i128);
|
||||
|
||||
macro_rules! impl_to_element_uint_to_int {
|
||||
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
|
||||
#[inline]
|
||||
$(#[$cfg])*
|
||||
fn $method(&self) -> $DstT {
|
||||
let max = $DstT::MAX as $SrcT;
|
||||
if size_of::<$SrcT>() < size_of::<$DstT>() || *self <= max {
|
||||
*self as $DstT
|
||||
} else {
|
||||
panic!("Element cannot be represented in the target type")
|
||||
}
|
||||
}
|
||||
)*}
|
||||
}
|
||||
|
||||
macro_rules! impl_to_element_uint_to_uint {
|
||||
($SrcT:ident : $( $(#[$cfg:meta])* fn $method:ident -> $DstT:ident ; )*) => {$(
|
||||
#[inline]
|
||||
$(#[$cfg])*
|
||||
fn $method(&self) -> $DstT {
|
||||
let max = $DstT::MAX as $SrcT;
|
||||
if size_of::<$SrcT>() <= size_of::<$DstT>() || *self <= max {
|
||||
*self as $DstT
|
||||
} else {
|
||||
panic!("Element cannot be represented in the target type")
|
||||
}
|
||||
}
|
||||
)*}
|
||||
}
|
||||
|
||||
macro_rules! impl_to_element_uint {
|
||||
($T:ident) => {
|
||||
impl ToElement for $T {
|
||||
impl_to_element_uint_to_int! { $T:
|
||||
fn to_isize -> isize;
|
||||
fn to_i8 -> i8;
|
||||
fn to_i16 -> i16;
|
||||
fn to_i32 -> i32;
|
||||
fn to_i64 -> i64;
|
||||
fn to_i128 -> i128;
|
||||
}
|
||||
|
||||
impl_to_element_uint_to_uint! { $T:
|
||||
fn to_usize -> usize;
|
||||
fn to_u8 -> u8;
|
||||
fn to_u16 -> u16;
|
||||
fn to_u32 -> u32;
|
||||
fn to_u64 -> u64;
|
||||
fn to_u128 -> u128;
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn to_f32(&self) -> f32 {
|
||||
*self as f32
|
||||
}
|
||||
#[inline]
|
||||
fn to_f64(&self) -> f64 {
|
||||
*self as f64
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_to_element_uint!(usize);
|
||||
impl_to_element_uint!(u8);
|
||||
impl_to_element_uint!(u16);
|
||||
impl_to_element_uint!(u32);
|
||||
impl_to_element_uint!(u64);
|
||||
impl_to_element_uint!(u128);
|
||||
|
||||
macro_rules! impl_to_element_float_to_float {
|
||||
($SrcT:ident : $( fn $method:ident -> $DstT:ident ; )*) => {$(
|
||||
#[inline]
|
||||
fn $method(&self) -> $DstT {
|
||||
// We can safely cast all values, whether NaN, +-inf, or finite.
|
||||
// Finite values that are reducing size may saturate to +-inf.
|
||||
*self as $DstT
|
||||
}
|
||||
)*}
|
||||
}
|
||||
|
||||
macro_rules! float_to_int_unchecked {
|
||||
// SAFETY: Must not be NaN or infinite; must be representable as the integer after truncating.
|
||||
// We already checked that the float is in the exclusive range `(MIN-1, MAX+1)`.
|
||||
($float:expr => $int:ty) => {
|
||||
unsafe { $float.to_int_unchecked::<$int>() }
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! impl_to_element_float_to_signed_int {
|
||||
($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $i:ident ; )*) => {$(
|
||||
#[inline]
|
||||
$(#[$cfg])*
|
||||
fn $method(&self) -> $i {
|
||||
// Float as int truncates toward zero, so we want to allow values
|
||||
// in the exclusive range `(MIN-1, MAX+1)`.
|
||||
if size_of::<$f>() > size_of::<$i>() {
|
||||
// With a larger size, we can represent the range exactly.
|
||||
const MIN_M1: $f = $i::MIN as $f - 1.0;
|
||||
const MAX_P1: $f = $i::MAX as $f + 1.0;
|
||||
if *self > MIN_M1 && *self < MAX_P1 {
|
||||
return float_to_int_unchecked!(*self => $i);
|
||||
}
|
||||
} else {
|
||||
// We can't represent `MIN-1` exactly, but there's no fractional part
|
||||
// at this magnitude, so we can just use a `MIN` inclusive boundary.
|
||||
const MIN: $f = $i::MIN as $f;
|
||||
// We can't represent `MAX` exactly, but it will round up to exactly
|
||||
// `MAX+1` (a power of two) when we cast it.
|
||||
const MAX_P1: $f = $i::MAX as $f;
|
||||
if *self >= MIN && *self < MAX_P1 {
|
||||
return float_to_int_unchecked!(*self => $i);
|
||||
}
|
||||
}
|
||||
panic!("Float cannot be represented in the target signed int type")
|
||||
}
|
||||
)*}
|
||||
}
|
||||
|
||||
macro_rules! impl_to_element_float_to_unsigned_int {
|
||||
($f:ident : $( $(#[$cfg:meta])* fn $method:ident -> $u:ident ; )*) => {$(
|
||||
#[inline]
|
||||
$(#[$cfg])*
|
||||
fn $method(&self) -> $u {
|
||||
// Float as int truncates toward zero, so we want to allow values
|
||||
// in the exclusive range `(-1, MAX+1)`.
|
||||
if size_of::<$f>() > size_of::<$u>() {
|
||||
// With a larger size, we can represent the range exactly.
|
||||
const MAX_P1: $f = $u::MAX as $f + 1.0;
|
||||
if *self > -1.0 && *self < MAX_P1 {
|
||||
return float_to_int_unchecked!(*self => $u);
|
||||
}
|
||||
} else {
|
||||
// We can't represent `MAX` exactly, but it will round up to exactly
|
||||
// `MAX+1` (a power of two) when we cast it.
|
||||
// (`u128::MAX as f32` is infinity, but this is still ok.)
|
||||
const MAX_P1: $f = $u::MAX as $f;
|
||||
if *self > -1.0 && *self < MAX_P1 {
|
||||
return float_to_int_unchecked!(*self => $u);
|
||||
}
|
||||
}
|
||||
panic!("Float cannot be represented in the target unsigned int type")
|
||||
}
|
||||
)*}
|
||||
}
|
||||
|
||||
macro_rules! impl_to_element_float {
|
||||
($T:ident) => {
|
||||
impl ToElement for $T {
|
||||
impl_to_element_float_to_signed_int! { $T:
|
||||
fn to_isize -> isize;
|
||||
fn to_i8 -> i8;
|
||||
fn to_i16 -> i16;
|
||||
fn to_i32 -> i32;
|
||||
fn to_i64 -> i64;
|
||||
fn to_i128 -> i128;
|
||||
}
|
||||
|
||||
impl_to_element_float_to_unsigned_int! { $T:
|
||||
fn to_usize -> usize;
|
||||
fn to_u8 -> u8;
|
||||
fn to_u16 -> u16;
|
||||
fn to_u32 -> u32;
|
||||
fn to_u64 -> u64;
|
||||
fn to_u128 -> u128;
|
||||
}
|
||||
|
||||
impl_to_element_float_to_float! { $T:
|
||||
fn to_f32 -> f32;
|
||||
fn to_f64 -> f64;
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl_to_element_float!(f32);
|
||||
impl_to_element_float!(f64);
|
||||
|
||||
impl ToElement for f16 {
|
||||
#[inline]
|
||||
fn to_i64(&self) -> i64 {
|
||||
Self::to_f32(*self).to_i64()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u64(&self) -> u64 {
|
||||
Self::to_f32(*self).to_u64()
|
||||
}
|
||||
#[inline]
|
||||
fn to_i8(&self) -> i8 {
|
||||
Self::to_f32(*self).to_i8()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u8(&self) -> u8 {
|
||||
Self::to_f32(*self).to_u8()
|
||||
}
|
||||
#[inline]
|
||||
fn to_i16(&self) -> i16 {
|
||||
Self::to_f32(*self).to_i16()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u16(&self) -> u16 {
|
||||
Self::to_f32(*self).to_u16()
|
||||
}
|
||||
#[inline]
|
||||
fn to_i32(&self) -> i32 {
|
||||
Self::to_f32(*self).to_i32()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u32(&self) -> u32 {
|
||||
Self::to_f32(*self).to_u32()
|
||||
}
|
||||
#[inline]
|
||||
fn to_f32(&self) -> f32 {
|
||||
Self::to_f32(*self)
|
||||
}
|
||||
#[inline]
|
||||
fn to_f64(&self) -> f64 {
|
||||
Self::to_f64(*self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToElement for bf16 {
|
||||
#[inline]
|
||||
fn to_i64(&self) -> i64 {
|
||||
Self::to_f32(*self).to_i64()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u64(&self) -> u64 {
|
||||
Self::to_f32(*self).to_u64()
|
||||
}
|
||||
#[inline]
|
||||
fn to_i8(&self) -> i8 {
|
||||
Self::to_f32(*self).to_i8()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u8(&self) -> u8 {
|
||||
Self::to_f32(*self).to_u8()
|
||||
}
|
||||
#[inline]
|
||||
fn to_i16(&self) -> i16 {
|
||||
Self::to_f32(*self).to_i16()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u16(&self) -> u16 {
|
||||
Self::to_f32(*self).to_u16()
|
||||
}
|
||||
#[inline]
|
||||
fn to_i32(&self) -> i32 {
|
||||
Self::to_f32(*self).to_i32()
|
||||
}
|
||||
#[inline]
|
||||
fn to_u32(&self) -> u32 {
|
||||
Self::to_f32(*self).to_u32()
|
||||
}
|
||||
#[inline]
|
||||
fn to_f32(&self) -> f32 {
|
||||
Self::to_f32(*self)
|
||||
}
|
||||
#[inline]
|
||||
fn to_f64(&self) -> f64 {
|
||||
Self::to_f64(*self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToElement for bool {
|
||||
#[inline]
|
||||
fn to_i64(&self) -> i64 {
|
||||
*self as i64
|
||||
}
|
||||
#[inline]
|
||||
fn to_u64(&self) -> u64 {
|
||||
*self as u64
|
||||
}
|
||||
#[inline]
|
||||
fn to_i8(&self) -> i8 {
|
||||
*self as i8
|
||||
}
|
||||
#[inline]
|
||||
fn to_u8(&self) -> u8 {
|
||||
*self as u8
|
||||
}
|
||||
#[inline]
|
||||
fn to_i16(&self) -> i16 {
|
||||
*self as i16
|
||||
}
|
||||
#[inline]
|
||||
fn to_u16(&self) -> u16 {
|
||||
*self as u16
|
||||
}
|
||||
#[inline]
|
||||
fn to_i32(&self) -> i32 {
|
||||
*self as i32
|
||||
}
|
||||
#[inline]
|
||||
fn to_u32(&self) -> u32 {
|
||||
*self as u32
|
||||
}
|
||||
#[inline]
|
||||
fn to_f32(&self) -> f32 {
|
||||
self.to_u8() as f32
|
||||
}
|
||||
#[inline]
|
||||
fn to_f64(&self) -> f64 {
|
||||
self.to_u8() as f64
|
||||
}
|
||||
}
|
||||
|
||||
mod tests {
|
||||
#[allow(unused_imports)]
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn to_element_float() {
|
||||
let f32_toolarge = 1e39f64;
|
||||
assert_eq!(f32_toolarge.to_f32(), f32::INFINITY);
|
||||
assert_eq!((-f32_toolarge).to_f32(), f32::NEG_INFINITY);
|
||||
assert_eq!((f32::MAX as f64).to_f32(), f32::MAX);
|
||||
assert_eq!((-f32::MAX as f64).to_f32(), -f32::MAX);
|
||||
assert_eq!(f64::INFINITY.to_f32(), f32::INFINITY);
|
||||
assert_eq!((f64::NEG_INFINITY).to_f32(), f32::NEG_INFINITY);
|
||||
assert!((f64::NAN).to_f32().is_nan());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_signed_to_u8_underflow() {
|
||||
let _x = (-1i8).to_u8();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_signed_to_u16_underflow() {
|
||||
let _x = (-1i8).to_u16();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_signed_to_u32_underflow() {
|
||||
let _x = (-1i8).to_u32();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_signed_to_u64_underflow() {
|
||||
let _x = (-1i8).to_u64();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_signed_to_u128_underflow() {
|
||||
let _x = (-1i8).to_u128();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_signed_to_usize_underflow() {
|
||||
let _x = (-1i8).to_usize();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_unsigned_to_u8_overflow() {
|
||||
let _x = 256.to_u8();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_unsigned_to_u16_overflow() {
|
||||
let _x = 65_536.to_u16();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_unsigned_to_u32_overflow() {
|
||||
let _x = 4_294_967_296u64.to_u32();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn to_element_unsigned_to_u64_overflow() {
|
||||
let _x = 18_446_744_073_709_551_616u128.to_u64();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_element_int_to_float() {
|
||||
assert_eq!((-1).to_f32(), -1.0);
|
||||
assert_eq!((-1).to_f64(), -1.0);
|
||||
assert_eq!(255.to_f32(), 255.0);
|
||||
assert_eq!(65_535.to_f64(), 65_535.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_element_float_to_int() {
|
||||
assert_eq!((-1.0).to_i8(), -1);
|
||||
assert_eq!(1.0.to_u8(), 1);
|
||||
assert_eq!(1.8.to_u16(), 1);
|
||||
assert_eq!(123.456.to_u32(), 123);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
mod base;
|
||||
|
||||
/// Tensor element casting.
|
||||
pub mod cast;
|
||||
|
||||
pub use base::*;
|
|
@ -1,13 +1,13 @@
|
|||
use super::cat::cat_with_slice_assign;
|
||||
use super::repeat::repeat_with_slice_assign;
|
||||
use super::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
use crate::tensor::cast::ToElement;
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Int};
|
||||
use crate::{cartesian_grid, Tensor};
|
||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::ops::Range;
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use crate::{argsort, sort, sort_with_indices};
|
||||
|
@ -536,10 +536,7 @@ pub trait IntTensorOps<B: Backend> {
|
|||
///
|
||||
/// The elements of `lhs` raised to the value of `rhs`.
|
||||
fn int_powi_scalar<const D: usize>(lhs: IntTensor<B, D>, rhs: IntElem<B>) -> IntTensor<B, D> {
|
||||
B::float_into_int(B::float_powf_scalar(
|
||||
B::int_into_float(lhs),
|
||||
rhs.to_f32().unwrap(),
|
||||
))
|
||||
B::float_into_int(B::float_powf_scalar(B::int_into_float(lhs), rhs.to_f32()))
|
||||
}
|
||||
|
||||
/// Element-wise power with a floatTensor.
|
||||
|
|
|
@ -2,13 +2,13 @@ use super::cat::cat_with_slice_assign;
|
|||
use super::repeat::repeat_with_slice_assign;
|
||||
use super::{BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntElem, IntTensor};
|
||||
use crate::backend::BackendBridge;
|
||||
use crate::tensor::cast::ToElement;
|
||||
use crate::Tensor;
|
||||
use crate::{backend::Backend, tensor::Shape, Data, Distribution, ElementConversion, Float};
|
||||
use crate::{tensor::api::chunk, tensor::api::narrow};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
use core::ops::Range;
|
||||
use num_traits::ToPrimitive;
|
||||
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
use crate::{argsort, sort, sort_with_indices};
|
||||
|
@ -1005,7 +1005,7 @@ pub trait FloatTensorOps<B: Backend> {
|
|||
lhs: FloatTensor<B, D>,
|
||||
rhs: IntElem<B>,
|
||||
) -> FloatTensor<B, D> {
|
||||
Self::float_powf_scalar(lhs, rhs.to_f32().unwrap())
|
||||
Self::float_powf_scalar(lhs, rhs.to_f32())
|
||||
}
|
||||
|
||||
/// Returns a new tensor with values raised to the power of float `value`.
|
||||
|
|
Loading…
Reference in New Issue