mirror of https://github.com/tracel-ai/burn.git
Perf/dynamic mm (#1906)
This commit is contained in:
parent
8071b637b8
commit
4f6db974a1
|
@ -97,23 +97,17 @@ where
|
|||
#[macro_export(local_inner_macros)]
|
||||
/// Create new memory ID types.
|
||||
macro_rules! memory_id_type {
|
||||
($id:ident, $handle:ident, $binding:ident) => {
|
||||
($id:ident, $handle:ident) => {
|
||||
/// Memory Handle.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct $handle {
|
||||
value: $crate::id::HandleRef<$id>,
|
||||
}
|
||||
|
||||
/// Binding of a memory handle.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct $binding {
|
||||
value: $crate::id::BindingRef<$id>,
|
||||
}
|
||||
|
||||
/// Memory ID.
|
||||
#[derive(Clone, Copy, Hash, PartialEq, Eq, Debug)]
|
||||
pub struct $id {
|
||||
value: usize,
|
||||
pub(crate) value: usize,
|
||||
}
|
||||
|
||||
impl $handle {
|
||||
|
@ -125,12 +119,6 @@ macro_rules! memory_id_type {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) fn binding(self) -> $binding {
|
||||
$binding {
|
||||
value: self.value.binding(),
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_id() -> usize {
|
||||
static COUNTER: core::sync::atomic::AtomicUsize =
|
||||
core::sync::atomic::AtomicUsize::new(0);
|
||||
|
@ -152,6 +140,30 @@ macro_rules! memory_id_type {
|
|||
}
|
||||
}
|
||||
|
||||
impl Default for $handle {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
($id:ident, $handle:ident, $binding:ident) => {
|
||||
memory_id_type!($id, $handle);
|
||||
|
||||
/// Binding of a memory handle.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct $binding {
|
||||
value: $crate::id::BindingRef<$id>,
|
||||
}
|
||||
|
||||
impl $handle {
|
||||
pub(crate) fn binding(self) -> $binding {
|
||||
$binding {
|
||||
value: self.value.binding(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl core::ops::Deref for $binding {
|
||||
type Target = $crate::id::BindingRef<$id>;
|
||||
|
||||
|
@ -159,11 +171,5 @@ macro_rules! memory_id_type {
|
|||
&self.value
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for $handle {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
@ -27,14 +27,14 @@ pub trait MemoryManagement<Storage: ComputeStorage>: Send + core::fmt::Debug {
|
|||
fn get(&mut self, binding: Self::Binding) -> Storage::Resource;
|
||||
|
||||
/// Finds a spot in memory for a resource with the given size in bytes, and returns a handle to it
|
||||
fn reserve(&mut self, size: usize) -> Self::Handle;
|
||||
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
|
||||
|
||||
/// Bypass the memory allocation algorithm to allocate data directly.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// Can be useful for servers that want specific control over memory.
|
||||
fn alloc(&mut self, size: usize) -> Self::Handle;
|
||||
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle;
|
||||
|
||||
/// Bypass the memory allocation algorithm to deallocate data directly.
|
||||
///
|
||||
|
|
|
@ -1,165 +1,46 @@
|
|||
use crate::{
|
||||
memory_id_type,
|
||||
storage::{ComputeStorage, StorageHandle, StorageUtilization},
|
||||
use super::memory_pool::{
|
||||
MemoryExtensionStrategy, MemoryPool, MemoryPoolBinding, MemoryPoolHandle, RoundingStrategy,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use hashbrown::HashMap;
|
||||
use crate::storage::ComputeStorage;
|
||||
|
||||
#[cfg(all(not(target_family = "wasm"), feature = "std"))]
|
||||
use std::time;
|
||||
#[cfg(all(target_family = "wasm", feature = "std"))]
|
||||
use web_time as time;
|
||||
|
||||
use super::{MemoryBinding, MemoryHandle, MemoryManagement};
|
||||
|
||||
// The ChunkId allows to keep track of how many references there are to a specific chunk.
|
||||
memory_id_type!(ChunkId, ChunkHandle, ChunkBinding);
|
||||
// The SliceId allows to keep track of how many references there are to a specific slice.
|
||||
memory_id_type!(SliceId, SliceHandle, SliceBinding);
|
||||
|
||||
/// A tensor memory handle, referring to either a chunk or a slice.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DynamicHandle {
|
||||
/// A whole chunk of memory.
|
||||
Chunk(ChunkHandle),
|
||||
/// A slice of a chunk of memory.
|
||||
Slice(SliceHandle),
|
||||
}
|
||||
|
||||
/// Binding of the [dynamic handle](DynamicHandle).
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum DynamicBinding {
|
||||
/// Binding of the [chunk handle](ChunkHandle).
|
||||
Chunk(ChunkBinding),
|
||||
/// Binding of the [slice handle](SliceHandle)
|
||||
Slice(SliceBinding),
|
||||
}
|
||||
|
||||
/// The strategy defines the frequency at which merging of free slices (defragmentation) occurs
|
||||
#[derive(Debug)]
|
||||
pub enum MergingStrategy {
|
||||
/// Once every n calls to reserve.
|
||||
PeriodTick {
|
||||
/// Number of calls to be executed before triggering the defragmentation.
|
||||
period: usize,
|
||||
/// Current state. Should start at zero.
|
||||
state: usize,
|
||||
},
|
||||
#[cfg(feature = "std")]
|
||||
/// Once every period of time
|
||||
PeriodTime {
|
||||
/// Number of time before triggering the defragmentation.
|
||||
period: time::Duration,
|
||||
/// Current state. Should start at now.
|
||||
state: time::Instant,
|
||||
},
|
||||
/// Never defragment.
|
||||
Never,
|
||||
}
|
||||
|
||||
/// The strategy defines when to reuse chunk with slices.
|
||||
#[derive(Debug)]
|
||||
pub enum SliceStrategy {
|
||||
/// Never use slices.
|
||||
Never,
|
||||
/// Ratio needed before the chunk can be used as a slice. Between 0 and 1.
|
||||
Ratio(f32),
|
||||
/// When the reserved memory is at least {} bytes.
|
||||
MinimumSize(usize),
|
||||
/// When the reserved memory less than {} bytes.
|
||||
MaximumSize(usize),
|
||||
}
|
||||
|
||||
impl SliceStrategy {
|
||||
/// If the chunk can be used with a slice.
|
||||
pub fn can_use_chunk(&self, chunk_size: usize, reserved_size: usize) -> bool {
|
||||
if chunk_size < reserved_size {
|
||||
return false;
|
||||
}
|
||||
|
||||
match self {
|
||||
SliceStrategy::Never => false,
|
||||
SliceStrategy::Ratio(ratio) => (reserved_size as f32 / chunk_size as f32) >= *ratio,
|
||||
SliceStrategy::MinimumSize(bytes) => reserved_size >= *bytes,
|
||||
SliceStrategy::MaximumSize(bytes) => reserved_size <= *bytes,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MergingStrategy {
|
||||
/// Create a new strategy with the given period.
|
||||
pub fn new_period_tick(period: usize) -> Self {
|
||||
MergingStrategy::PeriodTick { period, state: 0 }
|
||||
}
|
||||
|
||||
fn should_perform_defragmentation(&mut self) -> bool {
|
||||
match self {
|
||||
MergingStrategy::PeriodTick { period, state } => {
|
||||
*state = (*state + 1) % *period;
|
||||
*state == 0
|
||||
}
|
||||
#[cfg(feature = "std")]
|
||||
MergingStrategy::PeriodTime { period, state } => {
|
||||
if &state.elapsed() > period {
|
||||
*state = time::Instant::now();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
MergingStrategy::Never => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct Chunk {
|
||||
storage: StorageHandle,
|
||||
handle: ChunkHandle,
|
||||
slices: Vec<SliceId>,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct Slice {
|
||||
storage: StorageHandle,
|
||||
handle: SliceHandle,
|
||||
chunk: ChunkHandle,
|
||||
padding: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct Merging {
|
||||
start: usize,
|
||||
end: usize,
|
||||
offset: usize,
|
||||
size: usize,
|
||||
slice_ids: Vec<SliceId>,
|
||||
}
|
||||
|
||||
impl Slice {
|
||||
pub fn effective_size(&self) -> usize {
|
||||
self.storage.size() + self.padding
|
||||
}
|
||||
}
|
||||
|
||||
const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16;
|
||||
const BUFFER_ALIGNMENT: usize = 32;
|
||||
use super::MemoryManagement;
|
||||
|
||||
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
|
||||
pub struct DynamicMemoryManagement<Storage> {
|
||||
chunks: HashMap<ChunkId, Chunk>,
|
||||
slices: HashMap<SliceId, Slice>,
|
||||
merging_strategy: MergingStrategy,
|
||||
slice_strategy: SliceStrategy,
|
||||
small_memory_pool: MemoryPool,
|
||||
main_memory_pool: MemoryPool,
|
||||
storage: Storage,
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
|
||||
/// Creates a new instance using the given storage, merging_strategy strategy and slice strategy.
|
||||
pub fn new(storage: Storage) -> Self {
|
||||
let main_memory_pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::new_period_tick(10),
|
||||
RoundingStrategy::RoundUp,
|
||||
1024 * 1024 * 1024 * 2,
|
||||
true,
|
||||
);
|
||||
let small_memory_pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::Never,
|
||||
RoundingStrategy::None,
|
||||
1024 * 1024 * 512,
|
||||
false,
|
||||
);
|
||||
|
||||
Self {
|
||||
main_memory_pool,
|
||||
small_memory_pool,
|
||||
storage,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage> core::fmt::Debug for DynamicMemoryManagement<Storage> {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
alloc::format!(
|
||||
"DynamicMemoryManagement {:?} - {:?}",
|
||||
self.merging_strategy,
|
||||
"DynamicMemoryManagement {:?}",
|
||||
core::any::type_name::<Storage>(),
|
||||
)
|
||||
.as_str(),
|
||||
|
@ -167,677 +48,44 @@ impl<Storage> core::fmt::Debug for DynamicMemoryManagement<Storage> {
|
|||
}
|
||||
}
|
||||
|
||||
impl MemoryBinding for DynamicBinding {}
|
||||
|
||||
impl MemoryHandle<DynamicBinding> for DynamicHandle {
|
||||
fn can_mut(&self) -> bool {
|
||||
match &self {
|
||||
DynamicHandle::Chunk(id) => id.can_mut(),
|
||||
DynamicHandle::Slice(id) => id.can_mut(),
|
||||
}
|
||||
}
|
||||
|
||||
fn binding(self) -> DynamicBinding {
|
||||
match self {
|
||||
Self::Chunk(handle) => DynamicBinding::Chunk(handle.binding()),
|
||||
Self::Slice(handle) => DynamicBinding::Slice(handle.binding()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagement<Storage> {
|
||||
type Handle = DynamicHandle;
|
||||
type Binding = DynamicBinding;
|
||||
type Handle = MemoryPoolHandle;
|
||||
type Binding = MemoryPoolBinding;
|
||||
|
||||
/// Returns the resource from the storage, for the specified handle.
|
||||
fn get(&mut self, binding: Self::Binding) -> Storage::Resource {
|
||||
let storage = match binding {
|
||||
DynamicBinding::Chunk(chunk) => {
|
||||
&self
|
||||
.chunks
|
||||
.get(chunk.id())
|
||||
.expect("Storage found for the given execution buffer handle")
|
||||
.storage
|
||||
}
|
||||
DynamicBinding::Slice(slice) => {
|
||||
&self
|
||||
.slices
|
||||
.get(slice.id())
|
||||
.expect("Storage found for the given execution buffer handle")
|
||||
.storage
|
||||
}
|
||||
};
|
||||
|
||||
self.storage.get(storage)
|
||||
if let Some(handle) = self.small_memory_pool.get(&mut self.storage, &binding) {
|
||||
return handle;
|
||||
}
|
||||
|
||||
/// Reserves memory of specified size using the reserve algorithm, and return
|
||||
/// a handle to the reserved memory.
|
||||
///
|
||||
/// Also clean ups, merging free slices together if permitted by the merging strategy
|
||||
fn reserve(&mut self, size: usize) -> Self::Handle {
|
||||
let handle = self.reserve_algorithm(size);
|
||||
|
||||
if self.merging_strategy.should_perform_defragmentation() {
|
||||
self.defragmentation();
|
||||
if let Some(handle) = self.main_memory_pool.get(&mut self.storage, &binding) {
|
||||
return handle;
|
||||
}
|
||||
|
||||
handle
|
||||
panic!("No handle found in the small and main memory pool");
|
||||
}
|
||||
|
||||
fn alloc(&mut self, size: usize) -> Self::Handle {
|
||||
let handle_chunk = self.create_chunk(size);
|
||||
let chunk_id = *handle_chunk.id();
|
||||
|
||||
let slice = self.create_slice(0, size, handle_chunk);
|
||||
let handle_slice = slice.handle.clone();
|
||||
|
||||
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
|
||||
chunk.slices.push(*handle_slice.id());
|
||||
|
||||
self.slices.insert(*handle_slice.id(), slice);
|
||||
|
||||
DynamicHandle::Slice(handle_slice)
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, binding: Self::Binding) {
|
||||
match binding {
|
||||
DynamicBinding::Chunk(chunk) => {
|
||||
if let Some(chunk) = self.chunks.remove(chunk.id()) {
|
||||
self.storage.dealloc(chunk.storage.id);
|
||||
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
|
||||
if size < 512 {
|
||||
self.small_memory_pool
|
||||
.reserve(&mut self.storage, size, sync)
|
||||
} else {
|
||||
self.main_memory_pool.reserve(&mut self.storage, size, sync)
|
||||
}
|
||||
}
|
||||
DynamicBinding::Slice(_) => panic!("Can't dealloc slice manually"),
|
||||
|
||||
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
|
||||
if size < 512 {
|
||||
self.small_memory_pool.alloc(&mut self.storage, size, sync)
|
||||
} else {
|
||||
self.main_memory_pool.alloc(&mut self.storage, size, sync)
|
||||
}
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, _binding: Self::Binding) {
|
||||
// Can't dealloc slices.
|
||||
}
|
||||
|
||||
fn storage(&mut self) -> &mut Storage {
|
||||
&mut self.storage
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
|
||||
/// Creates a new instance using the given storage, merging_strategy strategy and slice strategy.
|
||||
pub fn new(
|
||||
storage: Storage,
|
||||
merging_strategy: MergingStrategy,
|
||||
slice_strategy: SliceStrategy,
|
||||
) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
slices: HashMap::new(),
|
||||
merging_strategy,
|
||||
slice_strategy,
|
||||
storage,
|
||||
}
|
||||
}
|
||||
|
||||
fn reserve_algorithm(&mut self, size: usize) -> DynamicHandle {
|
||||
// Looks for a large enough, existing but unused chunk of memory.
|
||||
let slice = self.get_free_slice(size);
|
||||
|
||||
match slice {
|
||||
Some(slice) => DynamicHandle::Slice(slice.clone()),
|
||||
None => self.alloc(size),
|
||||
}
|
||||
}
|
||||
|
||||
fn find_free_slice_best_fit(
|
||||
&self,
|
||||
size: usize,
|
||||
effective_size: usize,
|
||||
) -> Option<(SliceId, usize)> {
|
||||
let mut size_diff_current = usize::MAX;
|
||||
let mut found = None;
|
||||
for (__, chunk) in self.chunks.iter() {
|
||||
if size < MIN_SIZE_NEEDED_TO_OFFSET && chunk.slices.len() > 1 {
|
||||
continue;
|
||||
}
|
||||
if !self
|
||||
.slice_strategy
|
||||
.can_use_chunk(chunk.storage.size(), effective_size)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
for slice_id in chunk.slices.iter() {
|
||||
let slice = self.slices.get(slice_id).unwrap();
|
||||
let slice_can_be_reused =
|
||||
slice.handle.is_free() && slice.effective_size() >= effective_size;
|
||||
|
||||
if slice_can_be_reused {
|
||||
let size_diff = slice.effective_size() - effective_size;
|
||||
if size_diff < size_diff_current {
|
||||
size_diff_current = size_diff;
|
||||
found = Some((*slice_id, size_diff));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
found
|
||||
}
|
||||
|
||||
/// Tries to split a slice in two with the first slice being of the specified size
|
||||
/// If there is not enough space for 2 slice, only uses one slice
|
||||
/// returns the handle of the first slice
|
||||
fn split_slice_in_two(
|
||||
&mut self,
|
||||
slice_to_split_id: &SliceId,
|
||||
first_slice_size: usize,
|
||||
) -> Option<SliceHandle> {
|
||||
let slice_to_split = self.slices.get(slice_to_split_id).unwrap();
|
||||
let slice_to_split_effective_size = slice_to_split.effective_size();
|
||||
let chunk = self.chunks.get_mut(slice_to_split.chunk.id()).unwrap();
|
||||
let current_slice_chunk_handle = chunk.handle.clone();
|
||||
|
||||
let mut slices = Vec::with_capacity(chunk.slices.len() + 1);
|
||||
let mut offset = 0;
|
||||
|
||||
let mut slices_old = Vec::new();
|
||||
core::mem::swap(&mut slices_old, &mut chunk.slices);
|
||||
|
||||
let mut handle = None;
|
||||
for slice_id in slices_old.into_iter() {
|
||||
// Assumes that all slices are contiguous in a chunk.
|
||||
let slice = self.slices.get(&slice_id).unwrap();
|
||||
|
||||
if slice_id != *slice_to_split_id {
|
||||
slices.push(slice_id);
|
||||
offset += slice.effective_size();
|
||||
} else {
|
||||
let first_slice =
|
||||
self.create_slice(offset, first_slice_size, current_slice_chunk_handle.clone());
|
||||
let first_slice_id = *first_slice.handle.id();
|
||||
offset += first_slice.effective_size();
|
||||
|
||||
let second_slice_size =
|
||||
slice_to_split_effective_size - first_slice.effective_size();
|
||||
let slice_end = self.create_slice(
|
||||
offset,
|
||||
second_slice_size,
|
||||
current_slice_chunk_handle.clone(),
|
||||
);
|
||||
let slice_end_id = *slice_end.handle.id();
|
||||
offset += slice_end.effective_size();
|
||||
|
||||
let created_offset = first_slice.effective_size() + slice_end.effective_size();
|
||||
assert_eq!(created_offset, slice.effective_size());
|
||||
|
||||
handle = Some(first_slice.handle.clone());
|
||||
self.slices.insert(first_slice_id, first_slice);
|
||||
self.slices.insert(slice_end_id, slice_end);
|
||||
|
||||
slices.push(first_slice_id);
|
||||
slices.push(slice_end_id);
|
||||
}
|
||||
}
|
||||
|
||||
self.slices.remove(slice_to_split_id);
|
||||
let chunk = self
|
||||
.chunks
|
||||
.get_mut(current_slice_chunk_handle.id())
|
||||
.unwrap();
|
||||
chunk.slices = slices;
|
||||
handle
|
||||
}
|
||||
|
||||
/// Finds the smallest of the free and large enough chunks to fit `size`
|
||||
/// Returns the chunk's id and size.
|
||||
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
|
||||
let padding = Self::calculate_padding(size);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let found = self.find_free_slice_best_fit(size, effective_size);
|
||||
let (slice_id, size_diff_current) = match found {
|
||||
Some(val) => val,
|
||||
None => {
|
||||
return None;
|
||||
}
|
||||
};
|
||||
|
||||
// if same size reuse the slice
|
||||
if size_diff_current == 0 {
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
let offset = match slice.storage.utilization {
|
||||
StorageUtilization::Full(_) => 0,
|
||||
StorageUtilization::Slice { offset, size: _ } => offset,
|
||||
};
|
||||
slice.storage.utilization = StorageUtilization::Slice { offset, size };
|
||||
slice.padding = padding;
|
||||
|
||||
return Some(self.slices.get(&slice_id).unwrap().handle.clone());
|
||||
}
|
||||
|
||||
assert_eq!(size_diff_current % BUFFER_ALIGNMENT, 0);
|
||||
|
||||
// split into 2 if needed
|
||||
let handle = self.split_slice_in_two(&slice_id, size);
|
||||
if handle.is_none() {
|
||||
panic!("split should have returned a handle");
|
||||
}
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
/// Creates a slice of size `size` upon the given chunk with the given offset.
|
||||
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> Slice {
|
||||
if offset > 0 && size < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
panic!("tried to create slice of size {size} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
|
||||
}
|
||||
if offset % BUFFER_ALIGNMENT != 0 {
|
||||
panic!("slice with offset {offset} needs to be a multiple of {BUFFER_ALIGNMENT}");
|
||||
}
|
||||
let chunk = self.chunks.get(handle_chunk.id()).unwrap();
|
||||
let handle = SliceHandle::new();
|
||||
|
||||
let storage = StorageHandle {
|
||||
id: chunk.storage.id.clone(),
|
||||
utilization: StorageUtilization::Slice { offset, size },
|
||||
};
|
||||
|
||||
let padding = Self::calculate_padding(size);
|
||||
|
||||
Slice::new(storage, handle, chunk.handle.clone(), padding)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn insert_slice(&mut self, slice: Slice, chunk_id: ChunkId) {
|
||||
let slice_id = *slice.handle.id();
|
||||
self.slices.insert(*slice.handle.id(), slice);
|
||||
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
|
||||
chunk.slices.push(slice_id);
|
||||
}
|
||||
|
||||
/// Creates a chunk of given size by allocating on the storage.
|
||||
fn create_chunk(&mut self, size: usize) -> ChunkHandle {
|
||||
let padding = Self::calculate_padding(size);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let storage = self.storage.alloc(effective_size);
|
||||
let handle = ChunkHandle::new();
|
||||
|
||||
self.chunks.insert(
|
||||
*handle.id(),
|
||||
Chunk::new(storage, handle.clone(), Vec::new()),
|
||||
);
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
// generates and adds all the merging information of a given chunk to an hash_map
|
||||
fn generate_mergings(&self, chunk: &Chunk, merging_map: &mut HashMap<ChunkId, Vec<Merging>>) {
|
||||
let mut to_merge: Vec<Merging> = Vec::new();
|
||||
|
||||
let mut start_index: usize = 0;
|
||||
let mut num_merge = 0;
|
||||
let mut offset_current = 0;
|
||||
let mut offset = 0;
|
||||
let mut slices_ids = Vec::new();
|
||||
|
||||
for (i, slice_id) in chunk.slices.iter().enumerate() {
|
||||
let slice = self.slices.get(slice_id).unwrap();
|
||||
|
||||
if slice.handle.is_free() {
|
||||
slices_ids.push(*slice_id);
|
||||
num_merge += 1;
|
||||
offset += slice.effective_size();
|
||||
continue;
|
||||
} else if num_merge > 1 {
|
||||
let mut empty = Vec::new();
|
||||
core::mem::swap(&mut slices_ids, &mut empty);
|
||||
let merging = Merging {
|
||||
start: start_index,
|
||||
end: start_index + num_merge - 1,
|
||||
offset: offset_current,
|
||||
size: offset - offset_current,
|
||||
slice_ids: empty,
|
||||
};
|
||||
to_merge.push(merging);
|
||||
}
|
||||
offset += slice.effective_size();
|
||||
start_index = i + 1;
|
||||
num_merge = 0;
|
||||
offset_current = offset;
|
||||
slices_ids.clear();
|
||||
}
|
||||
if !to_merge.is_empty() {
|
||||
merging_map.insert(*chunk.handle.id(), to_merge);
|
||||
}
|
||||
}
|
||||
|
||||
// merges all free slices together use the mergings metadata
|
||||
fn merge_contiguous_slices(&mut self, chunk_id: ChunkId, mergings: &Vec<Merging>) {
|
||||
let chunk = self.chunks.get(&chunk_id).unwrap();
|
||||
let chunk_handle = chunk.handle.clone();
|
||||
let slices = chunk.slices.clone();
|
||||
let mut slices_updated = Vec::new();
|
||||
|
||||
let mut index = 0;
|
||||
|
||||
for merging in mergings {
|
||||
let slice = self.create_slice(merging.offset, merging.size, chunk_handle.clone());
|
||||
let slice_id = *slice.handle.id();
|
||||
self.slices.insert(slice_id, slice);
|
||||
for i in index..merging.start {
|
||||
slices_updated.push(*slices.get(i).unwrap());
|
||||
}
|
||||
index = merging.end + 1;
|
||||
slices_updated.push(slice_id);
|
||||
|
||||
for slice_id_to_remove in merging.slice_ids.iter() {
|
||||
self.slices.remove(slice_id_to_remove);
|
||||
}
|
||||
}
|
||||
|
||||
for i in index..slices.len() {
|
||||
slices_updated.push(*slices.get(i).unwrap());
|
||||
}
|
||||
let chunk = self.chunks.get_mut(&chunk_id).unwrap();
|
||||
core::mem::swap(&mut chunk.slices, &mut slices_updated);
|
||||
}
|
||||
|
||||
// Merge all contiguous free_slices together, assumes that slices are in sorted order.
|
||||
fn defragmentation(&mut self) {
|
||||
let mut chunk_to_merged_slice: HashMap<ChunkId, Vec<Merging>> = HashMap::new();
|
||||
for (.., chunk) in self.chunks.iter() {
|
||||
self.generate_mergings(chunk, &mut chunk_to_merged_slice);
|
||||
}
|
||||
|
||||
for (chunk_id, mergings) in chunk_to_merged_slice.into_iter() {
|
||||
self.merge_contiguous_slices(chunk_id, &mergings);
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_padding(size: usize) -> usize {
|
||||
let rem = size % BUFFER_ALIGNMENT;
|
||||
if rem != 0 {
|
||||
BUFFER_ALIGNMENT - rem
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::{
|
||||
memory_management::{MemoryHandle, MemoryManagement},
|
||||
storage::BytesStorage,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn can_mut_with_single_tensor_reference() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
|
||||
let chunk_size = 4;
|
||||
let simple_handle = memory_management.create_chunk(chunk_size);
|
||||
|
||||
let x = simple_handle.clone();
|
||||
core::mem::drop(simple_handle);
|
||||
|
||||
assert!(x.can_mut());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn two_tensor_references_remove_mutability() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
|
||||
let chunk_size = 4;
|
||||
let simple_handle = memory_management.create_chunk(chunk_size);
|
||||
|
||||
let x = simple_handle.clone();
|
||||
|
||||
assert!(!simple_handle.can_mut());
|
||||
assert!(!x.can_mut())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_non_empty_chunk_exists_and_other_one_created_there_should_be_two() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Never,
|
||||
);
|
||||
let chunk_size = 4;
|
||||
let _chunk_handle = memory_management.reserve(chunk_size);
|
||||
let _new_handle = memory_management.reserve(chunk_size);
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_big_chunk_is_freed_should_be_filled_with_smaller_slices() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Ratio(0.2),
|
||||
);
|
||||
let big_slice_size = 32 * 3;
|
||||
let small_slice_size = 32;
|
||||
|
||||
let big_slice = memory_management.reserve(big_slice_size);
|
||||
drop(big_slice);
|
||||
let _small_slice_1 = memory_management.reserve(small_slice_size);
|
||||
let _small_slice_2 = memory_management.reserve(small_slice_size);
|
||||
let _small_slice_3 = memory_management.reserve(small_slice_size);
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 1);
|
||||
assert_eq!(memory_management.slices.len(), 3);
|
||||
for (.., slice) in memory_management.slices.iter() {
|
||||
assert_eq!(slice.storage.size(), 32);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_defragmentation_called_if_two_slices_free_should_merge_into_bigger_slice() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::new_period_tick(1),
|
||||
SliceStrategy::Ratio(0.2),
|
||||
);
|
||||
|
||||
let chunk_handle = memory_management.create_chunk(32 + 32);
|
||||
let slice = memory_management.create_slice(0, 32 + 32, chunk_handle.clone());
|
||||
memory_management.insert_slice(slice, *chunk_handle.id());
|
||||
|
||||
let _slice_1 = memory_management.reserve(32);
|
||||
let _slice_2 = memory_management.reserve(32);
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 1);
|
||||
assert_eq!(memory_management.slices.len(), 2);
|
||||
for (.., slice) in memory_management.slices.iter() {
|
||||
assert_eq!(slice.storage.size(), 32);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn when_defragmentation_called_should_merge_contiguous_free_slice_together() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::new_period_tick(1),
|
||||
SliceStrategy::Ratio(0.1),
|
||||
);
|
||||
|
||||
//The chunk will be separated in 7 slices. 1 free, 23 not free, 456 free and 7 not free
|
||||
let slice_size = 32;
|
||||
let num_of_slice = 7;
|
||||
|
||||
let chunk_handle = memory_management.create_chunk(slice_size * num_of_slice);
|
||||
let chunk_id = *chunk_handle.id();
|
||||
let slice = memory_management.create_slice(0, slice_size * num_of_slice, chunk_handle);
|
||||
memory_management.insert_slice(slice, chunk_id);
|
||||
|
||||
let _slice_1 = memory_management.reserve(slice_size);
|
||||
let _slice_2 = memory_management.reserve(slice_size);
|
||||
let _slice_3 = memory_management.reserve(slice_size);
|
||||
let _slice_4 = memory_management.reserve(slice_size);
|
||||
let _slice_5 = memory_management.reserve(slice_size);
|
||||
let _slice_6 = memory_management.reserve(slice_size);
|
||||
let _slice_7 = memory_management.reserve(slice_size);
|
||||
drop(_slice_1);
|
||||
drop(_slice_4);
|
||||
drop(_slice_5);
|
||||
drop(_slice_6);
|
||||
memory_management.defragmentation();
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 1);
|
||||
assert_eq!(memory_management.slices.len(), 5);
|
||||
|
||||
let chunk = memory_management.chunks.get(&chunk_id).unwrap();
|
||||
let slices = &chunk.slices;
|
||||
|
||||
// first slice test
|
||||
let first_slice_id = slices.first().unwrap();
|
||||
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
|
||||
assert!(first_slice.handle.is_free());
|
||||
assert_eq!(first_slice.storage.size(), slice_size);
|
||||
assert_eq!(first_slice.storage.offset(), 0);
|
||||
|
||||
// second slice test
|
||||
let first_slice_id = slices.get(1).unwrap();
|
||||
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
|
||||
assert!(!first_slice.handle.is_free());
|
||||
assert_eq!(first_slice.storage.size(), slice_size);
|
||||
assert_eq!(first_slice.storage.offset(), slice_size);
|
||||
|
||||
// third slice test
|
||||
let first_slice_id = slices.get(2).unwrap();
|
||||
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
|
||||
assert!(!first_slice.handle.is_free());
|
||||
assert_eq!(first_slice.storage.size(), slice_size);
|
||||
assert_eq!(first_slice.storage.offset(), slice_size * 2);
|
||||
|
||||
// fourth slice test (456 merged)
|
||||
let first_slice_id = slices.get(3).unwrap();
|
||||
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
|
||||
assert!(first_slice.handle.is_free());
|
||||
assert_eq!(first_slice.storage.size(), slice_size * 3);
|
||||
assert_eq!(first_slice.storage.offset(), slice_size * 3);
|
||||
|
||||
// fifth slice test
|
||||
let first_slice_id = slices.get(4).unwrap();
|
||||
let first_slice = memory_management.slices.get(first_slice_id).unwrap();
|
||||
assert!(!first_slice.handle.is_free());
|
||||
assert_eq!(first_slice.storage.size(), slice_size);
|
||||
assert_eq!(first_slice.storage.offset(), slice_size * 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn never_dealloc_strategy_never_deallocs() {
|
||||
let mut never_dealloc = MergingStrategy::Never;
|
||||
for _ in 0..20 {
|
||||
assert!(!never_dealloc.should_perform_defragmentation())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn period_tick_dealloc_strategy_should_dealloc_after_period() {
|
||||
let period = 3;
|
||||
let mut period_tick_dealloc = MergingStrategy::new_period_tick(period);
|
||||
|
||||
for _ in 0..3 {
|
||||
for _ in 0..period - 1 {
|
||||
assert!(!period_tick_dealloc.should_perform_defragmentation());
|
||||
}
|
||||
assert!(period_tick_dealloc.should_perform_defragmentation());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_strategy_minimum_bytes() {
|
||||
let strategy = SliceStrategy::MinimumSize(100);
|
||||
|
||||
assert!(strategy.can_use_chunk(200, 101));
|
||||
assert!(!strategy.can_use_chunk(200, 99));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_strategy_maximum_bytes() {
|
||||
let strategy = SliceStrategy::MaximumSize(100);
|
||||
|
||||
assert!(strategy.can_use_chunk(200, 99));
|
||||
assert!(!strategy.can_use_chunk(200, 101));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slice_strategy_ratio() {
|
||||
let strategy = SliceStrategy::Ratio(0.9);
|
||||
|
||||
assert!(strategy.can_use_chunk(200, 180));
|
||||
assert!(!strategy.can_use_chunk(200, 179));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_handle_mutability() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Ratio(0.5),
|
||||
);
|
||||
let handle = memory_management.reserve(10);
|
||||
|
||||
let other_ref = handle.clone();
|
||||
|
||||
assert!(!handle.can_mut(), "Handle can't be mut when multiple ref.");
|
||||
drop(other_ref);
|
||||
assert!(handle.can_mut(), "Handle should be mut when only one ref.");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn support_multiple_slices_for_a_chunk() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Ratio(0.2),
|
||||
);
|
||||
let handle = memory_management.reserve(100);
|
||||
core::mem::drop(handle);
|
||||
|
||||
let _slice_1 = memory_management.reserve(30);
|
||||
let _slice_2 = memory_management.reserve(30);
|
||||
let _slice_3 = memory_management.reserve(30);
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 1);
|
||||
assert_eq!(memory_management.slices.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_mutability() {
|
||||
let mut memory_management = DynamicMemoryManagement::new(
|
||||
BytesStorage::default(),
|
||||
MergingStrategy::Never,
|
||||
SliceStrategy::Ratio(0.5),
|
||||
);
|
||||
let first_slice = memory_management.reserve(10);
|
||||
|
||||
drop(first_slice);
|
||||
|
||||
let slice = memory_management.reserve(8);
|
||||
|
||||
if let super::DynamicHandle::Slice(slice) = slice {
|
||||
let other_ref = slice.clone();
|
||||
|
||||
assert!(
|
||||
!slice.can_mut(),
|
||||
"Slice can't be mut when multiple ref to the same handle."
|
||||
);
|
||||
drop(other_ref);
|
||||
assert!(
|
||||
slice.can_mut(),
|
||||
"Slice should be mut when only one ref to the same handle."
|
||||
);
|
||||
assert!(
|
||||
!slice.is_free(),
|
||||
"Slice can't be reallocated when one ref still exist."
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,541 @@
|
|||
use super::index::SearchIndex;
|
||||
use super::{
|
||||
ChunkHandle, ChunkId, MemoryChunk, MemoryPoolBinding, MemoryPoolHandle, MemorySlice,
|
||||
RingBuffer, SliceHandle, SliceId,
|
||||
};
|
||||
use crate::storage::{ComputeStorage, StorageHandle, StorageUtilization};
|
||||
use alloc::vec::Vec;
|
||||
use hashbrown::{HashMap, HashSet};
|
||||
|
||||
pub struct MemoryPool {
|
||||
chunks: HashMap<ChunkId, Chunk>,
|
||||
slices: HashMap<SliceId, Slice>,
|
||||
memory_extension_strategy: MemoryExtensionStrategy,
|
||||
rounding: RoundingStrategy,
|
||||
chunk_index: SearchIndex<ChunkId>,
|
||||
max_chunk_size: usize,
|
||||
ring: RingBuffer<Chunk, Slice>,
|
||||
debug: bool,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct Chunk {
|
||||
pub storage: StorageHandle,
|
||||
pub handle: ChunkHandle,
|
||||
pub slices: Vec<SliceId>,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
pub struct Slice {
|
||||
pub storage: StorageHandle,
|
||||
pub handle: SliceHandle,
|
||||
pub chunk: ChunkHandle,
|
||||
pub padding: usize,
|
||||
}
|
||||
|
||||
impl Slice {
|
||||
pub fn effective_size(&self) -> usize {
|
||||
self.storage.size() + self.padding
|
||||
}
|
||||
}
|
||||
|
||||
const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16;
|
||||
const BUFFER_ALIGNMENT: usize = 32;
|
||||
const MB: usize = 1024 * 1024;
|
||||
|
||||
pub enum RoundingStrategy {
|
||||
RoundUp,
|
||||
None,
|
||||
}
|
||||
|
||||
impl RoundingStrategy {
|
||||
fn alloc_size(&self, size: usize) -> usize {
|
||||
match self {
|
||||
RoundingStrategy::RoundUp => {
|
||||
if size < BUFFER_ALIGNMENT {
|
||||
return BUFFER_ALIGNMENT;
|
||||
}
|
||||
|
||||
if size < MB {
|
||||
2 * MB
|
||||
} else if size < 10 * MB {
|
||||
return 20 * MB;
|
||||
} else {
|
||||
let factor = (size + (2 * MB - 1)) / (2 * MB);
|
||||
factor * 2 * MB
|
||||
}
|
||||
}
|
||||
RoundingStrategy::None => size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The strategy defines the frequency at which merging of free slices (defragmentation) occurs
|
||||
#[derive(Debug)]
|
||||
pub enum MemoryExtensionStrategy {
|
||||
/// Once every n calls to reserve.
|
||||
PeriodTick {
|
||||
/// Number of calls to be executed before triggering the defragmentation.
|
||||
period: usize,
|
||||
/// Current state. Should start at zero.
|
||||
state: usize,
|
||||
},
|
||||
/// Never defragment.
|
||||
Never,
|
||||
}
|
||||
|
||||
impl MemoryExtensionStrategy {
|
||||
/// Create a new strategy with the given period.
|
||||
pub fn new_period_tick(period: usize) -> Self {
|
||||
MemoryExtensionStrategy::PeriodTick { period, state: 0 }
|
||||
}
|
||||
|
||||
fn should_extend_max_memory(&mut self) -> bool {
|
||||
match self {
|
||||
MemoryExtensionStrategy::PeriodTick { period, state } => {
|
||||
*state = (*state + 1) % *period;
|
||||
*state == 0
|
||||
}
|
||||
MemoryExtensionStrategy::Never => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct SliceUpdate {
|
||||
slice_id: SliceId,
|
||||
size: usize,
|
||||
}
|
||||
|
||||
impl MemoryPool {
|
||||
pub fn new(
|
||||
merging_strategy: MemoryExtensionStrategy,
|
||||
alloc_strategy: RoundingStrategy,
|
||||
max_chunk_size: usize,
|
||||
debug: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
slices: HashMap::new(),
|
||||
memory_extension_strategy: merging_strategy,
|
||||
rounding: alloc_strategy,
|
||||
max_chunk_size,
|
||||
chunk_index: SearchIndex::new(),
|
||||
ring: RingBuffer::new(),
|
||||
debug,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the resource from the storage, for the specified handle.
|
||||
pub fn get<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
binding: &MemoryPoolBinding,
|
||||
) -> Option<Storage::Resource> {
|
||||
self.slices
|
||||
.get(binding.slice.id())
|
||||
.map(|s| &s.storage)
|
||||
.map(|h| storage.get(h))
|
||||
}
|
||||
|
||||
/// Reserves memory of specified size using the reserve algorithm, and return
|
||||
/// a handle to the reserved memory.
|
||||
///
|
||||
/// Also clean ups, merging free slices together if permitted by the merging strategy
|
||||
pub fn reserve<Storage: ComputeStorage, Sync: FnOnce()>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
// Looks for a large enough, existing but unused chunk of memory.
|
||||
let slice = self.get_free_slice(size);
|
||||
|
||||
match slice {
|
||||
Some(slice) => MemoryPoolHandle {
|
||||
slice: slice.clone(),
|
||||
},
|
||||
None => self.alloc(storage, size, sync),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn alloc<Storage: ComputeStorage, Sync: FnOnce()>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
if self.memory_extension_strategy.should_extend_max_memory() {
|
||||
sync();
|
||||
self.extend_max_memory(storage, size);
|
||||
|
||||
if let Some(handle) = self.get_free_slice(size) {
|
||||
return MemoryPoolHandle { slice: handle };
|
||||
}
|
||||
}
|
||||
|
||||
let alloc_size = self.rounding.alloc_size(size);
|
||||
self.alloc_slice(storage, alloc_size, size)
|
||||
}
|
||||
|
||||
fn alloc_slice<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
alloc_size: usize,
|
||||
slice_size: usize,
|
||||
) -> MemoryPoolHandle {
|
||||
let handle_chunk = self.create_chunk(storage, alloc_size);
|
||||
let chunk_id = *handle_chunk.id();
|
||||
let (slice, extra_slice) =
|
||||
self.allocate_slices(handle_chunk.clone(), alloc_size, slice_size);
|
||||
|
||||
let handle_slice = slice.handle.clone();
|
||||
self.update_chunk_metadata(chunk_id, slice, extra_slice);
|
||||
|
||||
MemoryPoolHandle {
|
||||
slice: handle_slice,
|
||||
}
|
||||
}
|
||||
|
||||
fn allocate_slices(
|
||||
&self,
|
||||
handle_chunk: ChunkHandle,
|
||||
alloc_size: usize,
|
||||
slice_size: usize,
|
||||
) -> (Slice, Option<Slice>) {
|
||||
let slice = self.create_slice(0, slice_size, handle_chunk.clone());
|
||||
let effective_size = slice.effective_size();
|
||||
|
||||
let extra_slice = if effective_size < alloc_size {
|
||||
Some(self.create_slice(effective_size, alloc_size - effective_size, handle_chunk))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
(slice, extra_slice)
|
||||
}
|
||||
|
||||
fn update_chunk_metadata(
|
||||
&mut self,
|
||||
chunk_id: ChunkId,
|
||||
slice: Slice,
|
||||
extra_slice: Option<Slice>,
|
||||
) {
|
||||
let slice_id = *slice.handle.id();
|
||||
|
||||
self.slices.insert(slice_id, slice);
|
||||
self.chunks
|
||||
.get_mut(&chunk_id)
|
||||
.unwrap()
|
||||
.slices
|
||||
.push(slice_id);
|
||||
|
||||
if let Some(extra_slice) = extra_slice {
|
||||
let extra_slice_id = *extra_slice.handle.id();
|
||||
self.slices.insert(extra_slice_id, extra_slice);
|
||||
self.chunks
|
||||
.get_mut(&chunk_id)
|
||||
.unwrap()
|
||||
.slices
|
||||
.push(extra_slice_id);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
fn display_memory_usage(&self) {
|
||||
let total_memory_usage: f64 = self
|
||||
.chunks
|
||||
.values()
|
||||
.map(|chunk| chunk.storage.size() as f64)
|
||||
.sum();
|
||||
let effective_memory_usage: f64 = self
|
||||
.slices
|
||||
.values()
|
||||
.filter(|slice| slice.handle.is_free())
|
||||
.map(|slice| slice.storage.size() as f64)
|
||||
.sum();
|
||||
let ratio = 100.0 * effective_memory_usage / total_memory_usage;
|
||||
log::info!("the memory usage is {ratio}");
|
||||
}
|
||||
|
||||
/// Finds a free slice that can contain the given size
|
||||
/// Returns the chunk's id and size.
|
||||
fn get_free_slice(&mut self, size: usize) -> Option<SliceHandle> {
|
||||
if size < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
return None;
|
||||
}
|
||||
|
||||
let padding = calculate_padding(size);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let slice_id =
|
||||
self.ring
|
||||
.find_free_slice(effective_size, &mut self.chunks, &mut self.slices);
|
||||
|
||||
let slice_id = match slice_id {
|
||||
Some(val) => val,
|
||||
None => return None,
|
||||
};
|
||||
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
|
||||
let offset = match slice.storage.utilization {
|
||||
StorageUtilization::Full(_) => 0,
|
||||
StorageUtilization::Slice { offset, size: _ } => offset,
|
||||
};
|
||||
slice.storage.utilization = StorageUtilization::Slice { offset, size };
|
||||
slice.padding = padding;
|
||||
|
||||
Some(slice.handle.clone())
|
||||
}
|
||||
|
||||
/// Creates a slice of size `size` upon the given chunk with the given offset.
|
||||
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> Slice {
|
||||
assert_eq!(
|
||||
offset % BUFFER_ALIGNMENT,
|
||||
0,
|
||||
"slice with offset {offset} needs to be a multiple of {BUFFER_ALIGNMENT}"
|
||||
);
|
||||
if offset > 0 && size < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
panic!("tried to create slice of size {size} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
|
||||
}
|
||||
let chunk = self.chunks.get(handle_chunk.id()).unwrap();
|
||||
let handle = SliceHandle::new();
|
||||
|
||||
let storage = StorageHandle {
|
||||
id: chunk.storage.id.clone(),
|
||||
utilization: StorageUtilization::Slice { offset, size },
|
||||
};
|
||||
|
||||
let padding = calculate_padding(size);
|
||||
|
||||
Slice::new(storage, handle, chunk.handle.clone(), padding)
|
||||
}
|
||||
|
||||
/// Creates a chunk of given size by allocating on the storage.
|
||||
fn create_chunk<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
size: usize,
|
||||
) -> ChunkHandle {
|
||||
let padding = calculate_padding(size);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let storage = storage.alloc(effective_size);
|
||||
let handle = ChunkHandle::new();
|
||||
let id = *handle.id();
|
||||
|
||||
self.ring.push_chunk(id);
|
||||
|
||||
self.chunks
|
||||
.insert(id, Chunk::new(storage, handle.clone(), Vec::new()));
|
||||
self.chunk_index.insert(id, size);
|
||||
|
||||
handle
|
||||
}
|
||||
|
||||
fn extend_max_memory<Storage: ComputeStorage>(&mut self, storage: &mut Storage, size: usize) {
|
||||
if self.debug {
|
||||
log::info!("Extend max memory ...");
|
||||
}
|
||||
|
||||
let mut slices = Vec::<SliceUpdate>::new();
|
||||
let mut current_size = size;
|
||||
|
||||
let chunks_sorted = self
|
||||
.chunk_index
|
||||
.find_by_size(0..self.max_chunk_size - 1)
|
||||
.map(Clone::clone)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mut deallocations = HashSet::<ChunkId>::new();
|
||||
|
||||
for chunk_id in chunks_sorted {
|
||||
let chunk = self.chunks.get(&chunk_id).unwrap();
|
||||
let chunk_id = *chunk.handle.id();
|
||||
let slices_ids = chunk.slices.clone();
|
||||
|
||||
for slice_id in slices_ids {
|
||||
let slice = self.slices.get(&slice_id).unwrap();
|
||||
let size = slice.storage.size();
|
||||
|
||||
let effective_size = slice.effective_size();
|
||||
current_size += effective_size;
|
||||
|
||||
if current_size >= self.max_chunk_size {
|
||||
let alloc_size = current_size - effective_size;
|
||||
// let alloc_size = self.max_chunk_size;
|
||||
self.move_to_new_chunk(alloc_size, storage, &mut slices, &mut deallocations);
|
||||
current_size = effective_size;
|
||||
}
|
||||
|
||||
slices.push(SliceUpdate { slice_id, size });
|
||||
}
|
||||
|
||||
deallocations.insert(chunk_id);
|
||||
}
|
||||
|
||||
if !slices.is_empty() {
|
||||
self.move_to_new_chunk(current_size, storage, &mut slices, &mut deallocations);
|
||||
} else {
|
||||
self.deallocate(storage, &mut deallocations);
|
||||
}
|
||||
}
|
||||
|
||||
fn deallocate<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
storage: &mut Storage,
|
||||
deallocations: &mut HashSet<ChunkId>,
|
||||
) {
|
||||
for id in deallocations.drain() {
|
||||
let mut chunk = self.chunks.remove(&id).unwrap();
|
||||
self.ring.remove_chunk(id);
|
||||
|
||||
for slice in chunk.slices.drain(..) {
|
||||
let slice = self.slices.get(&slice).unwrap();
|
||||
let chunk_id = *slice.chunk.id();
|
||||
|
||||
assert_ne!(chunk_id, id, "Chunk id should be updated");
|
||||
}
|
||||
|
||||
self.chunk_index.remove(&id);
|
||||
storage.dealloc(chunk.storage.id);
|
||||
}
|
||||
}
|
||||
|
||||
fn move_to_new_chunk<Storage: ComputeStorage>(
|
||||
&mut self,
|
||||
alloc_size: usize,
|
||||
storage: &mut Storage,
|
||||
slices: &mut Vec<SliceUpdate>,
|
||||
deallocations: &mut HashSet<ChunkId>,
|
||||
) {
|
||||
let chunk = self.create_chunk(storage, alloc_size);
|
||||
let storage_id = self.chunks.get(chunk.id()).unwrap().storage.id.clone();
|
||||
let mut offset = 0;
|
||||
let mut slices_ids = Vec::new();
|
||||
|
||||
for update in slices.drain(..) {
|
||||
let slice_id = update.slice_id;
|
||||
|
||||
let slice = self.slices.get_mut(&slice_id).unwrap();
|
||||
let old_storage = slice.storage.clone();
|
||||
|
||||
slice.chunk = chunk.clone();
|
||||
slice.storage = StorageHandle {
|
||||
id: storage_id.clone(),
|
||||
utilization: StorageUtilization::Slice {
|
||||
offset,
|
||||
size: update.size,
|
||||
},
|
||||
};
|
||||
storage.copy(&old_storage, &slice.storage);
|
||||
slices_ids.push(slice_id);
|
||||
|
||||
offset += slice.effective_size();
|
||||
}
|
||||
|
||||
let chunk = self.chunks.get_mut(chunk.id()).unwrap();
|
||||
chunk.slices = slices_ids;
|
||||
|
||||
self.deallocate(storage, deallocations);
|
||||
}
|
||||
}
|
||||
|
||||
fn calculate_padding(size: usize) -> usize {
|
||||
let remainder = size % BUFFER_ALIGNMENT;
|
||||
if remainder != 0 {
|
||||
BUFFER_ALIGNMENT - remainder
|
||||
} else {
|
||||
0
|
||||
}
|
||||
}
|
||||
|
||||
impl MemorySlice for Slice {
|
||||
fn is_free(&self) -> bool {
|
||||
self.handle.is_free()
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.effective_size()
|
||||
}
|
||||
|
||||
fn split(&mut self, offset_slice: usize) -> Self {
|
||||
let size_new = self.effective_size() - offset_slice;
|
||||
let offset_new = self.storage.offset() + offset_slice;
|
||||
|
||||
let storage_new = StorageHandle {
|
||||
id: self.storage.id.clone(),
|
||||
utilization: StorageUtilization::Slice {
|
||||
offset: offset_new,
|
||||
size: size_new,
|
||||
},
|
||||
};
|
||||
|
||||
self.storage.utilization = StorageUtilization::Slice {
|
||||
offset: self.storage.offset(),
|
||||
size: offset_slice,
|
||||
};
|
||||
|
||||
if offset_new > 0 && size_new < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
panic!("tried to create slice of size {size_new} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
|
||||
}
|
||||
if offset_new % BUFFER_ALIGNMENT != 0 {
|
||||
panic!("slice with offset {offset_new} needs to be a multiple of {BUFFER_ALIGNMENT}");
|
||||
}
|
||||
let handle = SliceHandle::new();
|
||||
|
||||
assert!(
|
||||
size_new >= BUFFER_ALIGNMENT,
|
||||
"Size new > {BUFFER_ALIGNMENT}"
|
||||
);
|
||||
let padding = calculate_padding(size_new - BUFFER_ALIGNMENT);
|
||||
|
||||
Slice::new(storage_new, handle, self.chunk.clone(), padding)
|
||||
}
|
||||
|
||||
fn id(&self) -> SliceId {
|
||||
*self.handle.id()
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryChunk<Slice> for Chunk {
|
||||
fn merge_next_slice(
|
||||
&mut self,
|
||||
from_slice_index: usize,
|
||||
slices: &mut HashMap<SliceId, Slice>,
|
||||
) -> bool {
|
||||
let slice_id_current = self.slices.get(from_slice_index).unwrap();
|
||||
let slice_id_next = self.slices.get(from_slice_index + 1);
|
||||
let slice_id_next = match slice_id_next {
|
||||
Some(val) => val,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
let slice_next = slices.get(slice_id_next).unwrap();
|
||||
let is_free = slice_next.is_free();
|
||||
let size = slice_next.effective_size();
|
||||
|
||||
let slice_current = slices.get_mut(slice_id_current).unwrap();
|
||||
|
||||
if is_free {
|
||||
slice_current.storage.utilization = StorageUtilization::Slice {
|
||||
size: slice_current.effective_size() + size,
|
||||
offset: slice_current.storage.offset(),
|
||||
};
|
||||
slices.remove(slice_id_next);
|
||||
self.slices.remove(from_slice_index + 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn slice(&self, index: usize) -> Option<SliceId> {
|
||||
self.slices.get(index).copied()
|
||||
}
|
||||
|
||||
fn insert_slice(&mut self, position: usize, slice_id: SliceId) {
|
||||
self.slices.insert(position, slice_id);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,33 @@
|
|||
use crate::memory_id_type;
|
||||
use crate::memory_management::{MemoryBinding, MemoryHandle};
|
||||
|
||||
// The ChunkId allows to keep track of how many references there are to a specific chunk.
|
||||
memory_id_type!(ChunkId, ChunkHandle);
|
||||
// The SliceId allows to keep track of how many references there are to a specific slice.
|
||||
memory_id_type!(SliceId, SliceHandle, SliceBinding);
|
||||
|
||||
/// A tensor memory handle, referring to either a chunk or a slice.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryPoolHandle {
|
||||
pub slice: SliceHandle,
|
||||
}
|
||||
|
||||
/// Binding of the [dynamic handle](DynamicHandle).
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MemoryPoolBinding {
|
||||
pub slice: SliceBinding,
|
||||
}
|
||||
|
||||
impl MemoryBinding for MemoryPoolBinding {}
|
||||
|
||||
impl MemoryHandle<MemoryPoolBinding> for MemoryPoolHandle {
|
||||
fn can_mut(&self) -> bool {
|
||||
self.slice.can_mut()
|
||||
}
|
||||
|
||||
fn binding(self) -> MemoryPoolBinding {
|
||||
MemoryPoolBinding {
|
||||
slice: self.slice.binding(),
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,64 @@
|
|||
use alloc::collections::BTreeMap;
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
use core::hash::Hash;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// Data Structure that helps to search items by size efficiently.
|
||||
pub struct SearchIndex<T> {
|
||||
items_per_size: BTreeMap<usize, Vec<T>>,
|
||||
sizes_per_item: HashMap<T, usize>,
|
||||
}
|
||||
|
||||
impl<T: PartialEq + Eq + Hash + Clone> SearchIndex<T> {
|
||||
/// Create a new item search index.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
items_per_size: BTreeMap::new(),
|
||||
sizes_per_item: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert a new sized item into the search index.
|
||||
pub fn insert(&mut self, item: T, size: usize) {
|
||||
self.remove(&item);
|
||||
|
||||
if let Some(values) = self.items_per_size.get_mut(&size) {
|
||||
values.push(item.clone())
|
||||
} else {
|
||||
self.items_per_size.insert(size, vec![item.clone()]);
|
||||
}
|
||||
self.sizes_per_item.insert(item, size);
|
||||
}
|
||||
|
||||
/// Find the item by size range.
|
||||
pub fn find_by_size(
|
||||
&self,
|
||||
range: core::ops::Range<usize>,
|
||||
) -> impl DoubleEndedIterator<Item = &T> {
|
||||
self.items_per_size.range(range).flat_map(|a| a.1)
|
||||
}
|
||||
|
||||
/// Remove an item from the index.
|
||||
pub fn remove(&mut self, item: &T) {
|
||||
let size = match self.sizes_per_item.remove(item) {
|
||||
Some(size) => size,
|
||||
None => return,
|
||||
};
|
||||
|
||||
if let Some(values) = self.items_per_size.get_mut(&size) {
|
||||
let mut removed_index = None;
|
||||
|
||||
for (i, v) in values.iter().enumerate() {
|
||||
if v == item {
|
||||
removed_index = Some(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(index) = removed_index {
|
||||
values.remove(index);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,9 @@
|
|||
pub(crate) mod index;
|
||||
mod ring;
|
||||
|
||||
mod base;
|
||||
mod handle;
|
||||
|
||||
pub use base::*;
|
||||
pub use handle::*;
|
||||
pub use ring::*;
|
|
@ -0,0 +1,416 @@
|
|||
use alloc::vec::Vec;
|
||||
use core::marker::PhantomData;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use super::{ChunkId, SliceId};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct RingBuffer<C: MemoryChunk<S>, S: MemorySlice> {
|
||||
queue: Vec<ChunkId>,
|
||||
chunk_positions: HashMap<ChunkId, usize>,
|
||||
cursor_slice: usize,
|
||||
cursor_chunk: usize,
|
||||
_s: PhantomData<S>,
|
||||
_c: PhantomData<C>,
|
||||
}
|
||||
|
||||
pub trait MemoryChunk<S: MemorySlice> {
|
||||
fn merge_next_slice(&mut self, slice_position: usize, slices: &mut HashMap<SliceId, S>)
|
||||
-> bool;
|
||||
fn slice(&self, index: usize) -> Option<SliceId>;
|
||||
fn insert_slice(&mut self, position: usize, slice_id: SliceId);
|
||||
}
|
||||
|
||||
pub trait MemorySlice {
|
||||
fn is_free(&self) -> bool;
|
||||
fn size(&self) -> usize;
|
||||
fn split(&mut self, offset: usize) -> Self;
|
||||
fn id(&self) -> SliceId;
|
||||
}
|
||||
|
||||
impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
queue: Vec::new(),
|
||||
chunk_positions: HashMap::new(),
|
||||
cursor_slice: 0,
|
||||
cursor_chunk: 0,
|
||||
_s: PhantomData,
|
||||
_c: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn push_chunk(&mut self, chunk_id: ChunkId) {
|
||||
self.queue.push(chunk_id);
|
||||
self.chunk_positions.insert(chunk_id, self.queue.len() - 1);
|
||||
}
|
||||
|
||||
pub fn remove_chunk(&mut self, chunk_id: ChunkId) {
|
||||
if let Some(position) = self.chunk_positions.remove(&chunk_id) {
|
||||
self.queue.remove(position);
|
||||
}
|
||||
|
||||
self.chunk_positions.clear();
|
||||
|
||||
for (pos, id) in self.queue.iter().enumerate() {
|
||||
self.chunk_positions.insert(*id, pos);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn find_free_slice(
|
||||
&mut self,
|
||||
size: usize,
|
||||
chunks: &mut HashMap<ChunkId, C>,
|
||||
slices: &mut HashMap<SliceId, S>,
|
||||
) -> Option<SliceId> {
|
||||
let max_second = self.cursor_chunk;
|
||||
let result = self.find_free_slice_in_all_chunks(size, chunks, slices, self.queue.len());
|
||||
|
||||
if result.is_some() {
|
||||
return result;
|
||||
}
|
||||
|
||||
self.cursor_chunk = 0;
|
||||
self.cursor_slice = 0;
|
||||
self.find_free_slice_in_all_chunks(size, chunks, slices, max_second)
|
||||
}
|
||||
|
||||
fn find_free_slice_in_chunk(
|
||||
&mut self,
|
||||
size: usize,
|
||||
chunk: &mut C,
|
||||
slices: &mut HashMap<SliceId, S>,
|
||||
mut slice_index: usize,
|
||||
) -> Option<(usize, SliceId)> {
|
||||
while let Some(slice_id) = chunk.slice(slice_index) {
|
||||
let slice = slices.get_mut(&slice_id).unwrap();
|
||||
|
||||
let is_big_enough = slice.size() >= size;
|
||||
let is_free = slice.is_free();
|
||||
|
||||
if is_big_enough && is_free {
|
||||
if slice.size() > size {
|
||||
let new_slice = slice.split(size);
|
||||
chunk.insert_slice(slice_index + 1, new_slice.id());
|
||||
slices.insert(new_slice.id(), new_slice);
|
||||
}
|
||||
|
||||
return Some((slice_index, slice_id));
|
||||
}
|
||||
|
||||
if is_free && chunk.merge_next_slice(slice_index, slices) {
|
||||
continue;
|
||||
}
|
||||
|
||||
slice_index += 1;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
fn find_free_slice_in_all_chunks(
|
||||
&mut self,
|
||||
size: usize,
|
||||
chunks: &mut HashMap<ChunkId, C>,
|
||||
slices: &mut HashMap<SliceId, S>,
|
||||
max_cursor_position: usize,
|
||||
) -> Option<SliceId> {
|
||||
let start = self.cursor_chunk;
|
||||
let end = usize::min(self.queue.len(), max_cursor_position);
|
||||
let mut slice_index = self.cursor_slice;
|
||||
|
||||
for chunk_index in start..end {
|
||||
if chunk_index > start {
|
||||
slice_index = 0;
|
||||
}
|
||||
|
||||
if let Some(id) = self.queue.get(chunk_index) {
|
||||
let chunk = chunks.get_mut(id).unwrap();
|
||||
let result = self.find_free_slice_in_chunk(size, chunk, slices, slice_index);
|
||||
|
||||
if let Some((cursor_slice, slice)) = result {
|
||||
self.cursor_slice = cursor_slice + 1;
|
||||
self.cursor_chunk = chunk_index;
|
||||
return Some(slice);
|
||||
}
|
||||
}
|
||||
self.cursor_chunk = chunk_index;
|
||||
self.cursor_slice = 0;
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::stub::*;
|
||||
use super::*;
|
||||
use alloc::vec;
|
||||
|
||||
#[test]
|
||||
fn simple_1() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
|
||||
let slice_1 = new_slice(0, 100);
|
||||
let slice_2 = new_slice(1, 200);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
|
||||
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
|
||||
let slice = ring.find_free_slice(50, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 0 });
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 50);
|
||||
assert_eq!(slices.len(), 3);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn simple_2() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
|
||||
let slice_1 = new_slice(0, 100);
|
||||
let slice_2 = new_slice(1, 200);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
|
||||
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
|
||||
let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 0 });
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 150);
|
||||
assert_eq!(slices.len(), 2);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn multiple_chunks() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
|
||||
let slice_1 = new_slice(0, 100);
|
||||
let slice_2 = new_slice(1, 200);
|
||||
let slice_3 = new_slice(2, 200);
|
||||
let slice_4 = new_slice(3, 200);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
let chunk_2 = new_chunk(1, vec![2, 3]);
|
||||
|
||||
let mut slices = HashMap::from([
|
||||
(slice_1.id, slice_1),
|
||||
(slice_2.id, slice_2),
|
||||
(slice_3.id, slice_3),
|
||||
(slice_4.id, slice_4),
|
||||
]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
ring.push_chunk(ChunkId { value: 1 });
|
||||
|
||||
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = false;
|
||||
slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = false;
|
||||
|
||||
let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 2 });
|
||||
|
||||
let slice = ring.find_free_slice(100, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 0 });
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_free_slice_with_exact_fit() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
|
||||
let slice_1 = new_slice(0, 100);
|
||||
let slice_2 = new_slice(1, 200);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
|
||||
let mut slices = HashMap::from([(slice_1.id, slice_1), (slice_2.id, slice_2)]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
|
||||
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = false;
|
||||
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
|
||||
|
||||
let slice = ring.find_free_slice(200, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 1 });
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 200);
|
||||
assert_eq!(slices.len(), 2);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_free_slice_with_merging() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
|
||||
let slice_1 = new_slice(0, 100);
|
||||
let slice_2 = new_slice(1, 50);
|
||||
let slice_3 = new_slice(2, 100);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1, 2]);
|
||||
|
||||
let mut slices = HashMap::from([
|
||||
(slice_1.id, slice_1),
|
||||
(slice_2.id, slice_2),
|
||||
(slice_3.id, slice_3),
|
||||
]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
|
||||
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true;
|
||||
|
||||
let slice = ring.find_free_slice(250, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slice, SliceId { value: 0 });
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 250);
|
||||
assert_eq!(slices.len(), 1);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn find_free_slice_with_multiple_chunks_and_merging() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
|
||||
let slice_1 = new_slice(0, 50);
|
||||
let slice_2 = new_slice(1, 50);
|
||||
let chunk_1 = new_chunk(0, vec![0, 1]);
|
||||
|
||||
let slice_3 = new_slice(2, 100);
|
||||
let slice_4 = new_slice(3, 50);
|
||||
let chunk_2 = new_chunk(1, vec![2, 3]);
|
||||
|
||||
let mut slices = HashMap::from([
|
||||
(slice_1.id, slice_1),
|
||||
(slice_2.id, slice_2),
|
||||
(slice_3.id, slice_3),
|
||||
(slice_4.id, slice_4),
|
||||
]);
|
||||
let mut chunks = HashMap::from([(chunk_1.id, chunk_1), (chunk_2.id, chunk_2)]);
|
||||
|
||||
ring.push_chunk(ChunkId { value: 0 });
|
||||
ring.push_chunk(ChunkId { value: 1 });
|
||||
|
||||
slices.get_mut(&SliceId { value: 0 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 1 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 2 }).unwrap().is_free = true;
|
||||
slices.get_mut(&SliceId { value: 3 }).unwrap().is_free = true;
|
||||
|
||||
let slice = ring.find_free_slice(150, &mut chunks, &mut slices).unwrap();
|
||||
|
||||
assert_eq!(slices.get(&slice).unwrap().size, 150);
|
||||
assert_eq!(slices.len(), 2);
|
||||
assert_eq!(chunks.values().last().unwrap().slices.len(), 1);
|
||||
}
|
||||
|
||||
fn new_slice(id: usize, size: usize) -> TestSlice {
|
||||
TestSlice {
|
||||
id: SliceId { value: id },
|
||||
is_free: true,
|
||||
size,
|
||||
}
|
||||
}
|
||||
|
||||
fn new_chunk(id: usize, slices: Vec<usize>) -> TestChunk {
|
||||
TestChunk {
|
||||
id: ChunkId { value: id },
|
||||
slices: slices.into_iter().map(|i| SliceId { value: i }).collect(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod stub {
|
||||
use super::*;
|
||||
use burn_common::*;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TestChunk {
|
||||
pub id: ChunkId,
|
||||
pub slices: Vec<SliceId>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct TestSlice {
|
||||
pub id: SliceId,
|
||||
pub is_free: bool,
|
||||
pub size: usize,
|
||||
}
|
||||
|
||||
impl MemorySlice for TestSlice {
|
||||
fn is_free(&self) -> bool {
|
||||
self.is_free
|
||||
}
|
||||
|
||||
fn size(&self) -> usize {
|
||||
self.size
|
||||
}
|
||||
|
||||
fn split(&mut self, offset: usize) -> Self {
|
||||
let size_remained = self.size - offset;
|
||||
self.size = offset;
|
||||
|
||||
Self {
|
||||
id: SliceId {
|
||||
value: rand::gen_random(),
|
||||
},
|
||||
is_free: true,
|
||||
size: size_remained,
|
||||
}
|
||||
}
|
||||
|
||||
fn id(&self) -> SliceId {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
||||
impl MemoryChunk<TestSlice> for TestChunk {
|
||||
fn merge_next_slice(
|
||||
&mut self,
|
||||
from_slice_index: usize,
|
||||
slices: &mut HashMap<SliceId, TestSlice>,
|
||||
) -> bool {
|
||||
let slice_id_current = self.slices.get(from_slice_index).unwrap();
|
||||
let slice_id_next = self.slices.get(from_slice_index + 1);
|
||||
let slice_id_next = match slice_id_next {
|
||||
Some(val) => val,
|
||||
None => return false,
|
||||
};
|
||||
|
||||
let slice_next = slices.get(slice_id_next).unwrap();
|
||||
let is_free = slice_next.is_free;
|
||||
let size = slice_next.size;
|
||||
|
||||
let slice_current = slices.get_mut(slice_id_current).unwrap();
|
||||
|
||||
if is_free {
|
||||
slice_current.size += size;
|
||||
slices.remove(slice_id_next);
|
||||
self.slices.remove(from_slice_index + 1);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
fn slice(&self, index: usize) -> Option<SliceId> {
|
||||
self.slices.get(index).copied()
|
||||
}
|
||||
|
||||
fn insert_slice(&mut self, position: usize, slice_id: SliceId) {
|
||||
self.slices.insert(position, slice_id);
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,4 +1,7 @@
|
|||
pub(crate) mod memory_pool;
|
||||
|
||||
mod base;
|
||||
|
||||
pub use base::*;
|
||||
|
||||
/// Dynamic memory management strategy.
|
||||
|
|
|
@ -200,7 +200,7 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManageme
|
|||
/// a handle to the reserved memory.
|
||||
///
|
||||
/// Also clean ups, removing unused slices, and chunks if permitted by deallocation strategy.
|
||||
fn reserve(&mut self, size: usize) -> Self::Handle {
|
||||
fn reserve<Sync: FnOnce()>(&mut self, size: usize, _sync: Sync) -> Self::Handle {
|
||||
self.cleanup_slices();
|
||||
|
||||
let handle = self.reserve_algorithm(size);
|
||||
|
@ -212,7 +212,7 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for SimpleMemoryManageme
|
|||
handle
|
||||
}
|
||||
|
||||
fn alloc(&mut self, size: usize) -> Self::Handle {
|
||||
fn alloc<Sync: FnOnce()>(&mut self, size: usize, _sync: Sync) -> Self::Handle {
|
||||
self.create_chunk(size)
|
||||
}
|
||||
|
||||
|
@ -387,6 +387,12 @@ mod tests {
|
|||
storage::BytesStorage,
|
||||
};
|
||||
|
||||
impl<Storage: ComputeStorage> SimpleMemoryManagement<Storage> {
|
||||
fn reserve_no_sync(&mut self, size: usize) -> SimpleHandle {
|
||||
self.reserve(size, || {})
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn can_mut_with_single_tensor_reference() {
|
||||
let mut memory_management = SimpleMemoryManagement::new(
|
||||
|
@ -429,8 +435,8 @@ mod tests {
|
|||
SliceStrategy::Never,
|
||||
);
|
||||
let chunk_size = 4;
|
||||
let _chunk_handle = memory_management.reserve(chunk_size);
|
||||
let _new_handle = memory_management.reserve(chunk_size);
|
||||
let _chunk_handle = memory_management.reserve_no_sync(chunk_size);
|
||||
let _new_handle = memory_management.reserve_no_sync(chunk_size);
|
||||
|
||||
assert_eq!(memory_management.chunks.len(), 2);
|
||||
}
|
||||
|
@ -443,7 +449,7 @@ mod tests {
|
|||
SliceStrategy::Never,
|
||||
);
|
||||
let chunk_size = 4;
|
||||
let chunk_handle = memory_management.reserve(chunk_size);
|
||||
let chunk_handle = memory_management.reserve_no_sync(chunk_size);
|
||||
drop(chunk_handle);
|
||||
memory_management.cleanup_chunks();
|
||||
|
||||
|
@ -502,7 +508,7 @@ mod tests {
|
|||
DeallocStrategy::Never,
|
||||
SliceStrategy::Ratio(0.5),
|
||||
);
|
||||
let handle = memory_management.reserve(10);
|
||||
let handle = memory_management.reserve_no_sync(10);
|
||||
|
||||
let other_ref = handle.clone();
|
||||
|
||||
|
@ -518,7 +524,7 @@ mod tests {
|
|||
DeallocStrategy::Never,
|
||||
SliceStrategy::Ratio(0.5),
|
||||
);
|
||||
let chunk = memory_management.reserve(10);
|
||||
let chunk = memory_management.reserve_no_sync(10);
|
||||
|
||||
if let super::SimpleHandle::Slice(_) = chunk {
|
||||
panic!("Should be a chunk.")
|
||||
|
@ -526,7 +532,7 @@ mod tests {
|
|||
|
||||
drop(chunk);
|
||||
|
||||
let slice = memory_management.reserve(8);
|
||||
let slice = memory_management.reserve_no_sync(8);
|
||||
|
||||
if let super::SimpleHandle::Chunk(_) = &slice {
|
||||
panic!("Should be a slice.")
|
||||
|
|
|
@ -18,7 +18,7 @@ pub enum StorageUtilization {
|
|||
}
|
||||
|
||||
/// Contains the [storage id](StorageId) of a resource and the way it is used.
|
||||
#[derive(new, Debug)]
|
||||
#[derive(new, Clone, Debug)]
|
||||
pub struct StorageHandle {
|
||||
/// Storage id.
|
||||
pub id: StorageId,
|
||||
|
@ -58,4 +58,7 @@ pub trait ComputeStorage: Send {
|
|||
|
||||
/// Deallocates the memory pointed by the given storage id.
|
||||
fn dealloc(&mut self, id: StorageId);
|
||||
|
||||
/// Copy
|
||||
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle);
|
||||
}
|
||||
|
|
|
@ -90,6 +90,23 @@ impl ComputeStorage for BytesStorage {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) {
|
||||
assert_eq!(from.size(), to.size());
|
||||
|
||||
let input = self.get(from);
|
||||
let output = self.get(to);
|
||||
|
||||
for i in 0..from.size() {
|
||||
let offset = i + from.offset();
|
||||
let ptr_out = output.ptr.wrapping_add(offset);
|
||||
|
||||
let offset = i + to.offset();
|
||||
let ptr_in = input.ptr.wrapping_add(offset);
|
||||
|
||||
unsafe { *ptr_in = *ptr_out }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
@ -37,7 +37,7 @@ where
|
|||
}
|
||||
|
||||
fn create(&mut self, data: &[u8]) -> Handle<Self> {
|
||||
let handle = self.memory_management.reserve(data.len());
|
||||
let handle = self.memory_management.reserve(data.len(), || {});
|
||||
let resource = self.memory_management.get(handle.clone().binding());
|
||||
|
||||
let bytes = resource.write();
|
||||
|
@ -50,7 +50,7 @@ where
|
|||
}
|
||||
|
||||
fn empty(&mut self, size: usize) -> Handle<Self> {
|
||||
Handle::new(self.memory_management.reserve(size))
|
||||
Handle::new(self.memory_management.reserve(size, || {}))
|
||||
}
|
||||
|
||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<Binding<Self>>) {
|
||||
|
|
|
@ -74,7 +74,9 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
|||
|
||||
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
|
||||
let ctx = self.get_context();
|
||||
let handle = ctx.memory_management.reserve(data.len());
|
||||
let handle = ctx.memory_management.reserve(data.len(), || unsafe {
|
||||
cudarc::driver::result::stream::synchronize(ctx.stream).unwrap();
|
||||
});
|
||||
let handle = server::Handle::new(handle);
|
||||
let binding = handle.clone().binding().memory;
|
||||
let resource = ctx.memory_management.get(binding);
|
||||
|
@ -88,7 +90,9 @@ impl<MM: MemoryManagement<CudaStorage>> ComputeServer for CudaServer<MM> {
|
|||
|
||||
fn empty(&mut self, size: usize) -> server::Handle<Self> {
|
||||
let ctx = self.get_context();
|
||||
let handle = ctx.memory_management.reserve(size);
|
||||
let handle = ctx.memory_management.reserve(size, || unsafe {
|
||||
cudarc::driver::result::stream::synchronize(ctx.stream).unwrap();
|
||||
});
|
||||
server::Handle::new(handle)
|
||||
}
|
||||
|
||||
|
|
|
@ -115,4 +115,18 @@ impl ComputeStorage for CudaStorage {
|
|||
fn dealloc(&mut self, id: StorageId) {
|
||||
self.deallocations.push(id);
|
||||
}
|
||||
|
||||
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) {
|
||||
let num_bytes = from.size();
|
||||
|
||||
unsafe {
|
||||
cudarc::driver::result::memcpy_dtod_async(
|
||||
self.get(to).ptr,
|
||||
self.get(from).ptr,
|
||||
num_bytes,
|
||||
self.stream,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -214,7 +214,17 @@ where
|
|||
/// This is important, otherwise the compute passes are going to be too small and we won't be able to
|
||||
/// fully utilize the GPU.
|
||||
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
|
||||
let handle = server::Handle::new(self.memory_management.reserve(data.len()));
|
||||
let handle = server::Handle::new(self.memory_management.reserve(data.len(), || {
|
||||
flush_tasks(
|
||||
&mut self.encoder,
|
||||
&self.queue,
|
||||
&self.device,
|
||||
&mut self.tasks_count,
|
||||
&mut self.staging_belt,
|
||||
);
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
}));
|
||||
|
||||
let non_zero_len = NonZeroU64::new(data.len() as u64);
|
||||
|
||||
// If there's nothing to copy, don't need to do any work here.
|
||||
|
@ -230,7 +240,7 @@ where
|
|||
let mut write_buf = self.staging_belt.write_buffer(
|
||||
&mut self.encoder,
|
||||
&resource.buffer,
|
||||
0,
|
||||
resource.offset(),
|
||||
len,
|
||||
&self.device,
|
||||
);
|
||||
|
@ -256,7 +266,16 @@ where
|
|||
}
|
||||
|
||||
fn empty(&mut self, size: usize) -> server::Handle<Self> {
|
||||
server::Handle::new(self.memory_management.reserve(size))
|
||||
server::Handle::new(self.memory_management.reserve(size, || {
|
||||
flush_tasks(
|
||||
&mut self.encoder,
|
||||
&self.queue,
|
||||
&self.device,
|
||||
&mut self.tasks_count,
|
||||
&mut self.staging_belt,
|
||||
);
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
}))
|
||||
}
|
||||
|
||||
fn execute(&mut self, kernel: Self::Kernel, bindings: Vec<server::Binding<Self>>) {
|
||||
|
@ -292,24 +311,41 @@ where
|
|||
}
|
||||
|
||||
fn sync(&mut self, sync_type: SyncType) {
|
||||
// Flush commands to the queue.
|
||||
self.staging_belt.finish();
|
||||
|
||||
let mut new_encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
core::mem::swap(&mut new_encoder, &mut self.encoder);
|
||||
|
||||
self.queue.submit(Some(new_encoder.finish()));
|
||||
self.tasks_count = 0;
|
||||
flush_tasks(
|
||||
&mut self.encoder,
|
||||
&self.queue,
|
||||
&self.device,
|
||||
&mut self.tasks_count,
|
||||
&mut self.staging_belt,
|
||||
);
|
||||
|
||||
// Cleanup allocations and deallocations.
|
||||
self.memory_management.storage().perform_deallocations();
|
||||
|
||||
self.staging_belt.recall();
|
||||
|
||||
if sync_type == SyncType::Wait {
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush tasks using the [command encoder](CommandEncoder).
|
||||
///
|
||||
/// This implementation is decoupled from both the [server](WgpuServer) and [memory management](MemoryManagement).
|
||||
/// This decoupling allows for safe usage within sync callbacks when allocating memory buffers.
|
||||
fn flush_tasks(
|
||||
encoder: &mut CommandEncoder,
|
||||
queue: &wgpu::Queue,
|
||||
device: &wgpu::Device,
|
||||
tasks_count: &mut usize,
|
||||
staging_belt: &mut StagingBelt,
|
||||
) {
|
||||
staging_belt.finish();
|
||||
|
||||
let mut new_encoder =
|
||||
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
core::mem::swap(&mut new_encoder, encoder);
|
||||
|
||||
queue.submit(Some(new_encoder.finish()));
|
||||
*tasks_count = 0;
|
||||
staging_belt.recall();
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ pub struct WgpuStorage {
|
|||
memory: HashMap<StorageId, Arc<wgpu::Buffer>>,
|
||||
deallocations: Vec<StorageId>,
|
||||
device: Arc<wgpu::Device>,
|
||||
queue: Arc<wgpu::Queue>,
|
||||
}
|
||||
|
||||
impl core::fmt::Debug for WgpuStorage {
|
||||
|
@ -67,11 +68,12 @@ pub enum WgpuResourceKind {
|
|||
/// Keeps actual wgpu buffer references in a hashmap with ids as key.
|
||||
impl WgpuStorage {
|
||||
/// Create a new storage on the given [device](wgpu::Device).
|
||||
pub fn new(device: Arc<wgpu::Device>) -> Self {
|
||||
pub fn new(device: Arc<wgpu::Device>, queue: Arc<wgpu::Queue>) -> Self {
|
||||
Self {
|
||||
memory: HashMap::new(),
|
||||
deallocations: Vec::new(),
|
||||
device,
|
||||
queue,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -121,4 +123,23 @@ impl ComputeStorage for WgpuStorage {
|
|||
fn dealloc(&mut self, id: StorageId) {
|
||||
self.deallocations.push(id);
|
||||
}
|
||||
|
||||
fn copy(&mut self, from: &StorageHandle, to: &StorageHandle) {
|
||||
let mut encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
|
||||
let from = self.get(from);
|
||||
let to = self.get(to);
|
||||
|
||||
encoder.copy_buffer_to_buffer(
|
||||
&from.buffer,
|
||||
from.offset(),
|
||||
&to.buffer,
|
||||
to.offset(),
|
||||
to.size(),
|
||||
);
|
||||
|
||||
self.queue.submit(Some(encoder.finish()));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -156,7 +156,7 @@ fn create_client(
|
|||
WgpuServer<SimpleMemoryManagement<WgpuStorage>>,
|
||||
MutexComputeChannel<WgpuServer<SimpleMemoryManagement<WgpuStorage>>>,
|
||||
> {
|
||||
let storage = WgpuStorage::new(device_wgpu.clone());
|
||||
let storage = WgpuStorage::new(device_wgpu.clone(), queue.clone());
|
||||
let memory_management =
|
||||
SimpleMemoryManagement::new(storage, options.dealloc_strategy, options.slice_strategy);
|
||||
let server = WgpuServer::new(memory_management, device_wgpu, queue, options.tasks_max);
|
||||
|
|
Loading…
Reference in New Issue