mirror of https://github.com/tracel-ai/burn.git
Dynamic memory management preset + updated wgpu buffer memory management (#1962)
--------- Co-authored-by: mepatrick73 <pameu17@ulaval.ca>
This commit is contained in:
parent
5236e12c81
commit
51aea94a30
|
@ -1,15 +1,22 @@
|
|||
use std::collections::LinkedList;
|
||||
|
||||
use burn_compute::{
|
||||
memory_management::{dynamic::DynamicMemoryManagement, MemoryManagement},
|
||||
memory_management::{
|
||||
dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions},
|
||||
MemoryManagement,
|
||||
},
|
||||
storage::BytesStorage,
|
||||
};
|
||||
|
||||
const MB: usize = 1024 * 1024;
|
||||
|
||||
fn main() {
|
||||
let start = std::time::Instant::now();
|
||||
let storage = BytesStorage::default();
|
||||
let mut mm = DynamicMemoryManagement::new(storage);
|
||||
let mut mm = DynamicMemoryManagement::new(
|
||||
storage,
|
||||
DynamicMemoryManagementOptions::preset(2048 * MB, 32),
|
||||
);
|
||||
let mut handles = LinkedList::new();
|
||||
for _ in 0..100 * 2048 {
|
||||
if handles.len() >= 4000 {
|
||||
|
|
|
@ -3,39 +3,107 @@ use super::memory_pool::{
|
|||
SmallMemoryPool,
|
||||
};
|
||||
use crate::storage::ComputeStorage;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
use super::MemoryManagement;
|
||||
|
||||
/// Reserves and keeps track of chunks of memory in the storage, and slices upon these chunks.
|
||||
pub struct DynamicMemoryManagement<Storage> {
|
||||
min_chunk_alignment_offset: usize,
|
||||
small_memory_pool: SmallMemoryPool,
|
||||
small_medium_memory_pool: MemoryPool,
|
||||
medium_memory_pool: MemoryPool,
|
||||
main_memory_pool: MemoryPool,
|
||||
pools: Vec<MemoryPool>,
|
||||
options: Vec<MemoryPoolOptions>,
|
||||
storage: Storage,
|
||||
}
|
||||
|
||||
/// Options to initialize a [dynamic memory management](DynamicMemoryManagement).
|
||||
#[derive(new, Debug)]
|
||||
pub struct DynamicMemoryManagementOptions {
|
||||
pools: Vec<MemoryPoolOptions>,
|
||||
min_chunk_alignment_offset: usize,
|
||||
}
|
||||
|
||||
/// Options to create a memory pool.
|
||||
#[derive(Debug)]
|
||||
pub struct MemoryPoolOptions {
|
||||
/// The amount of bytes used for each chunk in the memory pool.
|
||||
pub chunk_size: usize,
|
||||
/// The number of chunks allocated directly at creation.
|
||||
///
|
||||
/// Useful when you know in advance how much memory you'll need.
|
||||
pub chunk_num_prealloc: usize,
|
||||
/// The max size in bytes a slice can take in the pool.
|
||||
pub slice_max_size: usize,
|
||||
}
|
||||
|
||||
impl DynamicMemoryManagementOptions {
|
||||
/// Creates the options from device limits.
|
||||
pub fn preset(max_chunk_size: usize, min_chunk_alignment_offset: usize) -> Self {
|
||||
// Rounding down to a factor of 8.
|
||||
let max_chunk_size = (max_chunk_size / 8) * 8;
|
||||
|
||||
const MB: usize = 1024 * 1024;
|
||||
|
||||
let mut pools = Vec::new();
|
||||
|
||||
pools.push(MemoryPoolOptions {
|
||||
chunk_size: max_chunk_size,
|
||||
chunk_num_prealloc: 0,
|
||||
slice_max_size: max_chunk_size,
|
||||
});
|
||||
|
||||
let mut current = max_chunk_size;
|
||||
|
||||
while current >= 32 * MB {
|
||||
current /= 4;
|
||||
|
||||
pools.push(MemoryPoolOptions {
|
||||
chunk_size: current,
|
||||
chunk_num_prealloc: 0,
|
||||
// Creating max slices lower than the chunk size reduces fragmentation.
|
||||
slice_max_size: current / 2usize.pow(pools.len() as u32),
|
||||
});
|
||||
}
|
||||
|
||||
Self {
|
||||
pools,
|
||||
min_chunk_alignment_offset,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<Storage: ComputeStorage> DynamicMemoryManagement<Storage> {
|
||||
/// Creates a new instance using the given storage, merging_strategy strategy and slice strategy.
|
||||
pub fn new(storage: Storage) -> Self {
|
||||
let main_memory_pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::new_period_tick(10),
|
||||
RoundingStrategy::FixedAmount(1024 * 1024 * 1024),
|
||||
);
|
||||
let medium_memory_pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::Never,
|
||||
RoundingStrategy::FixedAmount(1024 * 1024 * 200),
|
||||
);
|
||||
let small_medium_memory_pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::Never,
|
||||
RoundingStrategy::FixedAmount(1024 * 1024 * 2),
|
||||
);
|
||||
let small_memory_pool = SmallMemoryPool::new();
|
||||
pub fn new(mut storage: Storage, mut options: DynamicMemoryManagementOptions) -> Self {
|
||||
options
|
||||
.pools
|
||||
.sort_by(|pool1, pool2| usize::cmp(&pool1.slice_max_size, &pool2.slice_max_size));
|
||||
|
||||
let min_chunk_alignment_offset = options.min_chunk_alignment_offset;
|
||||
|
||||
let pools = options
|
||||
.pools
|
||||
.iter()
|
||||
.map(|option| {
|
||||
let mut pool = MemoryPool::new(
|
||||
MemoryExtensionStrategy::Never,
|
||||
RoundingStrategy::FixedAmount(option.chunk_size),
|
||||
min_chunk_alignment_offset,
|
||||
);
|
||||
|
||||
for _ in 0..option.chunk_num_prealloc {
|
||||
pool.alloc(&mut storage, option.chunk_size, || {});
|
||||
}
|
||||
|
||||
pool
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
small_memory_pool,
|
||||
small_medium_memory_pool,
|
||||
main_memory_pool,
|
||||
medium_memory_pool,
|
||||
min_chunk_alignment_offset,
|
||||
small_memory_pool: SmallMemoryPool::new(min_chunk_alignment_offset),
|
||||
pools,
|
||||
options: options.pools,
|
||||
storage,
|
||||
}
|
||||
}
|
||||
|
@ -62,50 +130,45 @@ impl<Storage: ComputeStorage> MemoryManagement<Storage> for DynamicMemoryManagem
|
|||
return handle;
|
||||
}
|
||||
|
||||
if let Some(handle) = self
|
||||
.small_medium_memory_pool
|
||||
.get(&mut self.storage, &binding)
|
||||
{
|
||||
return handle;
|
||||
for pool in &mut self.pools {
|
||||
if let Some(handle) = pool.get(&mut self.storage, &binding) {
|
||||
return handle;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(handle) = self.medium_memory_pool.get(&mut self.storage, &binding) {
|
||||
return handle;
|
||||
}
|
||||
|
||||
if let Some(handle) = self.main_memory_pool.get(&mut self.storage, &binding) {
|
||||
return handle;
|
||||
}
|
||||
|
||||
panic!("No handle found in the small and main memory pool");
|
||||
panic!("No handle found in memory pools");
|
||||
}
|
||||
|
||||
fn reserve<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
|
||||
if size <= 32 {
|
||||
self.small_memory_pool
|
||||
.reserve(&mut self.storage, size, sync)
|
||||
} else if size <= 2 * 1024 * 1024 {
|
||||
self.small_medium_memory_pool
|
||||
.reserve(&mut self.storage, size, sync)
|
||||
} else if size < 200 * 1024 * 1024 {
|
||||
self.medium_memory_pool
|
||||
.reserve(&mut self.storage, size, sync)
|
||||
} else {
|
||||
self.main_memory_pool.reserve(&mut self.storage, size, sync)
|
||||
if size <= self.min_chunk_alignment_offset {
|
||||
return self
|
||||
.small_memory_pool
|
||||
.reserve(&mut self.storage, size, sync);
|
||||
}
|
||||
|
||||
for (index, option) in self.options.iter().enumerate() {
|
||||
if size <= option.slice_max_size {
|
||||
let pool = &mut self.pools[index];
|
||||
return pool.reserve(&mut self.storage, size, sync);
|
||||
}
|
||||
}
|
||||
|
||||
panic!("No memory pool big enough to reserve {size} bytes.");
|
||||
}
|
||||
|
||||
fn alloc<Sync: FnOnce()>(&mut self, size: usize, sync: Sync) -> Self::Handle {
|
||||
if size <= 32 {
|
||||
self.small_memory_pool.alloc(&mut self.storage, size, sync)
|
||||
} else if size <= 2 * 1024 * 1024 {
|
||||
self.small_medium_memory_pool
|
||||
.alloc(&mut self.storage, size, sync)
|
||||
} else if size <= 200 * 1024 * 1024 {
|
||||
self.medium_memory_pool.alloc(&mut self.storage, size, sync)
|
||||
} else {
|
||||
self.main_memory_pool.alloc(&mut self.storage, size, sync)
|
||||
if size <= self.min_chunk_alignment_offset {
|
||||
return self.small_memory_pool.alloc(&mut self.storage, size, sync);
|
||||
}
|
||||
|
||||
for (index, option) in self.options.iter().enumerate() {
|
||||
if size <= option.slice_max_size {
|
||||
let pool = &mut self.pools[index];
|
||||
return pool.alloc(&mut self.storage, size, sync);
|
||||
}
|
||||
}
|
||||
|
||||
panic!("No memory pool big enough to alloc {size} bytes.");
|
||||
}
|
||||
|
||||
fn dealloc(&mut self, _binding: Self::Binding) {
|
||||
|
|
|
@ -17,6 +17,7 @@ pub struct MemoryPool {
|
|||
ring: RingBuffer<Chunk, Slice>,
|
||||
recently_added_chunks: Vec<ChunkId>,
|
||||
recently_allocated_size: usize,
|
||||
buffer_alignment: usize,
|
||||
}
|
||||
|
||||
struct SliceUpdate {
|
||||
|
@ -111,33 +112,16 @@ impl Slice {
|
|||
}
|
||||
|
||||
const MIN_SIZE_NEEDED_TO_OFFSET: usize = 16;
|
||||
const BUFFER_ALIGNMENT: usize = 32;
|
||||
const MB: usize = 1024 * 1024;
|
||||
|
||||
pub enum RoundingStrategy {
|
||||
FixedAmount(usize),
|
||||
#[allow(unused)]
|
||||
RoundUp,
|
||||
#[allow(unused)]
|
||||
None,
|
||||
}
|
||||
|
||||
impl RoundingStrategy {
|
||||
fn alloc_size(&self, size: usize) -> usize {
|
||||
match self {
|
||||
RoundingStrategy::RoundUp => {
|
||||
if size < BUFFER_ALIGNMENT {
|
||||
return BUFFER_ALIGNMENT;
|
||||
}
|
||||
if size < MB {
|
||||
2 * MB
|
||||
} else if size < 10 * MB {
|
||||
return 20 * MB;
|
||||
} else {
|
||||
let factor = (size + (2 * MB - 1)) / (2 * MB);
|
||||
factor * 2 * MB
|
||||
}
|
||||
}
|
||||
RoundingStrategy::FixedAmount(chunk_size) => {
|
||||
assert!(*chunk_size >= size);
|
||||
*chunk_size
|
||||
|
@ -148,6 +132,7 @@ impl RoundingStrategy {
|
|||
}
|
||||
|
||||
/// The strategy defines the frequency at which merging of free slices (defragmentation) occurs
|
||||
#[allow(unused)]
|
||||
#[derive(Debug)]
|
||||
pub enum MemoryExtensionStrategy {
|
||||
/// Once every n calls to reserve.
|
||||
|
@ -161,6 +146,7 @@ pub enum MemoryExtensionStrategy {
|
|||
Never,
|
||||
}
|
||||
|
||||
#[allow(unused)]
|
||||
impl MemoryExtensionStrategy {
|
||||
/// Create a new strategy with the given period.
|
||||
pub fn new_period_tick(period: usize) -> Self {
|
||||
|
@ -183,6 +169,7 @@ impl MemoryPool {
|
|||
pub fn new(
|
||||
merging_strategy: MemoryExtensionStrategy,
|
||||
alloc_strategy: RoundingStrategy,
|
||||
buffer_alignment: usize,
|
||||
) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
|
@ -190,9 +177,10 @@ impl MemoryPool {
|
|||
memory_extension_strategy: merging_strategy,
|
||||
rounding: alloc_strategy,
|
||||
chunk_index: SearchIndex::new(),
|
||||
ring: RingBuffer::new(),
|
||||
ring: RingBuffer::new(buffer_alignment),
|
||||
recently_added_chunks: Vec::new(),
|
||||
recently_allocated_size: 0,
|
||||
buffer_alignment,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -335,7 +323,7 @@ impl MemoryPool {
|
|||
return None;
|
||||
}
|
||||
|
||||
let padding = calculate_padding(size);
|
||||
let padding = calculate_padding(size, self.buffer_alignment);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let slice_id =
|
||||
|
@ -364,9 +352,10 @@ impl MemoryPool {
|
|||
/// Creates a slice of size `size` upon the given chunk with the given offset.
|
||||
fn create_slice(&self, offset: usize, size: usize, handle_chunk: ChunkHandle) -> Slice {
|
||||
assert_eq!(
|
||||
offset % BUFFER_ALIGNMENT,
|
||||
offset % self.buffer_alignment,
|
||||
0,
|
||||
"slice with offset {offset} needs to be a multiple of {BUFFER_ALIGNMENT}"
|
||||
"slice with offset {offset} needs to be a multiple of {}",
|
||||
self.buffer_alignment
|
||||
);
|
||||
if offset > 0 && size < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
panic!("tried to create slice of size {size} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
|
||||
|
@ -379,7 +368,7 @@ impl MemoryPool {
|
|||
utilization: StorageUtilization::Slice { offset, size },
|
||||
};
|
||||
|
||||
let padding = calculate_padding(size);
|
||||
let padding = calculate_padding(size, self.buffer_alignment);
|
||||
|
||||
Slice::new(storage, handle, chunk.handle.clone(), padding)
|
||||
}
|
||||
|
@ -390,7 +379,7 @@ impl MemoryPool {
|
|||
storage: &mut Storage,
|
||||
size: usize,
|
||||
) -> ChunkHandle {
|
||||
let padding = calculate_padding(size);
|
||||
let padding = calculate_padding(size, self.buffer_alignment);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let storage = storage.alloc(effective_size);
|
||||
|
@ -496,7 +485,7 @@ impl MemoryPool {
|
|||
}
|
||||
let chunk_size = chunk.storage.size();
|
||||
let last_slice_size = chunk_size - offset;
|
||||
assert_eq!(last_slice_size % BUFFER_ALIGNMENT, 0);
|
||||
assert_eq!(last_slice_size % self.buffer_alignment, 0);
|
||||
if last_slice_size != 0 {
|
||||
self.create_slice(offset, last_slice_size, chunk_handle);
|
||||
}
|
||||
|
@ -505,10 +494,10 @@ impl MemoryPool {
|
|||
}
|
||||
}
|
||||
|
||||
fn calculate_padding(size: usize) -> usize {
|
||||
let remainder = size % BUFFER_ALIGNMENT;
|
||||
fn calculate_padding(size: usize, buffer_alignment: usize) -> usize {
|
||||
let remainder = size % buffer_alignment;
|
||||
if remainder != 0 {
|
||||
BUFFER_ALIGNMENT - remainder
|
||||
buffer_alignment - remainder
|
||||
} else {
|
||||
0
|
||||
}
|
||||
|
@ -523,7 +512,7 @@ impl MemorySlice for Slice {
|
|||
self.effective_size()
|
||||
}
|
||||
|
||||
fn split(&mut self, offset_slice: usize) -> Option<Self> {
|
||||
fn split(&mut self, offset_slice: usize, buffer_alignment: usize) -> Option<Self> {
|
||||
let size_new = self.effective_size() - offset_slice;
|
||||
let offset_new = self.storage.offset() + offset_slice;
|
||||
let old_size = self.effective_size();
|
||||
|
@ -544,22 +533,22 @@ impl MemorySlice for Slice {
|
|||
if offset_new > 0 && size_new < MIN_SIZE_NEEDED_TO_OFFSET {
|
||||
panic!("tried to create slice of size {size_new} with an offset while the size needs to atleast be of size {MIN_SIZE_NEEDED_TO_OFFSET} for offset support");
|
||||
}
|
||||
if offset_new % BUFFER_ALIGNMENT != 0 {
|
||||
panic!("slice with offset {offset_new} needs to be a multiple of {BUFFER_ALIGNMENT}");
|
||||
if offset_new % buffer_alignment != 0 {
|
||||
panic!("slice with offset {offset_new} needs to be a multiple of {buffer_alignment}");
|
||||
}
|
||||
let handle = SliceHandle::new();
|
||||
if size_new < BUFFER_ALIGNMENT {
|
||||
if size_new < buffer_alignment {
|
||||
self.padding = old_size - offset_slice;
|
||||
assert_eq!(self.effective_size(), old_size);
|
||||
return None;
|
||||
}
|
||||
|
||||
assert!(
|
||||
size_new >= BUFFER_ALIGNMENT,
|
||||
"Size new > {BUFFER_ALIGNMENT}"
|
||||
size_new >= buffer_alignment,
|
||||
"Size new > {buffer_alignment}"
|
||||
);
|
||||
self.padding = 0;
|
||||
let padding = calculate_padding(size_new - BUFFER_ALIGNMENT);
|
||||
let padding = calculate_padding(size_new - buffer_alignment, buffer_alignment);
|
||||
Some(Slice::new(storage_new, handle, self.chunk.clone(), padding))
|
||||
}
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ pub struct RingBuffer<C: MemoryChunk<S>, S: MemorySlice> {
|
|||
cursor_chunk: usize,
|
||||
_s: PhantomData<S>,
|
||||
_c: PhantomData<C>,
|
||||
buffer_alignment: usize,
|
||||
}
|
||||
|
||||
pub trait MemoryChunk<S: MemorySlice> {
|
||||
|
@ -24,13 +25,13 @@ pub trait MemoryChunk<S: MemorySlice> {
|
|||
pub trait MemorySlice: Sized {
|
||||
fn is_free(&self) -> bool;
|
||||
fn size(&self) -> usize;
|
||||
fn split(&mut self, offset: usize) -> Option<Self>;
|
||||
fn split(&mut self, offset: usize, buffer_alignment: usize) -> Option<Self>;
|
||||
fn id(&self) -> SliceId;
|
||||
fn next_slice_position(&self) -> usize;
|
||||
}
|
||||
|
||||
impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(buffer_alignment: usize) -> Self {
|
||||
Self {
|
||||
queue: Vec::new(),
|
||||
chunk_positions: HashMap::new(),
|
||||
|
@ -38,6 +39,7 @@ impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
|
|||
cursor_chunk: 0,
|
||||
_s: PhantomData,
|
||||
_c: PhantomData,
|
||||
buffer_alignment,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -56,6 +58,8 @@ impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
|
|||
for (pos, id) in self.queue.iter().enumerate() {
|
||||
self.chunk_positions.insert(*id, pos);
|
||||
}
|
||||
self.cursor_chunk = 0;
|
||||
self.cursor_slice = 0;
|
||||
}
|
||||
|
||||
pub fn find_free_slice(
|
||||
|
@ -93,7 +97,7 @@ impl<C: MemoryChunk<S>, S: MemorySlice> RingBuffer<C, S> {
|
|||
|
||||
if is_big_enough && is_free {
|
||||
if slice.size() > size {
|
||||
if let Some(new_slice) = slice.split(size) {
|
||||
if let Some(new_slice) = slice.split(size, self.buffer_alignment) {
|
||||
let new_slice_id = new_slice.id();
|
||||
chunk.insert_slice(slice.next_slice_position(), new_slice, slices);
|
||||
slices.get(&new_slice_id).unwrap();
|
||||
|
@ -163,7 +167,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn simple_1() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 200, 1);
|
||||
|
@ -184,7 +188,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn simple_2() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 200, 1);
|
||||
|
@ -205,7 +209,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn multiple_chunks() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 200, 1);
|
||||
|
@ -240,7 +244,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn find_free_slice_with_exact_fit() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 200, 1);
|
||||
|
@ -264,7 +268,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn find_free_slice_with_merging() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 100, 0);
|
||||
let slice_2 = new_slice(1, 50, 1);
|
||||
|
@ -294,7 +298,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn find_free_slice_with_multiple_chunks_and_merging() {
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new();
|
||||
let mut ring = RingBuffer::<TestChunk, TestSlice>::new(0);
|
||||
|
||||
let slice_1 = new_slice(0, 50, 0);
|
||||
let slice_2 = new_slice(1, 50, 1);
|
||||
|
@ -372,7 +376,7 @@ mod stub {
|
|||
self.size
|
||||
}
|
||||
|
||||
fn split(&mut self, offset: usize) -> Option<Self> {
|
||||
fn split(&mut self, offset: usize, _buffer_alignment: usize) -> Option<Self> {
|
||||
let size_remained = self.size - offset;
|
||||
self.size = offset;
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ pub struct SmallMemoryPool {
|
|||
slices: HashMap<SliceId, SmallSlice>,
|
||||
ring_buffer: Vec<ChunkId>,
|
||||
index: usize,
|
||||
buffer_storage_alignment_offset: usize,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
|
@ -43,15 +44,14 @@ impl SmallSlice {
|
|||
}
|
||||
}
|
||||
|
||||
const BUFFER_ALIGNMENT: usize = 32;
|
||||
|
||||
impl SmallMemoryPool {
|
||||
pub fn new() -> Self {
|
||||
pub fn new(buffer_storage_alignment_offset: usize) -> Self {
|
||||
Self {
|
||||
chunks: HashMap::new(),
|
||||
slices: HashMap::new(),
|
||||
ring_buffer: Vec::new(),
|
||||
index: 0,
|
||||
buffer_storage_alignment_offset,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -77,7 +77,7 @@ impl SmallMemoryPool {
|
|||
size: usize,
|
||||
sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
assert!(size <= BUFFER_ALIGNMENT);
|
||||
assert!(size <= self.buffer_storage_alignment_offset);
|
||||
let slice = self.get_free_slice(size);
|
||||
|
||||
match slice {
|
||||
|
@ -94,7 +94,7 @@ impl SmallMemoryPool {
|
|||
size: usize,
|
||||
_sync: Sync,
|
||||
) -> MemoryPoolHandle {
|
||||
assert!(size <= BUFFER_ALIGNMENT);
|
||||
assert!(size <= self.buffer_storage_alignment_offset);
|
||||
|
||||
self.alloc_slice(storage, size)
|
||||
}
|
||||
|
@ -104,7 +104,7 @@ impl SmallMemoryPool {
|
|||
storage: &mut Storage,
|
||||
slice_size: usize,
|
||||
) -> MemoryPoolHandle {
|
||||
let handle_chunk = self.create_chunk(storage, BUFFER_ALIGNMENT);
|
||||
let handle_chunk = self.create_chunk(storage, self.buffer_storage_alignment_offset);
|
||||
let chunk_id = *handle_chunk.id();
|
||||
let slice = self.allocate_slice(handle_chunk.clone(), slice_size);
|
||||
|
||||
|
@ -120,7 +120,7 @@ impl SmallMemoryPool {
|
|||
let slice = self.create_slice(0, slice_size, handle_chunk.clone());
|
||||
|
||||
let effective_size = slice.effective_size();
|
||||
assert_eq!(effective_size, BUFFER_ALIGNMENT);
|
||||
assert_eq!(effective_size, self.buffer_storage_alignment_offset);
|
||||
|
||||
slice
|
||||
}
|
||||
|
@ -184,7 +184,7 @@ impl SmallMemoryPool {
|
|||
utilization: StorageUtilization::Slice { offset, size },
|
||||
};
|
||||
|
||||
let padding = calculate_padding(size);
|
||||
let padding = calculate_padding(size, self.buffer_storage_alignment_offset);
|
||||
|
||||
SmallSlice::new(storage, handle, chunk.handle.clone(), padding)
|
||||
}
|
||||
|
@ -195,7 +195,7 @@ impl SmallMemoryPool {
|
|||
storage: &mut Storage,
|
||||
size: usize,
|
||||
) -> ChunkHandle {
|
||||
let padding = calculate_padding(size);
|
||||
let padding = calculate_padding(size, self.buffer_storage_alignment_offset);
|
||||
let effective_size = size + padding;
|
||||
|
||||
let storage = storage.alloc(effective_size);
|
||||
|
@ -216,10 +216,10 @@ impl SmallMemoryPool {
|
|||
}
|
||||
}
|
||||
|
||||
fn calculate_padding(size: usize) -> usize {
|
||||
let remainder = size % BUFFER_ALIGNMENT;
|
||||
fn calculate_padding(size: usize, buffer_storage_alignment_offset: usize) -> usize {
|
||||
let remainder = size % buffer_storage_alignment_offset;
|
||||
if remainder != 0 {
|
||||
BUFFER_ALIGNMENT - remainder
|
||||
buffer_storage_alignment_offset - remainder
|
||||
} else {
|
||||
0
|
||||
}
|
||||
|
|
|
@ -8,7 +8,7 @@ use burn_common::stub::RwLock;
|
|||
use burn_compute::{
|
||||
channel::MutexComputeChannel,
|
||||
client::ComputeClient,
|
||||
memory_management::simple::{DeallocStrategy, SimpleMemoryManagement, SliceStrategy},
|
||||
memory_management::dynamic::{DynamicMemoryManagement, DynamicMemoryManagementOptions},
|
||||
tune::Tuner,
|
||||
ComputeRuntime,
|
||||
};
|
||||
|
@ -25,20 +25,20 @@ pub struct WgpuRuntime {}
|
|||
|
||||
impl JitRuntime for WgpuRuntime {
|
||||
type JitDevice = WgpuDevice;
|
||||
type JitServer = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
|
||||
type JitServer = WgpuServer<DynamicMemoryManagement<WgpuStorage>>;
|
||||
}
|
||||
|
||||
/// The compute instance is shared across all [wgpu runtimes](WgpuRuntime).
|
||||
static RUNTIME: ComputeRuntime<WgpuDevice, Server, MutexComputeChannel<Server>> =
|
||||
ComputeRuntime::new();
|
||||
|
||||
type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
|
||||
type Server = WgpuServer<DynamicMemoryManagement<WgpuStorage>>;
|
||||
|
||||
impl Runtime for WgpuRuntime {
|
||||
type Compiler = wgsl::WgslCompiler;
|
||||
type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
|
||||
type Server = WgpuServer<DynamicMemoryManagement<WgpuStorage>>;
|
||||
|
||||
type Channel = MutexComputeChannel<WgpuServer<SimpleMemoryManagement<WgpuStorage>>>;
|
||||
type Channel = MutexComputeChannel<WgpuServer<DynamicMemoryManagement<WgpuStorage>>>;
|
||||
type Device = WgpuDevice;
|
||||
|
||||
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
|
||||
|
@ -79,10 +79,6 @@ impl DeviceOps for WgpuDevice {
|
|||
|
||||
/// The values that control how a WGPU Runtime will perform its calculations.
|
||||
pub struct RuntimeOptions {
|
||||
/// How the buffers are deallocated.
|
||||
pub dealloc_strategy: DeallocStrategy,
|
||||
/// Control the slicing strategy.
|
||||
pub slice_strategy: SliceStrategy,
|
||||
/// Control the amount of compute tasks to be aggregated into a single GPU command.
|
||||
pub tasks_max: usize,
|
||||
}
|
||||
|
@ -98,11 +94,7 @@ impl Default for RuntimeOptions {
|
|||
Err(_) => DEFAULT_MAX_TASKS,
|
||||
};
|
||||
|
||||
Self {
|
||||
dealloc_strategy: DeallocStrategy::new_period_tick(tasks_max * 2),
|
||||
slice_strategy: SliceStrategy::Ratio(0.8),
|
||||
tasks_max,
|
||||
}
|
||||
Self { tasks_max }
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -162,12 +154,18 @@ fn create_client(
|
|||
features: Arc<FeatureSet>,
|
||||
options: RuntimeOptions,
|
||||
) -> ComputeClient<
|
||||
WgpuServer<SimpleMemoryManagement<WgpuStorage>>,
|
||||
MutexComputeChannel<WgpuServer<SimpleMemoryManagement<WgpuStorage>>>,
|
||||
WgpuServer<DynamicMemoryManagement<WgpuStorage>>,
|
||||
MutexComputeChannel<WgpuServer<DynamicMemoryManagement<WgpuStorage>>>,
|
||||
> {
|
||||
let limits = device_wgpu.limits();
|
||||
let storage = WgpuStorage::new(device_wgpu.clone(), queue.clone());
|
||||
let memory_management =
|
||||
SimpleMemoryManagement::new(storage, options.dealloc_strategy, options.slice_strategy);
|
||||
let memory_management = DynamicMemoryManagement::new(
|
||||
storage,
|
||||
DynamicMemoryManagementOptions::preset(
|
||||
limits.max_storage_buffer_binding_size as usize,
|
||||
limits.min_storage_buffer_offset_alignment as usize,
|
||||
),
|
||||
);
|
||||
let server = WgpuServer::new(memory_management, device_wgpu, queue, options.tasks_max);
|
||||
let channel = MutexComputeChannel::new(server);
|
||||
let tuner_device_id = tuner_device_id(adapter.get_info());
|
||||
|
|
Loading…
Reference in New Issue