diff --git a/crates/burn-compute/src/memory_management/dynamic.rs b/crates/burn-compute/src/memory_management/dynamic.rs index 1cba7c596..7d3beb6b2 100644 --- a/crates/burn-compute/src/memory_management/dynamic.rs +++ b/crates/burn-compute/src/memory_management/dynamic.rs @@ -144,6 +144,7 @@ impl Slice { const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16; const BUFFER_ALIGNMENT: usize = 32; +const CHUNK_ROUNDING: usize = 2 * 1024 * 1024; /// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks. pub struct DynamicMemoryManagement { @@ -216,6 +217,13 @@ impl MemoryManagement for DynamicMemoryManagem /// /// Also clean ups, merging free slices together if permitted by the merging strategy fn reserve(&mut self, size: usize) -> Self::Handle { + // log::info!("Number of chunks {}", self.chunks.len()); + // log::info!("Number of slices {}", self.slices.len()); + // log::info!( + // "used memory {} , total memory {}", + // self.calculate_used_memory(), + // self.calculate_total_memory() + // ); let handle = self.reserve_algorithm(size); if self.merging_strategy.should_perform_defragmentation() { @@ -226,18 +234,50 @@ impl MemoryManagement for DynamicMemoryManagem } fn alloc(&mut self, size: usize) -> Self::Handle { - let handle_chunk = self.create_chunk(size); - let chunk_id = *handle_chunk.id(); + log::info!( + "biggest free slice {} MB requested allocation {} MB", + (self.find_biggest_free_slice() as f64) / (1024.0 * 1024.0), + (size as f64) / (1024.0 * 1024.0) + ); + if size < MIN_SIZE_NEEDED_TO_OFFSET { + let handle_chunk = self.create_chunk(size); + let chunk_id = *handle_chunk.id(); - let slice = self.create_slice(0, size, handle_chunk); - let handle_slice = slice.handle.clone(); + let slice = self.create_slice(0, size, handle_chunk); + let handle_slice = slice.handle.clone(); - let chunk = self.chunks.get_mut(&chunk_id).unwrap(); - chunk.slices.push(*handle_slice.id()); + let chunk = self.chunks.get_mut(&chunk_id).unwrap(); + chunk.slices.push(*handle_slice.id()); - self.slices.insert(*handle_slice.id(), slice); + self.slices.insert(*handle_slice.id(), slice); - DynamicHandle::Slice(handle_slice) + return DynamicHandle::Slice(handle_slice); + } else { + let chunk_size = Self::round_up_chunk(size); + + let handle_chunk = self.create_chunk(chunk_size); + let chunk_id = *handle_chunk.id(); + + let slice = self.create_slice(0, size, handle_chunk.clone()); + let first_slice_size = slice.effective_size(); + let handle_slice = slice.handle.clone(); + let chunk = self.chunks.get_mut(&chunk_id).unwrap(); + chunk.slices.push(*handle_slice.id()); + self.slices.insert(*handle_slice.id(), slice); + + let second_slice_size = chunk_size - first_slice_size; + assert_eq!(second_slice_size % BUFFER_ALIGNMENT, 0); + if second_slice_size >= BUFFER_ALIGNMENT { + let second_slice = + self.create_slice(first_slice_size, second_slice_size, handle_chunk); + let handle_slice = second_slice.handle.clone(); + let chunk = self.chunks.get_mut(&chunk_id).unwrap(); + chunk.slices.push(*handle_slice.id()); + self.slices.insert(*handle_slice.id(), second_slice); + } + + DynamicHandle::Slice(handle_slice) + } } fn dealloc(&mut self, binding: Self::Binding) { @@ -480,8 +520,11 @@ impl DynamicMemoryManagement { slices_ids.push(*slice_id); num_merge += 1; offset += slice.effective_size(); - continue; - } else if num_merge > 1 { + if i < chunk.slices.len() - 1 { + continue; + } + } + if num_merge > 1 { let mut empty = Vec::new(); core::mem::swap(&mut slices_ids, &mut empty); let merging = Merging { @@ -535,6 +578,35 @@ impl DynamicMemoryManagement { core::mem::swap(&mut chunk.slices, &mut slices_updated); } + fn calculate_used_memory(&self) -> u64 { + let mut used_memory = 0; + for (.., slice) in self.slices.iter() { + if !slice.handle.is_free() { + used_memory += slice.storage.size() as u64; + } + } + used_memory + } + + fn calculate_total_memory(&self) -> u64 { + let mut total_memory = 0; + for (.., chunk) in self.chunks.iter() { + total_memory += chunk.storage.size() as u64; + } + total_memory + } + + fn find_biggest_free_slice(&self) -> usize { + let mut current_biggest_slice: usize = 0; + for (.., slice) in self.slices.iter() { + if slice.handle.is_free() { + current_biggest_slice = + std::cmp::max(current_biggest_slice, slice.effective_size()); + } + } + current_biggest_slice + } + // Merge all contiguous free_slices together, assumes that slices are in sorted order. fn defragmentation(&mut self) { let mut chunk_to_merged_slice: HashMap> = HashMap::new(); @@ -555,13 +627,17 @@ impl DynamicMemoryManagement { 0 } } + + fn round_up_chunk(size: usize) -> usize { + return ((size + (CHUNK_ROUNDING - 1)) / CHUNK_ROUNDING) * CHUNK_ROUNDING; + } } #[cfg(test)] mod tests { use super::*; use crate::{ - memory_management::{MemoryHandle, MemoryManagement}, + memory_management::{self, MemoryHandle, MemoryManagement}, storage::BytesStorage, }; @@ -599,6 +675,72 @@ mod tests { assert!(!x.can_mut()) } + #[test] + fn when_size_not_big_enough_to_offset_should_alloc_not_rounded_chunk() { + let alloc_size = MIN_SIZE_NEEDED_TO_OFFSET - 1; + let mut memory_management = DynamicMemoryManagement::new( + BytesStorage::default(), + MergingStrategy::Never, + SliceStrategy::Never, + ); + + memory_management.alloc(alloc_size); + + assert_eq!(memory_management.chunks.len(), 1); + assert_eq!(memory_management.slices.len(), 1); + for (.., slice) in memory_management.slices.iter() { + assert_eq!(slice.storage.size(), alloc_size); + } + for (.., chunk) in memory_management.chunks.iter() { + assert_eq!(chunk.storage.size(), BUFFER_ALIGNMENT); + } + } + + #[test] + fn when_alloc_is_too_big_for_second_slice_should_only_have_one_slice() { + let alloc_size = CHUNK_ROUNDING - 2; + let mut memory_management = DynamicMemoryManagement::new( + BytesStorage::default(), + MergingStrategy::Never, + SliceStrategy::Never, + ); + + memory_management.alloc(alloc_size); + assert_eq!(memory_management.chunks.len(), 1); + assert_eq!(memory_management.slices.len(), 1); + for (.., slice) in memory_management.slices.iter() { + assert_eq!(slice.storage.size(), alloc_size); + } + for (.., chunk) in memory_management.chunks.iter() { + assert_eq!(chunk.storage.size(), CHUNK_ROUNDING); + } + } + + #[test] + fn when_alloc_is_small_enough_for_second_slice_should_have_two_slice() { + let alloc_size = MIN_SIZE_NEEDED_TO_OFFSET; + let padded_first_slice_size = + DynamicMemoryManagement::::calculate_padding(alloc_size) + alloc_size; + let mut memory_management = DynamicMemoryManagement::new( + BytesStorage::default(), + MergingStrategy::Never, + SliceStrategy::Never, + ); + + memory_management.alloc(alloc_size); + assert_eq!(memory_management.chunks.len(), 1); + assert_eq!(memory_management.slices.len(), 2); + for (.., slice) in memory_management.slices.iter() { + assert!( + slice.storage.size() == alloc_size + || slice.storage.size() == CHUNK_ROUNDING - padded_first_slice_size + ); + } + for (.., chunk) in memory_management.chunks.iter() { + assert_eq!(chunk.storage.size(), CHUNK_ROUNDING); + } + } + #[test] fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() { let mut memory_management = DynamicMemoryManagement::new( @@ -618,21 +760,22 @@ mod tests { let mut memory_management = DynamicMemoryManagement::new( BytesStorage::default(), MergingStrategy::Never, - SliceStrategy::Ratio(0.2), + SliceStrategy::MinimumSize(0), ); - let big_slice_size = 32 * 3; - let small_slice_size = 32; + const NUMBER_OF_USED_SLICES: usize = 3; + const BIG_SLICE_SIZE: usize = CHUNK_ROUNDING - BUFFER_ALIGNMENT * NUMBER_OF_USED_SLICES; - let big_slice = memory_management.reserve(big_slice_size); - drop(big_slice); - let _small_slice_1 = memory_management.reserve(small_slice_size); - let _small_slice_2 = memory_management.reserve(small_slice_size); - let _small_slice_3 = memory_management.reserve(small_slice_size); + let first_slice_to_be_dropped = memory_management.reserve(BUFFER_ALIGNMENT); + drop(first_slice_to_be_dropped); + let mut slice_holder = Vec::new(); + for _ in 0..NUMBER_OF_USED_SLICES { + slice_holder.push(memory_management.reserve(BUFFER_ALIGNMENT)); + } assert_eq!(memory_management.chunks.len(), 1); - assert_eq!(memory_management.slices.len(), 3); + assert_eq!(memory_management.slices.len(), NUMBER_OF_USED_SLICES + 1); for (.., slice) in memory_management.slices.iter() { - assert_eq!(slice.storage.size(), 32); + assert!(slice.storage.size() == 32 || slice.storage.size() == BIG_SLICE_SIZE); } } @@ -796,14 +939,15 @@ mod tests { let mut memory_management = DynamicMemoryManagement::new( BytesStorage::default(), MergingStrategy::Never, - SliceStrategy::Ratio(0.2), + SliceStrategy::MinimumSize(0), ); let handle = memory_management.reserve(100); core::mem::drop(handle); - let _slice_1 = memory_management.reserve(30); - let _slice_2 = memory_management.reserve(30); - let _slice_3 = memory_management.reserve(30); + let _slice_1 = memory_management.reserve(BUFFER_ALIGNMENT); + let _slice_2 = memory_management.reserve(BUFFER_ALIGNMENT); + let _slice_3 = memory_management.reserve(BUFFER_ALIGNMENT); + memory_management.defragmentation(); assert_eq!(memory_management.chunks.len(), 1); assert_eq!(memory_management.slices.len(), 4);