mirror of https://github.com/tracel-ai/burn.git
Feat/fixed chunk alloc by class (#1960)
This commit is contained in:
parent
d696d74e3d
commit
f5be04f44b
|
@ -9,6 +9,7 @@ use super::MemoryManagement;
|
|||
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
|
||||
pub struct DynamicMemoryManagement<Storage> {
|
||||
small_memory_pool: SmallMemoryPool,
|
||||
small_medium_memory_pool: MemoryPool,
|
||||
medium_memory_pool: MemoryPool,
|
||||
main_memory_pool: MemoryPool,
|
||||
storage: Storage,
|
||||
|
@ -19,17 +20,20 @@ impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
|
|||
pub fn new(storage: Storage) -> Self {
|
||||
let main_memory_pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::new_period_tick(10),
|
||||
RoundingStrategy::RoundUp,
|
||||
1024 * 1024 * 1024 * 2,
|
||||
RoundingStrategy::FixedAmount(1024 * 1024 * 1024),
|
||||
);
|
||||
let medium_memory_pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::Never,
|
||||
RoundingStrategy::None,
|
||||
1024 * 1024 * 512,
|
||||
RoundingStrategy::FixedAmount(1024 * 1024 * 200),
|
||||
);
|
||||
let small_medium_memory_pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::Never,
|
||||
RoundingStrategy::FixedAmount(1024 * 1024 * 2),
|
||||
);
|
||||
let small_memory_pool = SmallMemoryPool::new();
|
||||
Self {
|
||||
small_memory_pool,
|
||||
small_medium_memory_pool,
|
||||
main_memory_pool,
|
||||
medium_memory_pool,
|
||||
storage,
|
||||
|
@ -58,6 +62,13 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
|
|||
return handle;
|
||||
}
|
||||
|
||||
if let Some(handle) = self
|
||||
.small_medium_memory_pool
|
||||
.get(&mut self.storage, &binding)
|
||||
{
|
||||
return handle;
|
||||
}
|
||||
|
||||
if let Some(handle) = self.medium_memory_pool.get(&mut self.storage, &binding) {
|
||||
return handle;
|
||||
}
|
||||
|
@ -73,7 +84,10 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
|
|||
if size <= 32 {
|
||||
self.small_memory_pool
|
||||
.reserve(&mut self.storage, size, sync)
|
||||
} else if size < 512 {
|
||||
} else if size <= 2 * 1024 * 1024 {
|
||||
self.small_medium_memory_pool
|
||||
.reserve(&mut self.storage, size, sync)
|
||||
} else if size < 200 * 1024 * 1024 {
|
||||
self.medium_memory_pool
|
||||
.reserve(&mut self.storage, size, sync)
|
||||
} else {
|
||||
|
@ -84,7 +98,10 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
|
|||
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
|
||||
if size <= 32 {
|
||||
self.small_memory_pool.alloc(&mut self.storage, size, sync)
|
||||
} else if size < 512 {
|
||||
} else if size <= 2 * 1024 * 1024 {
|
||||
self.small_medium_memory_pool
|
||||
.alloc(&mut self.storage, size, sync)
|
||||
} else if size <= 200 * 1024 * 1024 {
|
||||
self.medium_memory_pool.alloc(&mut self.storage, size, sync)
|
||||
} else {
|
||||
self.main_memory_pool.alloc(&mut self.storage, size, sync)
|
||||
|
|
|
@ -4,6 +4,7 @@ use super::{
|
|||
RingBuffer, SliceHandle, SliceId,
|
||||
};
|
||||
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
|
||||
use alloc::vec::Vec;
|
||||
use hashbrown::{HashMap, HashSet};
|
||||
|
||||
pub struct MemoryPool {
|
||||
|
@ -13,9 +14,14 @@ pub struct MemoryPool {
|
|||
memory_extension_strategy: MemoryExtensionStrategy,
|
||||
rounding: RoundingStrategy,
|
||||
chunk_index: SearchIndex<ChunkId>,
|
||||
#[allow(unused)] // will be used when we rewrite memory extension
|
||||
max_chunk_size: usize,
|
||||
ring: RingBuffer<Chunk, Slice>,
|
||||
recently_added_chunks: Vec<ChunkId>,
|
||||
recently_allocated_size: usize,
|
||||
}
|
||||
|
||||
struct SliceUpdate {
|
||||
slice_id: SliceId,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
|
@ -78,8 +84,15 @@ impl MemoryPage {
|
|||
slice_id.copied()
|
||||
}
|
||||
|
||||
fn insert_slice(&mut self, address: usize, slice: &Slice) {
|
||||
self.slices.insert(address, slice.id());
|
||||
fn insert_slice(&mut self, address: usize, slice: SliceId) {
|
||||
self.slices.insert(address, slice);
|
||||
}
|
||||
|
||||
fn slices_sorted_by_address(&self) -> Vec<SliceId> {
|
||||
let mut entries: Vec<(usize, SliceId)> = self.slices.clone().into_iter().collect();
|
||||
entries.sort_by_key(|&(key, _)| key);
|
||||
let sorted_slices: Vec<SliceId> = entries.into_iter().map(|(_, values)| values).collect();
|
||||
sorted_slices
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,7 +115,10 @@ const BUFFER_ALIGNMENT: usize = 32;
|
|||
const MB: usize = 1024 * 1024;
|
||||
|
||||
pub enum RoundingStrategy {
|
||||
FixedAmount(usize),
|
||||
#[allow(unused)]
|
||||
RoundUp,
|
||||
#[allow(unused)]
|
||||
None,
|
||||
}
|
||||
|
||||
|
@ -122,6 +138,10 @@ impl RoundingStrategy {
|
|||
factor * 2 * MB
|
||||
}
|
||||
}
|
||||
RoundingStrategy::FixedAmount(chunk_size) => {
|
||||
assert!(*chunk_size >= size);
|
||||
*chunk_size
|
||||
}
|
||||
RoundingStrategy::None => size,
|
||||
}
|
||||
}
|
||||
|
@ -163,16 +183,16 @@ impl MemoryPool {
|
|||
pub fn new(
|
||||
merging_strategy: MemoryExtensionStrategy,
|
||||
alloc_strategy: RoundingStrategy,
|
||||
max_chunk_size: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
slices: HashMap::new(),
|
||||
memory_extension_strategy: merging_strategy,
|
||||
rounding: alloc_strategy,
|
||||
max_chunk_size,
|
||||
chunk_index: SearchIndex::new(),
|
||||
ring: RingBuffer::new(),
|
||||
recently_added_chunks: Vec::new(),
|
||||
recently_allocated_size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -212,12 +232,8 @@ impl MemoryPool {
|
|||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
_sync: Sync,
|
||||
#[allow(unused)] sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
if let Some(handle) = self.get_free_slice(size) {
|
||||
return MemoryPoolHandle { slice: handle };
|
||||
}
|
||||
|
||||
let alloc_size = self.rounding.alloc_size(size);
|
||||
self.alloc_slice(storage, alloc_size, size)
|
||||
}
|
||||
|
@ -228,10 +244,15 @@ impl MemoryPool {
|
|||
alloc_size: usize,
|
||||
slice_size: usize,
|
||||
) -> MemoryPoolHandle {
|
||||
let handle_chunk = self.create_chunk(storage, alloc_size);
|
||||
let chunk_size = self.rounding.alloc_size(alloc_size);
|
||||
let handle_chunk = self.create_chunk(storage, chunk_size);
|
||||
let chunk_size = self.chunks.get(handle_chunk.id()).unwrap().storage.size();
|
||||
self.recently_added_chunks.push(*handle_chunk.id());
|
||||
self.recently_allocated_size += chunk_size;
|
||||
|
||||
let chunk_id = *handle_chunk.id();
|
||||
let (slice, extra_slice) =
|
||||
self.allocate_slices(handle_chunk.clone(), alloc_size, slice_size);
|
||||
self.allocate_slices(handle_chunk.clone(), chunk_size, slice_size);
|
||||
|
||||
let handle_slice = slice.handle.clone();
|
||||
self.update_chunk_metadata(chunk_id, slice, extra_slice);
|
||||
|
@ -386,13 +407,101 @@ impl MemoryPool {
|
|||
|
||||
handle
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn extend_max_memory<Storage: ComputeStorage>(&mut self, storage: &mut Storage) {
|
||||
let mut slices = Vec::<SliceUpdate>::new();
|
||||
|
||||
let mut deallocations = HashSet::<ChunkId>::new();
|
||||
|
||||
let mut chunks_total_size: usize = 0;
|
||||
|
||||
for chunk_id in &self.recently_added_chunks {
|
||||
let chunk = self.chunks.get(chunk_id).unwrap();
|
||||
let chunk_id = *chunk.handle.id();
|
||||
let sorted_slice = chunk.slices.slices_sorted_by_address();
|
||||
for slice_id in sorted_slice {
|
||||
let slice = self.slices.get(&slice_id).unwrap();
|
||||
let size = slice.storage.size();
|
||||
|
||||
slices.push(SliceUpdate { slice_id, size });
|
||||
}
|
||||
chunks_total_size += chunk.storage.size();
|
||||
deallocations.insert(chunk_id);
|
||||
}
|
||||
|
||||
if !slices.is_empty() {
|
||||
self.move_to_new_chunk(chunks_total_size, storage, &mut slices, &mut deallocations);
|
||||
} else {
|
||||
self.deallocate(storage, &mut deallocations);
|
||||
}
|
||||
}
|
||||
|
||||
fn deallocate<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
_storage: &mut Storage,
|
||||
_deallocations: &mut HashSet<ChunkId>,
|
||||
storage: &mut Storage,
|
||||
deallocations: &mut HashSet<ChunkId>,
|
||||
) {
|
||||
todo!()
|
||||
for id in deallocations.drain() {
|
||||
let mut chunk = self.chunks.remove(&id).unwrap();
|
||||
self.ring.remove_chunk(id);
|
||||
|
||||
for (_address, slice_id) in chunk.slices.slices.drain() {
|
||||
let slice = self.slices.get(&slice_id).unwrap();
|
||||
let chunk_id = *slice.chunk.id();
|
||||
|
||||
assert_ne!(chunk_id, id, "Chunk id should be updated");
|
||||
}
|
||||
|
||||
self.chunk_index.remove(&id);
|
||||
storage.dealloc(chunk.storage.id);
|
||||
}
|
||||
}
|
||||
|
||||
fn move_to_new_chunk<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
alloc_size: usize,
|
||||
storage: &mut Storage,
|
||||
slices: &mut Vec<SliceUpdate>,
|
||||
deallocations: &mut HashSet<ChunkId>,
|
||||
) {
|
||||
let chunk = self.create_chunk(storage, alloc_size);
|
||||
let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone();
|
||||
let mut offset = 0;
|
||||
let mut slices_ids: Vec<(usize, SliceId)> = Vec::new();
|
||||
|
||||
for update in slices.drain(..) {
|
||||
let slice_id = update.slice_id;
|
||||
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
let old_storage = slice.storage.clone();
|
||||
|
||||
slice.chunk = chunk.clone();
|
||||
slice.storage = StorageHandle {
|
||||
id: storage_id.clone(),
|
||||
utilization: StorageUtilization::Slice {
|
||||
offset,
|
||||
size: update.size,
|
||||
},
|
||||
};
|
||||
storage.copy(&old_storage, &slice.storage);
|
||||
slices_ids.push((offset, slice_id));
|
||||
offset += slice.effective_size();
|
||||
}
|
||||
|
||||
let chunk = self.chunks.get_mut(chunk.id()).unwrap();
|
||||
let chunk_handle = chunk.handle.clone();
|
||||
for (address, slice_id) in slices_ids.drain(..) {
|
||||
chunk.slices.insert_slice(address, slice_id);
|
||||
}
|
||||
let chunk_size = chunk.storage.size();
|
||||
let last_slice_size = chunk_size - offset;
|
||||
assert_eq!(last_slice_size % BUFFER_ALIGNMENT, 0);
|
||||
if last_slice_size != 0 {
|
||||
self.create_slice(offset, last_slice_size, chunk_handle);
|
||||
}
|
||||
|
||||
self.deallocate(storage, deallocations);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -482,7 +591,7 @@ impl MemoryChunk<Slice> for Chunk {
|
|||
slice: Slice,
|
||||
slices: &mut HashMap<SliceId, Slice>,
|
||||
) {
|
||||
self.slices.insert_slice(position, &slice);
|
||||
self.slices.insert_slice(position, slice.id());
|
||||
slices.insert(slice.id(), slice);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,6 @@ impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
|
|||
self.chunk_positions.insert(chunk_id, self.queue.len() - 1);
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn remove_chunk(&mut self, chunk_id: ChunkId) {
|
||||
if let Some(position) = self.chunk_positions.remove(&chunk_id) {
|
||||
self.queue.remove(position);
|
||||
|
|
Loading…
Reference in New Issue