2023-06-02 23:52:47 +08:00
|
|
|
# Burn WGPU Backend
|
|
|
|
|
2023-12-02 03:33:28 +08:00
|
|
|
[Burn](https://github.com/tracel-ai/burn) WGPU backend
|
2023-07-25 22:44:53 +08:00
|
|
|
|
|
|
|
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-wgpu.svg)](https://crates.io/crates/burn-wgpu)
|
2023-12-02 03:33:28 +08:00
|
|
|
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/tracel-ai/burn-wgpu/blob/master/README.md)
|
2023-07-25 22:44:53 +08:00
|
|
|
|
2023-12-02 03:33:28 +08:00
|
|
|
This crate provides a WGPU backend for [Burn](https://github.com/tracel-ai/burn) using the
|
2023-11-18 02:04:41 +08:00
|
|
|
[wgpu](https://github.com/gfx-rs/wgpu).
|
2023-07-25 22:44:53 +08:00
|
|
|
|
|
|
|
The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.
|
|
|
|
|
|
|
|
## Usage Example
|
|
|
|
|
|
|
|
```rust
|
|
|
|
#[cfg(feature = "wgpu")]
|
|
|
|
mod wgpu {
|
2023-10-30 06:27:49 +08:00
|
|
|
use burn_autodiff::Autodiff;
|
2024-06-17 21:04:25 +08:00
|
|
|
use burn_wgpu::{Wgpu, WgpuDevice};
|
2023-07-25 22:44:53 +08:00
|
|
|
use mnist::training;
|
|
|
|
|
|
|
|
pub fn run() {
|
|
|
|
let device = WgpuDevice::default();
|
2024-06-17 21:04:25 +08:00
|
|
|
training::run::<Autodiff<Wgpu<f32, i32>>>(device);
|
2023-07-25 22:44:53 +08:00
|
|
|
}
|
|
|
|
}
|
|
|
|
```
|
2023-08-24 00:20:27 +08:00
|
|
|
|
|
|
|
## Configuration
|
|
|
|
|
2024-10-21 23:39:23 +08:00
|
|
|
You can set `BURN_WGPU_MAX_TASKS` to a positive integer that determines how many computing tasks are
|
|
|
|
submitted in batches to the graphics API.
|
|
|
|
|
|
|
|
## Alternative SPIR-V backend
|
|
|
|
|
|
|
|
When targeting Vulkan, the `spirv` feature flag can be enabled to enable the SPIR-V compiler backend,
|
|
|
|
which performs significantly better than WGSL. This is especially true for matrix multiplication,
|
|
|
|
where SPIR-V can make use of TensorCores and run at `f16` precision. This isn't currently supported
|
|
|
|
by WGSL.
|
|
|
|
The compiler can also be selected at runtime by setting the corresponding generic parameter to
|
|
|
|
either `SpirV` or `Wgsl`.
|
2023-11-18 02:04:41 +08:00
|
|
|
|
|
|
|
## Platform Support
|
|
|
|
|
|
|
|
| Option | CPU | GPU | Linux | MacOS | Windows | Android | iOS | WASM |
|
|
|
|
| :-------- | :-: | :-: | :---: | :---: | :-----: | :-----: | :-: | :--: |
|
|
|
|
| Metal | No | Yes | No | Yes | No | No | Yes | No |
|
|
|
|
| Vulkan | Yes | Yes | Yes | Yes | Yes | Yes | Yes | No |
|
|
|
|
| OpenGL | No | Yes | Yes | Yes | Yes | Yes | Yes | No |
|
|
|
|
| WebGpu | No | Yes | No | No | No | No | No | Yes |
|
|
|
|
| Dx11/Dx12 | No | Yes | No | No | Yes | No | No | No |
|