mirror of https://github.com/tracel-ai/burn.git
Remove GraphicsAPI generic for WgpuRuntime (#1888)
This commit is contained in:
parent
eead748e90
commit
ac9f942a46
|
@ -62,14 +62,9 @@ macro_rules! bench_on_backend {
|
|||
|
||||
#[cfg(feature = "wgpu")]
|
||||
{
|
||||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
use burn::backend::wgpu::{Wgpu, WgpuDevice};
|
||||
|
||||
bench::<Wgpu<AutoGraphicsApi, f32, i32>>(
|
||||
&WgpuDevice::default(),
|
||||
feature_name,
|
||||
url,
|
||||
token,
|
||||
);
|
||||
bench::<Wgpu<f32, i32>>(&WgpuDevice::default(), feature_name, url, token);
|
||||
}
|
||||
|
||||
#[cfg(feature = "tch-gpu")]
|
||||
|
|
|
@ -198,7 +198,7 @@ the raw `WgpuBackend` type.
|
|||
|
||||
```rust, ignore
|
||||
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G>, F, I> {
|
||||
impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
|
||||
fn fused_matmul_add_relu<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
|
|
|
@ -11,13 +11,13 @@ entrypoint of our program, namely the `main` function defined in `src/main.rs`.
|
|||
#
|
||||
use crate::{model::ModelConfig, training::TrainingConfig};
|
||||
use burn::{
|
||||
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
|
||||
backend::{Autodiff, Wgpu},
|
||||
# data::dataset::Dataset,
|
||||
optim::AdamConfig,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
@ -32,10 +32,9 @@ fn main() {
|
|||
|
||||
In this example, we use the `Wgpu` backend which is compatible with any operating system and will
|
||||
use the GPU. For other options, see the Burn README. This backend type takes the graphics API, the
|
||||
float type and the int type as generic arguments that will be used during the training. By leaving
|
||||
the graphics API as `AutoGraphicsApi`, it should automatically use an API available on your machine.
|
||||
The autodiff backend is simply the same backend, wrapped within the `Autodiff` struct which imparts
|
||||
differentiability to any backend.
|
||||
float type and the int type as generic arguments that will be used during the training. The autodiff
|
||||
backend is simply the same backend, wrapped within the `Autodiff` struct which imparts differentiability \
|
||||
to any backend.
|
||||
|
||||
We call the `train` function defined earlier with a directory for artifacts, the configuration of
|
||||
the model (the number of digit classes is 10 and the hidden dimension is 512), the optimizer
|
||||
|
|
|
@ -56,13 +56,13 @@ Add the call to `infer` to the `main.rs` file after the `train` function call:
|
|||
#
|
||||
# use crate::{model::ModelConfig, training::TrainingConfig};
|
||||
# use burn::{
|
||||
# backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
|
||||
# backend::{Autodiff, Wgpu},
|
||||
# data::dataset::Dataset,
|
||||
# optim::AdamConfig,
|
||||
# };
|
||||
#
|
||||
# fn main() {
|
||||
# type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||
# type MyBackend = Wgpu<f32, i32>;
|
||||
# type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
#
|
||||
# let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
|
|
@ -16,12 +16,12 @@ The backend supports Vulkan, Metal, DirectX11/12, OpenGL, WebGPU.
|
|||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use burn_autodiff::Autodiff;
|
||||
use burn_wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
use burn_wgpu::{Wgpu, WgpuDevice};
|
||||
use mnist::training;
|
||||
|
||||
pub fn run() {
|
||||
let device = WgpuDevice::default();
|
||||
training::run::<Autodiff<Wgpu<AutoGraphicsApi, f32, i32>>>(device);
|
||||
training::run::<Autodiff<Wgpu<f32, i32>>>(device);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -37,6 +37,21 @@ pub use burn_jit::{tensor::JitTensor, JitBackend};
|
|||
/// - [Metal] on Apple hardware.
|
||||
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
|
||||
///
|
||||
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
|
||||
/// you have to manually initialize the runtime. For example:
|
||||
///
|
||||
/// ```rust, ignore
|
||||
/// fn custom_init() {
|
||||
/// let device = Default::default();
|
||||
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
|
||||
/// &device,
|
||||
/// Default::default(),
|
||||
/// );
|
||||
/// }
|
||||
/// ```
|
||||
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
|
||||
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This version of the [wgpu] backend uses [burn_fusion] to compile and optimize streams of tensor
|
||||
|
@ -44,8 +59,7 @@ pub use burn_jit::{tensor::JitTensor, JitBackend};
|
|||
///
|
||||
/// You can disable the `fusion` feature flag to remove that functionality, which might be
|
||||
/// necessary on `wasm` for now.
|
||||
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
|
||||
burn_fusion::Fusion<JitBackend<WgpuRuntime<G>, F, I>>;
|
||||
pub type Wgpu<F = f32, I = i32> = burn_fusion::Fusion<JitBackend<WgpuRuntime, F, I>>;
|
||||
|
||||
#[cfg(not(feature = "fusion"))]
|
||||
/// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders.
|
||||
|
@ -57,6 +71,21 @@ pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
|
|||
/// - [Metal] on Apple hardware.
|
||||
/// - [WebGPU](crate::WebGpu) on supported browsers and `wasm` runtimes.
|
||||
///
|
||||
/// To configure the wgpu backend, eg. to select what graphics API to use or what memory strategy to use,
|
||||
/// you have to manually initialize the runtime. For example:
|
||||
///
|
||||
/// ```rust, ignore
|
||||
/// fn custom_init() {
|
||||
/// let device = Default::default();
|
||||
/// burn::backend::wgpu::init_sync::<burn::backend::wgpu::Vulkan>(
|
||||
/// &device,
|
||||
/// Default::default(),
|
||||
/// );
|
||||
/// }
|
||||
/// ```
|
||||
/// will mean the given device (in this case the default) will be initialized to use Vulkan as the graphics API.
|
||||
/// It's also possible to use an existing wgpu device, by using `init_existing_device`.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This version of the [wgpu] backend doesn't use [burn_fusion] to compile and optimize streams of tensor
|
||||
|
@ -64,13 +93,13 @@ pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> =
|
|||
///
|
||||
/// You can enable the `fusion` feature flag to add that functionality, which might improve
|
||||
/// performance.
|
||||
pub type Wgpu<G = AutoGraphicsApi, F = f32, I = i32> = JitBackend<WgpuRuntime<G>, F, I>;
|
||||
pub type Wgpu<F = f32, I = i32> = JitBackend<WgpuRuntime, F, I>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
pub type TestRuntime = crate::WgpuRuntime<AutoGraphicsApi>;
|
||||
pub type TestRuntime = crate::WgpuRuntime;
|
||||
|
||||
burn_jit::testgen_all!();
|
||||
burn_cube::testgen_all!();
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use crate::{
|
||||
compiler::wgsl,
|
||||
compute::{WgpuServer, WgpuStorage},
|
||||
GraphicsApi, WgpuDevice,
|
||||
AutoGraphicsApi, GraphicsApi, WgpuDevice,
|
||||
};
|
||||
use alloc::sync::Arc;
|
||||
use burn_common::stub::RwLock;
|
||||
|
@ -15,21 +15,16 @@ use burn_compute::{
|
|||
use burn_cube::Runtime;
|
||||
use burn_jit::JitRuntime;
|
||||
use burn_tensor::backend::{DeviceId, DeviceOps};
|
||||
use std::{
|
||||
marker::PhantomData,
|
||||
sync::atomic::{AtomicBool, Ordering},
|
||||
};
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use wgpu::{AdapterInfo, DeviceDescriptor};
|
||||
|
||||
/// Runtime that uses the [wgpu] crate with the wgsl compiler.
|
||||
///
|
||||
/// The [graphics api](GraphicsApi) type is passed as generic.
|
||||
/// Runtime that uses the [wgpu] crate with the wgsl compiler. This is used in the Wgpu backend.
|
||||
/// For advanced configuration, use [`init_sync`] to pass in runtime options or to select a
|
||||
/// specific graphics API.
|
||||
#[derive(Debug)]
|
||||
pub struct WgpuRuntime<G: GraphicsApi> {
|
||||
_g: PhantomData<G>,
|
||||
}
|
||||
pub struct WgpuRuntime {}
|
||||
|
||||
impl<G: GraphicsApi> JitRuntime for WgpuRuntime<G> {
|
||||
impl JitRuntime for WgpuRuntime {
|
||||
type JitDevice = WgpuDevice;
|
||||
type JitServer = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
|
||||
}
|
||||
|
@ -42,7 +37,7 @@ type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
|
|||
|
||||
static SUBGROUP: AtomicBool = AtomicBool::new(false);
|
||||
|
||||
impl<G: GraphicsApi> Runtime for WgpuRuntime<G> {
|
||||
impl Runtime for WgpuRuntime {
|
||||
type Compiler = wgsl::WgslCompiler;
|
||||
type Server = WgpuServer<SimpleMemoryManagement<WgpuStorage>>;
|
||||
|
||||
|
@ -51,7 +46,8 @@ impl<G: GraphicsApi> Runtime for WgpuRuntime<G> {
|
|||
|
||||
fn client(device: &Self::Device) -> ComputeClient<Self::Server, Self::Channel> {
|
||||
RUNTIME.client(device, move || {
|
||||
let (adapter, device_wgpu, queue) = pollster::block_on(create_wgpu_setup::<G>(device));
|
||||
let (adapter, device_wgpu, queue) =
|
||||
pollster::block_on(create_wgpu_setup::<AutoGraphicsApi>(device));
|
||||
create_client(adapter, device_wgpu, queue, RuntimeOptions::default())
|
||||
})
|
||||
}
|
||||
|
@ -125,14 +121,13 @@ pub fn init_existing_device(
|
|||
device_id
|
||||
}
|
||||
|
||||
/// Init the client sync, useful to configure the runtime options.
|
||||
/// Initialize a client on the given device with the given options. This function is useful to configure the runtime options
|
||||
/// or to pick a different graphics API. On wasm, it is necessary to use [`init_async`] instead.
|
||||
pub fn init_sync<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
|
||||
let (adapter, device_wgpu, queue) = pollster::block_on(create_wgpu_setup::<G>(device));
|
||||
let client = create_client(adapter, device_wgpu, queue, options);
|
||||
RUNTIME.register(device, client)
|
||||
pollster::block_on(init_async::<G>(device, options));
|
||||
}
|
||||
|
||||
/// Init the client async, necessary for wasm.
|
||||
/// Like [`init_sync`], but async, necessary for wasm.
|
||||
pub async fn init_async<G: GraphicsApi>(device: &WgpuDevice, options: RuntimeOptions) {
|
||||
let (adapter, device_wgpu, queue) = create_wgpu_setup::<G>(device).await;
|
||||
let client = create_client(adapter, device_wgpu, queue, options);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn::{
|
||||
backend::wgpu::{AutoGraphicsApi, WgpuRuntime},
|
||||
backend::wgpu::WgpuRuntime,
|
||||
tensor::{Distribution, Tensor},
|
||||
};
|
||||
use custom_wgpu_kernel::{
|
||||
|
@ -71,7 +71,7 @@ fn autodiff<B: AutodiffBackend>(device: &B::Device) {
|
|||
}
|
||||
|
||||
fn main() {
|
||||
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime<AutoGraphicsApi>, f32, i32>;
|
||||
type MyBackend = burn::backend::wgpu::JitBackend<WgpuRuntime, f32, i32>;
|
||||
type MyAutodiffBackend = burn::backend::Autodiff<MyBackend>;
|
||||
let device = Default::default();
|
||||
inference::<MyBackend>(&device);
|
||||
|
|
|
@ -9,15 +9,12 @@ use burn::{
|
|||
ops::{broadcast_shape, Backward, Ops, OpsKind},
|
||||
Autodiff, NodeID,
|
||||
},
|
||||
wgpu::{FloatElement, GraphicsApi, IntElement, JitBackend, WgpuRuntime},
|
||||
wgpu::{FloatElement, IntElement, JitBackend, WgpuRuntime},
|
||||
},
|
||||
tensor::Shape,
|
||||
};
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> AutodiffBackend
|
||||
for Autodiff<JitBackend<WgpuRuntime<G>, F, I>>
|
||||
{
|
||||
}
|
||||
impl<F: FloatElement, I: IntElement> AutodiffBackend for Autodiff<JitBackend<WgpuRuntime, F, I>> {}
|
||||
|
||||
// Implement our custom backend trait for any backend that also implements our custom backend trait.
|
||||
//
|
||||
|
|
|
@ -3,8 +3,8 @@ use crate::FloatTensor;
|
|||
use super::Backend;
|
||||
use burn::{
|
||||
backend::wgpu::{
|
||||
build_info, into_contiguous, kernel_wgsl, CubeCount, CubeDim, FloatElement, GraphicsApi,
|
||||
IntElement, JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime,
|
||||
build_info, into_contiguous, kernel_wgsl, CubeCount, CubeDim, FloatElement, IntElement,
|
||||
JitBackend, JitTensor, KernelSource, SourceKernel, SourceTemplate, WgpuRuntime,
|
||||
},
|
||||
tensor::Shape,
|
||||
};
|
||||
|
@ -36,7 +36,7 @@ impl<E: FloatElement> KernelSource for FusedMatmulAddRelu<E> {
|
|||
}
|
||||
|
||||
/// Implement our custom backend trait for the existing backend `WgpuBackend`.
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime<G>, F, I> {
|
||||
impl<F: FloatElement, I: IntElement> Backend for JitBackend<WgpuRuntime, F, I> {
|
||||
fn fused_matmul_add_relu<const D: usize>(
|
||||
lhs: FloatTensor<Self, D>,
|
||||
rhs: FloatTensor<Self, D>,
|
||||
|
|
|
@ -5,13 +5,13 @@ mod training;
|
|||
|
||||
use crate::{model::ModelConfig, training::TrainingConfig};
|
||||
use burn::{
|
||||
backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu},
|
||||
backend::{Autodiff, Wgpu},
|
||||
data::dataset::Dataset,
|
||||
optim::AdamConfig,
|
||||
};
|
||||
|
||||
fn main() {
|
||||
type MyBackend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||
type MyBackend = Wgpu<f32, i32>;
|
||||
type MyAutodiffBackend = Autodiff<MyBackend>;
|
||||
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
|
|
@ -34,7 +34,7 @@ pub enum ModelType {
|
|||
WithNdArrayBackend(Model<NdArray<f32>>),
|
||||
|
||||
/// The model is loaded to the Wgpu backend
|
||||
WithWgpuBackend(Model<Wgpu<AutoGraphicsApi, f32, i32>>),
|
||||
WithWgpuBackend(Model<Wgpu<f32, i32>>),
|
||||
}
|
||||
|
||||
/// The image is 224x224 pixels with 3 channels (RGB)
|
||||
|
|
|
@ -8,7 +8,7 @@ use burn::{
|
|||
use burn::backend::wgpu::{init_async, AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
|
||||
#[cfg(feature = "wgpu")]
|
||||
pub type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||
pub type Backend = Wgpu<f32, i32>;
|
||||
|
||||
#[cfg(all(feature = "ndarray", not(feature = "wgpu")))]
|
||||
pub type Backend = burn::backend::ndarray::NdArray<f32>;
|
||||
|
|
|
@ -74,10 +74,10 @@ mod tch_cpu {
|
|||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use crate::{launch, ElemType};
|
||||
use burn::backend::wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice};
|
||||
use burn::backend::wgpu::{Wgpu, WgpuDevice};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Wgpu<AutoGraphicsApi, ElemType, i32>>(WgpuDevice::default());
|
||||
launch::<Wgpu<ElemType, i32>>(WgpuDevice::default());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -85,12 +85,12 @@ mod tch_cpu {
|
|||
mod wgpu {
|
||||
use crate::{launch, ElemType};
|
||||
use burn::backend::{
|
||||
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
|
||||
wgpu::{Wgpu, WgpuDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
|
||||
launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![WgpuDevice::default()]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -82,14 +82,14 @@ mod tch_cpu {
|
|||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use burn::backend::{
|
||||
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
|
||||
wgpu::{Wgpu, WgpuDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(WgpuDevice::default());
|
||||
launch::<Autodiff<Wgpu<ElemType, i32>>>(WgpuDevice::default());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -81,14 +81,14 @@ mod tch_cpu {
|
|||
#[cfg(feature = "wgpu")]
|
||||
mod wgpu {
|
||||
use burn::backend::{
|
||||
wgpu::{AutoGraphicsApi, Wgpu, WgpuDevice},
|
||||
wgpu::{Wgpu, WgpuDevice},
|
||||
Autodiff,
|
||||
};
|
||||
|
||||
use crate::{launch, ElemType};
|
||||
|
||||
pub fn run() {
|
||||
launch::<Autodiff<Wgpu<AutoGraphicsApi, ElemType, i32>>>(vec![WgpuDevice::default()]);
|
||||
launch::<Autodiff<Wgpu<ElemType, i32>>>(vec![WgpuDevice::default()]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue