mirror of https://github.com/tracel-ai/burn.git
working memory extension, but not fast
This commit is contained in:
parent
9e49cc9e58
commit
a09ceab28b
|
@ -16,6 +16,13 @@ pub struct MemoryPool {
|
|||
#[allow(unused)] // will be used when we rewrite memory extension
|
||||
max_chunk_size: usize,
|
||||
ring: RingBuffer<Chunk, Slice>,
|
||||
recently_added_chunks: Vec<ChunkId>,
|
||||
recently_alloced_size: usize,
|
||||
}
|
||||
|
||||
struct SliceUpdate {
|
||||
slice_id: SliceId,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
|
@ -78,8 +85,15 @@ impl MemoryPage {
|
|||
slice_id.copied()
|
||||
}
|
||||
|
||||
fn insert_slice(&mut self, address: usize, slice: &Slice) {
|
||||
self.slices.insert(address, slice.id());
|
||||
fn insert_slice(&mut self, address: usize, slice: SliceId) {
|
||||
self.slices.insert(address, slice);
|
||||
}
|
||||
|
||||
fn slices_sorted_by_address(&self) -> Vec<SliceId> {
|
||||
let mut entries: Vec<(usize, SliceId)> = self.slices.clone().into_iter().collect();
|
||||
entries.sort_by_key(|&(key, _)| key);
|
||||
let sorted_slices: Vec<SliceId> = entries.into_iter().map(|(_, values)| values).collect();
|
||||
sorted_slices
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -173,6 +187,8 @@ impl MemoryPool {
|
|||
max_chunk_size,
|
||||
chunk_index: SearchIndex::new(),
|
||||
ring: RingBuffer::new(),
|
||||
recently_added_chunks: Vec::new(),
|
||||
recently_alloced_size: 0,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -198,6 +214,11 @@ impl MemoryPool {
|
|||
size: usize,
|
||||
sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
if self.recently_alloced_size > (0.8 * self.max_chunk_size as f64) as usize {
|
||||
self.extend_max_memory(storage);
|
||||
self.recently_added_chunks = Vec::new();
|
||||
self.recently_alloced_size = 0;
|
||||
}
|
||||
let slice = self.get_free_slice(size);
|
||||
|
||||
match slice {
|
||||
|
@ -212,12 +233,15 @@ impl MemoryPool {
|
|||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
_sync: Sync,
|
||||
sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
if self.recently_alloced_size > (0.8 * self.max_chunk_size as f64) as usize {
|
||||
sync();
|
||||
self.extend_max_memory(storage);
|
||||
if let Some(handle) = self.get_free_slice(size) {
|
||||
return MemoryPoolHandle { slice: handle };
|
||||
}
|
||||
|
||||
}
|
||||
let alloc_size = self.rounding.alloc_size(size);
|
||||
self.alloc_slice(storage, alloc_size, size)
|
||||
}
|
||||
|
@ -229,6 +253,10 @@ impl MemoryPool {
|
|||
slice_size: usize,
|
||||
) -> MemoryPoolHandle {
|
||||
let handle_chunk = self.create_chunk(storage, alloc_size);
|
||||
let chunk_size = self.chunks.get(handle_chunk.id()).unwrap().storage.size();
|
||||
self.recently_added_chunks.push(*handle_chunk.id());
|
||||
self.recently_alloced_size += chunk_size;
|
||||
|
||||
let chunk_id = *handle_chunk.id();
|
||||
let (slice, extra_slice) =
|
||||
self.allocate_slices(handle_chunk.clone(), alloc_size, slice_size);
|
||||
|
@ -391,13 +419,103 @@ impl MemoryPool {
|
|||
|
||||
handle
|
||||
}
|
||||
#[allow(unused)]
|
||||
|
||||
fn extend_max_memory<Storage: ComputeStorage>(&mut self, storage: &mut Storage) {
|
||||
log::info!("Extend max memory ...");
|
||||
|
||||
let mut slices = Vec::<SliceUpdate>::new();
|
||||
|
||||
let mut deallocations = HashSet::<ChunkId>::new();
|
||||
|
||||
let mut chunks_total_size: usize = 0;
|
||||
|
||||
for chunk_id in &self.recently_added_chunks {
|
||||
let chunk = self.chunks.get(chunk_id).unwrap();
|
||||
let chunk_id = *chunk.handle.id();
|
||||
let sorted_slice = chunk.slices.slices_sorted_by_address();
|
||||
for slice_id in sorted_slice {
|
||||
let slice = self.slices.get(&slice_id).unwrap();
|
||||
let size = slice.storage.size();
|
||||
|
||||
slices.push(SliceUpdate { slice_id, size });
|
||||
}
|
||||
chunks_total_size += chunk.storage.size();
|
||||
deallocations.insert(chunk_id);
|
||||
}
|
||||
|
||||
if !slices.is_empty() {
|
||||
self.move_to_new_chunk(chunks_total_size, storage, &mut slices, &mut deallocations);
|
||||
} else {
|
||||
self.deallocate(storage, &mut deallocations);
|
||||
}
|
||||
}
|
||||
|
||||
fn deallocate<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
_storage: &mut Storage,
|
||||
_deallocations: &mut HashSet<ChunkId>,
|
||||
storage: &mut Storage,
|
||||
deallocations: &mut HashSet<ChunkId>,
|
||||
) {
|
||||
todo!()
|
||||
for id in deallocations.drain() {
|
||||
let mut chunk = self.chunks.remove(&id).unwrap();
|
||||
self.ring.remove_chunk(id);
|
||||
|
||||
for (_address, slice_id) in chunk.slices.slices.drain() {
|
||||
let slice = self.slices.get(&slice_id).unwrap();
|
||||
let chunk_id = *slice.chunk.id();
|
||||
|
||||
assert_ne!(chunk_id, id, "Chunk id should be updated");
|
||||
}
|
||||
|
||||
self.chunk_index.remove(&id);
|
||||
storage.dealloc(chunk.storage.id);
|
||||
}
|
||||
}
|
||||
|
||||
fn move_to_new_chunk<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
alloc_size: usize,
|
||||
storage: &mut Storage,
|
||||
slices: &mut Vec<SliceUpdate>,
|
||||
deallocations: &mut HashSet<ChunkId>,
|
||||
) {
|
||||
let chunk = self.create_chunk(storage, alloc_size);
|
||||
let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone();
|
||||
let mut offset = 0;
|
||||
let mut slices_ids: Vec<(usize, SliceId)> = Vec::new();
|
||||
|
||||
for update in slices.drain(..) {
|
||||
let slice_id = update.slice_id;
|
||||
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
let old_storage = slice.storage.clone();
|
||||
|
||||
slice.chunk = chunk.clone();
|
||||
slice.storage = StorageHandle {
|
||||
id: storage_id.clone(),
|
||||
utilization: StorageUtilization::Slice {
|
||||
offset,
|
||||
size: update.size,
|
||||
},
|
||||
};
|
||||
storage.copy(&old_storage, &slice.storage);
|
||||
slices_ids.push((offset, slice_id));
|
||||
log::info!("slice with id {:?} and offset {offset}", slice_id);
|
||||
offset += slice.effective_size();
|
||||
}
|
||||
|
||||
let chunk = self.chunks.get_mut(chunk.id()).unwrap();
|
||||
let chunk_handle = chunk.handle.clone();
|
||||
for (address, slice_id) in slices_ids.drain(..) {
|
||||
chunk.slices.insert_slice(address, slice_id);
|
||||
}
|
||||
let chunk_size = chunk.storage.size();
|
||||
let last_slice_size = chunk_size - offset;
|
||||
assert_eq!(last_slice_size % BUFFER_ALIGNMENT, 0);
|
||||
if last_slice_size != 0 {
|
||||
self.create_slice(offset, last_slice_size, chunk_handle);
|
||||
}
|
||||
|
||||
self.deallocate(storage, deallocations);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -487,7 +605,7 @@ impl MemoryChunk<Slice> for Chunk {
|
|||
slice: Slice,
|
||||
slices: &mut HashMap<SliceId, Slice>,
|
||||
) {
|
||||
self.slices.insert_slice(position, &slice);
|
||||
self.slices.insert_slice(position, slice.id());
|
||||
slices.insert(slice.id(), slice);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -46,7 +46,6 @@ impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
|
|||
self.chunk_positions.insert(chunk_id, self.queue.len() - 1);
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn remove_chunk(&mut self, chunk_id: ChunkId) {
|
||||
if let Some(position) = self.chunk_positions.remove(&chunk_id) {
|
||||
self.queue.remove(position);
|
||||
|
|
Loading…
Reference in New Issue