mirror of https://github.com/tracel-ai/burn.git
Perf/ndarray data (#2439)
* Improve ndarray data manip * Remove print * Perf: TensorData * Cleanup
This commit is contained in:
parent
c0e975326c
commit
d3834785a3
|
@ -34,9 +34,58 @@ mod utils {
|
||||||
E: FloatNdArrayElement,
|
E: FloatNdArrayElement,
|
||||||
{
|
{
|
||||||
let shape = self.shape();
|
let shape = self.shape();
|
||||||
let values = self.array.into_iter().collect();
|
|
||||||
|
|
||||||
TensorData::new(values, shape)
|
let vec = if self.is_contiguous() {
|
||||||
|
match self.array.try_into_owned_nocopy() {
|
||||||
|
Ok(owned) => {
|
||||||
|
let (vec, _offset) = owned.into_raw_vec_and_offset();
|
||||||
|
vec
|
||||||
|
}
|
||||||
|
Err(array) => array.into_iter().collect(),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
self.array.into_iter().collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
TensorData::new(vec, shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub(crate) fn is_contiguous(&self) -> bool {
|
||||||
|
let shape = self.array.shape();
|
||||||
|
let strides = self.array.strides();
|
||||||
|
|
||||||
|
if shape.is_empty() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if shape.len() == 1 {
|
||||||
|
return strides[0] == 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut prev_stride = 1;
|
||||||
|
let mut current_num_elems_shape = 1;
|
||||||
|
|
||||||
|
for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
|
||||||
|
let stride = if *stride <= 0 {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
*stride as usize
|
||||||
|
};
|
||||||
|
if i > 0 {
|
||||||
|
if current_num_elems_shape != stride {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if prev_stride >= stride {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
current_num_elems_shape *= shape;
|
||||||
|
prev_stride = stride;
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,8 +153,17 @@ where
|
||||||
/// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
|
/// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
|
||||||
pub fn from_data(data: TensorData) -> NdArrayTensor<E> {
|
pub fn from_data(data: TensorData) -> NdArrayTensor<E> {
|
||||||
let shape: Shape = data.shape.clone().into();
|
let shape: Shape = data.shape.clone().into();
|
||||||
let to_array = |data: TensorData| Array::from_iter(data.iter()).into_shared();
|
let into_array = |data: TensorData| match data.into_vec::<E>() {
|
||||||
let array = to_array(data);
|
Ok(vec) => Array::from_vec(vec).into_shared(),
|
||||||
|
Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
|
||||||
|
};
|
||||||
|
let to_array = |data: TensorData| Array::from_iter(data.iter::<E>()).into_shared();
|
||||||
|
|
||||||
|
let array = if data.dtype == E::dtype() {
|
||||||
|
into_array(data)
|
||||||
|
} else {
|
||||||
|
to_array(data)
|
||||||
|
};
|
||||||
let ndims = shape.num_dims();
|
let ndims = shape.num_dims();
|
||||||
|
|
||||||
reshape!(
|
reshape!(
|
||||||
|
|
|
@ -4,6 +4,7 @@ use alloc::boxed::Box;
|
||||||
use alloc::format;
|
use alloc::format;
|
||||||
use alloc::string::String;
|
use alloc::string::String;
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
|
use bytemuck::AnyBitPattern;
|
||||||
use half::{bf16, f16};
|
use half::{bf16, f16};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -128,6 +129,31 @@ impl TensorData {
|
||||||
Ok(self.as_slice()?.to_vec())
|
Ok(self.as_slice()?.to_vec())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Returns the tensor data as a vector of scalar values.
|
||||||
|
pub fn into_vec<E: Element>(mut self) -> Result<Vec<E>, DataError> {
|
||||||
|
if E::dtype() != self.dtype {
|
||||||
|
return Err(DataError::TypeMismatch(format!(
|
||||||
|
"Invalid target element type (expected {:?}, got {:?})",
|
||||||
|
self.dtype,
|
||||||
|
E::dtype()
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
let capacity_bytes = self.bytes.capacity();
|
||||||
|
let length_bytes = self.bytes.len();
|
||||||
|
let size_elem = core::mem::size_of::<E>();
|
||||||
|
|
||||||
|
let capacity = capacity_bytes / size_elem;
|
||||||
|
let length = length_bytes / size_elem;
|
||||||
|
|
||||||
|
unsafe {
|
||||||
|
let ptr = self.bytes.as_mut_ptr();
|
||||||
|
core::mem::forget(self.bytes);
|
||||||
|
|
||||||
|
Ok(Vec::from_raw_parts(ptr.cast::<E>(), length, capacity))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns an iterator over the values of the tensor data.
|
/// Returns an iterator over the values of the tensor data.
|
||||||
pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
|
pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
|
||||||
if E::dtype() == self.dtype {
|
if E::dtype() == self.dtype {
|
||||||
|
@ -273,11 +299,47 @@ impl TensorData {
|
||||||
pub fn convert<E: Element>(self) -> Self {
|
pub fn convert<E: Element>(self) -> Self {
|
||||||
if E::dtype() == self.dtype {
|
if E::dtype() == self.dtype {
|
||||||
self
|
self
|
||||||
|
} else if core::mem::size_of::<E>() == self.dtype.size()
|
||||||
|
&& !matches!(self.dtype, DType::Bool | DType::QFloat(_))
|
||||||
|
{
|
||||||
|
match self.dtype {
|
||||||
|
DType::F64 => self.convert_inplace::<f64, E>(),
|
||||||
|
DType::F32 => self.convert_inplace::<f32, E>(),
|
||||||
|
DType::F16 => self.convert_inplace::<f16, E>(),
|
||||||
|
DType::BF16 => self.convert_inplace::<bf16, E>(),
|
||||||
|
DType::I64 => self.convert_inplace::<i64, E>(),
|
||||||
|
DType::I32 => self.convert_inplace::<i32, E>(),
|
||||||
|
DType::I16 => self.convert_inplace::<i16, E>(),
|
||||||
|
DType::I8 => self.convert_inplace::<i8, E>(),
|
||||||
|
DType::U64 => self.convert_inplace::<u64, E>(),
|
||||||
|
DType::U32 => self.convert_inplace::<u32, E>(),
|
||||||
|
DType::U8 => self.convert_inplace::<u8, E>(),
|
||||||
|
DType::Bool | DType::QFloat(_) => unreachable!(),
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
TensorData::new(self.iter::<E>().collect(), self.shape)
|
TensorData::new(self.iter::<E>().collect(), self.shape)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn convert_inplace<Current: Element + AnyBitPattern, Target: Element>(mut self) -> Self {
|
||||||
|
let step = core::mem::size_of::<Current>();
|
||||||
|
|
||||||
|
for offset in 0..(self.bytes.len() / step) {
|
||||||
|
let start = offset * step;
|
||||||
|
let end = start + step;
|
||||||
|
|
||||||
|
let slice_old = &mut self.bytes[start..end];
|
||||||
|
let val: Current = *bytemuck::from_bytes(slice_old);
|
||||||
|
let val = &val.elem::<Target>();
|
||||||
|
let slice_new = bytemuck::bytes_of(val);
|
||||||
|
|
||||||
|
slice_old.clone_from_slice(slice_new);
|
||||||
|
}
|
||||||
|
self.dtype = Target::dtype();
|
||||||
|
|
||||||
|
self
|
||||||
|
}
|
||||||
|
|
||||||
/// Returns the data as a slice of bytes.
|
/// Returns the data as a slice of bytes.
|
||||||
pub fn as_bytes(&self) -> &[u8] {
|
pub fn as_bytes(&self) -> &[u8] {
|
||||||
self.bytes.as_slice()
|
self.bytes.as_slice()
|
||||||
|
@ -1021,6 +1083,34 @@ mod tests {
|
||||||
use alloc::vec;
|
use alloc::vec;
|
||||||
use rand::{rngs::StdRng, SeedableRng};
|
use rand::{rngs::StdRng, SeedableRng};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn into_vec_should_yield_same_value_as_iter() {
|
||||||
|
let shape = Shape::new([3, 5, 6]);
|
||||||
|
let data = TensorData::random::<f32, _, _>(
|
||||||
|
shape,
|
||||||
|
Distribution::Default,
|
||||||
|
&mut StdRng::from_entropy(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let expected = data.iter::<f32>().collect::<Vec<f32>>();
|
||||||
|
let actual = data.into_vec::<f32>().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(expected, actual);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
|
fn into_vec_should_assert_wrong_dtype() {
|
||||||
|
let shape = Shape::new([3, 5, 6]);
|
||||||
|
let data = TensorData::random::<f32, _, _>(
|
||||||
|
shape,
|
||||||
|
Distribution::Default,
|
||||||
|
&mut StdRng::from_entropy(),
|
||||||
|
);
|
||||||
|
|
||||||
|
data.into_vec::<i32>().unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn should_have_right_num_elements() {
|
fn should_have_right_num_elements() {
|
||||||
let shape = Shape::new([3, 5, 6]);
|
let shape = Shape::new([3, 5, 6]);
|
||||||
|
@ -1084,4 +1174,25 @@ mod tests {
|
||||||
assert_eq!(data1.bytes.len(), 2 * factor);
|
assert_eq!(data1.bytes.len(), 2 * factor);
|
||||||
assert_eq!(data1.bytes.capacity(), 5 * factor);
|
assert_eq!(data1.bytes.capacity(), 5 * factor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn should_convert_bytes_correctly_inplace() {
|
||||||
|
fn test_precision<E: Element>() {
|
||||||
|
let data = TensorData::new((0..32).collect(), [32]);
|
||||||
|
for (i, val) in data
|
||||||
|
.clone()
|
||||||
|
.convert::<E>()
|
||||||
|
.into_vec::<E>()
|
||||||
|
.unwrap()
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
assert_eq!(i as u32, val.elem::<u32>())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
test_precision::<f32>();
|
||||||
|
test_precision::<f16>();
|
||||||
|
test_precision::<i64>();
|
||||||
|
test_precision::<i32>();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -260,6 +260,27 @@ pub enum DType {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DType {
|
impl DType {
|
||||||
|
/// Returns the size of a type in bytes.
|
||||||
|
pub const fn size(&self) -> usize {
|
||||||
|
match self {
|
||||||
|
DType::F64 => core::mem::size_of::<f64>(),
|
||||||
|
DType::F32 => core::mem::size_of::<f32>(),
|
||||||
|
DType::F16 => core::mem::size_of::<f16>(),
|
||||||
|
DType::BF16 => core::mem::size_of::<bf16>(),
|
||||||
|
DType::I64 => core::mem::size_of::<i64>(),
|
||||||
|
DType::I32 => core::mem::size_of::<i32>(),
|
||||||
|
DType::I16 => core::mem::size_of::<i16>(),
|
||||||
|
DType::I8 => core::mem::size_of::<i8>(),
|
||||||
|
DType::U64 => core::mem::size_of::<u64>(),
|
||||||
|
DType::U32 => core::mem::size_of::<u32>(),
|
||||||
|
DType::U8 => core::mem::size_of::<u8>(),
|
||||||
|
DType::Bool => core::mem::size_of::<bool>(),
|
||||||
|
DType::QFloat(strategy) => match strategy {
|
||||||
|
QuantizationStrategy::PerTensorAffineInt8(_) => core::mem::size_of::<u8>(),
|
||||||
|
QuantizationStrategy::PerTensorSymmetricInt8(_) => core::mem::size_of::<u8>(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
/// Returns true if the data type is a floating point type.
|
/// Returns true if the data type is a floating point type.
|
||||||
pub fn is_float(&self) -> bool {
|
pub fn is_float(&self) -> bool {
|
||||||
matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)
|
matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)
|
||||||
|
|
Loading…
Reference in New Issue