Perf/ndarray data (#2439)

* Improve ndarray data manip

* Remove print

* Perf: TensorData

* Cleanup
This commit is contained in:
Nathaniel Simard 2024-10-30 08:28:21 -04:00 committed by GitHub
parent c0e975326c
commit d3834785a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 194 additions and 4 deletions

View File

@ -34,9 +34,58 @@ mod utils {
E: FloatNdArrayElement,
{
let shape = self.shape();
let values = self.array.into_iter().collect();
TensorData::new(values, shape)
let vec = if self.is_contiguous() {
match self.array.try_into_owned_nocopy() {
Ok(owned) => {
let (vec, _offset) = owned.into_raw_vec_and_offset();
vec
}
Err(array) => array.into_iter().collect(),
}
} else {
self.array.into_iter().collect()
};
TensorData::new(vec, shape)
}
pub(crate) fn is_contiguous(&self) -> bool {
let shape = self.array.shape();
let strides = self.array.strides();
if shape.is_empty() {
return true;
}
if shape.len() == 1 {
return strides[0] == 1;
}
let mut prev_stride = 1;
let mut current_num_elems_shape = 1;
for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() {
let stride = if *stride <= 0 {
return false;
} else {
*stride as usize
};
if i > 0 {
if current_num_elems_shape != stride {
return false;
}
if prev_stride >= stride {
return false;
}
}
current_num_elems_shape *= shape;
prev_stride = stride;
}
true
}
}
}
@ -104,8 +153,17 @@ where
/// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData).
pub fn from_data(data: TensorData) -> NdArrayTensor<E> {
let shape: Shape = data.shape.clone().into();
let to_array = |data: TensorData| Array::from_iter(data.iter()).into_shared();
let array = to_array(data);
let into_array = |data: TensorData| match data.into_vec::<E>() {
Ok(vec) => Array::from_vec(vec).into_shared(),
Err(err) => panic!("Data should have the same element type as the tensor {err:?}"),
};
let to_array = |data: TensorData| Array::from_iter(data.iter::<E>()).into_shared();
let array = if data.dtype == E::dtype() {
into_array(data)
} else {
to_array(data)
};
let ndims = shape.num_dims();
reshape!(

View File

@ -4,6 +4,7 @@ use alloc::boxed::Box;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use bytemuck::AnyBitPattern;
use half::{bf16, f16};
use crate::{
@ -128,6 +129,31 @@ impl TensorData {
Ok(self.as_slice()?.to_vec())
}
/// Returns the tensor data as a vector of scalar values.
pub fn into_vec<E: Element>(mut self) -> Result<Vec<E>, DataError> {
if E::dtype() != self.dtype {
return Err(DataError::TypeMismatch(format!(
"Invalid target element type (expected {:?}, got {:?})",
self.dtype,
E::dtype()
)));
}
let capacity_bytes = self.bytes.capacity();
let length_bytes = self.bytes.len();
let size_elem = core::mem::size_of::<E>();
let capacity = capacity_bytes / size_elem;
let length = length_bytes / size_elem;
unsafe {
let ptr = self.bytes.as_mut_ptr();
core::mem::forget(self.bytes);
Ok(Vec::from_raw_parts(ptr.cast::<E>(), length, capacity))
}
}
/// Returns an iterator over the values of the tensor data.
pub fn iter<E: Element>(&self) -> Box<dyn Iterator<Item = E> + '_> {
if E::dtype() == self.dtype {
@ -273,11 +299,47 @@ impl TensorData {
pub fn convert<E: Element>(self) -> Self {
if E::dtype() == self.dtype {
self
} else if core::mem::size_of::<E>() == self.dtype.size()
&& !matches!(self.dtype, DType::Bool | DType::QFloat(_))
{
match self.dtype {
DType::F64 => self.convert_inplace::<f64, E>(),
DType::F32 => self.convert_inplace::<f32, E>(),
DType::F16 => self.convert_inplace::<f16, E>(),
DType::BF16 => self.convert_inplace::<bf16, E>(),
DType::I64 => self.convert_inplace::<i64, E>(),
DType::I32 => self.convert_inplace::<i32, E>(),
DType::I16 => self.convert_inplace::<i16, E>(),
DType::I8 => self.convert_inplace::<i8, E>(),
DType::U64 => self.convert_inplace::<u64, E>(),
DType::U32 => self.convert_inplace::<u32, E>(),
DType::U8 => self.convert_inplace::<u8, E>(),
DType::Bool | DType::QFloat(_) => unreachable!(),
}
} else {
TensorData::new(self.iter::<E>().collect(), self.shape)
}
}
fn convert_inplace<Current: Element + AnyBitPattern, Target: Element>(mut self) -> Self {
let step = core::mem::size_of::<Current>();
for offset in 0..(self.bytes.len() / step) {
let start = offset * step;
let end = start + step;
let slice_old = &mut self.bytes[start..end];
let val: Current = *bytemuck::from_bytes(slice_old);
let val = &val.elem::<Target>();
let slice_new = bytemuck::bytes_of(val);
slice_old.clone_from_slice(slice_new);
}
self.dtype = Target::dtype();
self
}
/// Returns the data as a slice of bytes.
pub fn as_bytes(&self) -> &[u8] {
self.bytes.as_slice()
@ -1021,6 +1083,34 @@ mod tests {
use alloc::vec;
use rand::{rngs::StdRng, SeedableRng};
#[test]
fn into_vec_should_yield_same_value_as_iter() {
let shape = Shape::new([3, 5, 6]);
let data = TensorData::random::<f32, _, _>(
shape,
Distribution::Default,
&mut StdRng::from_entropy(),
);
let expected = data.iter::<f32>().collect::<Vec<f32>>();
let actual = data.into_vec::<f32>().unwrap();
assert_eq!(expected, actual);
}
#[test]
#[should_panic]
fn into_vec_should_assert_wrong_dtype() {
let shape = Shape::new([3, 5, 6]);
let data = TensorData::random::<f32, _, _>(
shape,
Distribution::Default,
&mut StdRng::from_entropy(),
);
data.into_vec::<i32>().unwrap();
}
#[test]
fn should_have_right_num_elements() {
let shape = Shape::new([3, 5, 6]);
@ -1084,4 +1174,25 @@ mod tests {
assert_eq!(data1.bytes.len(), 2 * factor);
assert_eq!(data1.bytes.capacity(), 5 * factor);
}
#[test]
fn should_convert_bytes_correctly_inplace() {
fn test_precision<E: Element>() {
let data = TensorData::new((0..32).collect(), [32]);
for (i, val) in data
.clone()
.convert::<E>()
.into_vec::<E>()
.unwrap()
.into_iter()
.enumerate()
{
assert_eq!(i as u32, val.elem::<u32>())
}
}
test_precision::<f32>();
test_precision::<f16>();
test_precision::<i64>();
test_precision::<i32>();
}
}

View File

@ -260,6 +260,27 @@ pub enum DType {
}
impl DType {
/// Returns the size of a type in bytes.
pub const fn size(&self) -> usize {
match self {
DType::F64 => core::mem::size_of::<f64>(),
DType::F32 => core::mem::size_of::<f32>(),
DType::F16 => core::mem::size_of::<f16>(),
DType::BF16 => core::mem::size_of::<bf16>(),
DType::I64 => core::mem::size_of::<i64>(),
DType::I32 => core::mem::size_of::<i32>(),
DType::I16 => core::mem::size_of::<i16>(),
DType::I8 => core::mem::size_of::<i8>(),
DType::U64 => core::mem::size_of::<u64>(),
DType::U32 => core::mem::size_of::<u32>(),
DType::U8 => core::mem::size_of::<u8>(),
DType::Bool => core::mem::size_of::<bool>(),
DType::QFloat(strategy) => match strategy {
QuantizationStrategy::PerTensorAffineInt8(_) => core::mem::size_of::<u8>(),
QuantizationStrategy::PerTensorSymmetricInt8(_) => core::mem::size_of::<u8>(),
},
}
}
/// Returns true if the data type is a floating point type.
pub fn is_float(&self) -> bool {
matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)