mirror of https://github.com/tracel-ai/burn.git
Speedup client.create for small allocations. (#1858)
* Speedup client.create for small allocations.
This commit is contained in:
parent
675f6b3280
commit
75e26d03c3
|
@ -1,3 +1,5 @@
|
|||
use std::num::NonZeroU64;
|
||||
|
||||
use super::WgpuStorage;
|
||||
use alloc::{borrow::Cow, sync::Arc};
|
||||
use burn_compute::{
|
||||
|
@ -9,10 +11,15 @@ use burn_jit::JitAutotuneKey;
|
|||
use burn_tensor::Reader;
|
||||
use hashbrown::HashMap;
|
||||
use wgpu::{
|
||||
util::{BufferInitDescriptor, DeviceExt},
|
||||
util::{BufferInitDescriptor, DeviceExt, StagingBelt},
|
||||
BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor,
|
||||
};
|
||||
|
||||
// Allocations with existing data smaller than this can use a staging belt
|
||||
// which speeds up the allocation. A higher number here will catch more
|
||||
// allocations, but can also increase memory usage.
|
||||
const SMALL_ALLOC_SIZE: usize = 512;
|
||||
|
||||
/// Wgpu compute server.
|
||||
#[derive(Debug)]
|
||||
pub struct WgpuServer<MM: MemoryManagement<WgpuStorage>> {
|
||||
|
@ -20,6 +27,7 @@ pub struct WgpuServer<MM: MemoryManagement<WgpuStorage>> {
|
|||
device: Arc<wgpu::Device>,
|
||||
queue: Arc<wgpu::Queue>,
|
||||
encoder: CommandEncoder,
|
||||
staging_belt: StagingBelt,
|
||||
pipelines: HashMap<String, Arc<ComputePipeline>>,
|
||||
tasks_max: usize,
|
||||
tasks_count: usize,
|
||||
|
@ -45,6 +53,7 @@ where
|
|||
device,
|
||||
queue,
|
||||
encoder,
|
||||
staging_belt: StagingBelt::new(SMALL_ALLOC_SIZE as u64),
|
||||
pipelines: HashMap::new(),
|
||||
tasks_max,
|
||||
tasks_count: 0,
|
||||
|
@ -52,6 +61,8 @@ where
|
|||
}
|
||||
|
||||
fn submit(&mut self) {
|
||||
self.staging_belt.finish();
|
||||
|
||||
let mut new_encoder = self
|
||||
.device
|
||||
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
|
@ -62,6 +73,8 @@ where
|
|||
|
||||
// Cleanup allocations and deallocations.
|
||||
self.memory_management.storage().perform_deallocations();
|
||||
|
||||
self.staging_belt.recall();
|
||||
}
|
||||
|
||||
fn register_compute(
|
||||
|
@ -212,24 +225,42 @@ where
|
|||
/// fully utilize the GPU.
|
||||
fn create(&mut self, data: &[u8]) -> server::Handle<Self> {
|
||||
let handle = server::Handle::new(self.memory_management.reserve(data.len()));
|
||||
let binding = handle.clone().binding();
|
||||
let non_zero_len = NonZeroU64::new(data.len() as u64);
|
||||
|
||||
let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor {
|
||||
label: Some("Buffer Src"),
|
||||
contents: data,
|
||||
usage: wgpu::BufferUsages::COPY_SRC,
|
||||
}));
|
||||
// If there's nothing to copy, don't need to do any work here.
|
||||
if let Some(len) = non_zero_len {
|
||||
let binding = handle.clone().binding();
|
||||
let resource = self.memory_management.get(binding.memory);
|
||||
|
||||
let resource = self.memory_management.get(binding.memory);
|
||||
|
||||
self.encoder.copy_buffer_to_buffer(
|
||||
&buffer_src,
|
||||
0,
|
||||
&resource.buffer,
|
||||
resource.offset(),
|
||||
buffer_src.size(),
|
||||
);
|
||||
self.tasks_count += 1;
|
||||
if data.len() < SMALL_ALLOC_SIZE {
|
||||
// Use a staging belt if the allocation is small enough. This is faster than allocating a new buffer.
|
||||
// Ideally, we could use queue.write_buffer_with(), which seems to be the recommended method for performance,
|
||||
// but that doesn't seem to work, as we might re-use a buffer multiple times, and need to schedule this
|
||||
// precisely in the encoder.
|
||||
let mut write_buf = self.staging_belt.write_buffer(
|
||||
&mut self.encoder,
|
||||
&resource.buffer,
|
||||
0,
|
||||
len,
|
||||
&self.device,
|
||||
);
|
||||
write_buf.copy_from_slice(data);
|
||||
} else {
|
||||
let buffer_src = Arc::new(self.device.create_buffer_init(&BufferInitDescriptor {
|
||||
label: Some("Buffer Src"),
|
||||
contents: data,
|
||||
usage: wgpu::BufferUsages::COPY_SRC,
|
||||
}));
|
||||
self.encoder.copy_buffer_to_buffer(
|
||||
&buffer_src,
|
||||
0,
|
||||
&resource.buffer,
|
||||
resource.offset(),
|
||||
buffer_src.size(),
|
||||
);
|
||||
}
|
||||
self.tasks_count += 1;
|
||||
}
|
||||
|
||||
handle
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue