mirror of https://github.com/tracel-ai/burn.git
Perf/ndarray data (#2439)
* Improve ndarray data manip * Remove print * Perf: TensorData * Cleanup
This commit is contained in:
parent
c0e975326c
commit
d3834785a3
|
@ -34,9 +34,58 @@ mod utils {
|
|||
E: FloatNdArrayElement,
|
||||
{
|
||||
let shape = self.shape();
|
||||
let values = self.array.into_iter().collect();
|
||||
|
||||
TensorData::new(values, shape)
|
||||
let vec = if self.is_contiguous() {
|
||||
match self.array.try_into_owned_nocopy() {
|
||||
Ok(owned) => {
|
||||
let (vec, _offset) = owned.into_raw_vec_and_offset();
|
||||
vec
|
||||
}
|
||||
Err(array) => array.into_iter().collect(),
|
||||
}
|
||||
} else {
|
||||
self.array.into_iter().collect()
|
||||
};
|
||||
|
||||
TensorData::new(vec, shape)
|
||||
}
|
||||
|
||||
pub(crate) fn is_contiguous(&self) -> bool {
|
||||
let shape = self.array.shape();
|
||||
let strides = self.array.strides();
|
||||
|
||||
if shape.is_empty() {
|
||||
return true;
|
||||
}
|
||||
|
||||
if shape.len() == 1 {
|
||||
return strides[0] == 1;
|
||||
}
|
||||
|
||||
let mut prev_stride = 1;
|
||||
let mut current_num_elems_shape = 1;
|
||||
|
||||
for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
|
||||
let stride = if *stride <= 0 {
|
||||
return false;
|
||||
} else {
|
||||
*stride as usize
|
||||
};
|
||||
if i > 0 {
|
||||
if current_num_elems_shape != stride {
|
||||
return false;
|
||||
}
|
||||
|
||||
if prev_stride >= stride {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
current_num_elems_shape *= shape;
|
||||
prev_stride = stride;
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -104,8 +153,17 @@ where
|
|||
/// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
|
||||
pub fn from_data(data: TensorData) -> NdArrayTensor<E> {
|
||||
let shape: Shape = data.shape.clone().into();
|
||||
let to_array = |data: TensorData| Array::from_iter(data.iter()).into_shared();
|
||||
let array = to_array(data);
|
||||
let into_array = |data: TensorData| match data.into_vec::<E>() {
|
||||
Ok(vec) => Array::from_vec(vec).into_shared(),
|
||||
Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
|
||||
};
|
||||
let to_array = |data: TensorData| Array::from_iter(data.iter::<E>()).into_shared();
|
||||
|
||||
let array = if data.dtype == E::dtype() {
|
||||
into_array(data)
|
||||
} else {
|
||||
to_array(data)
|
||||
};
|
||||
let ndims = shape.num_dims();
|
||||
|
||||
reshape!(
|
||||
|
|
|
@ -4,6 +4,7 @@ use alloc::boxed::Box;
|
|||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use bytemuck::AnyBitPattern;
|
||||
use half::{bf16, f16};
|
||||
|
||||
use crate::{
|
||||
|
@ -128,6 +129,31 @@ impl TensorData {
|
|||
Ok(self.as_slice()?.to_vec())
|
||||
}
|
||||
|
||||
/// Returns the tensor data as a vector of scalar values.
|
||||
pub fn into_vec<E: Element>(mut self) -> Result<Vec<E>, DataError> {
|
||||
if E::dtype() != self.dtype {
|
||||
return Err(DataError::TypeMismatch(format!(
|
||||
"Invalid target element type (expected {:?}, got {:?})",
|
||||
self.dtype,
|
||||
E::dtype()
|
||||
)));
|
||||
}
|
||||
|
||||
let capacity_bytes = self.bytes.capacity();
|
||||
let length_bytes = self.bytes.len();
|
||||
let size_elem = core::mem::size_of::<E>();
|
||||
|
||||
let capacity = capacity_bytes / size_elem;
|
||||
let length = length_bytes / size_elem;
|
||||
|
||||
unsafe {
|
||||
let ptr = self.bytes.as_mut_ptr();
|
||||
core::mem::forget(self.bytes);
|
||||
|
||||
Ok(Vec::from_raw_parts(ptr.cast::<E>(), length, capacity))
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns an iterator over the values of the tensor data.
|
||||
pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
|
||||
if E::dtype() == self.dtype {
|
||||
|
@ -273,11 +299,47 @@ impl TensorData {
|
|||
pub fn convert<E: Element>(self) -> Self {
|
||||
if E::dtype() == self.dtype {
|
||||
self
|
||||
} else if core::mem::size_of::<E>() == self.dtype.size()
|
||||
&& !matches!(self.dtype, DType::Bool | DType::QFloat(_))
|
||||
{
|
||||
match self.dtype {
|
||||
DType::F64 => self.convert_inplace::<f64, E>(),
|
||||
DType::F32 => self.convert_inplace::<f32, E>(),
|
||||
DType::F16 => self.convert_inplace::<f16, E>(),
|
||||
DType::BF16 => self.convert_inplace::<bf16, E>(),
|
||||
DType::I64 => self.convert_inplace::<i64, E>(),
|
||||
DType::I32 => self.convert_inplace::<i32, E>(),
|
||||
DType::I16 => self.convert_inplace::<i16, E>(),
|
||||
DType::I8 => self.convert_inplace::<i8, E>(),
|
||||
DType::U64 => self.convert_inplace::<u64, E>(),
|
||||
DType::U32 => self.convert_inplace::<u32, E>(),
|
||||
DType::U8 => self.convert_inplace::<u8, E>(),
|
||||
DType::Bool | DType::QFloat(_) => unreachable!(),
|
||||
}
|
||||
} else {
|
||||
TensorData::new(self.iter::<E>().collect(), self.shape)
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_inplace<Current: Element + AnyBitPattern, Target: Element>(mut self) -> Self {
|
||||
let step = core::mem::size_of::<Current>();
|
||||
|
||||
for offset in 0..(self.bytes.len() / step) {
|
||||
let start = offset * step;
|
||||
let end = start + step;
|
||||
|
||||
let slice_old = &mut self.bytes[start..end];
|
||||
let val: Current = *bytemuck::from_bytes(slice_old);
|
||||
let val = &val.elem::<Target>();
|
||||
let slice_new = bytemuck::bytes_of(val);
|
||||
|
||||
slice_old.clone_from_slice(slice_new);
|
||||
}
|
||||
self.dtype = Target::dtype();
|
||||
|
||||
self
|
||||
}
|
||||
|
||||
/// Returns the data as a slice of bytes.
|
||||
pub fn as_bytes(&self) -> &[u8] {
|
||||
self.bytes.as_slice()
|
||||
|
@ -1021,6 +1083,34 @@ mod tests {
|
|||
use alloc::vec;
|
||||
use rand::{rngs::StdRng, SeedableRng};
|
||||
|
||||
#[test]
|
||||
fn into_vec_should_yield_same_value_as_iter() {
|
||||
let shape = Shape::new([3, 5, 6]);
|
||||
let data = TensorData::random::<f32, _, _>(
|
||||
shape,
|
||||
Distribution::Default,
|
||||
&mut StdRng::from_entropy(),
|
||||
);
|
||||
|
||||
let expected = data.iter::<f32>().collect::<Vec<f32>>();
|
||||
let actual = data.into_vec::<f32>().unwrap();
|
||||
|
||||
assert_eq!(expected, actual);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn into_vec_should_assert_wrong_dtype() {
|
||||
let shape = Shape::new([3, 5, 6]);
|
||||
let data = TensorData::random::<f32, _, _>(
|
||||
shape,
|
||||
Distribution::Default,
|
||||
&mut StdRng::from_entropy(),
|
||||
);
|
||||
|
||||
data.into_vec::<i32>().unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_have_right_num_elements() {
|
||||
let shape = Shape::new([3, 5, 6]);
|
||||
|
@ -1084,4 +1174,25 @@ mod tests {
|
|||
assert_eq!(data1.bytes.len(), 2 * factor);
|
||||
assert_eq!(data1.bytes.capacity(), 5 * factor);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn should_convert_bytes_correctly_inplace() {
|
||||
fn test_precision<E: Element>() {
|
||||
let data = TensorData::new((0..32).collect(), [32]);
|
||||
for (i, val) in data
|
||||
.clone()
|
||||
.convert::<E>()
|
||||
.into_vec::<E>()
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
{
|
||||
assert_eq!(i as u32, val.elem::<u32>())
|
||||
}
|
||||
}
|
||||
test_precision::<f32>();
|
||||
test_precision::<f16>();
|
||||
test_precision::<i64>();
|
||||
test_precision::<i32>();
|
||||
}
|
||||
}
|
||||
|
|
|
@ -260,6 +260,27 @@ pub enum DType {
|
|||
}
|
||||
|
||||
impl DType {
|
||||
/// Returns the size of a type in bytes.
|
||||
pub const fn size(&self) -> usize {
|
||||
match self {
|
||||
DType::F64 => core::mem::size_of::<f64>(),
|
||||
DType::F32 => core::mem::size_of::<f32>(),
|
||||
DType::F16 => core::mem::size_of::<f16>(),
|
||||
DType::BF16 => core::mem::size_of::<bf16>(),
|
||||
DType::I64 => core::mem::size_of::<i64>(),
|
||||
DType::I32 => core::mem::size_of::<i32>(),
|
||||
DType::I16 => core::mem::size_of::<i16>(),
|
||||
DType::I8 => core::mem::size_of::<i8>(),
|
||||
DType::U64 => core::mem::size_of::<u64>(),
|
||||
DType::U32 => core::mem::size_of::<u32>(),
|
||||
DType::U8 => core::mem::size_of::<u8>(),
|
||||
DType::Bool => core::mem::size_of::<bool>(),
|
||||
DType::QFloat(strategy) => match strategy {
|
||||
QuantizationStrategy::PerTensorAffineInt8(_) => core::mem::size_of::<u8>(),
|
||||
QuantizationStrategy::PerTensorSymmetricInt8(_) => core::mem::size_of::<u8>(),
|
||||
},
|
||||
}
|
||||
}
|
||||
/// Returns true if the data type is a floating point type.
|
||||
pub fn is_float(&self) -> bool {
|
||||
matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)
|
||||
|
|
Loading…
Reference in New Issue