mirror of https://github.com/tracel-ai/burn.git
feat: rounding-up + merging fix
implemented rounding up fixed a bug where if the a merging happened at the end of a chunk, it wouldn't be defragmented
This commit is contained in:
parent
b89691a359
commit
0ed66f627f
|
@ -144,6 +144,7 @@ impl Slice {
|
|||
|
||||
const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16;
|
||||
const BUFFER_ALIGNMENT: usize = 32;
|
||||
const CHUNK_ROUNDING: usize = 2 * 1024 * 1024;
|
||||
|
||||
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
|
||||
pub struct DynamicMemoryManagement<Storage> {
|
||||
|
@ -216,6 +217,13 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
|
|||
///
|
||||
/// Also clean ups, merging free slices together if permitted by the merging strategy
|
||||
fn reserve(&mut self, size: usize) -> Self::Handle {
|
||||
// log::info!("Number of chunks {}", self.chunks.len());
|
||||
// log::info!("Number of slices {}", self.slices.len());
|
||||
// log::info!(
|
||||
// "used memory {} , total memory {}",
|
||||
// self.calculate_used_memory(),
|
||||
// self.calculate_total_memory()
|
||||
// );
|
||||
let handle = self.reserve_algorithm(size);
|
||||
|
||||
if self.merging_strategy.should_perform_defragmentation() {
|
||||
|
@ -226,18 +234,50 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
|
|||
}
|
||||
|
||||
fn alloc(&mut self, size: usize) -> Self::Handle {
|
||||
let handle_chunk = self.create_chunk(size);
|
||||
let chunk_id = *handle_chunk.id();
|
||||
log::info!(
|
||||
"biggest free slice {} MB requested allocation {} MB",
|
||||
(self.find_biggest_free_slice() as f64) / (1024.0 * 1024.0),
|
||||
(size as f64) / (1024.0 * 1024.0)
|
||||
);
|
||||
if size < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
let handle_chunk = self.create_chunk(size);
|
||||
let chunk_id = *handle_chunk.id();
|
||||
|
||||
let slice = self.create_slice(0, size, handle_chunk);
|
||||
let handle_slice = slice.handle.clone();
|
||||
let slice = self.create_slice(0, size, handle_chunk);
|
||||
let handle_slice = slice.handle.clone();
|
||||
|
||||
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
|
||||
chunk.slices.push(*handle_slice.id());
|
||||
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
|
||||
chunk.slices.push(*handle_slice.id());
|
||||
|
||||
self.slices.insert(*handle_slice.id(), slice);
|
||||
self.slices.insert(*handle_slice.id(), slice);
|
||||
|
||||
DynamicHandle::Slice(handle_slice)
|
||||
return DynamicHandle::Slice(handle_slice);
|
||||
} else {
|
||||
let chunk_size = Self::round_up_chunk(size);
|
||||
|
||||
let handle_chunk = self.create_chunk(chunk_size);
|
||||
let chunk_id = *handle_chunk.id();
|
||||
|
||||
let slice = self.create_slice(0, size, handle_chunk.clone());
|
||||
let first_slice_size = slice.effective_size();
|
||||
let handle_slice = slice.handle.clone();
|
||||
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
|
||||
chunk.slices.push(*handle_slice.id());
|
||||
self.slices.insert(*handle_slice.id(), slice);
|
||||
|
||||
let second_slice_size = chunk_size - first_slice_size;
|
||||
assert_eq!(second_slice_size % BUFFER_ALIGNMENT, 0);
|
||||
if second_slice_size >= BUFFER_ALIGNMENT {
|
||||
let second_slice =
|
||||
self.create_slice(first_slice_size, second_slice_size, handle_chunk);
|
||||
let handle_slice = second_slice.handle.clone();
|
||||
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
|
||||
chunk.slices.push(*handle_slice.id());
|
||||
self.slices.insert(*handle_slice.id(), second_slice);
|
||||
}
|
||||
|
||||
DynamicHandle::Slice(handle_slice)
|
||||
}
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, binding: Self::Binding) {
|
||||
|
@ -480,8 +520,11 @@ impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
|
|||
slices_ids.push(*slice_id);
|
||||
num_merge += 1;
|
||||
offset += slice.effective_size();
|
||||
continue;
|
||||
} else if num_merge > 1 {
|
||||
if i < chunk.slices.len() - 1 {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if num_merge > 1 {
|
||||
let mut empty = Vec::new();
|
||||
core::mem::swap(&mut slices_ids, &mut empty);
|
||||
let merging = Merging {
|
||||
|
@ -535,6 +578,35 @@ impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
|
|||
core::mem::swap(&mut chunk.slices, &mut slices_updated);
|
||||
}
|
||||
|
||||
fn calculate_used_memory(&self) -> u64 {
|
||||
let mut used_memory = 0;
|
||||
for (.., slice) in self.slices.iter() {
|
||||
if !slice.handle.is_free() {
|
||||
used_memory += slice.storage.size() as u64;
|
||||
}
|
||||
}
|
||||
used_memory
|
||||
}
|
||||
|
||||
fn calculate_total_memory(&self) -> u64 {
|
||||
let mut total_memory = 0;
|
||||
for (.., chunk) in self.chunks.iter() {
|
||||
total_memory += chunk.storage.size() as u64;
|
||||
}
|
||||
total_memory
|
||||
}
|
||||
|
||||
fn find_biggest_free_slice(&self) -> usize {
|
||||
let mut current_biggest_slice: usize = 0;
|
||||
for (.., slice) in self.slices.iter() {
|
||||
if slice.handle.is_free() {
|
||||
current_biggest_slice =
|
||||
std::cmp::max(current_biggest_slice, slice.effective_size());
|
||||
}
|
||||
}
|
||||
current_biggest_slice
|
||||
}
|
||||
|
||||
// Merge all contiguous free_slices together, assumes that slices are in sorted order.
|
||||
fn defragmentation(&mut self) {
|
||||
let mut chunk_to_merged_slice: HashMap<ChunkId, Vec<Merging>> = HashMap::new();
|
||||
|
@ -555,13 +627,17 @@ impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
|
|||
0
|
||||
}
|
||||
}
|
||||
|
||||
fn round_up_chunk(size: usize) -> usize {
|
||||
return ((size + (CHUNK_ROUNDING - 1)) / CHUNK_ROUNDING) * CHUNK_ROUNDING;
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
memory_management::{MemoryHandle, MemoryManagement},
|
||||
memory_management::{self, MemoryHandle, MemoryManagement},
|
||||
storage::BytesStorage,
|
||||
};
|
||||
|
||||
|
@ -599,6 +675,72 @@ mod tests {
|
|||
assert!(!x.can_mut())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_size_not_big_enough_to_offset_should_alloc_not_rounded_chunk() {
|
||||
let alloc_size = MIN_SIZE_NEEDED_TO_OFFSET - 1;
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
|
||||
memory_management.alloc(alloc_size);
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 1);
|
||||
assert_eq!(memory_management.slices.len(), 1);
|
||||
for (.., slice) in memory_management.slices.iter() {
|
||||
assert_eq!(slice.storage.size(), alloc_size);
|
||||
}
|
||||
for (.., chunk) in memory_management.chunks.iter() {
|
||||
assert_eq!(chunk.storage.size(), BUFFER_ALIGNMENT);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_alloc_is_too_big_for_second_slice_should_only_have_one_slice() {
|
||||
let alloc_size = CHUNK_ROUNDING - 2;
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
|
||||
memory_management.alloc(alloc_size);
|
||||
assert_eq!(memory_management.chunks.len(), 1);
|
||||
assert_eq!(memory_management.slices.len(), 1);
|
||||
for (.., slice) in memory_management.slices.iter() {
|
||||
assert_eq!(slice.storage.size(), alloc_size);
|
||||
}
|
||||
for (.., chunk) in memory_management.chunks.iter() {
|
||||
assert_eq!(chunk.storage.size(), CHUNK_ROUNDING);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_alloc_is_small_enough_for_second_slice_should_have_two_slice() {
|
||||
let alloc_size = MIN_SIZE_NEEDED_TO_OFFSET;
|
||||
let padded_first_slice_size =
|
||||
DynamicMemoryManagement::<BytesStorage>::calculate_padding(alloc_size) + alloc_size;
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
|
||||
memory_management.alloc(alloc_size);
|
||||
assert_eq!(memory_management.chunks.len(), 1);
|
||||
assert_eq!(memory_management.slices.len(), 2);
|
||||
for (.., slice) in memory_management.slices.iter() {
|
||||
assert!(
|
||||
slice.storage.size() == alloc_size
|
||||
|| slice.storage.size() == CHUNK_ROUNDING - padded_first_slice_size
|
||||
);
|
||||
}
|
||||
for (.., chunk) in memory_management.chunks.iter() {
|
||||
assert_eq!(chunk.storage.size(), CHUNK_ROUNDING);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
|
@ -618,21 +760,22 @@ mod tests {
|
|||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Ratio(0.2),
|
||||
SliceStrategy::MinimumSize(0),
|
||||
);
|
||||
let big_slice_size = 32 * 3;
|
||||
let small_slice_size = 32;
|
||||
const NUMBER_OF_USED_SLICES: usize = 3;
|
||||
const BIG_SLICE_SIZE: usize = CHUNK_ROUNDING - BUFFER_ALIGNMENT * NUMBER_OF_USED_SLICES;
|
||||
|
||||
let big_slice = memory_management.reserve(big_slice_size);
|
||||
drop(big_slice);
|
||||
let _small_slice_1 = memory_management.reserve(small_slice_size);
|
||||
let _small_slice_2 = memory_management.reserve(small_slice_size);
|
||||
let _small_slice_3 = memory_management.reserve(small_slice_size);
|
||||
let first_slice_to_be_dropped = memory_management.reserve(BUFFER_ALIGNMENT);
|
||||
drop(first_slice_to_be_dropped);
|
||||
let mut slice_holder = Vec::new();
|
||||
for _ in 0..NUMBER_OF_USED_SLICES {
|
||||
slice_holder.push(memory_management.reserve(BUFFER_ALIGNMENT));
|
||||
}
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 1);
|
||||
assert_eq!(memory_management.slices.len(), 3);
|
||||
assert_eq!(memory_management.slices.len(), NUMBER_OF_USED_SLICES + 1);
|
||||
for (.., slice) in memory_management.slices.iter() {
|
||||
assert_eq!(slice.storage.size(), 32);
|
||||
assert!(slice.storage.size() == 32 || slice.storage.size() == BIG_SLICE_SIZE);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -796,14 +939,15 @@ mod tests {
|
|||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Ratio(0.2),
|
||||
SliceStrategy::MinimumSize(0),
|
||||
);
|
||||
let handle = memory_management.reserve(100);
|
||||
core::mem::drop(handle);
|
||||
|
||||
let _slice_1 = memory_management.reserve(30);
|
||||
let _slice_2 = memory_management.reserve(30);
|
||||
let _slice_3 = memory_management.reserve(30);
|
||||
let _slice_1 = memory_management.reserve(BUFFER_ALIGNMENT);
|
||||
let _slice_2 = memory_management.reserve(BUFFER_ALIGNMENT);
|
||||
let _slice_3 = memory_management.reserve(BUFFER_ALIGNMENT);
|
||||
memory_management.defragmentation();
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 1);
|
||||
assert_eq!(memory_management.slices.len(), 4);
|
||||
|
|
Loading…
Reference in New Issue