From f5be04f44b07f1e8665391384575deea9c4d44f1 Mon Sep 17 00:00:00 2001 From: mepatrick73 <114622680+mepatrick73@users.noreply.github.com> Date: Wed, 3 Jul 2024 08:06:05 -0400 Subject: [PATCH] Feat/fixed chunk alloc by class (#1960) --- .../src/memory_management/dynamic.rs | 29 +++- .../src/memory_management/memory_pool/base.rs | 143 +++++++++++++++--- .../src/memory_management/memory_pool/ring.rs | 1 - 3 files changed, 149 insertions(+), 24 deletions(-) diff --git a/crates/burn-compute/src/memory_management/dynamic.rs b/crates/burn-compute/src/memory_management/dynamic.rs index ea61de595..4e449c449 100644 --- a/crates/burn-compute/src/memory_management/dynamic.rs +++ b/crates/burn-compute/src/memory_management/dynamic.rs @@ -9,6 +9,7 @@ use super::MemoryManagement; /// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks. pub struct DynamicMemoryManagement { small_memory_pool: SmallMemoryPool, + small_medium_memory_pool: MemoryPool, medium_memory_pool: MemoryPool, main_memory_pool: MemoryPool, storage: Storage, @@ -19,17 +20,20 @@ impl DynamicMemoryManagement { pub fn new(storage: Storage) -> Self { let main_memory_pool = MemoryPool::new( MemoryExtensionStrategy::new_period_tick(10), - RoundingStrategy::RoundUp, - 1024 * 1024 * 1024 * 2, + RoundingStrategy::FixedAmount(1024 * 1024 * 1024), ); let medium_memory_pool = MemoryPool::new( MemoryExtensionStrategy::Never, - RoundingStrategy::None, - 1024 * 1024 * 512, + RoundingStrategy::FixedAmount(1024 * 1024 * 200), + ); + let small_medium_memory_pool = MemoryPool::new( + MemoryExtensionStrategy::Never, + RoundingStrategy::FixedAmount(1024 * 1024 * 2), ); let small_memory_pool = SmallMemoryPool::new(); Self { small_memory_pool, + small_medium_memory_pool, main_memory_pool, medium_memory_pool, storage, @@ -58,6 +62,13 @@ impl MemoryManagement for DynamicMemoryManagem return handle; } + if let Some(handle) = self + .small_medium_memory_pool + .get(&mut self.storage, &binding) + { + return handle; + } + if let Some(handle) = self.medium_memory_pool.get(&mut self.storage, &binding) { return handle; } @@ -73,7 +84,10 @@ impl MemoryManagement for DynamicMemoryManagem if size <= 32 { self.small_memory_pool .reserve(&mut self.storage, size, sync) - } else if size < 512 { + } else if size <= 2 * 1024 * 1024 { + self.small_medium_memory_pool + .reserve(&mut self.storage, size, sync) + } else if size < 200 * 1024 * 1024 { self.medium_memory_pool .reserve(&mut self.storage, size, sync) } else { @@ -84,7 +98,10 @@ impl MemoryManagement for DynamicMemoryManagem fn alloc(&mut self, size: usize, sync: Sync) -> Self::Handle { if size <= 32 { self.small_memory_pool.alloc(&mut self.storage, size, sync) - } else if size < 512 { + } else if size <= 2 * 1024 * 1024 { + self.small_medium_memory_pool + .alloc(&mut self.storage, size, sync) + } else if size <= 200 * 1024 * 1024 { self.medium_memory_pool.alloc(&mut self.storage, size, sync) } else { self.main_memory_pool.alloc(&mut self.storage, size, sync) diff --git a/crates/burn-compute/src/memory_management/memory_pool/base.rs b/crates/burn-compute/src/memory_management/memory_pool/base.rs index 6da5a66dc..9b94764ed 100644 --- a/crates/burn-compute/src/memory_management/memory_pool/base.rs +++ b/crates/burn-compute/src/memory_management/memory_pool/base.rs @@ -4,6 +4,7 @@ use super::{ RingBuffer, SliceHandle, SliceId, }; use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization}; +use alloc::vec::Vec; use hashbrown::{HashMap, HashSet}; pub struct MemoryPool { @@ -13,9 +14,14 @@ pub struct MemoryPool { memory_extension_strategy: MemoryExtensionStrategy, rounding: RoundingStrategy, chunk_index: SearchIndex, - #[allow(unused)] // will be used when we rewrite memory extension - max_chunk_size: usize, ring: RingBuffer, + recently_added_chunks: Vec, + recently_allocated_size: usize, +} + +struct SliceUpdate { + slice_id: SliceId, + size: usize, } #[derive(new, Debug)] @@ -78,8 +84,15 @@ impl MemoryPage { slice_id.copied() } - fn insert_slice(&mut self, address: usize, slice: &Slice) { - self.slices.insert(address, slice.id()); + fn insert_slice(&mut self, address: usize, slice: SliceId) { + self.slices.insert(address, slice); + } + + fn slices_sorted_by_address(&self) -> Vec { + let mut entries: Vec<(usize, SliceId)> = self.slices.clone().into_iter().collect(); + entries.sort_by_key(|&(key, _)| key); + let sorted_slices: Vec = entries.into_iter().map(|(_, values)| values).collect(); + sorted_slices } } @@ -102,7 +115,10 @@ const BUFFER_ALIGNMENT: usize = 32; const MB: usize = 1024 * 1024; pub enum RoundingStrategy { + FixedAmount(usize), + #[allow(unused)] RoundUp, + #[allow(unused)] None, } @@ -122,6 +138,10 @@ impl RoundingStrategy { factor * 2 * MB } } + RoundingStrategy::FixedAmount(chunk_size) => { + assert!(*chunk_size >= size); + *chunk_size + } RoundingStrategy::None => size, } } @@ -163,16 +183,16 @@ impl MemoryPool { pub fn new( merging_strategy: MemoryExtensionStrategy, alloc_strategy: RoundingStrategy, - max_chunk_size: usize, ) -> Self { Self { chunks: HashMap::new(), slices: HashMap::new(), memory_extension_strategy: merging_strategy, rounding: alloc_strategy, - max_chunk_size, chunk_index: SearchIndex::new(), ring: RingBuffer::new(), + recently_added_chunks: Vec::new(), + recently_allocated_size: 0, } } @@ -212,12 +232,8 @@ impl MemoryPool { &mut self, storage: &mut Storage, size: usize, - _sync: Sync, + #[allow(unused)] sync: Sync, ) -> MemoryPoolHandle { - if let Some(handle) = self.get_free_slice(size) { - return MemoryPoolHandle { slice: handle }; - } - let alloc_size = self.rounding.alloc_size(size); self.alloc_slice(storage, alloc_size, size) } @@ -228,10 +244,15 @@ impl MemoryPool { alloc_size: usize, slice_size: usize, ) -> MemoryPoolHandle { - let handle_chunk = self.create_chunk(storage, alloc_size); + let chunk_size = self.rounding.alloc_size(alloc_size); + let handle_chunk = self.create_chunk(storage, chunk_size); + let chunk_size = self.chunks.get(handle_chunk.id()).unwrap().storage.size(); + self.recently_added_chunks.push(*handle_chunk.id()); + self.recently_allocated_size += chunk_size; + let chunk_id = *handle_chunk.id(); let (slice, extra_slice) = - self.allocate_slices(handle_chunk.clone(), alloc_size, slice_size); + self.allocate_slices(handle_chunk.clone(), chunk_size, slice_size); let handle_slice = slice.handle.clone(); self.update_chunk_metadata(chunk_id, slice, extra_slice); @@ -386,13 +407,101 @@ impl MemoryPool { handle } + #[allow(unused)] + fn extend_max_memory(&mut self, storage: &mut Storage) { + let mut slices = Vec::::new(); + + let mut deallocations = HashSet::::new(); + + let mut chunks_total_size: usize = 0; + + for chunk_id in &self.recently_added_chunks { + let chunk = self.chunks.get(chunk_id).unwrap(); + let chunk_id = *chunk.handle.id(); + let sorted_slice = chunk.slices.slices_sorted_by_address(); + for slice_id in sorted_slice { + let slice = self.slices.get(&slice_id).unwrap(); + let size = slice.storage.size(); + + slices.push(SliceUpdate { slice_id, size }); + } + chunks_total_size += chunk.storage.size(); + deallocations.insert(chunk_id); + } + + if !slices.is_empty() { + self.move_to_new_chunk(chunks_total_size, storage, &mut slices, &mut deallocations); + } else { + self.deallocate(storage, &mut deallocations); + } + } + fn deallocate( &mut self, - _storage: &mut Storage, - _deallocations: &mut HashSet, + storage: &mut Storage, + deallocations: &mut HashSet, ) { - todo!() + for id in deallocations.drain() { + let mut chunk = self.chunks.remove(&id).unwrap(); + self.ring.remove_chunk(id); + + for (_address, slice_id) in chunk.slices.slices.drain() { + let slice = self.slices.get(&slice_id).unwrap(); + let chunk_id = *slice.chunk.id(); + + assert_ne!(chunk_id, id, "Chunk id should be updated"); + } + + self.chunk_index.remove(&id); + storage.dealloc(chunk.storage.id); + } + } + + fn move_to_new_chunk( + &mut self, + alloc_size: usize, + storage: &mut Storage, + slices: &mut Vec, + deallocations: &mut HashSet, + ) { + let chunk = self.create_chunk(storage, alloc_size); + let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone(); + let mut offset = 0; + let mut slices_ids: Vec<(usize, SliceId)> = Vec::new(); + + for update in slices.drain(..) { + let slice_id = update.slice_id; + + let slice = self.slices.get_mut(&slice_id).unwrap(); + let old_storage = slice.storage.clone(); + + slice.chunk = chunk.clone(); + slice.storage = StorageHandle { + id: storage_id.clone(), + utilization: StorageUtilization::Slice { + offset, + size: update.size, + }, + }; + storage.copy(&old_storage, &slice.storage); + slices_ids.push((offset, slice_id)); + offset += slice.effective_size(); + } + + let chunk = self.chunks.get_mut(chunk.id()).unwrap(); + let chunk_handle = chunk.handle.clone(); + for (address, slice_id) in slices_ids.drain(..) { + chunk.slices.insert_slice(address, slice_id); + } + let chunk_size = chunk.storage.size(); + let last_slice_size = chunk_size - offset; + assert_eq!(last_slice_size % BUFFER_ALIGNMENT, 0); + if last_slice_size != 0 { + self.create_slice(offset, last_slice_size, chunk_handle); + } + + self.deallocate(storage, deallocations); } } @@ -482,7 +591,7 @@ impl MemoryChunk for Chunk { slice: Slice, slices: &mut HashMap, ) { - self.slices.insert_slice(position, &slice); + self.slices.insert_slice(position, slice.id()); slices.insert(slice.id(), slice); } } diff --git a/crates/burn-compute/src/memory_management/memory_pool/ring.rs b/crates/burn-compute/src/memory_management/memory_pool/ring.rs index 55c28d1c6..df9cb8c5f 100644 --- a/crates/burn-compute/src/memory_management/memory_pool/ring.rs +++ b/crates/burn-compute/src/memory_management/memory_pool/ring.rs @@ -46,7 +46,6 @@ impl, S: MemorySlice> RingBuffer { self.chunk_positions.insert(chunk_id, self.queue.len() - 1); } - #[allow(unused)] pub fn remove_chunk(&mut self, chunk_id: ChunkId) { if let Some(position) = self.chunk_positions.remove(&chunk_id) { self.queue.remove(position);