From ce2429eb107dc09a997a3302721631722c385b60 Mon Sep 17 00:00:00 2001 From: Guillaume Lagrange Date: Fri, 26 Apr 2024 08:53:55 -0400 Subject: [PATCH] Refactor element type to be decoupled from runtime (#1693) --- crates/burn-jit/src/backend.rs | 25 ++++++---- crates/burn-jit/src/bridge.rs | 39 +++++++++------ crates/burn-jit/src/codegen/compiler.rs | 11 ----- .../src/codegen/dialect/gpu/macros.rs | 4 +- .../src/codegen/dialect/gpu/shader.rs | 23 +++++++-- crates/burn-jit/src/codegen/kernel.rs | 4 +- crates/burn-jit/src/element.rs | 4 +- crates/burn-jit/src/fusion/base.rs | 18 ++++--- .../burn-jit/src/fusion/elemwise/builder.rs | 31 ++++++++---- .../src/fusion/elemwise/optimization.rs | 19 ++++--- crates/burn-jit/src/fusion/kernel.rs | 26 +++++++--- crates/burn-jit/src/fusion/tracing/builder.rs | 4 +- crates/burn-jit/src/fusion/tracing/trace.rs | 5 +- .../src/kernel/conv/conv_transpose2d.rs | 20 ++++---- crates/burn-jit/src/kernel/index/gather.rs | 6 +-- crates/burn-jit/src/kernel/index/scatter.rs | 4 +- crates/burn-jit/src/kernel/index/select.rs | 4 +- .../src/kernel/index/select_assign.rs | 4 +- .../src/kernel/interpolate/bicubic.rs | 29 ++++++----- .../src/kernel/interpolate/nearest.rs | 23 ++++++--- .../kernel/interpolate/nearest_backward.rs | 22 ++++++--- crates/burn-jit/src/kernel/matmul/simple.rs | 29 ++++++----- crates/burn-jit/src/kernel/matmul/tiling2d.rs | 31 +++++++----- .../pool/adaptive_avg_pool2d_backward.rs | 22 ++++++--- .../src/kernel/pool/adaptive_pool2d_shader.rs | 10 ++-- .../src/kernel/pool/avg_pool2d_backward.rs | 34 ++++++------- .../src/kernel/pool/max_pool2d_backward.rs | 40 +++++++-------- .../burn-jit/src/kernel/pool/pool2d_shader.rs | 9 ++-- crates/burn-jit/src/kernel/prng/bernoulli.rs | 2 +- crates/burn-jit/src/kernel/prng/normal.rs | 7 +-- crates/burn-jit/src/kernel/prng/uniform.rs | 7 +-- crates/burn-jit/src/ops/activation_ops.rs | 4 +- crates/burn-jit/src/ops/bool_ops.rs | 9 +++- crates/burn-jit/src/ops/float_ops.rs | 9 +++- crates/burn-jit/src/ops/int_ops.rs | 9 +++- crates/burn-jit/src/ops/module_ops.rs | 9 +++- crates/burn-jit/src/runtime.rs | 14 ------ crates/burn-jit/src/tests/mod.rs | 4 +- .../burn-wgpu/src/compiler/wgsl/compiler.rs | 49 +++++-------------- crates/burn-wgpu/src/lib.rs | 6 +-- crates/burn-wgpu/src/runtime.rs | 14 ++---- .../examples/custom-wgpu-kernel.rs | 2 +- examples/custom-wgpu-kernel/src/backward.rs | 2 +- examples/custom-wgpu-kernel/src/forward.rs | 2 +- 44 files changed, 365 insertions(+), 284 deletions(-) diff --git a/crates/burn-jit/src/backend.rs b/crates/burn-jit/src/backend.rs index e0784358f..3fd650fcc 100644 --- a/crates/burn-jit/src/backend.rs +++ b/crates/burn-jit/src/backend.rs @@ -1,4 +1,4 @@ -use crate::{codegen::Compiler, tensor::JitTensor, PrecisionBridge, Runtime}; +use crate::{tensor::JitTensor, FloatElement, IntElement, PrecisionBridge, Runtime}; use burn_tensor::backend::Backend; use rand::{rngs::StdRng, SeedableRng}; use std::{marker::PhantomData, sync::Mutex}; @@ -7,16 +7,23 @@ pub(crate) static SEED: Mutex> = Mutex::new(None); /// Generic tensor backend that can be compiled just-in-time to any shader runtime #[derive(new)] -pub struct JitBackend { +pub struct JitBackend { _runtime: PhantomData, + _float_elem: PhantomData, + _int_elem: PhantomData, } -impl Backend for JitBackend { +impl Backend for JitBackend +where + R: Runtime, + F: FloatElement, + I: IntElement, +{ type Device = R::Device; - type FullPrecisionBridge = PrecisionBridge; - type FloatElem = ::Float; - type IntElem = ::Int; + type FullPrecisionBridge = PrecisionBridge; + type FloatElem = F; + type IntElem = I; type FloatTensorPrimitive = JitTensor; type IntTensorPrimitive = JitTensor; @@ -42,19 +49,19 @@ impl Backend for JitBackend { } } -impl core::fmt::Debug for JitBackend { +impl core::fmt::Debug for JitBackend { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_fmt(format_args!("JitBackend {{ runtime: {}}}", R::name())) } } -impl Clone for JitBackend { +impl Clone for JitBackend { fn clone(&self) -> Self { Self::new() } } -impl Default for JitBackend { +impl Default for JitBackend { fn default() -> Self { Self::new() } diff --git a/crates/burn-jit/src/bridge.rs b/crates/burn-jit/src/bridge.rs index 4e1cf7e89..ba0f32d29 100644 --- a/crates/burn-jit/src/bridge.rs +++ b/crates/burn-jit/src/bridge.rs @@ -1,4 +1,6 @@ -use crate::{kernel, ops::to_device, tensor::JitTensor, JitBackend, Runtime}; +use crate::{ + kernel, ops::to_device, tensor::JitTensor, FloatElement, IntElement, JitBackend, Runtime, +}; use burn_tensor::{ backend::BackendBridge, ops::{FloatElem, FloatTensor}, @@ -7,26 +9,31 @@ use core::marker::PhantomData; /// Handle precision conversion for the jit backend. #[derive(Debug)] -pub struct PrecisionBridge { +pub struct PrecisionBridge { _runtime: PhantomData, + _float_elem: PhantomData, + _int_elem: PhantomData, } -impl BackendBridge> for PrecisionBridge +impl BackendBridge> + for PrecisionBridge where - ROrigin: Runtime, - RTarget: - Runtime, + R: Runtime, + FOrigin: FloatElement, + IOrigin: IntElement, + FTarget: FloatElement, + ITarget: IntElement, { - type Target = JitBackend; + type Target = JitBackend; fn into_target( - tensor: FloatTensor, D>, + tensor: FloatTensor, D>, device: Option>, ) -> FloatTensor { let tensor = kernel::cast::< - ROrigin, - FloatElem>, - FloatElem>, + R, + FloatElem>, + FloatElem>, D, >(tensor); @@ -42,12 +49,12 @@ where fn from_target( tensor: FloatTensor, - device: Option>>, - ) -> FloatTensor, D> { + device: Option>>, + ) -> FloatTensor, D> { let tensor = kernel::cast::< - RTarget, - FloatElem>, - FloatElem>, + R, + FloatElem>, + FloatElem>, D, >(tensor); // The line below does the backend type cast. diff --git a/crates/burn-jit/src/codegen/compiler.rs b/crates/burn-jit/src/codegen/compiler.rs index 250fecfe9..4b3a59d56 100644 --- a/crates/burn-jit/src/codegen/compiler.rs +++ b/crates/burn-jit/src/codegen/compiler.rs @@ -1,5 +1,4 @@ use super::dialect::gpu; -use crate::{FloatElement, IntElement}; use std::fmt::Display; /// Compiles the [gpu representation](gpu::ComputeShader) into its own representation that can be @@ -7,16 +6,6 @@ use std::fmt::Display; pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug { /// The representation for the compiled code. type Representation: Display; - /// The float element type used for compilation. - type Float: FloatElement; - /// The int element type used for compilation. - type Int: IntElement; - /// The compiler that can be used to generate full precision shaders. - type FullPrecisionCompiler: Compiler< - Representation = Self::Representation, - Float = f32, - Int = i32, - >; /// Compiles the [gpu shader](gpu::ComputeShader) into the compiler's representation. fn compile(shader: gpu::ComputeShader) -> Self::Representation; diff --git a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs index 4feef5d0e..836d23d40 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/macros.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/macros.rs @@ -390,13 +390,13 @@ impl From for Variable { impl From for Variable { fn from(value: i32) -> Self { - Self::ConstantScalar(value as f64, super::Elem::Int) + Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I32)) } } impl From for Variable { fn from(value: f32) -> Self { - Self::ConstantScalar(value as f64, super::Elem::Float) + Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F32)) } } diff --git a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs index 7fbba14a3..a4651f029 100644 --- a/crates/burn-jit/src/codegen/dialect/gpu/shader.rs +++ b/crates/burn-jit/src/codegen/dialect/gpu/shader.rs @@ -17,11 +17,25 @@ pub enum Visibility { ReadWrite, } +#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)] +#[allow(missing_docs)] +pub enum FloatKind { + F32, + F64, +} + +#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)] +#[allow(missing_docs)] +pub enum IntKind { + I32, + I64, +} + #[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)] #[allow(missing_docs)] pub enum Elem { - Float, - Int, + Float(FloatKind), + Int(IntKind), UInt, Bool, } @@ -35,8 +49,9 @@ impl From for Item { impl Display for Elem { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { - Self::Float => f.write_str("float"), - Self::Int => f.write_str("int"), + // NOTE: we'll eventually want to differentiate between int/float types + Self::Float(_) => f.write_str("float"), + Self::Int(_) => f.write_str("int"), Self::UInt => f.write_str("uint"), Self::Bool => f.write_str("bool"), } diff --git a/crates/burn-jit/src/codegen/kernel.rs b/crates/burn-jit/src/codegen/kernel.rs index 0370658f9..c74acfa87 100644 --- a/crates/burn-jit/src/codegen/kernel.rs +++ b/crates/burn-jit/src/codegen/kernel.rs @@ -312,8 +312,8 @@ fn create_scalar_handles Vec> { // It is crucial that scalars follow this order: float, int, uint let element_priority = |elem: Elem| match elem { - Elem::Float => 0, - Elem::Int => 1, + Elem::Float(_) => 0, + Elem::Int(_) => 1, Elem::UInt => 2, Elem::Bool => panic!("Bool scalars are not supported"), }; diff --git a/crates/burn-jit/src/element.rs b/crates/burn-jit/src/element.rs index ecd6089ab..5fd78e484 100644 --- a/crates/burn-jit/src/element.rs +++ b/crates/burn-jit/src/element.rs @@ -59,7 +59,7 @@ impl JitElement for i32 { bytemuck::cast_slice(bytes) } fn gpu_elem() -> gpu::Elem { - gpu::Elem::Int + gpu::Elem::Int(gpu::IntKind::I32) } fn maximum_value() -> Self { // Seems to cause problem for some GPU @@ -82,7 +82,7 @@ impl JitElement for f32 { bytemuck::cast_slice(bytes) } fn gpu_elem() -> gpu::Elem { - gpu::Elem::Float + gpu::Elem::Float(gpu::FloatKind::F32) } fn maximum_value() -> Self { f32::MAX diff --git a/crates/burn-jit/src/fusion/base.rs b/crates/burn-jit/src/fusion/base.rs index 1cc9cb107..f1a49290e 100644 --- a/crates/burn-jit/src/fusion/base.rs +++ b/crates/burn-jit/src/fusion/base.rs @@ -1,6 +1,7 @@ use super::{ElementWise, ElementWiseState}; use crate::{ - element::JitElement, fusion::ElementWiseBuilder, tensor::JitTensor, JitBackend, Runtime, + element::JitElement, fusion::ElementWiseBuilder, tensor::JitTensor, FloatElement, IntElement, + JitBackend, Runtime, }; use burn_compute::client::ComputeClient; use burn_fusion::{client::MutexFusionClient, FusionBackend}; @@ -25,8 +26,13 @@ pub enum JitOptimizationState { ElementWise(ElementWiseState), } -impl burn_fusion::Optimization> for JitOptimization { - fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitBackend>) { +impl burn_fusion::Optimization> for JitOptimization +where + R: Runtime, + F: FloatElement, + I: IntElement, +{ + fn execute(&mut self, context: &mut burn_fusion::stream::Context<'_, JitBackend>) { match self { Self::ElementWise(op) => op.execute(context), } @@ -53,7 +59,7 @@ impl burn_fusion::Optimization> for JitOptimization } } -impl ReprBackend for JitBackend { +impl ReprBackend for JitBackend { type Handle = JitFusionHandle; fn float_tensor( @@ -96,7 +102,7 @@ impl ReprBackend for JitBackend { } } -impl FusionBackend for JitBackend { +impl FusionBackend for JitBackend { type OptimizationState = JitOptimizationState; type Optimization = JitOptimization; type FusionClient = MutexFusionClient; @@ -104,7 +110,7 @@ impl FusionBackend for JitBackend { fn optimizations( device: R::Device, ) -> Vec>> { - vec![Box::new(ElementWiseBuilder::new(device))] + vec![Box::new(ElementWiseBuilder::::new(device))] } } diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 7958fe7e4..0cd346af7 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -1,3 +1,5 @@ +use core::marker::PhantomData; + use super::{optimization::ElementWise, CompilationPhase}; use crate::{ codegen::dialect::gpu::{ @@ -5,7 +7,7 @@ use crate::{ }, element::JitElement, fusion::{tracing::TraceBuilder, JitOptimization}, - JitBackend, Runtime, + FloatElement, IntElement, JitBackend, Runtime, }; use burn_fusion::{OptimizationBuilder, OptimizationProperties, OptimizationStatus}; use burn_tensor::{ @@ -19,15 +21,22 @@ use burn_tensor::{ }; /// Fused element wise operations that are normally memory bound. -pub(crate) struct ElementWiseBuilder { +pub(crate) struct ElementWiseBuilder { builder: TraceBuilder, current_output_shape: Vec, status: OptimizationStatus, num_added: usize, device: R::Device, + _float_elem: PhantomData, + _int_elem: PhantomData, } -impl OptimizationBuilder> for ElementWiseBuilder { +impl OptimizationBuilder> for ElementWiseBuilder +where + R: Runtime, + F: FloatElement, + I: IntElement, +{ fn register(&mut self, ops: &OperationDescription) { if let OptimizationStatus::Closed = self.status { return; @@ -35,31 +44,31 @@ impl OptimizationBuilder> for ElementWiseBuilder< match ops { OperationDescription::BaseFloat(ops) => { - if !self.register_base::>>(ops) { + if !self.register_base::>>(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::BaseInt(ops) => { - if !self.register_base::>>(ops) { + if !self.register_base::>>(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::Float(ops) => { - if !self.register_float::>>(ops) { + if !self.register_float::>>(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::NumericFloat(ops) => { - if !self.register_numeric::>, _>(ops) { + if !self.register_numeric::>, _>(ops) { self.status = OptimizationStatus::Closed; return; } } OperationDescription::NumericInt(ops) => { - if !self.register_numeric::>, _>(ops) { + if !self.register_numeric::>, _>(ops) { self.status = OptimizationStatus::Closed; return; } @@ -110,14 +119,16 @@ impl OptimizationBuilder> for ElementWiseBuilder< } } -impl ElementWiseBuilder { - pub fn new(device: Device>) -> Self { +impl ElementWiseBuilder { + pub fn new(device: Device>) -> Self { Self { builder: TraceBuilder::new(), num_added: 0, current_output_shape: Vec::new(), status: OptimizationStatus::Open, device, + _float_elem: PhantomData, + _int_elem: PhantomData, } } diff --git a/crates/burn-jit/src/fusion/elemwise/optimization.rs b/crates/burn-jit/src/fusion/elemwise/optimization.rs index d76957e71..ac6010bbd 100644 --- a/crates/burn-jit/src/fusion/elemwise/optimization.rs +++ b/crates/burn-jit/src/fusion/elemwise/optimization.rs @@ -8,7 +8,7 @@ use crate::{ codegen::dialect::gpu::WorkgroupSize, compute::JitAutotuneKey, fusion::{kernel::FusionKernel, tracing::Trace}, - JitBackend, Runtime, + FloatElement, IntElement, JitBackend, Runtime, }; use burn_common::id::IdGenerator; use burn_compute::client::ComputeClient; @@ -66,7 +66,10 @@ impl ElementWise { } impl ElementWise> { - pub(crate) fn execute(&mut self, context: &mut Context<'_, JitBackend>) { + pub(crate) fn execute( + &mut self, + context: &mut Context<'_, JitBackend>, + ) { let client = R::client(&self.device); let key = JitAutotuneKey::FusionElemWise(FusionElemWiseAutotuneKey::new( @@ -81,9 +84,9 @@ impl ElementWise> { } } - fn run_kernel( + fn run_kernel( &mut self, - context: &mut Context<'_, JitBackend>, + context: &mut Context<'_, JitBackend>, client: ComputeClient, fastest_set_index: usize, ) { @@ -106,9 +109,9 @@ impl ElementWise> { kernel.execute(); } - fn run_autotune( + fn run_autotune( &mut self, - context: &mut Context<'_, JitBackend>, + context: &mut Context<'_, JitBackend>, client: ComputeClient, key: JitAutotuneKey, ) { @@ -152,9 +155,9 @@ impl ElementWise> { } /// The first output is chosen when possible, otherwise the first input is chosen. - pub(crate) fn autotune_shape<'a>( + pub(crate) fn autotune_shape<'a, F: FloatElement, I: IntElement>( &self, - context: &mut Context<'a, JitBackend>, + context: &mut Context<'a, JitBackend>, ) -> &'a [usize] { let info = self.trace.running(); diff --git a/crates/burn-jit/src/fusion/kernel.rs b/crates/burn-jit/src/fusion/kernel.rs index 18fc8fe8a..0304bfeb4 100644 --- a/crates/burn-jit/src/fusion/kernel.rs +++ b/crates/burn-jit/src/fusion/kernel.rs @@ -9,6 +9,8 @@ use crate::fusion::strides_dyn_rank; use crate::fusion::JitFusionHandle; use crate::gpu::ComputeShader; use crate::kernel::GpuComputeShaderPhase; +use crate::FloatElement; +use crate::IntElement; use crate::JitBackend; use crate::Runtime; use burn_compute::client::ComputeClient; @@ -106,14 +108,19 @@ impl From> for AutotunableKernel { } impl FusionKernel { - pub fn create>( + pub fn create( factory: &K, running_info: &ExecutionInfo<'_>, - context: &mut Context<'_, JitBackend>, - device: Device>, + context: &mut Context<'_, JitBackend>, + device: Device>, client: ComputeClient, stateful: bool, - ) -> ExecutableKernel { + ) -> ExecutableKernel + where + K: FusionKernelFactory, + F: FloatElement, + I: IntElement, + { let (handles_input, inputs_description_updated, outputs_description_updated) = process_inputs_outputs( &running_info.inputs, @@ -266,16 +273,21 @@ fn register_info_tensor( } } -fn process_inputs_outputs<'a, R: Runtime>( +fn process_inputs_outputs<'a, R, F, I>( inputs: &[&TensorDescription], outputs: &[&TensorDescription], - context: &'a mut Context<'_, JitBackend>, + context: &'a mut Context<'_, JitBackend>, stateful: bool, ) -> ( Vec>, Vec<&'a TensorDescription>, Vec<&'a TensorDescription>, -) { +) +where + R: Runtime, + F: FloatElement, + I: IntElement, +{ let mut inputs_description_updated = Vec::with_capacity(inputs.len()); let mut outputs_description_updated = Vec::with_capacity(outputs.len()); let mut handles_input = Vec::new(); diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index 0e62ae380..0377e19e7 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -90,14 +90,14 @@ impl TraceBuilder { /// Create a variable from an input [scalar](Element). pub fn scalar(&mut self, _value: &E, elem_type: gpu::Elem) -> gpu::Variable { match elem_type { - gpu::Elem::Float => { + gpu::Elem::Float(_) => { let var = self .scope .read_scalar(self.scalars.num_float as u16, elem_type); self.scalars.num_float += 1; var } - gpu::Elem::Int => { + gpu::Elem::Int(_) => { let var = self .scope .read_scalar(self.scalars.num_int as u16, elem_type); diff --git a/crates/burn-jit/src/fusion/tracing/trace.rs b/crates/burn-jit/src/fusion/tracing/trace.rs index 4bda66d85..ed229218a 100644 --- a/crates/burn-jit/src/fusion/tracing/trace.rs +++ b/crates/burn-jit/src/fusion/tracing/trace.rs @@ -57,9 +57,10 @@ impl Trace { }) .collect::>(); + // NOTE: we might want to pass a struct including all inputs/outputs metadata instead of 3 arrays if self.scalars.num_float > 0 { inputs.push(InputInfo::Scalar { - elem: gpu::Elem::Float, + elem: gpu::Elem::Float(gpu::FloatKind::F32), size: self.scalars.num_float, }) } @@ -73,7 +74,7 @@ impl Trace { if self.scalars.num_int > 0 { inputs.push(InputInfo::Scalar { - elem: gpu::Elem::Int, + elem: gpu::Elem::Int(gpu::IntKind::I32), size: self.scalars.num_int, }) } diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs index 41fb96b66..d4896cbc0 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs @@ -6,7 +6,7 @@ use crate::{ OutputInfo, WorkgroupLaunch, }, element::JitElement, - gpu::{gpu, ComputeShader, Elem, Scope, Variable, Visibility}, + gpu::{gpu, ComputeShader, Elem, IntKind, Scope, Variable, Visibility}, kernel::{self, GpuComputeShaderPhase}, ops::{ numeric::{empty_device, zeros_device}, @@ -99,8 +99,8 @@ impl Conv2dTransposeComputeShader { let padding_1 = Variable::GlobalScalar(5, Elem::UInt); let groups = Variable::GlobalScalar(6, Elem::UInt); - let stride_0_i = scope.create_local(Elem::Int); - let stride_1_i = scope.create_local(Elem::Int); + let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); + let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); gpu!(scope, stride_0_i = cast(conv_stride_0)); gpu!(scope, stride_1_i = cast(conv_stride_1)); @@ -139,15 +139,15 @@ impl Conv2dTransposeComputeShader { gpu!(scope, ic_end = ic_start + ic_tmp); let tmp_u = scope.create_local(Elem::UInt); - let tmp_i = scope.create_local(Elem::Int); - let zero_i = scope.zero(Elem::Int); - let one_i = scope.create_with_value(1, Elem::Int); + let tmp_i = scope.create_local(Elem::Int(IntKind::I32)); + let zero_i = scope.zero(Elem::Int(IntKind::I32)); + let one_i = scope.create_with_value(1, Elem::Int(IntKind::I32)); let kms_u = scope.create_local(Elem::UInt); - let kms_0 = scope.create_local(Elem::Int); - let kms_1 = scope.create_local(Elem::Int); - let ih_start_tmp = scope.create_local(Elem::Int); - let iw_start_tmp = scope.create_local(Elem::Int); + let kms_0 = scope.create_local(Elem::Int(IntKind::I32)); + let kms_1 = scope.create_local(Elem::Int(IntKind::I32)); + let ih_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); + let iw_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); let ih_start = scope.create_local(Elem::UInt); let iw_start = scope.create_local(Elem::UInt); let ih_end = scope.create_local(Elem::UInt); diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index d73caa132..3622a70fc 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -1,6 +1,6 @@ use crate::codegen::dialect::gpu::{gpu, Elem, Scope, Variable}; use crate::codegen::Execution; -use crate::gpu::ComputeShader; +use crate::gpu::{ComputeShader, IntKind}; use crate::{ codegen::{ dialect::gpu, Compilation, CompilationInfo, CompilationSettings, EagerHandle, InputInfo, @@ -80,7 +80,7 @@ impl GpuComputeShaderPhase for GatherEagerKernel ComputeShader { let mut scope = gpu::Scope::root(); let item_tensor = E::gpu_elem().into(); - let item_indices: gpu::Item = gpu::Elem::Int.into(); + let item_indices: gpu::Item = gpu::Elem::Int(IntKind::I32).into(); let tensor = gpu::Variable::GlobalInputArray(0, item_tensor); let indices = scope.read_array(1, item_indices); @@ -103,7 +103,7 @@ impl GpuComputeShaderPhase for GatherEagerKernel GpuComputeShaderPhase for ScatterEagerKernel ComputeShader { let mut scope = gpu::Scope::root(); let item_value = E::gpu_elem().into(); - let item_indices: gpu::Item = gpu::Elem::Int.into(); + let item_indices: gpu::Item = gpu::Elem::Int(gpu::IntKind::I32).into(); let input_output = gpu::Variable::GlobalInputArray(0, item_value); - let indices = gpu::Variable::GlobalInputArray(1, Elem::Int.into()); + let indices = gpu::Variable::GlobalInputArray(1, Elem::Int(gpu::IntKind::I32).into()); let value = gpu::Variable::GlobalInputArray(2, item_value); scope.write_global_custom(input_output); diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index 719f0ea9f..403fe8b49 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility}, + dialect::gpu::{gpu, Elem, IntKind, Item, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -74,7 +74,7 @@ impl GpuComputeShaderPhase for SelectEagerKernel ComputeShader { let mut scope = Scope::root(); let item = E::gpu_elem().into(); - let item_indices: Item = Elem::Int.into(); + let item_indices: Item = Elem::Int(IntKind::I32).into(); let input = Variable::GlobalInputArray(0, item); let indices = Variable::GlobalInputArray(1, item_indices); diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index 7218b2bd7..ca6be7ae3 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Branch, Elem, Item, Scope, Variable, Visibility}, + dialect::gpu::{gpu, Branch, Elem, IntKind, Item, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, WorkgroupLaunch, }, @@ -132,7 +132,7 @@ impl GpuComputeShaderPhase for SelectAssignEagerKerne fn compile(&self) -> ComputeShader { let mut scope = Scope::root(); let item = E::gpu_elem().into(); - let item_indices: Item = Elem::Int.into(); + let item_indices: Item = Elem::Int(IntKind::I32).into(); let tensor = Variable::GlobalInputArray(0, item); let value = Variable::GlobalInputArray(1, item); diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index c64007807..d00fbd1ab 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -17,16 +17,18 @@ struct InterpolateBicubicEagerKernel { _elem: PhantomData, } -struct InterpolateBicubicShader { +struct InterpolateBicubicShader { input: Variable, output: Variable, + _elem: PhantomData, } -impl InterpolateBicubicShader { +impl InterpolateBicubicShader { pub(crate) fn expand(self, scope: &mut Scope) { let input = self.input; let output = self.output; let id = Variable::Id; + let elem = E::gpu_elem(); let input_stride_0 = scope.create_local(Elem::UInt); let input_stride_1 = scope.create_local(Elem::UInt); @@ -83,20 +85,20 @@ impl InterpolateBicubicShader { let input_height = scope.create_local(Elem::UInt); let output_height = scope.create_local(Elem::UInt); - let output_height_float = scope.create_local(Elem::Float); + let output_height_float = scope.create_local(elem); let input_width = scope.create_local(Elem::UInt); let output_width = scope.create_local(Elem::UInt); - let output_width_float = scope.create_local(Elem::Float); + let output_width_float = scope.create_local(elem); - let frac = scope.create_local(Elem::Float); + let frac = scope.create_local(elem); let numerator = scope.create_local(Elem::UInt); - let numerator_float = scope.create_local(Elem::Float); + let numerator_float = scope.create_local(elem); let not_zero = scope.create_local(Elem::Bool); - let y_in_float = scope.create_local(Elem::Float); + let y_in_float = scope.create_local(elem); let y_in = scope.create_local(Elem::UInt); - let yw = scope.create_local(Elem::Float); + let yw = scope.create_local(elem); let y_tmp = scope.create_local(Elem::UInt); gpu!(scope, input_height = input_shape_2 - 1u32); @@ -123,9 +125,9 @@ impl InterpolateBicubicShader { gpu!(scope, y_tmp = y_in + 2u32); let y3 = Self::min(scope, y_tmp, input_height); - let x_in_float = scope.create_local(Elem::Float); + let x_in_float = scope.create_local(elem); let x_in = scope.create_local(Elem::UInt); - let xw = scope.create_local(Elem::Float); + let xw = scope.create_local(elem); let x_tmp = scope.create_local(Elem::UInt); gpu!(scope, input_width = input_shape_3 - 1u32); @@ -374,7 +376,12 @@ impl GpuComputeShaderPhase for InterpolateBicubicEage let input = Variable::GlobalInputArray(0, item); let output = Variable::GlobalOutputArray(0, item); - InterpolateBicubicShader { input, output }.expand(&mut scope); + InterpolateBicubicShader { + input, + output, + _elem: PhantomData::, + } + .expand(&mut scope); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest.rs b/crates/burn-jit/src/kernel/interpolate/nearest.rs index f32dab99a..776a58a2e 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest.rs @@ -17,16 +17,18 @@ struct InterpolateNearestEagerKernel { _elem: PhantomData, } -struct InterpolateNearestShader { +struct InterpolateNearestShader { input: Variable, output: Variable, + _elem: PhantomData, } -impl InterpolateNearestShader { +impl InterpolateNearestShader { pub(crate) fn expand(self, scope: &mut Scope) { let input = self.input; let output = self.output; let id = Variable::Id; + let elem = E::gpu_elem(); let input_stride_0 = scope.create_local(Elem::UInt); let input_stride_1 = scope.create_local(Elem::UInt); @@ -81,11 +83,11 @@ impl InterpolateNearestShader { gpu!(scope, w = id / output_stride_3); gpu!(scope, w = w % output_shape_3); - let factor_float = scope.create_local(Elem::Float); - let numerator_float = scope.create_local(Elem::Float); - let denominator_float = scope.create_local(Elem::Float); - let x = scope.create_local(Elem::Float); - let y = scope.create_local(Elem::Float); + let factor_float = scope.create_local(elem); + let numerator_float = scope.create_local(elem); + let denominator_float = scope.create_local(elem); + let x = scope.create_local(elem); + let y = scope.create_local(elem); let xu = scope.create_local(Elem::UInt); let yu = scope.create_local(Elem::UInt); @@ -130,7 +132,12 @@ impl GpuComputeShaderPhase for InterpolateNearestEage let input = Variable::GlobalInputArray(0, item); let output = Variable::GlobalOutputArray(0, item); - InterpolateNearestShader { input, output }.expand(&mut scope); + InterpolateNearestShader { + input, + output, + _elem: PhantomData::, + } + .expand(&mut scope); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs index 7bfa1b1c6..d52438f60 100644 --- a/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs +++ b/crates/burn-jit/src/kernel/interpolate/nearest_backward.rs @@ -17,12 +17,13 @@ struct InterpolateNearestBackwardEagerKernel { _elem: PhantomData, } -struct InterpolateNearestBackwardShader { +struct InterpolateNearestBackwardShader { out_grad: Variable, output: Variable, + _elem: PhantomData, } -impl InterpolateNearestBackwardShader { +impl InterpolateNearestBackwardShader { fn expand(self, scope: &mut Scope) { let grad = self.out_grad; let output = self.output; @@ -134,8 +135,9 @@ impl InterpolateNearestBackwardShader { output_size: Variable, input_size: Variable, ) -> Variable { - let numerator_float = scope.create_local(Elem::Float); - let div = scope.create_local(Elem::Float); + let elem = E::gpu_elem(); + let numerator_float = scope.create_local(elem); + let div = scope.create_local(elem); let index = scope.create_local(Elem::UInt); gpu!(scope, index = input_index * output_size); @@ -154,8 +156,9 @@ impl InterpolateNearestBackwardShader { output_size: Variable, input_size: Variable, ) -> Variable { - let numerator_float = scope.create_local(Elem::Float); - let div = scope.create_local(Elem::Float); + let elem = E::gpu_elem(); + let numerator_float = scope.create_local(elem); + let div = scope.create_local(elem); let index = scope.create_local(Elem::UInt); let min = scope.create_local(Elem::Bool); let end_index = scope.create_local(Elem::UInt); @@ -189,7 +192,12 @@ impl GpuComputeShaderPhase let out_grad = Variable::GlobalInputArray(0, item); let output = Variable::GlobalOutputArray(0, item); - InterpolateNearestBackwardShader { out_grad, output }.expand(&mut scope); + InterpolateNearestBackwardShader { + out_grad, + output, + _elem: PhantomData::, + } + .expand(&mut scope); scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/matmul/simple.rs b/crates/burn-jit/src/kernel/matmul/simple.rs index c93445ee5..f3ba6843f 100644 --- a/crates/burn-jit/src/kernel/matmul/simple.rs +++ b/crates/burn-jit/src/kernel/matmul/simple.rs @@ -18,10 +18,11 @@ use std::marker::PhantomData; use super::simple_launch_options; #[derive(new, Debug)] -struct MatmulEagerKernel { +struct MatmulEagerKernel { workgroup_size_x: usize, workgroup_size_y: usize, _runtime: PhantomData, + _elem: PhantomData, } struct MatmulComputeShader { @@ -151,7 +152,7 @@ impl MatmulComputeShader { } } -impl GpuComputeShaderPhase for MatmulEagerKernel { +impl GpuComputeShaderPhase for MatmulEagerKernel { fn compile(&self) -> ComputeShader { assert_eq!( self.workgroup_size_x, self.workgroup_size_y, @@ -159,9 +160,17 @@ impl GpuComputeShaderPhase for MatmulEagerKernel { ); let mut scope = gpu::Scope::root(); - let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into()); - let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into()); - let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into()); + let elem = E::gpu_elem(); + assert!( + elem == gpu::Elem::Float(gpu::FloatKind::F32) + || elem == gpu::Elem::Float(gpu::FloatKind::F64), + "Only float elements are supported." + ); + let item = elem.into(); + + let lhs = gpu::Variable::GlobalInputArray(0, item); + let rhs = gpu::Variable::GlobalInputArray(1, item); + let out = gpu::Variable::GlobalOutputArray(0, item); scope.write_global_custom(out); @@ -172,16 +181,14 @@ impl GpuComputeShaderPhase for MatmulEagerKernel { .expand(&mut scope); let lhs = InputInfo::Array { - item: gpu::Elem::Float.into(), + item, visibility: gpu::Visibility::Read, }; let rhs = InputInfo::Array { - item: gpu::Elem::Float.into(), + item, visibility: gpu::Visibility::Read, }; - let out = OutputInfo::Array { - item: gpu::Elem::Float.into(), - }; + let out = OutputInfo::Array { item }; let info = CompilationInfo { inputs: vec![lhs, rhs], @@ -236,7 +243,7 @@ pub fn matmul_simple( workgroup_size_y, ); - let kernel = MatmulEagerKernel::::new(workgroup_size_x, workgroup_size_y); + let kernel = MatmulEagerKernel::::new(workgroup_size_x, workgroup_size_y); Execution::start(kernel, rhs.client) .inputs(&[ diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index 8691743f9..3250fb54c 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -26,18 +26,27 @@ struct MatmulTiling2d { } #[derive(new, Debug)] -struct MatmulTiling2dEagerKernel { +struct MatmulTiling2dEagerKernel { config: Tiling2dConfig, bounds_check_required: bool, _runtime: PhantomData, + _elem: PhantomData, } -impl GpuComputeShaderPhase for MatmulTiling2dEagerKernel { +impl GpuComputeShaderPhase for MatmulTiling2dEagerKernel { fn compile(&self) -> ComputeShader { let mut scope = gpu::Scope::root(); - let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into()); - let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into()); - let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into()); + let elem = E::gpu_elem(); + assert!( + elem == gpu::Elem::Float(gpu::FloatKind::F32) + || elem == gpu::Elem::Float(gpu::FloatKind::F64), + "Only float elements are supported." + ); + let item = elem.into(); + + let lhs = gpu::Variable::GlobalInputArray(0, item); + let rhs = gpu::Variable::GlobalInputArray(1, item); + let out = gpu::Variable::GlobalOutputArray(0, item); scope.write_global_custom(out); @@ -49,16 +58,14 @@ impl GpuComputeShaderPhase for MatmulTiling2dEagerKernel { .expand(&mut scope); let lhs = InputInfo::Array { - item: gpu::Elem::Float.into(), + item, visibility: gpu::Visibility::Read, }; let rhs = InputInfo::Array { - item: gpu::Elem::Float.into(), + item, visibility: gpu::Visibility::Read, }; - let out = OutputInfo::Array { - item: gpu::Elem::Float.into(), - }; + let out = OutputInfo::Array { item }; let info = CompilationInfo { inputs: vec![lhs, rhs], @@ -94,7 +101,7 @@ pub fn matmul_tiling_2d( ) -> JitTensor { let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config); - let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), bounds_check_required); + let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), bounds_check_required); let client = lhs.client.clone(); let lhs = match lhs.batch_swapped_with_row_col() { @@ -126,7 +133,7 @@ pub fn matmul_tiling_2d_padded, config: Tiling2dConfig, ) -> JitTensor { - let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), false); + let kernel = MatmulTiling2dEagerKernel::::new(config.clone(), false); let client = lhs.client.clone(); // A tensor may need to be padded, in which case it will implicitly become contiguous diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs index 042ef4357..e09060bb7 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -18,12 +18,13 @@ struct AdaptiveAvgPool2dBackwardEagerKernel { _elem: PhantomData, } -struct AdaptiveAvgPool2dBackwardComputeShader { +struct AdaptiveAvgPool2dBackwardComputeShader { grad: Variable, output: Variable, + _elem: PhantomData, } -impl AdaptiveAvgPool2dBackwardComputeShader { +impl AdaptiveAvgPool2dBackwardComputeShader { fn expand(self, scope: &mut Scope) { let grad = self.grad; let output = self.output; @@ -158,8 +159,9 @@ impl AdaptiveAvgPool2dBackwardComputeShader { output_size: Variable, input_size: Variable, ) -> Variable { - let numerator_float = scope.create_local(Elem::Float); - let div = scope.create_local(Elem::Float); + let elem = E::gpu_elem(); + let numerator_float = scope.create_local(elem); + let div = scope.create_local(elem); let index = scope.create_local(Elem::UInt); gpu!(scope, index = output_size_index * input_size); @@ -177,8 +179,9 @@ impl AdaptiveAvgPool2dBackwardComputeShader { output_size: Variable, input_size: Variable, ) -> Variable { - let numerator_float = scope.create_local(Elem::Float); - let div = scope.create_local(Elem::Float); + let elem = E::gpu_elem(); + let numerator_float = scope.create_local(elem); + let div = scope.create_local(elem); let index = scope.create_local(Elem::UInt); let min = scope.create_local(Elem::Bool); let end_index = scope.create_local(Elem::UInt); @@ -213,7 +216,12 @@ impl GpuComputeShaderPhase scope.write_global_custom(output); - AdaptiveAvgPool2dBackwardComputeShader { grad, output }.expand(&mut scope); + AdaptiveAvgPool2dBackwardComputeShader { + grad, + output, + _elem: PhantomData::, + } + .expand(&mut scope); let grad = InputInfo::Array { item, diff --git a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs index 6d71d93db..ec9669e53 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs @@ -137,8 +137,9 @@ impl AdaptivePool2dComputeShader { output_size: Variable, input_size: Variable, ) -> Variable { - let numerator_float = scope.create_local(Elem::Float); - let div = scope.create_local(Elem::Float); + let elem = E::gpu_elem(); + let numerator_float = scope.create_local(elem); + let div = scope.create_local(elem); let index = scope.create_local(Elem::UInt); gpu!(scope, index = output_size_index * input_size); @@ -156,8 +157,9 @@ impl AdaptivePool2dComputeShader { output_size: Variable, input_size: Variable, ) -> Variable { - let numerator_float = scope.create_local(Elem::Float); - let div = scope.create_local(Elem::Float); + let elem = E::gpu_elem(); + let numerator_float = scope.create_local(elem); + let div = scope.create_local(elem); let index = scope.create_local(Elem::UInt); let min = scope.create_local(Elem::Bool); let end_index = scope.create_local(Elem::UInt); diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index fde8f0e1b..1b8cac774 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, Scope, Variable, Visibility}, + dialect::gpu::{gpu, Elem, IntKind, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -223,17 +223,17 @@ impl AvgPool2dBackwardComputeShader { let [kernel_size_0, kernel_size_1] = self.kernel_size; - let signed_ih = scope.create_local(Elem::Int); - let signed_iw = scope.create_local(Elem::Int); + let signed_ih = scope.create_local(Elem::Int(IntKind::I32)); + let signed_iw = scope.create_local(Elem::Int(IntKind::I32)); - let signed_pool_stride_0 = scope.create_local(Elem::Int); - let signed_pool_stride_1 = scope.create_local(Elem::Int); - let signed_dilation_0 = scope.create_local(Elem::Int); - let signed_dilation_1 = scope.create_local(Elem::Int); - let signed_padding_0 = scope.create_local(Elem::Int); - let signed_padding_1 = scope.create_local(Elem::Int); - let signed_kernel_size_0 = scope.create_local(Elem::Int); - let signed_kernel_size_1 = scope.create_local(Elem::Int); + let signed_pool_stride_0 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_pool_stride_1 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_dilation_0 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_dilation_1 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_padding_0 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_padding_1 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_kernel_size_0 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_kernel_size_1 = scope.create_local(Elem::Int(IntKind::I32)); gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0)); gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1)); @@ -248,8 +248,8 @@ impl AvgPool2dBackwardComputeShader { gpu!(scope, signed_ih = cast(ih)); gpu!(scope, signed_iw = cast(iw)); - let kms_0 = scope.create_local(Elem::Int); - let kms_1 = scope.create_local(Elem::Int); + let kms_0 = scope.create_local(Elem::Int(IntKind::I32)); + let kms_1 = scope.create_local(Elem::Int(IntKind::I32)); gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0); gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0); @@ -257,8 +257,8 @@ impl AvgPool2dBackwardComputeShader { gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1); gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1); - let oh_start_tmp = scope.create_local(Elem::Int); - let ow_start_tmp = scope.create_local(Elem::Int); + let oh_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); + let ow_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0); gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0); @@ -277,8 +277,8 @@ impl AvgPool2dBackwardComputeShader { gpu!(scope, oh_start = cast(oh_start_tmp)); gpu!(scope, ow_start = cast(ow_start_tmp)); - let oh_end_tmp = scope.create_local(Elem::Int); - let ow_end_tmp = scope.create_local(Elem::Int); + let oh_end_tmp = scope.create_local(Elem::Int(IntKind::I32)); + let ow_end_tmp = scope.create_local(Elem::Int(IntKind::I32)); gpu!(scope, oh_end_tmp = max(kms_0, 0i32)); gpu!(scope, ow_end_tmp = max(kms_1, 0i32)); diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index 850b15a6c..9b86c5e56 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -1,6 +1,6 @@ use crate::{ codegen::{ - dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility}, + dialect::gpu::{gpu, Elem, IntKind, Item, Scope, Variable, Visibility}, Compilation, CompilationInfo, CompilationSettings, EagerHandle, Execution, InputInfo, OutputInfo, WorkgroupLaunch, }, @@ -94,7 +94,7 @@ impl MaxPool2dBackwardComputeShader { gpu!(scope, index_current_tmp = iw * output_stride_3); gpu!(scope, index_current += index_current_tmp); - let index_select = scope.create_local(Elem::Int); + let index_select = scope.create_local(Elem::Int(IntKind::I32)); let index_max = scope.create_local(Elem::UInt); let is_max = scope.create_local(Elem::Bool); @@ -169,17 +169,17 @@ impl MaxPool2dBackwardComputeShader { let [kernel_size_0, kernel_size_1] = self.kernel_size; - let signed_ih = scope.create_local(Elem::Int); - let signed_iw = scope.create_local(Elem::Int); + let signed_ih = scope.create_local(Elem::Int(IntKind::I32)); + let signed_iw = scope.create_local(Elem::Int(IntKind::I32)); - let signed_pool_stride_0 = scope.create_local(Elem::Int); - let signed_pool_stride_1 = scope.create_local(Elem::Int); - let signed_dilation_0 = scope.create_local(Elem::Int); - let signed_dilation_1 = scope.create_local(Elem::Int); - let signed_padding_0 = scope.create_local(Elem::Int); - let signed_padding_1 = scope.create_local(Elem::Int); - let signed_kernel_size_0 = scope.create_local(Elem::Int); - let signed_kernel_size_1 = scope.create_local(Elem::Int); + let signed_pool_stride_0 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_pool_stride_1 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_dilation_0 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_dilation_1 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_padding_0 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_padding_1 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_kernel_size_0 = scope.create_local(Elem::Int(IntKind::I32)); + let signed_kernel_size_1 = scope.create_local(Elem::Int(IntKind::I32)); gpu!(scope, signed_pool_stride_0 = cast(pool_stride_0)); gpu!(scope, signed_pool_stride_1 = cast(pool_stride_1)); @@ -194,8 +194,8 @@ impl MaxPool2dBackwardComputeShader { gpu!(scope, signed_ih = cast(ih)); gpu!(scope, signed_iw = cast(iw)); - let kms_0 = scope.create_local(Elem::Int); - let kms_1 = scope.create_local(Elem::Int); + let kms_0 = scope.create_local(Elem::Int(IntKind::I32)); + let kms_1 = scope.create_local(Elem::Int(IntKind::I32)); gpu!(scope, kms_0 = signed_dilation_0 * signed_kernel_size_0); gpu!(scope, kms_0 = kms_0 - signed_pool_stride_0); @@ -203,8 +203,8 @@ impl MaxPool2dBackwardComputeShader { gpu!(scope, kms_1 = signed_dilation_1 * signed_kernel_size_1); gpu!(scope, kms_1 = kms_1 - signed_pool_stride_1); - let oh_start_tmp = scope.create_local(Elem::Int); - let ow_start_tmp = scope.create_local(Elem::Int); + let oh_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); + let ow_start_tmp = scope.create_local(Elem::Int(IntKind::I32)); gpu!(scope, oh_start_tmp = signed_ih + signed_padding_0); gpu!(scope, oh_start_tmp = oh_start_tmp - kms_0); @@ -223,8 +223,8 @@ impl MaxPool2dBackwardComputeShader { gpu!(scope, oh_start = cast(oh_start_tmp)); gpu!(scope, ow_start = cast(ow_start_tmp)); - let oh_end_tmp = scope.create_local(Elem::Int); - let ow_end_tmp = scope.create_local(Elem::Int); + let oh_end_tmp = scope.create_local(Elem::Int(IntKind::I32)); + let ow_end_tmp = scope.create_local(Elem::Int(IntKind::I32)); gpu!(scope, oh_end_tmp = max(kms_0, 0i32)); gpu!(scope, ow_end_tmp = max(kms_1, 0i32)); @@ -268,7 +268,7 @@ impl GpuComputeShaderPhase let mut scope = Scope::root(); let item = E::gpu_elem().into(); - let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int)); + let indices = Variable::GlobalInputArray(0, Item::Scalar(Elem::Int(IntKind::I32))); let grad = Variable::GlobalInputArray(1, item); let output = Variable::GlobalOutputArray(0, item); @@ -283,7 +283,7 @@ impl GpuComputeShaderPhase .expand(&mut scope); let indices = InputInfo::Array { - item: Item::Scalar(Elem::Int), + item: Item::Scalar(Elem::Int(IntKind::I32)), visibility: Visibility::Read, }; diff --git a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs index dabd064e8..64101b714 100644 --- a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs +++ b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs @@ -2,7 +2,7 @@ use std::marker::PhantomData; use crate::{ codegen::{Compilation, CompilationInfo, CompilationSettings, InputInfo, OutputInfo}, - gpu::{gpu, ComputeShader, Elem, Item, Scope, Variable, Visibility}, + gpu::{gpu, ComputeShader, Elem, IntKind, Item, Scope, Variable, Visibility}, kernel::GpuComputeShaderPhase, JitElement, Runtime, }; @@ -182,7 +182,10 @@ impl GpuComputeShaderPhase let input = Variable::GlobalInputArray(0, item); let output = Variable::GlobalOutputArray(0, item); let indices = if P::with_indices() { - Some(Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int))) + Some(Variable::GlobalOutputArray( + 1, + Item::Scalar(Elem::Int(IntKind::I32)), + )) } else { None }; @@ -213,7 +216,7 @@ impl GpuComputeShaderPhase vec![ output, OutputInfo::Array { - item: Item::Scalar(Elem::Int), + item: Item::Scalar(Elem::Int(IntKind::I32)), }, ] } else { diff --git a/crates/burn-jit/src/kernel/prng/bernoulli.rs b/crates/burn-jit/src/kernel/prng/bernoulli.rs index 36cfe1f19..df21278ef 100644 --- a/crates/burn-jit/src/kernel/prng/bernoulli.rs +++ b/crates/burn-jit/src/kernel/prng/bernoulli.rs @@ -44,7 +44,7 @@ impl Prng for Bernoulli { gpu!(scope, int_random = int_random ^ state_2); gpu!(scope, int_random = int_random ^ state_3); - let float_random = scope.create_local(Elem::Float); + let float_random = scope.create_local(E::gpu_elem()); cast_uint_to_float(scope, int_random, float_random); let bernoulli = scope.create_local(Elem::Bool); diff --git a/crates/burn-jit/src/kernel/prng/normal.rs b/crates/burn-jit/src/kernel/prng/normal.rs index a0c1ec762..b8f62fb1d 100644 --- a/crates/burn-jit/src/kernel/prng/normal.rs +++ b/crates/burn-jit/src/kernel/prng/normal.rs @@ -33,10 +33,11 @@ impl Prng for Normal { state_3: Variable, output: Variable, ) { + let elem = E::gpu_elem(); let item = output.item(); let mean = args[0]; let std = args[1]; - let two_pi = scope.create_with_value(2. * PI, Elem::Float); + let two_pi = scope.create_with_value(2. * PI, elem); let t_neg = scope.create_with_value(-2.0, item); let two: Variable = 2u32.into(); @@ -55,7 +56,7 @@ impl Prng for Normal { gpu!(scope, int_random = int_random ^ state_2); gpu!(scope, int_random = int_random ^ state_3); - let unit_0 = scope.create_local(Elem::Float); + let unit_0 = scope.create_local(elem); cast_uint_to_float(scope, int_random, unit_0); // Second random uniform integer @@ -68,7 +69,7 @@ impl Prng for Normal { gpu!(scope, int_random = int_random ^ state_2); gpu!(scope, int_random = int_random ^ state_3); - let unit_1 = scope.create_local(Elem::Float); + let unit_1 = scope.create_local(elem); cast_uint_to_float(scope, int_random, unit_1); // Box-Muller transform diff --git a/crates/burn-jit/src/kernel/prng/uniform.rs b/crates/burn-jit/src/kernel/prng/uniform.rs index f9b778c33..06d206cf1 100644 --- a/crates/burn-jit/src/kernel/prng/uniform.rs +++ b/crates/burn-jit/src/kernel/prng/uniform.rs @@ -31,6 +31,7 @@ impl Prng for Uniform { state_3: Variable, output: Variable, ) { + let elem = E::gpu_elem(); let item = output.item(); let lower_bound = args[0]; let upper_bound = args[1]; @@ -50,12 +51,12 @@ impl Prng for Uniform { gpu!(scope, int_random = int_random ^ state_2); gpu!(scope, int_random = int_random ^ state_3); - let float_random = scope.create_local(Elem::Float); - let float_scale = scope.create_local(Elem::Float); + let float_random = scope.create_local(elem); + let float_scale = scope.create_local(elem); cast_uint_to_float(scope, int_random, float_random); gpu!(scope, float_scale = cast(scale)); - let uniform_float = scope.create_local(Elem::Float); + let uniform_float = scope.create_local(elem); let uniform = scope.create_local(item); gpu!(scope, uniform_float = float_random * float_scale); gpu!(scope, uniform = cast(uniform_float)); diff --git a/crates/burn-jit/src/ops/activation_ops.rs b/crates/burn-jit/src/ops/activation_ops.rs index 4a4c1f592..592f84d58 100644 --- a/crates/burn-jit/src/ops/activation_ops.rs +++ b/crates/burn-jit/src/ops/activation_ops.rs @@ -1,4 +1,4 @@ -use crate::{JitBackend, Runtime}; +use crate::{FloatElement, IntElement, JitBackend, Runtime}; use burn_tensor::ops::ActivationOps; -impl ActivationOps for JitBackend {} +impl ActivationOps for JitBackend {} diff --git a/crates/burn-jit/src/ops/bool_ops.rs b/crates/burn-jit/src/ops/bool_ops.rs index 2d807fae1..0ba248f90 100644 --- a/crates/burn-jit/src/ops/bool_ops.rs +++ b/crates/burn-jit/src/ops/bool_ops.rs @@ -1,4 +1,4 @@ -use crate::{kernel, JitBackend, Runtime}; +use crate::{kernel, FloatElement, IntElement, JitBackend, Runtime}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntTensor}; use burn_tensor::Reader; use burn_tensor::{ops::BoolTensorOps, Data, Shape}; @@ -6,7 +6,12 @@ use std::ops::Range; use super::{expand, permute}; -impl BoolTensorOps for JitBackend { +impl BoolTensorOps for JitBackend +where + R: Runtime, + F: FloatElement, + I: IntElement, +{ fn bool_empty(shape: Shape, device: &Device) -> BoolTensor { super::empty(shape, device) } diff --git a/crates/burn-jit/src/ops/float_ops.rs b/crates/burn-jit/src/ops/float_ops.rs index 3b6afbcca..d533bf232 100644 --- a/crates/burn-jit/src/ops/float_ops.rs +++ b/crates/burn-jit/src/ops/float_ops.rs @@ -3,14 +3,19 @@ use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryO use crate::kernel::matmul::{matmul, MatmulStrategy}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; use crate::kernel::{self, reduce}; -use crate::Runtime; use crate::{unary, JitBackend}; +use crate::{FloatElement, IntElement, Runtime}; use burn_tensor::ops::{BoolTensor, Device, FloatElem, FloatTensor, IntTensor}; use burn_tensor::{ops::FloatTensorOps, Data, Distribution, Shape}; use burn_tensor::{ElementConversion, Reader}; use std::ops::Range; -impl FloatTensorOps for JitBackend { +impl FloatTensorOps for JitBackend +where + R: Runtime, + F: FloatElement, + I: IntElement, +{ fn float_from_data( data: Data, D>, device: &Device, diff --git a/crates/burn-jit/src/ops/int_ops.rs b/crates/burn-jit/src/ops/int_ops.rs index 74ef0977a..2b357099e 100644 --- a/crates/burn-jit/src/ops/int_ops.rs +++ b/crates/burn-jit/src/ops/int_ops.rs @@ -1,12 +1,17 @@ use super::{expand, numeric, permute}; use crate::codegen::dialect::gpu::{Elem, Item, Operator, Scope, UnaryOperator}; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; -use crate::{kernel, unary, JitBackend, Runtime}; +use crate::{kernel, unary, FloatElement, IntElement, JitBackend, Runtime}; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; use burn_tensor::{ops::IntTensorOps, Data, Distribution, ElementConversion, Reader, Shape}; use std::ops::Range; -impl IntTensorOps for JitBackend { +impl IntTensorOps for JitBackend +where + R: Runtime, + F: FloatElement, + I: IntElement, +{ fn int_empty(shape: Shape, device: &Device) -> IntTensor { super::empty(shape, device) } diff --git a/crates/burn-jit/src/ops/module_ops.rs b/crates/burn-jit/src/ops/module_ops.rs index a66822f55..19248d24e 100644 --- a/crates/burn-jit/src/ops/module_ops.rs +++ b/crates/burn-jit/src/ops/module_ops.rs @@ -1,11 +1,16 @@ -use crate::{kernel, JitBackend, Runtime}; +use crate::{kernel, FloatElement, IntElement, JitBackend, Runtime}; use burn_tensor::ops::{ ConvOptions, ConvTransposeOptions, InterpolateOptions, MaxPool2dBackward, MaxPool2dWithIndices, ModuleOps, }; use burn_tensor::ops::{FloatTensor, IntTensor}; -impl ModuleOps for JitBackend { +impl ModuleOps for JitBackend +where + R: Runtime, + F: FloatElement, + I: IntElement, +{ fn conv2d( x: FloatTensor, weight: FloatTensor, diff --git a/crates/burn-jit/src/runtime.rs b/crates/burn-jit/src/runtime.rs index 6954df913..91a194be9 100644 --- a/crates/burn-jit/src/runtime.rs +++ b/crates/burn-jit/src/runtime.rs @@ -4,9 +4,6 @@ use crate::{ }; use burn_compute::{channel::ComputeChannel, client::ComputeClient, server::ComputeServer}; -/// Type alias to the runtime signed int element type. -pub type RuntimeInt = <::Compiler as Compiler>::Int; - /// Runtime for the [just-in-time backend](crate::JitBackend). pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { /// The compiler used to compile the inner representation into tokens. @@ -26,17 +23,6 @@ pub trait Runtime: Send + Sync + 'static + core::fmt::Debug { + Sync + Send; - /// A version of the runtime that supports full precision. - /// - /// Note that the runtime should share all other runtime components. - /// This way, it's possible to share the same handles for both runtimes and reduce data copies to a minimum. - type FullPrecisionRuntime: Runtime< - Compiler = ::FullPrecisionCompiler, - Device = Self::Device, - Server = Self::Server, - Channel = Self::Channel, - >; - /// Retrieve the compute client from the runtime device. fn client(device: &Self::Device) -> ComputeClient; diff --git a/crates/burn-jit/src/tests/mod.rs b/crates/burn-jit/src/tests/mod.rs index 7c2ceca6e..9f51c7d7d 100644 --- a/crates/burn-jit/src/tests/mod.rs +++ b/crates/burn-jit/src/tests/mod.rs @@ -84,7 +84,7 @@ macro_rules! testgen_jit { use super::*; use burn_jit::tests::{burn_autodiff, burn_ndarray, burn_tensor, serial_test}; - pub type TestBackend = JitBackend; + pub type TestBackend = JitBackend; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; @@ -106,7 +106,7 @@ macro_rules! testgen_jit_fusion { use super::*; use burn_jit::tests::{burn_autodiff, burn_fusion, burn_ndarray, burn_tensor}; - pub type TestBackend = burn_fusion::Fusion>; + pub type TestBackend = burn_fusion::Fusion>; pub type ReferenceBackend = burn_ndarray::NdArray; pub type TestTensor = burn_tensor::Tensor; diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index dd711942f..87d2a31b9 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -1,13 +1,11 @@ use super::LocalArray; use super::{shader::ComputeShader, Item, SharedMemory}; use crate::compiler::wgsl; -use crate::{FloatElement, IntElement}; use burn_jit::gpu; -use std::marker::PhantomData; /// Wgsl Compiler. -#[derive(Clone)] -pub struct WgslCompiler { +#[derive(Clone, Default)] +pub struct WgslCompiler { num_inputs: usize, num_outputs: usize, local_invocation_index: bool, @@ -21,43 +19,16 @@ pub struct WgslCompiler { num_workgroups: bool, shared_memories: Vec, local_arrays: Vec, - _float: PhantomData, - _int: PhantomData, } -impl core::fmt::Debug for WgslCompiler { +impl core::fmt::Debug for WgslCompiler { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.write_str("WgslCompiler") } } -impl Default for WgslCompiler { - fn default() -> Self { - Self { - num_inputs: 0, - num_outputs: 0, - local_invocation_index: false, - local_invocation_id: false, - global_invocation_id: false, - workgroup_id: false, - rank: false, - id: false, - stride: false, - shape: false, - num_workgroups: false, - shared_memories: Vec::default(), - local_arrays: Vec::default(), - _float: PhantomData, - _int: PhantomData, - } - } -} - -impl burn_jit::Compiler for WgslCompiler { +impl burn_jit::Compiler for WgslCompiler { type Representation = ComputeShader; - type Float = F; - type Int = I; - type FullPrecisionCompiler = WgslCompiler; fn compile(shader: gpu::ComputeShader) -> Self::Representation { let mut compiler = Self::default(); @@ -73,7 +44,7 @@ impl burn_jit::Compiler for WgslCompiler { } } -impl WgslCompiler { +impl WgslCompiler { fn compile_shader(&mut self, mut value: gpu::ComputeShader) -> wgsl::ComputeShader { self.num_inputs = value.inputs.len(); self.num_outputs = value.outputs.len(); @@ -128,8 +99,14 @@ impl WgslCompiler { fn compile_elem(value: gpu::Elem) -> wgsl::Elem { match value { - gpu::Elem::Float => F::wgpu_elem(), - gpu::Elem::Int => I::wgpu_elem(), + gpu::Elem::Float(f) => match f { + gpu::FloatKind::F32 => wgsl::Elem::F32, + gpu::FloatKind::F64 => panic!("f64 is not a valid WgpuElement"), + }, + gpu::Elem::Int(i) => match i { + gpu::IntKind::I32 => wgsl::Elem::I32, + gpu::IntKind::I64 => panic!("i64 is not a valid WgpuElement"), + }, gpu::Elem::UInt => wgsl::Elem::U32, gpu::Elem::Bool => wgsl::Elem::Bool, } diff --git a/crates/burn-wgpu/src/lib.rs b/crates/burn-wgpu/src/lib.rs index f34791347..59a246840 100644 --- a/crates/burn-wgpu/src/lib.rs +++ b/crates/burn-wgpu/src/lib.rs @@ -45,7 +45,7 @@ pub use burn_jit::{tensor::JitTensor, JitBackend}; /// You can disable the `fusion` feature flag to remove that functionality, which might be /// necessary on `wasm` for now. pub type Wgpu = - burn_fusion::Fusion>>; + burn_fusion::Fusion, F, I>>; #[cfg(not(feature = "fusion"))] /// Tensor backend that uses the [wgpu] crate for executing GPU compute shaders. @@ -64,13 +64,13 @@ pub type Wgpu = /// /// You can enable the `fusion` feature flag to add that functionality, which might improve /// performance. -pub type Wgpu = JitBackend>; +pub type Wgpu = JitBackend, F, I>; #[cfg(test)] mod tests { use super::*; - pub type TestRuntime = crate::WgpuRuntime; + pub type TestRuntime = crate::WgpuRuntime; burn_jit::testgen_all!(); } diff --git a/crates/burn-wgpu/src/runtime.rs b/crates/burn-wgpu/src/runtime.rs index e9004e608..e3192f223 100644 --- a/crates/burn-wgpu/src/runtime.rs +++ b/crates/burn-wgpu/src/runtime.rs @@ -1,7 +1,7 @@ use crate::{ compiler::wgsl, compute::{WgpuServer, WgpuStorage}, - FloatElement, GraphicsApi, IntElement, WgpuDevice, + GraphicsApi, WgpuDevice, }; use alloc::sync::Arc; use burn_common::stub::RwLock; @@ -19,13 +19,10 @@ use wgpu::{AdapterInfo, DeviceDescriptor}; /// Runtime that uses the [wgpu] crate with the wgsl compiler. /// -/// The [graphics api](GraphicsApi), the [float element](FloatElement) and the -/// [int element](IntElement) types are passed as generic. +/// The [graphics api](GraphicsApi) type is passed as generic. #[derive(Debug)] -pub struct WgpuRuntime { +pub struct WgpuRuntime { _g: PhantomData, - _f: PhantomData, - _i: PhantomData, } /// The compute instance is shared across all [wgpu runtimes](WgpuRuntime). @@ -34,9 +31,8 @@ static RUNTIME: ComputeRuntime> type Server = WgpuServer>; -impl Runtime for WgpuRuntime { - type FullPrecisionRuntime = WgpuRuntime; - type Compiler = wgsl::WgslCompiler; +impl Runtime for WgpuRuntime { + type Compiler = wgsl::WgslCompiler; type Server = WgpuServer>; type Channel = MutexComputeChannel>>; diff --git a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs index 63aa95286..960fc7647 100644 --- a/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs +++ b/examples/custom-wgpu-kernel/examples/custom-wgpu-kernel.rs @@ -71,7 +71,7 @@ fn autodiff(device: &B::Device) { } fn main() { - type MyBackend = burn::backend::wgpu::JitBackend>; + type MyBackend = burn::backend::wgpu::JitBackend, f32, i32>; type MyAutodiffBackend = burn::backend::Autodiff; let device = Default::default(); inference::(&device); diff --git a/examples/custom-wgpu-kernel/src/backward.rs b/examples/custom-wgpu-kernel/src/backward.rs index fa8234e90..5a2a03129 100644 --- a/examples/custom-wgpu-kernel/src/backward.rs +++ b/examples/custom-wgpu-kernel/src/backward.rs @@ -15,7 +15,7 @@ use burn::{ }; impl AutodiffBackend - for Autodiff>> + for Autodiff, F, I>> { } diff --git a/examples/custom-wgpu-kernel/src/forward.rs b/examples/custom-wgpu-kernel/src/forward.rs index 7d5403f4a..f23b54c6f 100644 --- a/examples/custom-wgpu-kernel/src/forward.rs +++ b/examples/custom-wgpu-kernel/src/forward.rs @@ -37,7 +37,7 @@ impl KernelSource for FusedMatmulAddRelu { } /// Implement our custom backend trait for the existing backend `WgpuBackend`. -impl Backend for JitBackend> { +impl Backend for JitBackend, F, I> { fn fused_matmul_add_relu( lhs: FloatTensor, rhs: FloatTensor,