Perf/dynamic mm (#1906)

This commit is contained in:
Nathaniel Simard 2024-06-18 08:41:07 -04:00 committed by GitHub
parent 8071b637b8
commit 4f6db974a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 1285 additions and 864 deletions

View File

@ -97,23 +97,17 @@ where
#[macro_export(local_inner_macros)]
/// Create new memory ID types.
macro_rules! memory_id_type {
($id:ident, $handle:ident, $binding:ident) => {
($id:ident, $handle:ident) => {
/// Memory Handle.
#[derive(Clone, Debug)]
pub struct $handle {
value: $crate::id::HandleRef<$id>,
}
/// Binding of a memory handle.
#[derive(Clone, Debug)]
pub struct $binding {
value: $crate::id::BindingRef<$id>,
}
/// Memory ID.
#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
pub struct $id {
value: usize,
pub(crate) value: usize,
}
impl $handle {
@ -125,12 +119,6 @@ macro_rules! memory_id_type {
}
}
pub(crate) fn binding(self) -> $binding {
$binding {
value: self.value.binding(),
}
}
fn gen_id() -> usize {
static COUNTER: core::sync::atomic::AtomicUsize =
core::sync::atomic::AtomicUsize::new(0);
@ -152,6 +140,30 @@ macro_rules! memory_id_type {
}
}
impl Default for $handle {
fn default() -> Self {
Self::new()
}
}
};
($id:ident, $handle:ident, $binding:ident) => {
memory_id_type!($id, $handle);
/// Binding of a memory handle.
#[derive(Clone, Debug)]
pub struct $binding {
value: $crate::id::BindingRef<$id>,
}
impl $handle {
pub(crate) fn binding(self) -> $binding {
$binding {
value: self.value.binding(),
}
}
}
impl core::ops::Deref for $binding {
type Target = $crate::id::BindingRef<$id>;
@ -159,11 +171,5 @@ macro_rules! memory_id_type {
&self.value
}
}
impl Default for $handle {
fn default() -> Self {
Self::new()
}
}
};
}

View File

@ -27,14 +27,14 @@ pub trait MemoryManagement<Storage: ComputeStorage>: Send + core::fmt::Debug {
fn get(&mut self, binding: Self::Binding) -> Storage::Resource;
/// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it
fn reserve(&mut self, size: usize) -> Self::Handle;
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
/// Bypass the memory allocation algorithm to allocate data directly.
///
/// # Notes
///
/// Can be useful for servers that want specific control over memory.
fn alloc(&mut self, size: usize) -> Self::Handle;
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
/// Bypass the memory allocation algorithm to deallocate data directly.
///

View File

@ -1,165 +1,46 @@
use crate::{
memory_id_type,
storage::{ComputeStorage, StorageHandle, StorageUtilization},
use super::memory_pool::{
MemoryExtensionStrategy, MemoryPool, MemoryPoolBinding, MemoryPoolHandle, RoundingStrategy,
};
use alloc::vec::Vec;
use hashbrown::HashMap;
use crate::storage::ComputeStorage;
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
use std::time;
#[cfg(all(target_family = "wasm", feature = "std"))]
use web_time as time;
use super::{MemoryBinding, MemoryHandle, MemoryManagement};
// The ChunkId allows to keep track of how many references there are to a specific chunk.
memory_id_type!(ChunkId, ChunkHandle, ChunkBinding);
// The SliceId allows to keep track of how many references there are to a specific slice.
memory_id_type!(SliceId, SliceHandle, SliceBinding);
/// A tensor memory handle, referring to either a chunk or a slice.
#[derive(Debug, Clone)]
pub enum DynamicHandle {
/// A whole chunk of memory.
Chunk(ChunkHandle),
/// A slice of a chunk of memory.
Slice(SliceHandle),
}
/// Binding of the [dynamic handle](DynamicHandle).
#[derive(Debug, Clone)]
pub enum DynamicBinding {
/// Binding of the [chunk handle](ChunkHandle).
Chunk(ChunkBinding),
/// Binding of the [slice handle](SliceHandle)
Slice(SliceBinding),
}
/// The strategy defines the frequency at which merging of free slices (defragmentation) occurs
#[derive(Debug)]
pub enum MergingStrategy {
/// Once every n calls to reserve.
PeriodTick {
/// Number of calls to be executed before triggering the defragmentation.
period: usize,
/// Current state. Should start at zero.
state: usize,
},
#[cfg(feature = "std")]
/// Once every period of time
PeriodTime {
/// Number of time before triggering the defragmentation.
period: time::Duration,
/// Current state. Should start at now.
state: time::Instant,
},
/// Never defragment.
Never,
}
/// The strategy defines when to reuse chunk with slices.
#[derive(Debug)]
pub enum SliceStrategy {
/// Never use slices.
Never,
/// Ratio needed before the chunk can be used as a slice. Between 0 and 1.
Ratio(f32),
/// When the reserved memory is at least {} bytes.
MinimumSize(usize),
/// When the reserved memory less than {} bytes.
MaximumSize(usize),
}
impl SliceStrategy {
/// If the chunk can be used with a slice.
pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool {
if chunk_size < reserved_size {
return false;
}
match self {
SliceStrategy::Never => false,
SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio,
SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes,
SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes,
}
}
}
impl MergingStrategy {
/// Create a new strategy with the given period.
pub fn new_period_tick(period: usize) -> Self {
MergingStrategy::PeriodTick { period, state: 0 }
}
fn should_perform_defragmentation(&mut self) -> bool {
match self {
MergingStrategy::PeriodTick { period, state } => {
*state = (*state + 1) % *period;
*state == 0
}
#[cfg(feature = "std")]
MergingStrategy::PeriodTime { period, state } => {
if &state.elapsed() > period {
*state = time::Instant::now();
true
} else {
false
}
}
MergingStrategy::Never => false,
}
}
}
#[derive(new, Debug)]
struct Chunk {
storage: StorageHandle,
handle: ChunkHandle,
slices: Vec<SliceId>,
}
#[derive(new, Debug)]
struct Slice {
storage: StorageHandle,
handle: SliceHandle,
chunk: ChunkHandle,
padding: usize,
}
#[derive(Debug)]
struct Merging {
start: usize,
end: usize,
offset: usize,
size: usize,
slice_ids: Vec<SliceId>,
}
impl Slice {
pub fn effective_size(&self) -> usize {
self.storage.size() + self.padding
}
}
const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16;
const BUFFER_ALIGNMENT: usize = 32;
use super::MemoryManagement;
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
pub struct DynamicMemoryManagement<Storage> {
chunks: HashMap<ChunkId, Chunk>,
slices: HashMap<SliceId, Slice>,
merging_strategy: MergingStrategy,
slice_strategy: SliceStrategy,
small_memory_pool: MemoryPool,
main_memory_pool: MemoryPool,
storage: Storage,
}
impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
/// Creates a new instance using the given storage, merging_strategy strategy and slice strategy.
pub fn new(storage: Storage) -> Self {
let main_memory_pool = MemoryPool::new(
MemoryExtensionStrategy::new_period_tick(10),
RoundingStrategy::RoundUp,
1024 * 1024 * 1024 * 2,
true,
);
let small_memory_pool = MemoryPool::new(
MemoryExtensionStrategy::Never,
RoundingStrategy::None,
1024 * 1024 * 512,
false,
);
Self {
main_memory_pool,
small_memory_pool,
storage,
}
}
}
impl<Storage> core::fmt::Debug for DynamicMemoryManagement<Storage> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
alloc::format!(
"DynamicMemoryManagement {:?} - {:?}",
self.merging_strategy,
"DynamicMemoryManagement {:?}",
core::any::type_name::<Storage>(),
)
.as_str(),
@ -167,677 +48,44 @@ impl<Storage> core::fmt::Debug for DynamicMemoryManagement<Storage> {
}
}
impl MemoryBinding for DynamicBinding {}
impl MemoryHandle<DynamicBinding> for DynamicHandle {
fn can_mut(&self) -> bool {
match &self {
DynamicHandle::Chunk(id) => id.can_mut(),
DynamicHandle::Slice(id) => id.can_mut(),
}
}
fn binding(self) -> DynamicBinding {
match self {
Self::Chunk(handle) => DynamicBinding::Chunk(handle.binding()),
Self::Slice(handle) => DynamicBinding::Slice(handle.binding()),
}
}
}
impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagement<Storage> {
type Handle = DynamicHandle;
type Binding = DynamicBinding;
type Handle = MemoryPoolHandle;
type Binding = MemoryPoolBinding;
/// Returns the resource from the storage, for the specified handle.
fn get(&mut self, binding: Self::Binding) -> Storage::Resource {
let storage = match binding {
DynamicBinding::Chunk(chunk) => {
&self
.chunks
.get(chunk.id())
.expect("Storage found for the given execution buffer handle")
.storage
}
DynamicBinding::Slice(slice) => {
&self
.slices
.get(slice.id())
.expect("Storage found for the given execution buffer handle")
.storage
}
};
self.storage.get(storage)
if let Some(handle) = self.small_memory_pool.get(&mut self.storage, &binding) {
return handle;
}
/// Reserves memory of specified size using the reserve algorithm, and return
/// a handle to the reserved memory.
///
/// Also clean ups, merging free slices together if permitted by the merging strategy
fn reserve(&mut self, size: usize) -> Self::Handle {
let handle = self.reserve_algorithm(size);
if self.merging_strategy.should_perform_defragmentation() {
self.defragmentation();
if let Some(handle) = self.main_memory_pool.get(&mut self.storage, &binding) {
return handle;
}
handle
panic!("No handle found in the small and main memory pool");
}
fn alloc(&mut self, size: usize) -> Self::Handle {
let handle_chunk = self.create_chunk(size);
let chunk_id = *handle_chunk.id();
let slice = self.create_slice(0, size, handle_chunk);
let handle_slice = slice.handle.clone();
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
chunk.slices.push(*handle_slice.id());
self.slices.insert(*handle_slice.id(), slice);
DynamicHandle::Slice(handle_slice)
}
fn dealloc(&mut self, binding: Self::Binding) {
match binding {
DynamicBinding::Chunk(chunk) => {
if let Some(chunk) = self.chunks.remove(chunk.id()) {
self.storage.dealloc(chunk.storage.id);
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
if size < 512 {
self.small_memory_pool
.reserve(&mut self.storage, size, sync)
} else {
self.main_memory_pool.reserve(&mut self.storage, size, sync)
}
}
DynamicBinding::Slice(_) => panic!("Can't dealloc slice manually"),
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
if size < 512 {
self.small_memory_pool.alloc(&mut self.storage, size, sync)
} else {
self.main_memory_pool.alloc(&mut self.storage, size, sync)
}
}
fn dealloc(&mut self, _binding: Self::Binding) {
// Can't dealloc slices.
}
fn storage(&mut self) -> &mut Storage {
&mut self.storage
}
}
impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
/// Creates a new instance using the given storage, merging_strategy strategy and slice strategy.
pub fn new(
storage: Storage,
merging_strategy: MergingStrategy,
slice_strategy: SliceStrategy,
) -> Self {
Self {
chunks: HashMap::new(),
slices: HashMap::new(),
merging_strategy,
slice_strategy,
storage,
}
}
fn reserve_algorithm(&mut self, size: usize) -> DynamicHandle {
// Looks for a large enough, existing but unused chunk of memory.
let slice = self.get_free_slice(size);
match slice {
Some(slice) => DynamicHandle::Slice(slice.clone()),
None => self.alloc(size),
}
}
fn find_free_slice_best_fit(
&self,
size: usize,
effective_size: usize,
) -> Option<(SliceId, usize)> {
let mut size_diff_current = usize::MAX;
let mut found = None;
for (__, chunk) in self.chunks.iter() {
if size < MIN_SIZE_NEEDED_TO_OFFSET && chunk.slices.len() > 1 {
continue;
}
if !self
.slice_strategy
.can_use_chunk(chunk.storage.size(), effective_size)
{
continue;
}
for slice_id in chunk.slices.iter() {
let slice = self.slices.get(slice_id).unwrap();
let slice_can_be_reused =
slice.handle.is_free() && slice.effective_size() >= effective_size;
if slice_can_be_reused {
let size_diff = slice.effective_size() - effective_size;
if size_diff < size_diff_current {
size_diff_current = size_diff;
found = Some((*slice_id, size_diff));
}
}
}
}
found
}
/// Tries to split a slice in two with the first slice being of the specified size
/// If there is not enough space for 2 slice, only uses one slice
/// returns the handle of the first slice
fn split_slice_in_two(
&mut self,
slice_to_split_id: &SliceId,
first_slice_size: usize,
) -> Option<SliceHandle> {
let slice_to_split = self.slices.get(slice_to_split_id).unwrap();
let slice_to_split_effective_size = slice_to_split.effective_size();
let chunk = self.chunks.get_mut(slice_to_split.chunk.id()).unwrap();
let current_slice_chunk_handle = chunk.handle.clone();
let mut slices = Vec::with_capacity(chunk.slices.len() + 1);
let mut offset = 0;
let mut slices_old = Vec::new();
core::mem::swap(&mut slices_old, &mut chunk.slices);
let mut handle = None;
for slice_id in slices_old.into_iter() {
// Assumes that all slices are contiguous in a chunk.
let slice = self.slices.get(&slice_id).unwrap();
if slice_id != *slice_to_split_id {
slices.push(slice_id);
offset += slice.effective_size();
} else {
let first_slice =
self.create_slice(offset, first_slice_size, current_slice_chunk_handle.clone());
let first_slice_id = *first_slice.handle.id();
offset += first_slice.effective_size();
let second_slice_size =
slice_to_split_effective_size - first_slice.effective_size();
let slice_end = self.create_slice(
offset,
second_slice_size,
current_slice_chunk_handle.clone(),
);
let slice_end_id = *slice_end.handle.id();
offset += slice_end.effective_size();
let created_offset = first_slice.effective_size() + slice_end.effective_size();
assert_eq!(created_offset, slice.effective_size());
handle = Some(first_slice.handle.clone());
self.slices.insert(first_slice_id, first_slice);
self.slices.insert(slice_end_id, slice_end);
slices.push(first_slice_id);
slices.push(slice_end_id);
}
}
self.slices.remove(slice_to_split_id);
let chunk = self
.chunks
.get_mut(current_slice_chunk_handle.id())
.unwrap();
chunk.slices = slices;
handle
}
/// Finds the smallest of the free and large enough chunks to fit `size`
/// Returns the chunk's id and size.
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
let padding = Self::calculate_padding(size);
let effective_size = size + padding;
let found = self.find_free_slice_best_fit(size, effective_size);
let (slice_id, size_diff_current) = match found {
Some(val) => val,
None => {
return None;
}
};
// if same size reuse the slice
if size_diff_current == 0 {
let slice = self.slices.get_mut(&slice_id).unwrap();
let offset = match slice.storage.utilization {
StorageUtilization::Full(_) => 0,
StorageUtilization::Slice { offset, size: _ } => offset,
};
slice.storage.utilization = StorageUtilization::Slice { offset, size };
slice.padding = padding;
return Some(self.slices.get(&slice_id).unwrap().handle.clone());
}
assert_eq!(size_diff_current % BUFFER_ALIGNMENT, 0);
// split into 2 if needed
let handle = self.split_slice_in_two(&slice_id, size);
if handle.is_none() {
panic!("split should have returned a handle");
}
handle
}
/// Creates a slice of size `size` upon the given chunk with the given offset.
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> Slice {
if offset > 0 && size < MIN_SIZE_NEEDED_TO_OFFSET {
panic!("tried to create slice of size {size} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
}
if offset % BUFFER_ALIGNMENT != 0 {
panic!("slice with offset {offset} needs to be a multiple of {BUFFER_ALIGNMENT}");
}
let chunk = self.chunks.get(handle_chunk.id()).unwrap();
let handle = SliceHandle::new();
let storage = StorageHandle {
id: chunk.storage.id.clone(),
utilization: StorageUtilization::Slice { offset, size },
};
let padding = Self::calculate_padding(size);
Slice::new(storage, handle, chunk.handle.clone(), padding)
}
#[cfg(test)]
fn insert_slice(&mut self, slice: Slice, chunk_id: ChunkId) {
let slice_id = *slice.handle.id();
self.slices.insert(*slice.handle.id(), slice);
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
chunk.slices.push(slice_id);
}
/// Creates a chunk of given size by allocating on the storage.
fn create_chunk(&mut self, size: usize) -> ChunkHandle {
let padding = Self::calculate_padding(size);
let effective_size = size + padding;
let storage = self.storage.alloc(effective_size);
let handle = ChunkHandle::new();
self.chunks.insert(
*handle.id(),
Chunk::new(storage, handle.clone(), Vec::new()),
);
handle
}
// generates and adds all the merging information of a given chunk to an hash_map
fn generate_mergings(&self, chunk: &Chunk, merging_map: &mut HashMap<ChunkId, Vec<Merging>>) {
let mut to_merge: Vec<Merging> = Vec::new();
let mut start_index: usize = 0;
let mut num_merge = 0;
let mut offset_current = 0;
let mut offset = 0;
let mut slices_ids = Vec::new();
for (i, slice_id) in chunk.slices.iter().enumerate() {
let slice = self.slices.get(slice_id).unwrap();
if slice.handle.is_free() {
slices_ids.push(*slice_id);
num_merge += 1;
offset += slice.effective_size();
continue;
} else if num_merge > 1 {
let mut empty = Vec::new();
core::mem::swap(&mut slices_ids, &mut empty);
let merging = Merging {
start: start_index,
end: start_index + num_merge - 1,
offset: offset_current,
size: offset - offset_current,
slice_ids: empty,
};
to_merge.push(merging);
}
offset += slice.effective_size();
start_index = i + 1;
num_merge = 0;
offset_current = offset;
slices_ids.clear();
}
if !to_merge.is_empty() {
merging_map.insert(*chunk.handle.id(), to_merge);
}
}
// merges all free slices together use the mergings metadata
fn merge_contiguous_slices(&mut self, chunk_id: ChunkId, mergings: &Vec<Merging>) {
let chunk = self.chunks.get(&chunk_id).unwrap();
let chunk_handle = chunk.handle.clone();
let slices = chunk.slices.clone();
let mut slices_updated = Vec::new();
let mut index = 0;
for merging in mergings {
let slice = self.create_slice(merging.offset, merging.size, chunk_handle.clone());
let slice_id = *slice.handle.id();
self.slices.insert(slice_id, slice);
for i in index..merging.start {
slices_updated.push(*slices.get(i).unwrap());
}
index = merging.end + 1;
slices_updated.push(slice_id);
for slice_id_to_remove in merging.slice_ids.iter() {
self.slices.remove(slice_id_to_remove);
}
}
for i in index..slices.len() {
slices_updated.push(*slices.get(i).unwrap());
}
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
core::mem::swap(&mut chunk.slices, &mut slices_updated);
}
// Merge all contiguous free_slices together, assumes that slices are in sorted order.
fn defragmentation(&mut self) {
let mut chunk_to_merged_slice: HashMap<ChunkId, Vec<Merging>> = HashMap::new();
for (.., chunk) in self.chunks.iter() {
self.generate_mergings(chunk, &mut chunk_to_merged_slice);
}
for (chunk_id, mergings) in chunk_to_merged_slice.into_iter() {
self.merge_contiguous_slices(chunk_id, &mergings);
}
}
fn calculate_padding(size: usize) -> usize {
let rem = size % BUFFER_ALIGNMENT;
if rem != 0 {
BUFFER_ALIGNMENT - rem
} else {
0
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
memory_management::{MemoryHandle, MemoryManagement},
storage::BytesStorage,
};
#[test]
fn can_mut_with_single_tensor_reference() {
let mut memory_management = DynamicMemoryManagement::new(
BytesStorage::default(),
MergingStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let simple_handle = memory_management.create_chunk(chunk_size);
let x = simple_handle.clone();
core::mem::drop(simple_handle);
assert!(x.can_mut());
}
#[test]
fn two_tensor_references_remove_mutability() {
let mut memory_management = DynamicMemoryManagement::new(
BytesStorage::default(),
MergingStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let simple_handle = memory_management.create_chunk(chunk_size);
let x = simple_handle.clone();
assert!(!simple_handle.can_mut());
assert!(!x.can_mut())
}
#[test]
fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
let mut memory_management = DynamicMemoryManagement::new(
BytesStorage::default(),
MergingStrategy::Never,
SliceStrategy::Never,
);
let chunk_size = 4;
let _chunk_handle = memory_management.reserve(chunk_size);
let _new_handle = memory_management.reserve(chunk_size);
assert_eq!(memory_management.chunks.len(), 2);
}
#[test]
fn when_big_chunk_is_freed_should_be_filled_with_smaller_slices() {
let mut memory_management = DynamicMemoryManagement::new(
BytesStorage::default(),
MergingStrategy::Never,
SliceStrategy::Ratio(0.2),
);
let big_slice_size = 32 * 3;
let small_slice_size = 32;
let big_slice = memory_management.reserve(big_slice_size);
drop(big_slice);
let _small_slice_1 = memory_management.reserve(small_slice_size);
let _small_slice_2 = memory_management.reserve(small_slice_size);
let _small_slice_3 = memory_management.reserve(small_slice_size);
assert_eq!(memory_management.chunks.len(), 1);
assert_eq!(memory_management.slices.len(), 3);
for (.., slice) in memory_management.slices.iter() {
assert_eq!(slice.storage.size(), 32);
}
}
#[test]
fn when_defragmentation_called_if_two_slices_free_should_merge_into_bigger_slice() {
let mut memory_management = DynamicMemoryManagement::new(
BytesStorage::default(),
MergingStrategy::new_period_tick(1),
SliceStrategy::Ratio(0.2),
);
let chunk_handle = memory_management.create_chunk(32 + 32);
let slice = memory_management.create_slice(0, 32 + 32, chunk_handle.clone());
memory_management.insert_slice(slice, *chunk_handle.id());
let _slice_1 = memory_management.reserve(32);
let _slice_2 = memory_management.reserve(32);
assert_eq!(memory_management.chunks.len(), 1);
assert_eq!(memory_management.slices.len(), 2);
for (.., slice) in memory_management.slices.iter() {
assert_eq!(slice.storage.size(), 32);
}
}
#[test]
fn when_defragmentation_called_should_merge_contiguous_free_slice_together() {
let mut memory_management = DynamicMemoryManagement::new(
BytesStorage::default(),
MergingStrategy::new_period_tick(1),
SliceStrategy::Ratio(0.1),
);
//The chunk will be separated in 7 slices. 1 free, 23 not free, 456 free and 7 not free
let slice_size = 32;
let num_of_slice = 7;
let chunk_handle = memory_management.create_chunk(slice_size * num_of_slice);
let chunk_id = *chunk_handle.id();
let slice = memory_management.create_slice(0, slice_size * num_of_slice, chunk_handle);
memory_management.insert_slice(slice, chunk_id);
let _slice_1 = memory_management.reserve(slice_size);
let _slice_2 = memory_management.reserve(slice_size);
let _slice_3 = memory_management.reserve(slice_size);
let _slice_4 = memory_management.reserve(slice_size);
let _slice_5 = memory_management.reserve(slice_size);
let _slice_6 = memory_management.reserve(slice_size);
let _slice_7 = memory_management.reserve(slice_size);
drop(_slice_1);
drop(_slice_4);
drop(_slice_5);
drop(_slice_6);
memory_management.defragmentation();
assert_eq!(memory_management.chunks.len(), 1);
assert_eq!(memory_management.slices.len(), 5);
let chunk = memory_management.chunks.get(&chunk_id).unwrap();
let slices = &chunk.slices;
// first slice test
let first_slice_id = slices.first().unwrap();
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
assert!(first_slice.handle.is_free());
assert_eq!(first_slice.storage.size(), slice_size);
assert_eq!(first_slice.storage.offset(), 0);
// second slice test
let first_slice_id = slices.get(1).unwrap();
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
assert!(!first_slice.handle.is_free());
assert_eq!(first_slice.storage.size(), slice_size);
assert_eq!(first_slice.storage.offset(), slice_size);
// third slice test
let first_slice_id = slices.get(2).unwrap();
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
assert!(!first_slice.handle.is_free());
assert_eq!(first_slice.storage.size(), slice_size);
assert_eq!(first_slice.storage.offset(), slice_size * 2);
// fourth slice test (456 merged)
let first_slice_id = slices.get(3).unwrap();
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
assert!(first_slice.handle.is_free());
assert_eq!(first_slice.storage.size(), slice_size * 3);
assert_eq!(first_slice.storage.offset(), slice_size * 3);
// fifth slice test
let first_slice_id = slices.get(4).unwrap();
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
assert!(!first_slice.handle.is_free());
assert_eq!(first_slice.storage.size(), slice_size);
assert_eq!(first_slice.storage.offset(), slice_size * 6);
}
#[test]
fn never_dealloc_strategy_never_deallocs() {
let mut never_dealloc = MergingStrategy::Never;
for _ in 0..20 {
assert!(!never_dealloc.should_perform_defragmentation())
}
}
#[test]
fn period_tick_dealloc_strategy_should_dealloc_after_period() {
let period = 3;
let mut period_tick_dealloc = MergingStrategy::new_period_tick(period);
for _ in 0..3 {
for _ in 0..period - 1 {
assert!(!period_tick_dealloc.should_perform_defragmentation());
}
assert!(period_tick_dealloc.should_perform_defragmentation());
}
}
#[test]
fn slice_strategy_minimum_bytes() {
let strategy = SliceStrategy::MinimumSize(100);
assert!(strategy.can_use_chunk(200, 101));
assert!(!strategy.can_use_chunk(200, 99));
}
#[test]
fn slice_strategy_maximum_bytes() {
let strategy = SliceStrategy::MaximumSize(100);
assert!(strategy.can_use_chunk(200, 99));
assert!(!strategy.can_use_chunk(200, 101));
}
#[test]
fn slice_strategy_ratio() {
let strategy = SliceStrategy::Ratio(0.9);
assert!(strategy.can_use_chunk(200, 180));
assert!(!strategy.can_use_chunk(200, 179));
}
#[test]
fn test_handle_mutability() {
let mut memory_management = DynamicMemoryManagement::new(
BytesStorage::default(),
MergingStrategy::Never,
SliceStrategy::Ratio(0.5),
);
let handle = memory_management.reserve(10);
let other_ref = handle.clone();
assert!(!handle.can_mut(), "Handle can't be mut when multiple ref.");
drop(other_ref);
assert!(handle.can_mut(), "Handle should be mut when only one ref.");
}
#[test]
fn support_multiple_slices_for_a_chunk() {
let mut memory_management = DynamicMemoryManagement::new(
BytesStorage::default(),
MergingStrategy::Never,
SliceStrategy::Ratio(0.2),
);
let handle = memory_management.reserve(100);
core::mem::drop(handle);
let _slice_1 = memory_management.reserve(30);
let _slice_2 = memory_management.reserve(30);
let _slice_3 = memory_management.reserve(30);
assert_eq!(memory_management.chunks.len(), 1);
assert_eq!(memory_management.slices.len(), 4);
}
#[test]
fn test_slice_mutability() {
let mut memory_management = DynamicMemoryManagement::new(
BytesStorage::default(),
MergingStrategy::Never,
SliceStrategy::Ratio(0.5),
);
let first_slice = memory_management.reserve(10);
drop(first_slice);
let slice = memory_management.reserve(8);
if let super::DynamicHandle::Slice(slice) = slice {
let other_ref = slice.clone();
assert!(
!slice.can_mut(),
"Slice can't be mut when multiple ref to the same handle."
);
drop(other_ref);
assert!(
slice.can_mut(),
"Slice should be mut when only one ref to the same handle."
);
assert!(
!slice.is_free(),
"Slice can't be reallocated when one ref still exist."
);
}
}
}

View File

@ -0,0 +1,541 @@
use super::index::SearchIndex;
use super::{
ChunkHandle, ChunkId, MemoryChunk, MemoryPoolBinding, MemoryPoolHandle, MemorySlice,
RingBuffer, SliceHandle, SliceId,
};
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
use alloc::vec::Vec;
use hashbrown::{HashMap, HashSet};
pub struct MemoryPool {
chunks: HashMap<ChunkId, Chunk>,
slices: HashMap<SliceId, Slice>,
memory_extension_strategy: MemoryExtensionStrategy,
rounding: RoundingStrategy,
chunk_index: SearchIndex<ChunkId>,
max_chunk_size: usize,
ring: RingBuffer<Chunk, Slice>,
debug: bool,
}
#[derive(new, Debug)]
pub struct Chunk {
pub storage: StorageHandle,
pub handle: ChunkHandle,
pub slices: Vec<SliceId>,
}
#[derive(new, Debug)]
pub struct Slice {
pub storage: StorageHandle,
pub handle: SliceHandle,
pub chunk: ChunkHandle,
pub padding: usize,
}
impl Slice {
pub fn effective_size(&self) -> usize {
self.storage.size() + self.padding
}
}
const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16;
const BUFFER_ALIGNMENT: usize = 32;
const MB: usize = 1024 * 1024;
pub enum RoundingStrategy {
RoundUp,
None,
}
impl RoundingStrategy {
fn alloc_size(&self, size: usize) -> usize {
match self {
RoundingStrategy::RoundUp => {
if size < BUFFER_ALIGNMENT {
return BUFFER_ALIGNMENT;
}
if size < MB {
2 * MB
} else if size < 10 * MB {
return 20 * MB;
} else {
let factor = (size + (2 * MB - 1)) / (2 * MB);
factor * 2 * MB
}
}
RoundingStrategy::None => size,
}
}
}
/// The strategy defines the frequency at which merging of free slices (defragmentation) occurs
#[derive(Debug)]
pub enum MemoryExtensionStrategy {
/// Once every n calls to reserve.
PeriodTick {
/// Number of calls to be executed before triggering the defragmentation.
period: usize,
/// Current state. Should start at zero.
state: usize,
},
/// Never defragment.
Never,
}
impl MemoryExtensionStrategy {
/// Create a new strategy with the given period.
pub fn new_period_tick(period: usize) -> Self {
MemoryExtensionStrategy::PeriodTick { period, state: 0 }
}
fn should_extend_max_memory(&mut self) -> bool {
match self {
MemoryExtensionStrategy::PeriodTick { period, state } => {
*state = (*state + 1) % *period;
*state == 0
}
MemoryExtensionStrategy::Never => false,
}
}
}
struct SliceUpdate {
slice_id: SliceId,
size: usize,
}
impl MemoryPool {
pub fn new(
merging_strategy: MemoryExtensionStrategy,
alloc_strategy: RoundingStrategy,
max_chunk_size: usize,
debug: bool,
) -> Self {
Self {
chunks: HashMap::new(),
slices: HashMap::new(),
memory_extension_strategy: merging_strategy,
rounding: alloc_strategy,
max_chunk_size,
chunk_index: SearchIndex::new(),
ring: RingBuffer::new(),
debug,
}
}
/// Returns the resource from the storage, for the specified handle.
pub fn get<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
binding: &MemoryPoolBinding,
) -> Option<Storage::Resource> {
self.slices
.get(binding.slice.id())
.map(|s| &s.storage)
.map(|h| storage.get(h))
}
/// Reserves memory of specified size using the reserve algorithm, and return
/// a handle to the reserved memory.
///
/// Also clean ups, merging free slices together if permitted by the merging strategy
pub fn reserve<Storage: ComputeStorage, Sync: FnOnce()>(
&mut self,
storage: &mut Storage,
size: usize,
sync: Sync,
) -> MemoryPoolHandle {
// Looks for a large enough, existing but unused chunk of memory.
let slice = self.get_free_slice(size);
match slice {
Some(slice) => MemoryPoolHandle {
slice: slice.clone(),
},
None => self.alloc(storage, size, sync),
}
}
pub fn alloc<Storage: ComputeStorage, Sync: FnOnce()>(
&mut self,
storage: &mut Storage,
size: usize,
sync: Sync,
) -> MemoryPoolHandle {
if self.memory_extension_strategy.should_extend_max_memory() {
sync();
self.extend_max_memory(storage, size);
if let Some(handle) = self.get_free_slice(size) {
return MemoryPoolHandle { slice: handle };
}
}
let alloc_size = self.rounding.alloc_size(size);
self.alloc_slice(storage, alloc_size, size)
}
fn alloc_slice<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
alloc_size: usize,
slice_size: usize,
) -> MemoryPoolHandle {
let handle_chunk = self.create_chunk(storage, alloc_size);
let chunk_id = *handle_chunk.id();
let (slice, extra_slice) =
self.allocate_slices(handle_chunk.clone(), alloc_size, slice_size);
let handle_slice = slice.handle.clone();
self.update_chunk_metadata(chunk_id, slice, extra_slice);
MemoryPoolHandle {
slice: handle_slice,
}
}
fn allocate_slices(
&self,
handle_chunk: ChunkHandle,
alloc_size: usize,
slice_size: usize,
) -> (Slice, Option<Slice>) {
let slice = self.create_slice(0, slice_size, handle_chunk.clone());
let effective_size = slice.effective_size();
let extra_slice = if effective_size < alloc_size {
Some(self.create_slice(effective_size, alloc_size - effective_size, handle_chunk))
} else {
None
};
(slice, extra_slice)
}
fn update_chunk_metadata(
&mut self,
chunk_id: ChunkId,
slice: Slice,
extra_slice: Option<Slice>,
) {
let slice_id = *slice.handle.id();
self.slices.insert(slice_id, slice);
self.chunks
.get_mut(&chunk_id)
.unwrap()
.slices
.push(slice_id);
if let Some(extra_slice) = extra_slice {
let extra_slice_id = *extra_slice.handle.id();
self.slices.insert(extra_slice_id, extra_slice);
self.chunks
.get_mut(&chunk_id)
.unwrap()
.slices
.push(extra_slice_id);
}
}
#[allow(unused)]
fn display_memory_usage(&self) {
let total_memory_usage: f64 = self
.chunks
.values()
.map(|chunk| chunk.storage.size() as f64)
.sum();
let effective_memory_usage: f64 = self
.slices
.values()
.filter(|slice| slice.handle.is_free())
.map(|slice| slice.storage.size() as f64)
.sum();
let ratio = 100.0 * effective_memory_usage / total_memory_usage;
log::info!("the memory usage is {ratio}");
}
/// Finds a free slice that can contain the given size
/// Returns the chunk's id and size.
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
if size < MIN_SIZE_NEEDED_TO_OFFSET {
return None;
}
let padding = calculate_padding(size);
let effective_size = size + padding;
let slice_id =
self.ring
.find_free_slice(effective_size, &mut self.chunks, &mut self.slices);
let slice_id = match slice_id {
Some(val) => val,
None => return None,
};
let slice = self.slices.get_mut(&slice_id).unwrap();
let offset = match slice.storage.utilization {
StorageUtilization::Full(_) => 0,
StorageUtilization::Slice { offset, size: _ } => offset,
};
slice.storage.utilization = StorageUtilization::Slice { offset, size };
slice.padding = padding;
Some(slice.handle.clone())
}
/// Creates a slice of size `size` upon the given chunk with the given offset.
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> Slice {
assert_eq!(
offset % BUFFER_ALIGNMENT,
0,
"slice with offset {offset} needs to be a multiple of {BUFFER_ALIGNMENT}"
);
if offset > 0 && size < MIN_SIZE_NEEDED_TO_OFFSET {
panic!("tried to create slice of size {size} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
}
let chunk = self.chunks.get(handle_chunk.id()).unwrap();
let handle = SliceHandle::new();
let storage = StorageHandle {
id: chunk.storage.id.clone(),
utilization: StorageUtilization::Slice { offset, size },
};
let padding = calculate_padding(size);
Slice::new(storage, handle, chunk.handle.clone(), padding)
}
/// Creates a chunk of given size by allocating on the storage.
fn create_chunk<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
size: usize,
) -> ChunkHandle {
let padding = calculate_padding(size);
let effective_size = size + padding;
let storage = storage.alloc(effective_size);
let handle = ChunkHandle::new();
let id = *handle.id();
self.ring.push_chunk(id);
self.chunks
.insert(id, Chunk::new(storage, handle.clone(), Vec::new()));
self.chunk_index.insert(id, size);
handle
}
fn extend_max_memory<Storage: ComputeStorage>(&mut self, storage: &mut Storage, size: usize) {
if self.debug {
log::info!("Extend max memory ...");
}
let mut slices = Vec::<SliceUpdate>::new();
let mut current_size = size;
let chunks_sorted = self
.chunk_index
.find_by_size(0..self.max_chunk_size - 1)
.map(Clone::clone)
.collect::<Vec<_>>();
let mut deallocations = HashSet::<ChunkId>::new();
for chunk_id in chunks_sorted {
let chunk = self.chunks.get(&chunk_id).unwrap();
let chunk_id = *chunk.handle.id();
let slices_ids = chunk.slices.clone();
for slice_id in slices_ids {
let slice = self.slices.get(&slice_id).unwrap();
let size = slice.storage.size();
let effective_size = slice.effective_size();
current_size += effective_size;
if current_size >= self.max_chunk_size {
let alloc_size = current_size - effective_size;
// let alloc_size = self.max_chunk_size;
self.move_to_new_chunk(alloc_size, storage, &mut slices, &mut deallocations);
current_size = effective_size;
}
slices.push(SliceUpdate { slice_id, size });
}
deallocations.insert(chunk_id);
}
if !slices.is_empty() {
self.move_to_new_chunk(current_size, storage, &mut slices, &mut deallocations);
} else {
self.deallocate(storage, &mut deallocations);
}
}
fn deallocate<Storage: ComputeStorage>(
&mut self,
storage: &mut Storage,
deallocations: &mut HashSet<ChunkId>,
) {
for id in deallocations.drain() {
let mut chunk = self.chunks.remove(&id).unwrap();
self.ring.remove_chunk(id);
for slice in chunk.slices.drain(..) {
let slice = self.slices.get(&slice).unwrap();
let chunk_id = *slice.chunk.id();
assert_ne!(chunk_id, id, "Chunk id should be updated");
}
self.chunk_index.remove(&id);
storage.dealloc(chunk.storage.id);
}
}
fn move_to_new_chunk<Storage: ComputeStorage>(
&mut self,
alloc_size: usize,
storage: &mut Storage,
slices: &mut Vec<SliceUpdate>,
deallocations: &mut HashSet<ChunkId>,
) {
let chunk = self.create_chunk(storage, alloc_size);
let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone();
let mut offset = 0;
let mut slices_ids = Vec::new();
for update in slices.drain(..) {
let slice_id = update.slice_id;
let slice = self.slices.get_mut(&slice_id).unwrap();
let old_storage = slice.storage.clone();
slice.chunk = chunk.clone();
slice.storage = StorageHandle {
id: storage_id.clone(),
utilization: StorageUtilization::Slice {
offset,
size: update.size,
},
};
storage.copy(&old_storage, &slice.storage);
slices_ids.push(slice_id);
offset += slice.effective_size();
}
let chunk = self.chunks.get_mut(chunk.id()).unwrap();
chunk.slices = slices_ids;
self.deallocate(storage, deallocations);
}
}
fn calculate_padding(size: usize) -> usize {
let remainder = size % BUFFER_ALIGNMENT;
if remainder != 0 {
BUFFER_ALIGNMENT - remainder
} else {
0
}
}
impl MemorySlice for Slice {
fn is_free(&self) -> bool {
self.handle.is_free()
}
fn size(&self) -> usize {
self.effective_size()
}
fn split(&mut self, offset_slice: usize) -> Self {
let size_new = self.effective_size() - offset_slice;
let offset_new = self.storage.offset() + offset_slice;
let storage_new = StorageHandle {
id: self.storage.id.clone(),
utilization: StorageUtilization::Slice {
offset: offset_new,
size: size_new,
},
};
self.storage.utilization = StorageUtilization::Slice {
offset: self.storage.offset(),
size: offset_slice,
};
if offset_new > 0 && size_new < MIN_SIZE_NEEDED_TO_OFFSET {
panic!("tried to create slice of size {size_new} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
}
if offset_new % BUFFER_ALIGNMENT != 0 {
panic!("slice with offset {offset_new} needs to be a multiple of {BUFFER_ALIGNMENT}");
}
let handle = SliceHandle::new();
assert!(
size_new >= BUFFER_ALIGNMENT,
"Size new > {BUFFER_ALIGNMENT}"
);
let padding = calculate_padding(size_new - BUFFER_ALIGNMENT);
Slice::new(storage_new, handle, self.chunk.clone(), padding)
}
fn id(&self) -> SliceId {
*self.handle.id()
}
}
impl MemoryChunk<Slice> for Chunk {
fn merge_next_slice(
&mut self,
from_slice_index: usize,
slices: &mut HashMap<SliceId, Slice>,
) -> bool {
let slice_id_current = self.slices.get(from_slice_index).unwrap();
let slice_id_next = self.slices.get(from_slice_index + 1);
let slice_id_next = match slice_id_next {
Some(val) => val,
None => return false,
};
let slice_next = slices.get(slice_id_next).unwrap();
let is_free = slice_next.is_free();
let size = slice_next.effective_size();
let slice_current = slices.get_mut(slice_id_current).unwrap();
if is_free {
slice_current.storage.utilization = StorageUtilization::Slice {
size: slice_current.effective_size() + size,
offset: slice_current.storage.offset(),
};
slices.remove(slice_id_next);
self.slices.remove(from_slice_index + 1);
return true;
}
false
}
fn slice(&self, index: usize) -> Option<SliceId> {
self.slices.get(index).copied()
}
fn insert_slice(&mut self, position: usize, slice_id: SliceId) {
self.slices.insert(position, slice_id);
}
}

View File

@ -0,0 +1,33 @@
use crate::memory_id_type;
use crate::memory_management::{MemoryBinding, MemoryHandle};
// The ChunkId allows to keep track of how many references there are to a specific chunk.
memory_id_type!(ChunkId, ChunkHandle);
// The SliceId allows to keep track of how many references there are to a specific slice.
memory_id_type!(SliceId, SliceHandle, SliceBinding);
/// A tensor memory handle, referring to either a chunk or a slice.
#[derive(Debug, Clone)]
pub struct MemoryPoolHandle {
pub slice: SliceHandle,
}
/// Binding of the [dynamic handle](DynamicHandle).
#[derive(Debug, Clone)]
pub struct MemoryPoolBinding {
pub slice: SliceBinding,
}
impl MemoryBinding for MemoryPoolBinding {}
impl MemoryHandle<MemoryPoolBinding> for MemoryPoolHandle {
fn can_mut(&self) -> bool {
self.slice.can_mut()
}
fn binding(self) -> MemoryPoolBinding {
MemoryPoolBinding {
slice: self.slice.binding(),
}
}
}

View File

@ -0,0 +1,64 @@
use alloc::collections::BTreeMap;
use alloc::vec;
use alloc::vec::Vec;
use core::hash::Hash;
use hashbrown::HashMap;
/// Data Structure that helps to search items by size efficiently.
pub struct SearchIndex<T> {
items_per_size: BTreeMap<usize, Vec<T>>,
sizes_per_item: HashMap<T, usize>,
}
impl<T: PartialEq + Eq + Hash + Clone> SearchIndex<T> {
/// Create a new item search index.
pub fn new() -> Self {
Self {
items_per_size: BTreeMap::new(),
sizes_per_item: HashMap::new(),
}
}
/// Insert a new sized item into the search index.
pub fn insert(&mut self, item: T, size: usize) {
self.remove(&item);
if let Some(values) = self.items_per_size.get_mut(&size) {
values.push(item.clone())
} else {
self.items_per_size.insert(size, vec![item.clone()]);
}
self.sizes_per_item.insert(item, size);
}
/// Find the item by size range.
pub fn find_by_size(
&self,
range: core::ops::Range<usize>,
) -> impl DoubleEndedIterator<Item = &T> {
self.items_per_size.range(range).flat_map(|a| a.1)
}
/// Remove an item from the index.
pub fn remove(&mut self, item: &T) {
let size = match self.sizes_per_item.remove(item) {
Some(size) => size,
None => return,
};
if let Some(values) = self.items_per_size.get_mut(&size) {
let mut removed_index = None;
for (i, v) in values.iter().enumerate() {
if v == item {
removed_index = Some(i);
break;
}
}
if let Some(index) = removed_index {
values.remove(index);
}
}
}
}

View File

@ -0,0 +1,9 @@
pub(crate) mod index;
mod ring;
mod base;
mod handle;
pub use base::*;
pub use handle::*;
pub use ring::*;

View File

@ -0,0 +1,416 @@
use alloc::vec::Vec;
use core::marker::PhantomData;
use hashbrown::HashMap;
use super::{ChunkId, SliceId};
#[derive(Debug)]
pub struct RingBuffer<C: MemoryChunk<S>, S: MemorySlice> {
queue: Vec<ChunkId>,
chunk_positions: HashMap<ChunkId, usize>,
cursor_slice: usize,
cursor_chunk: usize,
_s: PhantomData<S>,
_c: PhantomData<C>,
}
pub trait MemoryChunk<S: MemorySlice> {
fn merge_next_slice(&mut self, slice_position: usize, slices: &mut HashMap<SliceId, S>)
-> bool;
fn slice(&self, index: usize) -> Option<SliceId>;
fn insert_slice(&mut self, position: usize, slice_id: SliceId);
}
pub trait MemorySlice {
fn is_free(&self) -> bool;
fn size(&self) -> usize;
fn split(&mut self, offset: usize) -> Self;
fn id(&self) -> SliceId;
}
impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
pub fn new() -> Self {
Self {
queue: Vec::new(),
chunk_positions: HashMap::new(),
cursor_slice: 0,
cursor_chunk: 0,
_s: PhantomData,
_c: PhantomData,
}
}
pub fn push_chunk(&mut self, chunk_id: ChunkId) {
self.queue.push(chunk_id);
self.chunk_positions.insert(chunk_id, self.queue.len() - 1);
}
pub fn remove_chunk(&mut self, chunk_id: ChunkId) {
if let Some(position) = self.chunk_positions.remove(&chunk_id) {
self.queue.remove(position);
}
self.chunk_positions.clear();
for (pos, id) in self.queue.iter().enumerate() {
self.chunk_positions.insert(*id, pos);
}
}
pub fn find_free_slice(
&mut self,
size: usize,
chunks: &mut HashMap<ChunkId, C>,
slices: &mut HashMap<SliceId, S>,
) -> Option<SliceId> {
let max_second = self.cursor_chunk;
let result = self.find_free_slice_in_all_chunks(size, chunks, slices, self.queue.len());
if result.is_some() {
return result;
}
self.cursor_chunk = 0;
self.cursor_slice = 0;
self.find_free_slice_in_all_chunks(size, chunks, slices, max_second)
}
fn find_free_slice_in_chunk(
&mut self,
size: usize,
chunk: &mut C,
slices: &mut HashMap<SliceId, S>,
mut slice_index: usize,
) -> Option<(usize, SliceId)> {
while let Some(slice_id) = chunk.slice(slice_index) {
let slice = slices.get_mut(&slice_id).unwrap();
let is_big_enough = slice.size() >= size;
let is_free = slice.is_free();
if is_big_enough && is_free {
if slice.size() > size {
let new_slice = slice.split(size);
chunk.insert_slice(slice_index + 1, new_slice.id());
slices.insert(new_slice.id(), new_slice);
}
return Some((slice_index, slice_id));
}
if is_free && chunk.merge_next_slice(slice_index, slices) {
continue;
}
slice_index += 1;
}
None
}
fn find_free_slice_in_all_chunks(
&mut self,
size: usize,
chunks: &mut HashMap<ChunkId, C>,
slices: &mut HashMap<SliceId, S>,
max_cursor_position: usize,
) -> Option<SliceId> {
let start = self.cursor_chunk;
let end = usize::min(self.queue.len(), max_cursor_position);
let mut slice_index = self.cursor_slice;
for chunk_index in start..end {
if chunk_index > start {
slice_index = 0;
}
if let Some(id) = self.queue.get(chunk_index) {
let chunk = chunks.get_mut(id).unwrap();
let result = self.find_free_slice_in_chunk(size, chunk, slices, slice_index);
if let Some((cursor_slice, slice)) = result {
self.cursor_slice = cursor_slice + 1;
self.cursor_chunk = chunk_index;
return Some(slice);
}
}
self.cursor_chunk = chunk_index;
self.cursor_slice = 0;
}
None
}
}
#[cfg(test)]
mod tests {
use super::stub::*;
use super::*;
use alloc::vec;
#[test]
fn simple_1() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
let slice_1 = new_slice(0, 100);
let slice_2 = new_slice(1, 200);
let chunk_1 = new_chunk(0, vec![0, 1]);
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
ring.push_chunk(ChunkId { value: 0 });
let slice = ring.find_free_slice(50, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 0 });
assert_eq!(slices.get(&slice).unwrap().size, 50);
assert_eq!(slices.len(), 3);
assert_eq!(chunks.values().last().unwrap().slices.len(), 3);
}
#[test]
fn simple_2() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
let slice_1 = new_slice(0, 100);
let slice_2 = new_slice(1, 200);
let chunk_1 = new_chunk(0, vec![0, 1]);
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
ring.push_chunk(ChunkId { value: 0 });
let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 0 });
assert_eq!(slices.get(&slice).unwrap().size, 150);
assert_eq!(slices.len(), 2);
assert_eq!(chunks.values().last().unwrap().slices.len(), 2);
}
#[test]
fn multiple_chunks() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
let slice_1 = new_slice(0, 100);
let slice_2 = new_slice(1, 200);
let slice_3 = new_slice(2, 200);
let slice_4 = new_slice(3, 200);
let chunk_1 = new_chunk(0, vec![0, 1]);
let chunk_2 = new_chunk(1, vec![2, 3]);
let mut slices = HashMap::from([
(slice_1.id, slice_1),
(slice_2.id, slice_2),
(slice_3.id, slice_3),
(slice_4.id, slice_4),
]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]);
ring.push_chunk(ChunkId { value: 0 });
ring.push_chunk(ChunkId { value: 1 });
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = false;
slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = false;
let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 2 });
let slice = ring.find_free_slice(100, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 0 });
}
#[test]
fn find_free_slice_with_exact_fit() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
let slice_1 = new_slice(0, 100);
let slice_2 = new_slice(1, 200);
let chunk_1 = new_chunk(0, vec![0, 1]);
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
ring.push_chunk(ChunkId { value: 0 });
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = false;
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 1 });
assert_eq!(slices.get(&slice).unwrap().size, 200);
assert_eq!(slices.len(), 2);
assert_eq!(chunks.values().last().unwrap().slices.len(), 2);
}
#[test]
fn find_free_slice_with_merging() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
let slice_1 = new_slice(0, 100);
let slice_2 = new_slice(1, 50);
let slice_3 = new_slice(2, 100);
let chunk_1 = new_chunk(0, vec![0, 1, 2]);
let mut slices = HashMap::from([
(slice_1.id, slice_1),
(slice_2.id, slice_2),
(slice_3.id, slice_3),
]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
ring.push_chunk(ChunkId { value: 0 });
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true;
let slice = ring.find_free_slice(250, &mut chunks, &mut slices).unwrap();
assert_eq!(slice, SliceId { value: 0 });
assert_eq!(slices.get(&slice).unwrap().size, 250);
assert_eq!(slices.len(), 1);
assert_eq!(chunks.values().last().unwrap().slices.len(), 1);
}
#[test]
fn find_free_slice_with_multiple_chunks_and_merging() {
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
let slice_1 = new_slice(0, 50);
let slice_2 = new_slice(1, 50);
let chunk_1 = new_chunk(0, vec![0, 1]);
let slice_3 = new_slice(2, 100);
let slice_4 = new_slice(3, 50);
let chunk_2 = new_chunk(1, vec![2, 3]);
let mut slices = HashMap::from([
(slice_1.id, slice_1),
(slice_2.id, slice_2),
(slice_3.id, slice_3),
(slice_4.id, slice_4),
]);
let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]);
ring.push_chunk(ChunkId { value: 0 });
ring.push_chunk(ChunkId { value: 1 });
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true;
slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = true;
let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap();
assert_eq!(slices.get(&slice).unwrap().size, 150);
assert_eq!(slices.len(), 2);
assert_eq!(chunks.values().last().unwrap().slices.len(), 1);
}
fn new_slice(id: usize, size: usize) -> TestSlice {
TestSlice {
id: SliceId { value: id },
is_free: true,
size,
}
}
fn new_chunk(id: usize, slices: Vec<usize>) -> TestChunk {
TestChunk {
id: ChunkId { value: id },
slices: slices.into_iter().map(|i| SliceId { value: i }).collect(),
}
}
}
#[cfg(test)]
mod stub {
use super::*;
use burn_common::*;
#[derive(Debug)]
pub struct TestChunk {
pub id: ChunkId,
pub slices: Vec<SliceId>,
}
#[derive(Debug)]
pub struct TestSlice {
pub id: SliceId,
pub is_free: bool,
pub size: usize,
}
impl MemorySlice for TestSlice {
fn is_free(&self) -> bool {
self.is_free
}
fn size(&self) -> usize {
self.size
}
fn split(&mut self, offset: usize) -> Self {
let size_remained = self.size - offset;
self.size = offset;
Self {
id: SliceId {
value: rand::gen_random(),
},
is_free: true,
size: size_remained,
}
}
fn id(&self) -> SliceId {
self.id
}
}
impl MemoryChunk<TestSlice> for TestChunk {
fn merge_next_slice(
&mut self,
from_slice_index: usize,
slices: &mut HashMap<SliceId, TestSlice>,
) -> bool {
let slice_id_current = self.slices.get(from_slice_index).unwrap();
let slice_id_next = self.slices.get(from_slice_index + 1);
let slice_id_next = match slice_id_next {
Some(val) => val,
None => return false,
};
let slice_next = slices.get(slice_id_next).unwrap();
let is_free = slice_next.is_free;
let size = slice_next.size;
let slice_current = slices.get_mut(slice_id_current).unwrap();
if is_free {
slice_current.size += size;
slices.remove(slice_id_next);
self.slices.remove(from_slice_index + 1);
return true;
}
false
}
fn slice(&self, index: usize) -> Option<SliceId> {
self.slices.get(index).copied()
}
fn insert_slice(&mut self, position: usize, slice_id: SliceId) {
self.slices.insert(position, slice_id);
}
}
}

View File

@ -1,4 +1,7 @@
pub(crate) mod memory_pool;
mod base;
pub use base::*;
/// Dynamic memory management strategy.

View File

@ -200,7 +200,7 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManageme
/// a handle to the reserved memory.
///
/// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy.
fn reserve(&mut self, size: usize) -> Self::Handle {
fn reserve<Sync: FnOnce()>(&mut self, size: usize, _sync: Sync) -> Self::Handle {
self.cleanup_slices();
let handle = self.reserve_algorithm(size);
@ -212,7 +212,7 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManageme
handle
}
fn alloc(&mut self, size: usize) -> Self::Handle {
fn alloc<Sync: FnOnce()>(&mut self, size: usize, _sync: Sync) -> Self::Handle {
self.create_chunk(size)
}
@ -387,6 +387,12 @@ mod tests {
storage::BytesStorage,
};
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
fn reserve_no_sync(&mut self, size: usize) -> SimpleHandle {
self.reserve(size, || {})
}
}
#[test]
fn can_mut_with_single_tensor_reference() {
let mut memory_management = SimpleMemoryManagement::new(
@ -429,8 +435,8 @@ mod tests {
SliceStrategy::Never,
);
let chunk_size = 4;
let _chunk_handle = memory_management.reserve(chunk_size);
let _new_handle = memory_management.reserve(chunk_size);
let _chunk_handle = memory_management.reserve_no_sync(chunk_size);
let _new_handle = memory_management.reserve_no_sync(chunk_size);
assert_eq!(memory_management.chunks.len(), 2);
}
@ -443,7 +449,7 @@ mod tests {
SliceStrategy::Never,
);
let chunk_size = 4;
let chunk_handle = memory_management.reserve(chunk_size);
let chunk_handle = memory_management.reserve_no_sync(chunk_size);
drop(chunk_handle);
memory_management.cleanup_chunks();
@ -502,7 +508,7 @@ mod tests {
DeallocStrategy::Never,
SliceStrategy::Ratio(0.5),
);
let handle = memory_management.reserve(10);
let handle = memory_management.reserve_no_sync(10);
let other_ref = handle.clone();
@ -518,7 +524,7 @@ mod tests {
DeallocStrategy::Never,
SliceStrategy::Ratio(0.5),
);
let chunk = memory_management.reserve(10);
let chunk = memory_management.reserve_no_sync(10);
if let super::SimpleHandle::Slice(_) = chunk {
panic!("Should be a chunk.")
@ -526,7 +532,7 @@ mod tests {
drop(chunk);
let slice = memory_management.reserve(8);
let slice = memory_management.reserve_no_sync(8);
if let super::SimpleHandle::Chunk(_) = &slice {
panic!("Should be a slice.")

View File

@ -18,7 +18,7 @@ pub enum StorageUtilization {
}
/// Contains the [storage id](StorageId) of a resource and the way it is used.
#[derive(new, Debug)]
#[derive(new, Clone, Debug)]
pub struct StorageHandle {
/// Storage id.
pub id: StorageId,
@ -58,4 +58,7 @@ pub trait ComputeStorage: Send {
/// Deallocates the memory pointed by the given storage id.
fn dealloc(&mut self, id: StorageId);
/// Copy
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle);
}

View File

@ -90,6 +90,23 @@ impl ComputeStorage for BytesStorage {
}
}
}
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) {
assert_eq!(from.size(), to.size());
let input = self.get(from);
let output = self.get(to);
for i in 0..from.size() {
let offset = i + from.offset();
let ptr_out = output.ptr.wrapping_add(offset);
let offset = i + to.offset();
let ptr_in = input.ptr.wrapping_add(offset);
unsafe { *ptr_in = *ptr_out }
}
}
}
#[cfg(test)]

View File

@ -37,7 +37,7 @@ where
}
fn create(&mut self, data: &[u8]) -> Handle<Self> {
let handle = self.memory_management.reserve(data.len());
let handle = self.memory_management.reserve(data.len(), || {});
let resource = self.memory_management.get(handle.clone().binding());
let bytes = resource.write();
@ -50,7 +50,7 @@ where
}
fn empty(&mut self, size: usize) -> Handle<Self> {
Handle::new(self.memory_management.reserve(size))
Handle::new(self.memory_management.reserve(size, || {}))
}
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>) {

View File

@ -74,7 +74,9 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
let ctx = self.get_context();
let handle = ctx.memory_management.reserve(data.len());
let handle = ctx.memory_management.reserve(data.len(), || unsafe {
cudarc::driver::result::stream::synchronize(ctx.stream).unwrap();
});
let handle = server::Handle::new(handle);
let binding = handle.clone().binding().memory;
let resource = ctx.memory_management.get(binding);
@ -88,7 +90,9 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
fn empty(&mut self, size: usize) -> server::Handle<Self> {
let ctx = self.get_context();
let handle = ctx.memory_management.reserve(size);
let handle = ctx.memory_management.reserve(size, || unsafe {
cudarc::driver::result::stream::synchronize(ctx.stream).unwrap();
});
server::Handle::new(handle)
}

View File

@ -115,4 +115,18 @@ impl ComputeStorage for CudaStorage {
fn dealloc(&mut self, id: StorageId) {
self.deallocations.push(id);
}
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) {
let num_bytes = from.size();
unsafe {
cudarc::driver::result::memcpy_dtod_async(
self.get(to).ptr,
self.get(from).ptr,
num_bytes,
self.stream,
)
.unwrap();
}
}
}

View File

@ -214,7 +214,17 @@ where
/// This is important, otherwise the compute passes are going to be too small and we won't be able to
/// fully utilize the GPU.
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
let handle = server::Handle::new(self.memory_management.reserve(data.len()));
let handle = server::Handle::new(self.memory_management.reserve(data.len(), || {
flush_tasks(
&mut self.encoder,
&self.queue,
&self.device,
&mut self.tasks_count,
&mut self.staging_belt,
);
self.device.poll(wgpu::Maintain::Wait);
}));
let non_zero_len = NonZeroU64::new(data.len() as u64);
// If there's nothing to copy, don't need to do any work here.
@ -230,7 +240,7 @@ where
let mut write_buf = self.staging_belt.write_buffer(
&mut self.encoder,
&resource.buffer,
0,
resource.offset(),
len,
&self.device,
);
@ -256,7 +266,16 @@ where
}
fn empty(&mut self, size: usize) -> server::Handle<Self> {
server::Handle::new(self.memory_management.reserve(size))
server::Handle::new(self.memory_management.reserve(size, || {
flush_tasks(
&mut self.encoder,
&self.queue,
&self.device,
&mut self.tasks_count,
&mut self.staging_belt,
);
self.device.poll(wgpu::Maintain::Wait);
}))
}
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) {
@ -292,24 +311,41 @@ where
}
fn sync(&mut self, sync_type: SyncType) {
// Flush commands to the queue.
self.staging_belt.finish();
let mut new_encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
core::mem::swap(&mut new_encoder, &mut self.encoder);
self.queue.submit(Some(new_encoder.finish()));
self.tasks_count = 0;
flush_tasks(
&mut self.encoder,
&self.queue,
&self.device,
&mut self.tasks_count,
&mut self.staging_belt,
);
// Cleanup allocations and deallocations.
self.memory_management.storage().perform_deallocations();
self.staging_belt.recall();
if sync_type == SyncType::Wait {
self.device.poll(wgpu::Maintain::Wait);
}
}
}
/// Flush tasks using the [command encoder](CommandEncoder).
///
/// This implementation is decoupled from both the [server](WgpuServer) and [memory management](MemoryManagement).
/// This decoupling allows for safe usage within sync callbacks when allocating memory buffers.
fn flush_tasks(
encoder: &mut CommandEncoder,
queue: &wgpu::Queue,
device: &wgpu::Device,
tasks_count: &mut usize,
staging_belt: &mut StagingBelt,
) {
staging_belt.finish();
let mut new_encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
core::mem::swap(&mut new_encoder, encoder);
queue.submit(Some(new_encoder.finish()));
*tasks_count = 0;
staging_belt.recall();
}

View File

@ -7,6 +7,7 @@ pub struct WgpuStorage {
memory: HashMap<StorageId, Arc<wgpu::Buffer>>,
deallocations: Vec<StorageId>,
device: Arc<wgpu::Device>,
queue: Arc<wgpu::Queue>,
}
impl core::fmt::Debug for WgpuStorage {
@ -67,11 +68,12 @@ pub enum WgpuResourceKind {
/// Keeps actual wgpu buffer references in a hashmap with ids as key.
impl WgpuStorage {
/// Create a new storage on the given [device](wgpu::Device).
pub fn new(device: Arc<wgpu::Device>) -> Self {
pub fn new(device: Arc<wgpu::Device>, queue: Arc<wgpu::Queue>) -> Self {
Self {
memory: HashMap::new(),
deallocations: Vec::new(),
device,
queue,
}
}
@ -121,4 +123,23 @@ impl ComputeStorage for WgpuStorage {
fn dealloc(&mut self, id: StorageId) {
self.deallocations.push(id);
}
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) {
let mut encoder = self
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
let from = self.get(from);
let to = self.get(to);
encoder.copy_buffer_to_buffer(
&from.buffer,
from.offset(),
&to.buffer,
to.offset(),
to.size(),
);
self.queue.submit(Some(encoder.finish()));
}
}

View File

@ -156,7 +156,7 @@ fn create_client(
WgpuServer<SimpleMemoryManagement<WgpuStorage>>,
MutexComputeChannel<WgpuServer<SimpleMemoryManagement<WgpuStorage>>>,
> {
let storage = WgpuStorage::new(device_wgpu.clone());
let storage = WgpuStorage::new(device_wgpu.clone(), queue.clone());
let memory_management =
SimpleMemoryManagement::new(storage, options.dealloc_strategy, options.slice_strategy);
let server = WgpuServer::new(memory_management, device_wgpu, queue, options.tasks_max);