From d3834785a364ed3d9519873df5c8523635147d83 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Wed, 30 Oct 2024 08:28:21 -0400 Subject: [PATCH] Perf/ndarray data (#2439) * Improve ndarray data manip * Remove print * Perf: TensorData * Cleanup --- crates/burn-ndarray/src/tensor.rs | 66 ++++++++++- crates/burn-tensor/src/tensor/data.rs | 111 ++++++++++++++++++ crates/burn-tensor/src/tensor/element/base.rs | 21 ++++ 3 files changed, 194 insertions(+), 4 deletions(-) diff --git a/crates/burn-ndarray/src/tensor.rs b/crates/burn-ndarray/src/tensor.rs index 22141d14d..f9a4acd2e 100644 --- a/crates/burn-ndarray/src/tensor.rs +++ b/crates/burn-ndarray/src/tensor.rs @@ -34,9 +34,58 @@ mod utils { E: FloatNdArrayElement, { let shape = self.shape(); - let values = self.array.into_iter().collect(); - TensorData::new(values, shape) + let vec = if self.is_contiguous() { + match self.array.try_into_owned_nocopy() { + Ok(owned) => { + let (vec, _offset) = owned.into_raw_vec_and_offset(); + vec + } + Err(array) => array.into_iter().collect(), + } + } else { + self.array.into_iter().collect() + }; + + TensorData::new(vec, shape) + } + + pub(crate) fn is_contiguous(&self) -> bool { + let shape = self.array.shape(); + let strides = self.array.strides(); + + if shape.is_empty() { + return true; + } + + if shape.len() == 1 { + return strides[0] == 1; + } + + let mut prev_stride = 1; + let mut current_num_elems_shape = 1; + + for (i, (stride, shape)) in strides.iter().zip(shape).rev().enumerate() { + let stride = if *stride <= 0 { + return false; + } else { + *stride as usize + }; + if i > 0 { + if current_num_elems_shape != stride { + return false; + } + + if prev_stride >= stride { + return false; + } + } + + current_num_elems_shape *= shape; + prev_stride = stride; + } + + true } } } @@ -104,8 +153,17 @@ where /// Create a new [ndarray tensor](NdArrayTensor) from [data](TensorData). pub fn from_data(data: TensorData) -> NdArrayTensor { let shape: Shape = data.shape.clone().into(); - let to_array = |data: TensorData| Array::from_iter(data.iter()).into_shared(); - let array = to_array(data); + let into_array = |data: TensorData| match data.into_vec::() { + Ok(vec) => Array::from_vec(vec).into_shared(), + Err(err) => panic!("Data should have the same element type as the tensor {err:?}"), + }; + let to_array = |data: TensorData| Array::from_iter(data.iter::()).into_shared(); + + let array = if data.dtype == E::dtype() { + into_array(data) + } else { + to_array(data) + }; let ndims = shape.num_dims(); reshape!( diff --git a/crates/burn-tensor/src/tensor/data.rs b/crates/burn-tensor/src/tensor/data.rs index cf0233d33..36df4163e 100644 --- a/crates/burn-tensor/src/tensor/data.rs +++ b/crates/burn-tensor/src/tensor/data.rs @@ -4,6 +4,7 @@ use alloc::boxed::Box; use alloc::format; use alloc::string::String; use alloc::vec::Vec; +use bytemuck::AnyBitPattern; use half::{bf16, f16}; use crate::{ @@ -128,6 +129,31 @@ impl TensorData { Ok(self.as_slice()?.to_vec()) } + /// Returns the tensor data as a vector of scalar values. + pub fn into_vec(mut self) -> Result, DataError> { + if E::dtype() != self.dtype { + return Err(DataError::TypeMismatch(format!( + "Invalid target element type (expected {:?}, got {:?})", + self.dtype, + E::dtype() + ))); + } + + let capacity_bytes = self.bytes.capacity(); + let length_bytes = self.bytes.len(); + let size_elem = core::mem::size_of::(); + + let capacity = capacity_bytes / size_elem; + let length = length_bytes / size_elem; + + unsafe { + let ptr = self.bytes.as_mut_ptr(); + core::mem::forget(self.bytes); + + Ok(Vec::from_raw_parts(ptr.cast::(), length, capacity)) + } + } + /// Returns an iterator over the values of the tensor data. pub fn iter(&self) -> Box + '_> { if E::dtype() == self.dtype { @@ -273,11 +299,47 @@ impl TensorData { pub fn convert(self) -> Self { if E::dtype() == self.dtype { self + } else if core::mem::size_of::() == self.dtype.size() + && !matches!(self.dtype, DType::Bool | DType::QFloat(_)) + { + match self.dtype { + DType::F64 => self.convert_inplace::(), + DType::F32 => self.convert_inplace::(), + DType::F16 => self.convert_inplace::(), + DType::BF16 => self.convert_inplace::(), + DType::I64 => self.convert_inplace::(), + DType::I32 => self.convert_inplace::(), + DType::I16 => self.convert_inplace::(), + DType::I8 => self.convert_inplace::(), + DType::U64 => self.convert_inplace::(), + DType::U32 => self.convert_inplace::(), + DType::U8 => self.convert_inplace::(), + DType::Bool | DType::QFloat(_) => unreachable!(), + } } else { TensorData::new(self.iter::().collect(), self.shape) } } + fn convert_inplace(mut self) -> Self { + let step = core::mem::size_of::(); + + for offset in 0..(self.bytes.len() / step) { + let start = offset * step; + let end = start + step; + + let slice_old = &mut self.bytes[start..end]; + let val: Current = *bytemuck::from_bytes(slice_old); + let val = &val.elem::(); + let slice_new = bytemuck::bytes_of(val); + + slice_old.clone_from_slice(slice_new); + } + self.dtype = Target::dtype(); + + self + } + /// Returns the data as a slice of bytes. pub fn as_bytes(&self) -> &[u8] { self.bytes.as_slice() @@ -1021,6 +1083,34 @@ mod tests { use alloc::vec; use rand::{rngs::StdRng, SeedableRng}; + #[test] + fn into_vec_should_yield_same_value_as_iter() { + let shape = Shape::new([3, 5, 6]); + let data = TensorData::random::( + shape, + Distribution::Default, + &mut StdRng::from_entropy(), + ); + + let expected = data.iter::().collect::>(); + let actual = data.into_vec::().unwrap(); + + assert_eq!(expected, actual); + } + + #[test] + #[should_panic] + fn into_vec_should_assert_wrong_dtype() { + let shape = Shape::new([3, 5, 6]); + let data = TensorData::random::( + shape, + Distribution::Default, + &mut StdRng::from_entropy(), + ); + + data.into_vec::().unwrap(); + } + #[test] fn should_have_right_num_elements() { let shape = Shape::new([3, 5, 6]); @@ -1084,4 +1174,25 @@ mod tests { assert_eq!(data1.bytes.len(), 2 * factor); assert_eq!(data1.bytes.capacity(), 5 * factor); } + + #[test] + fn should_convert_bytes_correctly_inplace() { + fn test_precision() { + let data = TensorData::new((0..32).collect(), [32]); + for (i, val) in data + .clone() + .convert::() + .into_vec::() + .unwrap() + .into_iter() + .enumerate() + { + assert_eq!(i as u32, val.elem::()) + } + } + test_precision::(); + test_precision::(); + test_precision::(); + test_precision::(); + } } diff --git a/crates/burn-tensor/src/tensor/element/base.rs b/crates/burn-tensor/src/tensor/element/base.rs index f7b062d2d..bf08e6ad1 100644 --- a/crates/burn-tensor/src/tensor/element/base.rs +++ b/crates/burn-tensor/src/tensor/element/base.rs @@ -260,6 +260,27 @@ pub enum DType { } impl DType { + /// Returns the size of a type in bytes. + pub const fn size(&self) -> usize { + match self { + DType::F64 => core::mem::size_of::(), + DType::F32 => core::mem::size_of::(), + DType::F16 => core::mem::size_of::(), + DType::BF16 => core::mem::size_of::(), + DType::I64 => core::mem::size_of::(), + DType::I32 => core::mem::size_of::(), + DType::I16 => core::mem::size_of::(), + DType::I8 => core::mem::size_of::(), + DType::U64 => core::mem::size_of::(), + DType::U32 => core::mem::size_of::(), + DType::U8 => core::mem::size_of::(), + DType::Bool => core::mem::size_of::(), + DType::QFloat(strategy) => match strategy { + QuantizationStrategy::PerTensorAffineInt8(_) => core::mem::size_of::(), + QuantizationStrategy::PerTensorSymmetricInt8(_) => core::mem::size_of::(), + }, + } + } /// Returns true if the data type is a floating point type. pub fn is_float(&self) -> bool { matches!(self, DType::F64 | DType::F32 | DType::F16 | DType::BF16)