diff --git a/Cargo.lock b/Cargo.lock index ac3acac8b..1955b6b40 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -721,6 +721,12 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "cfg_aliases" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fd16c4719339c4530435d38e511904438d07cce7950afa3718a84ac36c10e89e" + [[package]] name = "cipher" version = "0.4.4" @@ -803,10 +809,35 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] -name = "com-rs" -version = "0.2.1" +name = "com" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf43edc576402991846b093a7ca18a3477e0ef9c588cde84964b5d3e43016642" +checksum = "7e17887fd17353b65b1b2ef1c526c83e26cd72e74f598a8dc1bee13a48f3d9f6" +dependencies = [ + "com_macros", +] + +[[package]] +name = "com_macros" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d375883580a668c7481ea6631fc1a8863e33cc335bf56bfad8d7e6d4b04b13a5" +dependencies = [ + "com_macros_support", + "proc-macro2", + "syn 1.0.109", +] + +[[package]] +name = "com_macros_support" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad899a1087a9296d5644792d7cb72b8e34c1bec8e7d4fbc002230169a6e8710c" +dependencies = [ + "proc-macro2", + "quote", + "syn 1.0.109", +] [[package]] name = "console" @@ -1046,9 +1077,9 @@ dependencies = [ [[package]] name = "d3d12" -version = "0.7.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e16e44ab292b1dddfdaf7be62cfd8877df52f2f3fde5858d95bab606be259f20" +checksum = "3e3d747f100290a1ca24b752186f61f6637e1deffe3bf6320de6fcb29510a307" dependencies = [ "bitflags 2.4.2", "libloading 0.8.1", @@ -1359,7 +1390,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "279d3efcc55e19917fff7ab3ddd6c14afb6a90881a0078465196fe2f99d08c56" dependencies = [ "bit_field", - "flume 0.10.14", + "flume", "half", "lebe", "miniz_oxide", @@ -1458,18 +1489,6 @@ dependencies = [ "spin", ] -[[package]] -name = "flume" -version = "0.11.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "55ac459de2512911e4b674ce33cf20befaba382d05b62b008afc1c8b57cbf181" -dependencies = [ - "futures-core", - "futures-sink", - "nanorand", - "spin", -] - [[package]] name = "fnv" version = "1.0.7" @@ -1925,11 +1944,10 @@ dependencies = [ [[package]] name = "gpu-allocator" -version = "0.23.0" +version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40fe17c8a05d60c38c0a4e5a3c802f2f1ceb66b76c67d96ffb34bef0475a7fad" +checksum = "6f56f6318968d03c18e1bcf4857ff88c61157e9da8e47c5f29055d60e1228884" dependencies = [ - "backtrace", "log", "presser", "thiserror", @@ -2037,14 +2055,14 @@ dependencies = [ [[package]] name = "hassle-rs" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1397650ee315e8891a0df210707f0fc61771b0cc518c3023896064c5407cb3b0" +checksum = "af2a7e73e1f34c48da31fb668a907f250794837e08faa144fd24f0b8b741e890" dependencies = [ - "bitflags 1.3.2", - "com-rs", + "bitflags 2.4.2", + "com", "libc", - "libloading 0.7.4", + "libloading 0.8.1", "thiserror", "widestring", "winapi", @@ -2646,9 +2664,9 @@ dependencies = [ [[package]] name = "naga" -version = "0.14.2" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae585df4b6514cf8842ac0f1ab4992edc975892704835b549cf818dc0191249e" +checksum = "8878eb410fc90853da3908aebfe61d73d26d4437ef850b70050461f939509899" dependencies = [ "bit-set", "bitflags 2.4.2", @@ -3380,9 +3398,9 @@ dependencies = [ [[package]] name = "raw-window-handle" -version = "0.5.2" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2ff9a1f06a88b01621b7ae906ef0211290d1c8a168a15542486a8f61c0833b9" +checksum = "42a9830a0e1b9fb145ebb365b8bc4ccd75f290f98c0247deafbbe2c75cefb544" [[package]] name = "rawpointer" @@ -4039,12 +4057,11 @@ dependencies = [ [[package]] name = "spirv" -version = "0.2.0+1.5.4" +version = "0.3.0+sdk-1.3.268.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "246bfa38fe3db3f1dfc8ca5a2cdeb7348c78be2112740cc0ec8ef18b6d94f830" +checksum = "eda41003dc44290527a59b13432d4a0379379fa074b70174882adfbdfd917844" dependencies = [ - "bitflags 1.3.2", - "num-traits", + "bitflags 2.4.2", ] [[package]] @@ -4833,9 +4850,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.64" +version = "0.3.67" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b85cbef8c220a6abc02aefd892dfc0fc23afb1c6a426316ec33253a3877249b" +checksum = "58cd2333b6e0be7a39605f0e255892fd7418a682d8da8fe042fe25128794d2ed" dependencies = [ "js-sys", "wasm-bindgen", @@ -4865,13 +4882,13 @@ checksum = "53a85b86a771b1c87058196170769dd264f66c0782acf1ae6cc51bfd64b39082" [[package]] name = "wgpu" -version = "0.18.0" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30e7d227c9f961f2061c26f4cb0fbd4df0ef37e056edd0931783599d6c94ef24" +checksum = "0bfe9a310dcf2e6b85f00c46059aaeaf4184caa8e29a1ecd4b7a704c3482332d" dependencies = [ "arrayvec", "cfg-if", - "flume 0.11.0", + "cfg_aliases", "js-sys", "log", "naga", @@ -4890,16 +4907,19 @@ dependencies = [ [[package]] name = "wgpu-core" -version = "0.18.1" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ef91c1d62d1e9e81c79e600131a258edf75c9531cbdbde09c44a011a47312726" +checksum = "6b15e451d4060ada0d99a64df44e4d590213496da7c4f245572d51071e8e30ed" dependencies = [ "arrayvec", "bit-vec", "bitflags 2.4.2", + "cfg_aliases", "codespan-reporting", + "indexmap 2.2.1", "log", "naga", + "once_cell", "parking_lot 0.12.1", "profiling", "raw-window-handle", @@ -4913,9 +4933,9 @@ dependencies = [ [[package]] name = "wgpu-hal" -version = "0.18.1" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b84ecc802da3eb67b4cf3dd9ea6fe45bbb47ef13e6c49c5c3240868a9cc6cdd9" +checksum = "e3bb47856236bfafc0bc591a925eb036ac19cd987624a447ff353e7a7e7e6f72" dependencies = [ "android_system_properties", "arrayvec", @@ -4923,6 +4943,7 @@ dependencies = [ "bit-set", "bitflags 2.4.2", "block", + "cfg_aliases", "core-graphics-types", "d3d12", "glow", @@ -4956,9 +4977,9 @@ dependencies = [ [[package]] name = "wgpu-types" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d5ed5f0edf0de351fe311c53304986315ce866f394a2e6df0c4b3c70774bcdd" +checksum = "895fcbeb772bfb049eb80b2d6e47f6c9af235284e9703c96fc0218a42ffd5af2" dependencies = [ "bitflags 2.4.2", "js-sys", diff --git a/Cargo.toml b/Cargo.toml index f78a2841f..94630b79f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -105,7 +105,7 @@ crossterm = "0.27" futures-intrusive = "0.5" text_placeholder = "0.5.0" pollster = "0.3" -wgpu = "0.18.0" +wgpu = "0.19.1" bincode = { version = "2.0.0-rc.3", features = [ "alloc", diff --git a/burn-wgpu/src/compute/base.rs b/burn-wgpu/src/compute/base.rs index fa3554d2c..16a4b1643 100644 --- a/burn-wgpu/src/compute/base.rs +++ b/burn-wgpu/src/compute/base.rs @@ -88,8 +88,8 @@ pub async fn select_device( .request_device( &DeviceDescriptor { label: None, - features: wgpu::Features::empty(), - limits, + required_features: wgpu::Features::empty(), + required_limits: limits, }, None, ) @@ -130,6 +130,7 @@ fn select_adapter(device: &WgpuDevice) -> wgpu::Adapter { instance .enumerate_adapters(G::backend().into()) + .into_iter() .for_each(|adapter| { let device_type = adapter.get_info().device_type; diff --git a/burn-wgpu/src/compute/server.rs b/burn-wgpu/src/compute/server.rs index c33286f4d..5cabe1df6 100644 --- a/burn-wgpu/src/compute/server.rs +++ b/burn-wgpu/src/compute/server.rs @@ -9,7 +9,7 @@ use burn_tensor::Reader; use hashbrown::HashMap; use wgpu::{ util::{BufferInitDescriptor, DeviceExt}, - BindGroup, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, + BindGroup, BindGroupLayout, CommandEncoder, ComputePipeline, ShaderModuleDescriptor, }; /// Wgpu compute server. @@ -19,16 +19,22 @@ pub struct WgpuServer> { device: Arc, queue: wgpu::Queue, encoder: CommandEncoder, - pipelines: HashMap>, + kernels: HashMap>, tasks: Vec, max_tasks: usize, manual_available: HashMap>>, manual_taken: Vec<(usize, server::Handle)>, } +#[derive(new, Debug)] +struct CachedKernel { + pipeline: ComputePipeline, + layout: BindGroupLayout, +} + #[derive(new, Debug)] struct ComputeTask { - pipeline: Arc, + pipeline: Arc, bind_group: BindGroup, work_group: WorkGroup, } @@ -94,7 +100,7 @@ where device, queue, encoder, - pipelines: HashMap::new(), + kernels: HashMap::new(), tasks: Vec::new(), max_tasks, manual_available: HashMap::new(), @@ -170,7 +176,7 @@ where }); for task in self.tasks.iter() { - compute.set_pipeline(&task.pipeline); + compute.set_pipeline(&task.pipeline.pipeline); compute.set_bind_group(0, &task.bind_group, &[]); compute.dispatch_workgroups(task.work_group.x, task.work_group.y, task.work_group.z); } @@ -179,34 +185,37 @@ where self.tasks.clear(); } - fn pipeline(&mut self, kernel: Box) -> Arc { + fn kernel(&mut self, kernel: Box) -> Arc { let kernel_id = kernel.id(); - if let Some(pipeline) = self.pipelines.get(&kernel_id) { + if let Some(pipeline) = self.kernels.get(&kernel_id) { return pipeline.clone(); } let source = kernel.source().complete(); let pipeline = self.compile_source(&source); - self.pipelines.insert(kernel_id.clone(), pipeline.clone()); + self.kernels.insert(kernel_id.clone(), pipeline.clone()); pipeline } - fn compile_source(&self, source: &str) -> Arc { + fn compile_source(&self, source: &str) -> Arc { let module = self.device.create_shader_module(ShaderModuleDescriptor { label: None, source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(source)), }); - Arc::new( - self.device - .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { - label: None, - layout: None, - module: &module, - entry_point: "main", - }), - ) + let pipeline = self + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: None, + layout: None, + module: &module, + entry_point: "main", + }); + + let bind_group = pipeline.get_bind_group_layout(0); + + Arc::new(CachedKernel::new(pipeline, bind_group)) } fn buffer_reader(&mut self, handle: &server::Handle) -> BufferReader { @@ -332,8 +341,7 @@ where fn execute(&mut self, kernel: Self::Kernel, handles: &[&server::Handle]) { let work_group = kernel.workgroup(); - let pipeline = self.pipeline(kernel); - let group_layout = pipeline.get_bind_group_layout(0); + let kernel = self.kernel(kernel); let handles = handles .iter() @@ -351,12 +359,12 @@ where let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { label: None, - layout: &group_layout, + layout: &kernel.layout, entries: &entries, }); self.tasks - .push(ComputeTask::new(pipeline, bind_group, work_group)); + .push(ComputeTask::new(kernel, bind_group, work_group)); if self.tasks.len() >= self.max_tasks { self.register_tasks();