mirror of https://github.com/tracel-ai/burn.git
Feat/fixed chunk alloc by class (#1960)
This commit is contained in:
parent
d696d74e3d
commit
f5be04f44b
|
@ -9,6 +9,7 @@ use super::MemoryManagement;
|
||||||
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
|
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
|
||||||
pub struct DynamicMemoryManagement<Storage> {
|
pub struct DynamicMemoryManagement<Storage> {
|
||||||
small_memory_pool: SmallMemoryPool,
|
small_memory_pool: SmallMemoryPool,
|
||||||
|
small_medium_memory_pool: MemoryPool,
|
||||||
medium_memory_pool: MemoryPool,
|
medium_memory_pool: MemoryPool,
|
||||||
main_memory_pool: MemoryPool,
|
main_memory_pool: MemoryPool,
|
||||||
storage: Storage,
|
storage: Storage,
|
||||||
|
@ -19,17 +20,20 @@ impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
|
||||||
pub fn new(storage: Storage) -> Self {
|
pub fn new(storage: Storage) -> Self {
|
||||||
let main_memory_pool = MemoryPool::new(
|
let main_memory_pool = MemoryPool::new(
|
||||||
MemoryExtensionStrategy::new_period_tick(10),
|
MemoryExtensionStrategy::new_period_tick(10),
|
||||||
RoundingStrategy::RoundUp,
|
RoundingStrategy::FixedAmount(1024 * 1024 * 1024),
|
||||||
1024 * 1024 * 1024 * 2,
|
|
||||||
);
|
);
|
||||||
let medium_memory_pool = MemoryPool::new(
|
let medium_memory_pool = MemoryPool::new(
|
||||||
MemoryExtensionStrategy::Never,
|
MemoryExtensionStrategy::Never,
|
||||||
RoundingStrategy::None,
|
RoundingStrategy::FixedAmount(1024 * 1024 * 200),
|
||||||
1024 * 1024 * 512,
|
);
|
||||||
|
let small_medium_memory_pool = MemoryPool::new(
|
||||||
|
MemoryExtensionStrategy::Never,
|
||||||
|
RoundingStrategy::FixedAmount(1024 * 1024 * 2),
|
||||||
);
|
);
|
||||||
let small_memory_pool = SmallMemoryPool::new();
|
let small_memory_pool = SmallMemoryPool::new();
|
||||||
Self {
|
Self {
|
||||||
small_memory_pool,
|
small_memory_pool,
|
||||||
|
small_medium_memory_pool,
|
||||||
main_memory_pool,
|
main_memory_pool,
|
||||||
medium_memory_pool,
|
medium_memory_pool,
|
||||||
storage,
|
storage,
|
||||||
|
@ -58,6 +62,13 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
|
||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(handle) = self
|
||||||
|
.small_medium_memory_pool
|
||||||
|
.get(&mut self.storage, &binding)
|
||||||
|
{
|
||||||
|
return handle;
|
||||||
|
}
|
||||||
|
|
||||||
if let Some(handle) = self.medium_memory_pool.get(&mut self.storage, &binding) {
|
if let Some(handle) = self.medium_memory_pool.get(&mut self.storage, &binding) {
|
||||||
return handle;
|
return handle;
|
||||||
}
|
}
|
||||||
|
@ -73,7 +84,10 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
|
||||||
if size <= 32 {
|
if size <= 32 {
|
||||||
self.small_memory_pool
|
self.small_memory_pool
|
||||||
.reserve(&mut self.storage, size, sync)
|
.reserve(&mut self.storage, size, sync)
|
||||||
} else if size < 512 {
|
} else if size <= 2 * 1024 * 1024 {
|
||||||
|
self.small_medium_memory_pool
|
||||||
|
.reserve(&mut self.storage, size, sync)
|
||||||
|
} else if size < 200 * 1024 * 1024 {
|
||||||
self.medium_memory_pool
|
self.medium_memory_pool
|
||||||
.reserve(&mut self.storage, size, sync)
|
.reserve(&mut self.storage, size, sync)
|
||||||
} else {
|
} else {
|
||||||
|
@ -84,7 +98,10 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
|
||||||
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
|
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
|
||||||
if size <= 32 {
|
if size <= 32 {
|
||||||
self.small_memory_pool.alloc(&mut self.storage, size, sync)
|
self.small_memory_pool.alloc(&mut self.storage, size, sync)
|
||||||
} else if size < 512 {
|
} else if size <= 2 * 1024 * 1024 {
|
||||||
|
self.small_medium_memory_pool
|
||||||
|
.alloc(&mut self.storage, size, sync)
|
||||||
|
} else if size <= 200 * 1024 * 1024 {
|
||||||
self.medium_memory_pool.alloc(&mut self.storage, size, sync)
|
self.medium_memory_pool.alloc(&mut self.storage, size, sync)
|
||||||
} else {
|
} else {
|
||||||
self.main_memory_pool.alloc(&mut self.storage, size, sync)
|
self.main_memory_pool.alloc(&mut self.storage, size, sync)
|
||||||
|
|
|
@ -4,6 +4,7 @@ use super::{
|
||||||
RingBuffer, SliceHandle, SliceId,
|
RingBuffer, SliceHandle, SliceId,
|
||||||
};
|
};
|
||||||
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
|
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
|
||||||
|
use alloc::vec::Vec;
|
||||||
use hashbrown::{HashMap, HashSet};
|
use hashbrown::{HashMap, HashSet};
|
||||||
|
|
||||||
pub struct MemoryPool {
|
pub struct MemoryPool {
|
||||||
|
@ -13,9 +14,14 @@ pub struct MemoryPool {
|
||||||
memory_extension_strategy: MemoryExtensionStrategy,
|
memory_extension_strategy: MemoryExtensionStrategy,
|
||||||
rounding: RoundingStrategy,
|
rounding: RoundingStrategy,
|
||||||
chunk_index: SearchIndex<ChunkId>,
|
chunk_index: SearchIndex<ChunkId>,
|
||||||
#[allow(unused)] // will be used when we rewrite memory extension
|
|
||||||
max_chunk_size: usize,
|
|
||||||
ring: RingBuffer<Chunk, Slice>,
|
ring: RingBuffer<Chunk, Slice>,
|
||||||
|
recently_added_chunks: Vec<ChunkId>,
|
||||||
|
recently_allocated_size: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct SliceUpdate {
|
||||||
|
slice_id: SliceId,
|
||||||
|
size: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(new, Debug)]
|
#[derive(new, Debug)]
|
||||||
|
@ -78,8 +84,15 @@ impl MemoryPage {
|
||||||
slice_id.copied()
|
slice_id.copied()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn insert_slice(&mut self, address: usize, slice: &Slice) {
|
fn insert_slice(&mut self, address: usize, slice: SliceId) {
|
||||||
self.slices.insert(address, slice.id());
|
self.slices.insert(address, slice);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn slices_sorted_by_address(&self) -> Vec<SliceId> {
|
||||||
|
let mut entries: Vec<(usize, SliceId)> = self.slices.clone().into_iter().collect();
|
||||||
|
entries.sort_by_key(|&(key, _)| key);
|
||||||
|
let sorted_slices: Vec<SliceId> = entries.into_iter().map(|(_, values)| values).collect();
|
||||||
|
sorted_slices
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,7 +115,10 @@ const BUFFER_ALIGNMENT: usize = 32;
|
||||||
const MB: usize = 1024 * 1024;
|
const MB: usize = 1024 * 1024;
|
||||||
|
|
||||||
pub enum RoundingStrategy {
|
pub enum RoundingStrategy {
|
||||||
|
FixedAmount(usize),
|
||||||
|
#[allow(unused)]
|
||||||
RoundUp,
|
RoundUp,
|
||||||
|
#[allow(unused)]
|
||||||
None,
|
None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -122,6 +138,10 @@ impl RoundingStrategy {
|
||||||
factor * 2 * MB
|
factor * 2 * MB
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
RoundingStrategy::FixedAmount(chunk_size) => {
|
||||||
|
assert!(*chunk_size >= size);
|
||||||
|
*chunk_size
|
||||||
|
}
|
||||||
RoundingStrategy::None => size,
|
RoundingStrategy::None => size,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -163,16 +183,16 @@ impl MemoryPool {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
merging_strategy: MemoryExtensionStrategy,
|
merging_strategy: MemoryExtensionStrategy,
|
||||||
alloc_strategy: RoundingStrategy,
|
alloc_strategy: RoundingStrategy,
|
||||||
max_chunk_size: usize,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
chunks: HashMap::new(),
|
chunks: HashMap::new(),
|
||||||
slices: HashMap::new(),
|
slices: HashMap::new(),
|
||||||
memory_extension_strategy: merging_strategy,
|
memory_extension_strategy: merging_strategy,
|
||||||
rounding: alloc_strategy,
|
rounding: alloc_strategy,
|
||||||
max_chunk_size,
|
|
||||||
chunk_index: SearchIndex::new(),
|
chunk_index: SearchIndex::new(),
|
||||||
ring: RingBuffer::new(),
|
ring: RingBuffer::new(),
|
||||||
|
recently_added_chunks: Vec::new(),
|
||||||
|
recently_allocated_size: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -212,12 +232,8 @@ impl MemoryPool {
|
||||||
&mut self,
|
&mut self,
|
||||||
storage: &mut Storage,
|
storage: &mut Storage,
|
||||||
size: usize,
|
size: usize,
|
||||||
_sync: Sync,
|
#[allow(unused)] sync: Sync,
|
||||||
) -> MemoryPoolHandle {
|
) -> MemoryPoolHandle {
|
||||||
if let Some(handle) = self.get_free_slice(size) {
|
|
||||||
return MemoryPoolHandle { slice: handle };
|
|
||||||
}
|
|
||||||
|
|
||||||
let alloc_size = self.rounding.alloc_size(size);
|
let alloc_size = self.rounding.alloc_size(size);
|
||||||
self.alloc_slice(storage, alloc_size, size)
|
self.alloc_slice(storage, alloc_size, size)
|
||||||
}
|
}
|
||||||
|
@ -228,10 +244,15 @@ impl MemoryPool {
|
||||||
alloc_size: usize,
|
alloc_size: usize,
|
||||||
slice_size: usize,
|
slice_size: usize,
|
||||||
) -> MemoryPoolHandle {
|
) -> MemoryPoolHandle {
|
||||||
let handle_chunk = self.create_chunk(storage, alloc_size);
|
let chunk_size = self.rounding.alloc_size(alloc_size);
|
||||||
|
let handle_chunk = self.create_chunk(storage, chunk_size);
|
||||||
|
let chunk_size = self.chunks.get(handle_chunk.id()).unwrap().storage.size();
|
||||||
|
self.recently_added_chunks.push(*handle_chunk.id());
|
||||||
|
self.recently_allocated_size += chunk_size;
|
||||||
|
|
||||||
let chunk_id = *handle_chunk.id();
|
let chunk_id = *handle_chunk.id();
|
||||||
let (slice, extra_slice) =
|
let (slice, extra_slice) =
|
||||||
self.allocate_slices(handle_chunk.clone(), alloc_size, slice_size);
|
self.allocate_slices(handle_chunk.clone(), chunk_size, slice_size);
|
||||||
|
|
||||||
let handle_slice = slice.handle.clone();
|
let handle_slice = slice.handle.clone();
|
||||||
self.update_chunk_metadata(chunk_id, slice, extra_slice);
|
self.update_chunk_metadata(chunk_id, slice, extra_slice);
|
||||||
|
@ -386,13 +407,101 @@ impl MemoryPool {
|
||||||
|
|
||||||
handle
|
handle
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
|
fn extend_max_memory<Storage: ComputeStorage>(&mut self, storage: &mut Storage) {
|
||||||
|
let mut slices = Vec::<SliceUpdate>::new();
|
||||||
|
|
||||||
|
let mut deallocations = HashSet::<ChunkId>::new();
|
||||||
|
|
||||||
|
let mut chunks_total_size: usize = 0;
|
||||||
|
|
||||||
|
for chunk_id in &self.recently_added_chunks {
|
||||||
|
let chunk = self.chunks.get(chunk_id).unwrap();
|
||||||
|
let chunk_id = *chunk.handle.id();
|
||||||
|
let sorted_slice = chunk.slices.slices_sorted_by_address();
|
||||||
|
for slice_id in sorted_slice {
|
||||||
|
let slice = self.slices.get(&slice_id).unwrap();
|
||||||
|
let size = slice.storage.size();
|
||||||
|
|
||||||
|
slices.push(SliceUpdate { slice_id, size });
|
||||||
|
}
|
||||||
|
chunks_total_size += chunk.storage.size();
|
||||||
|
deallocations.insert(chunk_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.is_empty() {
|
||||||
|
self.move_to_new_chunk(chunks_total_size, storage, &mut slices, &mut deallocations);
|
||||||
|
} else {
|
||||||
|
self.deallocate(storage, &mut deallocations);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
fn deallocate<Storage: ComputeStorage>(
|
fn deallocate<Storage: ComputeStorage>(
|
||||||
&mut self,
|
&mut self,
|
||||||
_storage: &mut Storage,
|
storage: &mut Storage,
|
||||||
_deallocations: &mut HashSet<ChunkId>,
|
deallocations: &mut HashSet<ChunkId>,
|
||||||
) {
|
) {
|
||||||
todo!()
|
for id in deallocations.drain() {
|
||||||
|
let mut chunk = self.chunks.remove(&id).unwrap();
|
||||||
|
self.ring.remove_chunk(id);
|
||||||
|
|
||||||
|
for (_address, slice_id) in chunk.slices.slices.drain() {
|
||||||
|
let slice = self.slices.get(&slice_id).unwrap();
|
||||||
|
let chunk_id = *slice.chunk.id();
|
||||||
|
|
||||||
|
assert_ne!(chunk_id, id, "Chunk id should be updated");
|
||||||
|
}
|
||||||
|
|
||||||
|
self.chunk_index.remove(&id);
|
||||||
|
storage.dealloc(chunk.storage.id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn move_to_new_chunk<Storage: ComputeStorage>(
|
||||||
|
&mut self,
|
||||||
|
alloc_size: usize,
|
||||||
|
storage: &mut Storage,
|
||||||
|
slices: &mut Vec<SliceUpdate>,
|
||||||
|
deallocations: &mut HashSet<ChunkId>,
|
||||||
|
) {
|
||||||
|
let chunk = self.create_chunk(storage, alloc_size);
|
||||||
|
let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone();
|
||||||
|
let mut offset = 0;
|
||||||
|
let mut slices_ids: Vec<(usize, SliceId)> = Vec::new();
|
||||||
|
|
||||||
|
for update in slices.drain(..) {
|
||||||
|
let slice_id = update.slice_id;
|
||||||
|
|
||||||
|
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||||
|
let old_storage = slice.storage.clone();
|
||||||
|
|
||||||
|
slice.chunk = chunk.clone();
|
||||||
|
slice.storage = StorageHandle {
|
||||||
|
id: storage_id.clone(),
|
||||||
|
utilization: StorageUtilization::Slice {
|
||||||
|
offset,
|
||||||
|
size: update.size,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
storage.copy(&old_storage, &slice.storage);
|
||||||
|
slices_ids.push((offset, slice_id));
|
||||||
|
offset += slice.effective_size();
|
||||||
|
}
|
||||||
|
|
||||||
|
let chunk = self.chunks.get_mut(chunk.id()).unwrap();
|
||||||
|
let chunk_handle = chunk.handle.clone();
|
||||||
|
for (address, slice_id) in slices_ids.drain(..) {
|
||||||
|
chunk.slices.insert_slice(address, slice_id);
|
||||||
|
}
|
||||||
|
let chunk_size = chunk.storage.size();
|
||||||
|
let last_slice_size = chunk_size - offset;
|
||||||
|
assert_eq!(last_slice_size % BUFFER_ALIGNMENT, 0);
|
||||||
|
if last_slice_size != 0 {
|
||||||
|
self.create_slice(offset, last_slice_size, chunk_handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
self.deallocate(storage, deallocations);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -482,7 +591,7 @@ impl MemoryChunk<Slice> for Chunk {
|
||||||
slice: Slice,
|
slice: Slice,
|
||||||
slices: &mut HashMap<SliceId, Slice>,
|
slices: &mut HashMap<SliceId, Slice>,
|
||||||
) {
|
) {
|
||||||
self.slices.insert_slice(position, &slice);
|
self.slices.insert_slice(position, slice.id());
|
||||||
slices.insert(slice.id(), slice);
|
slices.insert(slice.id(), slice);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -46,7 +46,6 @@ impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
|
||||||
self.chunk_positions.insert(chunk_id, self.queue.len() - 1);
|
self.chunk_positions.insert(chunk_id, self.queue.len() - 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(unused)]
|
|
||||||
pub fn remove_chunk(&mut self, chunk_id: ChunkId) {
|
pub fn remove_chunk(&mut self, chunk_id: ChunkId) {
|
||||||
if let Some(position) = self.chunk_positions.remove(&chunk_id) {
|
if let Some(position) = self.chunk_positions.remove(&chunk_id) {
|
||||||
self.queue.remove(position);
|
self.queue.remove(position);
|
||||||
|
|
Loading…
Reference in New Issue