From b5c49c5bf7b8376b2753b479350e4063c6b9bdb2 Mon Sep 17 00:00:00 2001 From: Nathaniel Simard Date: Mon, 18 Dec 2023 12:16:08 -0500 Subject: [PATCH] Fusion wgpu compilation cache (#1069) * Refactor fusion in the wgpu backend * WIP * Refactor * WIP * Fix inplace ops * Works ish * Cleanup Output * Refactoring * Refactor Clamp * Cleanup * Cleanup * Updates * Fix CI * Code review --- burn-fusion/src/backend.rs | 2 +- burn-fusion/src/graph/base.rs | 2 +- burn-fusion/src/graph/context.rs | 14 - burn-fusion/src/graph/execution.rs | 10 +- burn-fusion/src/graph/ops.rs | 18 - burn-fusion/src/graph/path/base.rs | 25 +- burn-fusion/src/ops/float.rs | 42 -- burn-fusion/src/ops/int.rs | 42 -- burn-wgpu/src/{fusion => }/codegen/body.rs | 2 +- .../src/{fusion => }/codegen/function.rs | 2 +- burn-wgpu/src/codegen/kernel.rs | 359 +++++++++++++++ burn-wgpu/src/codegen/mod.rs | 13 + .../src/{fusion => }/codegen/operator.rs | 49 ++- burn-wgpu/src/{fusion => }/codegen/shader.rs | 33 +- .../src/{fusion => }/codegen/variable.rs | 2 +- burn-wgpu/src/element.rs | 18 +- burn-wgpu/src/fusion/base.rs | 8 - burn-wgpu/src/fusion/cache.rs | 57 +++ burn-wgpu/src/fusion/codegen/mod.rs | 11 - burn-wgpu/src/fusion/elemwise/builder.rs | 27 +- burn-wgpu/src/fusion/elemwise/optimization.rs | 100 ++++- burn-wgpu/src/fusion/kernel.rs | 409 +++--------------- burn-wgpu/src/fusion/mod.rs | 2 +- burn-wgpu/src/kernel/clamp.rs | 101 +---- burn-wgpu/src/kernel/mod.rs | 2 - burn-wgpu/src/kernel/unary.rs | 303 +++++++------ burn-wgpu/src/kernel/unary_scalar.rs | 220 ---------- burn-wgpu/src/lib.rs | 2 + burn-wgpu/src/ops/activation_ops.rs | 16 +- burn-wgpu/src/ops/float_ops.rs | 212 ++++----- burn-wgpu/src/ops/int_ops.rs | 34 +- burn-wgpu/src/ops/numeric.rs | 87 ++-- burn-wgpu/src/template/clamp/clamp.wgsl | 25 -- .../src/template/clamp/clamp_inplace.wgsl | 21 - burn-wgpu/src/template/erf.wgsl | 25 -- burn-wgpu/src/template/powf.wgsl | 14 - burn-wgpu/src/template/safe_tanh.wgsl | 8 - burn-wgpu/src/template/unary.wgsl | 17 - burn-wgpu/src/template/unary_inplace.wgsl | 13 - burn-wgpu/src/template/unary_scalar.wgsl | 21 - .../src/template/unary_scalar_inplace.wgsl | 17 - burn-wgpu/src/tensor/base.rs | 13 +- 42 files changed, 1040 insertions(+), 1358 deletions(-) rename burn-wgpu/src/{fusion => }/codegen/body.rs (97%) rename burn-wgpu/src/{fusion => }/codegen/function.rs (98%) create mode 100644 burn-wgpu/src/codegen/kernel.rs create mode 100644 burn-wgpu/src/codegen/mod.rs rename burn-wgpu/src/{fusion => }/codegen/operator.rs (78%) rename burn-wgpu/src/{fusion => }/codegen/shader.rs (86%) rename burn-wgpu/src/{fusion => }/codegen/variable.rs (96%) create mode 100644 burn-wgpu/src/fusion/cache.rs delete mode 100644 burn-wgpu/src/fusion/codegen/mod.rs delete mode 100644 burn-wgpu/src/kernel/unary_scalar.rs delete mode 100644 burn-wgpu/src/template/clamp/clamp.wgsl delete mode 100644 burn-wgpu/src/template/clamp/clamp_inplace.wgsl delete mode 100644 burn-wgpu/src/template/erf.wgsl delete mode 100644 burn-wgpu/src/template/powf.wgsl delete mode 100644 burn-wgpu/src/template/safe_tanh.wgsl delete mode 100644 burn-wgpu/src/template/unary.wgsl delete mode 100644 burn-wgpu/src/template/unary_inplace.wgsl delete mode 100644 burn-wgpu/src/template/unary_scalar.wgsl delete mode 100644 burn-wgpu/src/template/unary_scalar_inplace.wgsl diff --git a/burn-fusion/src/backend.rs b/burn-fusion/src/backend.rs index 03f4e792b..a02dba548 100644 --- a/burn-fusion/src/backend.rs +++ b/burn-fusion/src/backend.rs @@ -96,7 +96,7 @@ pub trait OptimizationBuilder: Send { /// The operation created from the [builder](OptimizationBuilder). pub trait Optimization: Send { /// Execute the operation. - fn execute(&self, context: &mut Context<'_, B>); + fn execute(&mut self, context: &mut Context<'_, B>); /// The number of registered operations in this optimization. fn len(&self) -> usize; /// If the current optimization is empty. diff --git a/burn-fusion/src/graph/base.rs b/burn-fusion/src/graph/base.rs index 17f4c305e..7fc6b3231 100644 --- a/burn-fusion/src/graph/base.rs +++ b/burn-fusion/src/graph/base.rs @@ -53,7 +53,7 @@ impl Graph { pub(crate) fn execute_optimization( &mut self, handles: &mut HandleContainer, - optimization: &dyn Optimization, + optimization: &mut dyn Optimization, ) { let num_keep = optimization.len(); let mut context = self.converter.context(handles); diff --git a/burn-fusion/src/graph/context.rs b/burn-fusion/src/graph/context.rs index 3b271c0bd..5d449af76 100644 --- a/burn-fusion/src/graph/context.rs +++ b/burn-fusion/src/graph/context.rs @@ -682,20 +682,6 @@ impl NumericOpsDescription { out: desc.out.to_relative(converter), }) } - NumericOpsDescription::ClampMax(desc) => { - NumericOpsDescription::ClampMax(ScalarOpsDescription { - lhs: desc.lhs.to_relative(converter), - rhs: local_elem(converter, &desc.rhs), - out: desc.out.to_relative(converter), - }) - } - NumericOpsDescription::ClampMin(desc) => { - NumericOpsDescription::ClampMin(ScalarOpsDescription { - lhs: desc.lhs.to_relative(converter), - rhs: local_elem(converter, &desc.rhs), - out: desc.out.to_relative(converter), - }) - } } } } diff --git a/burn-fusion/src/graph/execution.rs b/burn-fusion/src/graph/execution.rs index d17891333..1d180c373 100644 --- a/burn-fusion/src/graph/execution.rs +++ b/burn-fusion/src/graph/execution.rs @@ -71,7 +71,7 @@ impl GraphExecution { }; } CacheResult::Found(ops) => { - graph.execute_optimization(handles, ops.as_ref()); + graph.execute_optimization(handles, ops.as_mut()); self.reset(graph); } }; @@ -107,14 +107,14 @@ impl GraphExecution { } } - match find_best_optimization_index(&self.optimizations) { + match find_best_optimization_index(&mut self.optimizations) { Some(index) => { let (relative, next_ops) = Self::split_relative_graph_owned(graph, mode); let optimization = &self.optimizations[index]; let ops = self .optimization_cache .complete(optimization, relative, next_ops); - BuildAction::ExecuteOptimization(ops.as_ref()) + BuildAction::ExecuteOptimization(ops.as_mut()) } None => { // TODO: Cache this result too. @@ -184,7 +184,7 @@ impl GraphExecution { } enum BuildAction<'a, B: FusionBackend> { - ExecuteOptimization(&'a dyn Optimization), + ExecuteOptimization(&'a mut dyn Optimization), ExecuteOperations, ContinueBuilding, } @@ -202,7 +202,7 @@ fn still_optimizing(optimizations: &[Box( - optimizations: &[Box>], + optimizations: &mut [Box>], ) -> Option { let mut best_index = None; let mut best_score = 0; diff --git a/burn-fusion/src/graph/ops.rs b/burn-fusion/src/graph/ops.rs index 73cafd6cf..74d8fa5ae 100644 --- a/burn-fusion/src/graph/ops.rs +++ b/burn-fusion/src/graph/ops.rs @@ -379,16 +379,6 @@ pub enum NumericOpsDescription { /// Float => [clamp](burn_tensor::ops::TensorOps::clamp). /// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp). Clamp(ClampOpsDescription), - /// Operation corresponding to: - /// - /// Float => [clamp max](burn_tensor::ops::TensorOps::clamp_max). - /// Int => [clamp max](burn_tensor::ops::IntTensorOps::int_clamp_max). - ClampMax(ScalarOpsDescription), - /// Operation corresponding to: - /// - /// Float => [clamp min](burn_tensor::ops::TensorOps::clamp_min). - /// Int => [cleamp min](burn_tensor::ops::IntTensorOps::int_clamp_min). - ClampMin(ScalarOpsDescription), } /// Operation description specific to an int tensor. @@ -900,12 +890,6 @@ impl NumericOpsDescription { NumericOpsDescription::Clamp(desc) => { vec![&desc.tensor, &desc.out] } - NumericOpsDescription::ClampMin(desc) => { - vec![&desc.lhs, &desc.out] - } - NumericOpsDescription::ClampMax(desc) => { - vec![&desc.lhs, &desc.out] - } NumericOpsDescription::Abs(desc) => { vec![&desc.input, &desc.out] } @@ -1144,8 +1128,6 @@ impl core::hash::Hash for NumericOpsDescription { NumericOpsDescription::MaxDim(desc) => desc.hash(state), NumericOpsDescription::MinDim(desc) => desc.hash(state), NumericOpsDescription::Clamp(desc) => desc.hash(state), - NumericOpsDescription::ClampMax(desc) => desc.hash(state), - NumericOpsDescription::ClampMin(desc) => desc.hash(state), } } } diff --git a/burn-fusion/src/graph/path/base.rs b/burn-fusion/src/graph/path/base.rs index 5293ccd91..c95e72d0b 100644 --- a/burn-fusion/src/graph/path/base.rs +++ b/burn-fusion/src/graph/path/base.rs @@ -60,16 +60,13 @@ impl OptimizationCache { } if let Some(candidate) = self.found { - return CacheResult::Found(&self.optimizations.get(candidate).unwrap().value); + return CacheResult::Found(&mut self.optimizations.get_mut(candidate).unwrap().value); } // Invalidate candidates. let mut invalidated_candidate = Vec::new(); for id in self.candidates.iter() { - let item = match self.optimizations.get(*id) { - Some(item) => item, - None => panic!("Should have an optimization"), - }; + let item = &self.optimizations[*id]; let next_ops = graph.last().expect("Validated earlier"); let next_ops_index = graph.len() - 1; let next_ops_candidate = match item.graph.get(next_ops_index) { @@ -93,13 +90,13 @@ impl OptimizationCache { Condition::NextOps(ops) => ops, Condition::Sync => { self.found = Some(*id); - return CacheResult::Found(&item.value); + break; } }; if item.end_conditions.contains(ops) { self.found = Some(*id); - return CacheResult::Found(&item.value); + break; } else { self.availables.push((*id, graph.len())); invalidated_candidate.push(*id); @@ -107,6 +104,10 @@ impl OptimizationCache { } } + if let Some(id) = self.found { + return CacheResult::Found(&mut self.optimizations[id].value); + } + let mut updated_candidates = Vec::new(); core::mem::swap(&mut updated_candidates, &mut self.candidates); @@ -136,7 +137,7 @@ impl OptimizationCache { factory: &Factory, graph: Vec, next_ops: Option, - ) -> &'a O { + ) -> &'a mut O { let existing_optim = self .availables .iter() @@ -149,7 +150,7 @@ impl OptimizationCache { optimization.end_conditions.push(ops) }; - return &optimization.value; + return &mut optimization.value; }; self.starters @@ -164,7 +165,9 @@ impl OptimizationCache { }; self.optimizations.push(optimization); - &self.optimizations.last().unwrap().value + + let last_index = self.optimizations.len() - 1; + &mut self.optimizations[last_index].value } // Signal that a new path will begin. @@ -188,7 +191,7 @@ pub enum CacheResult<'a, T> { /// happens. OnPath, /// An optimization has been found, and the best action is to execute it! - Found(&'a T), + Found(&'a mut T), } /// When checking if an optimization is possible, a start or an end condition ensures that this optimization is diff --git a/burn-fusion/src/ops/float.rs b/burn-fusion/src/ops/float.rs index 3688f97bd..3ef997aa5 100644 --- a/burn-fusion/src/ops/float.rs +++ b/burn-fusion/src/ops/float.rs @@ -265,48 +265,6 @@ impl TensorOps for Fusion { out } - fn clamp_min( - tensor: FloatTensor, - min: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(ClampMinOps, B::clamp_min); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - let desc = ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: min.elem(), - out: out.to_description_out(), - }; - out.client.register( - TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMin(desc.clone())), - ClampMinOps::::new(desc), - ); - - out - } - - fn clamp_max( - tensor: FloatTensor, - max: FloatElem, - ) -> FloatTensor { - scalar_float_ops!(ClampMaxOps, B::clamp_max); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - let desc = ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: max.elem(), - out: out.to_description_out(), - }; - out.client.register( - TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMax(desc.clone())), - ClampMaxOps::::new(desc), - ); - - out - } - fn clamp( tensor: FloatTensor, min: FloatElem, diff --git a/burn-fusion/src/ops/int.rs b/burn-fusion/src/ops/int.rs index 37a80dacc..33289fea3 100644 --- a/burn-fusion/src/ops/int.rs +++ b/burn-fusion/src/ops/int.rs @@ -1034,48 +1034,6 @@ impl IntTensorOps for Fusion { out } - fn int_clamp_min( - tensor: IntTensor, - min: IntElem, - ) -> IntTensor { - scalar_int_ops!(ClampMinOps, B::int_clamp_min); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - let desc = ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: min.elem(), - out: out.to_description_out(), - }; - out.client.register( - TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMin(desc.clone())), - ClampMinOps::::new(desc), - ); - - out - } - - fn int_clamp_max( - tensor: IntTensor, - max: IntElem, - ) -> IntTensor { - scalar_int_ops!(ClampMaxOps, B::int_clamp_max); - - let out = tensor.client.tensor_uninitialized(tensor.shape.clone()); - - let desc = ScalarOpsDescription { - lhs: tensor.into_description(), - rhs: max.elem(), - out: out.to_description_out(), - }; - out.client.register( - TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMax(desc.clone())), - ClampMaxOps::::new(desc), - ); - - out - } - fn int_clamp( tensor: IntTensor, min: IntElem, diff --git a/burn-wgpu/src/fusion/codegen/body.rs b/burn-wgpu/src/codegen/body.rs similarity index 97% rename from burn-wgpu/src/fusion/codegen/body.rs rename to burn-wgpu/src/codegen/body.rs index cab35bf75..fd8e2fb1e 100644 --- a/burn-wgpu/src/fusion/codegen/body.rs +++ b/burn-wgpu/src/codegen/body.rs @@ -5,7 +5,7 @@ use std::fmt::Display; /// /// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size /// X and Y, but with Z=1. -#[derive(Hash, new)] +#[derive(new)] pub struct Body { operators: Vec, } diff --git a/burn-wgpu/src/fusion/codegen/function.rs b/burn-wgpu/src/codegen/function.rs similarity index 98% rename from burn-wgpu/src/fusion/codegen/function.rs rename to burn-wgpu/src/codegen/function.rs index fceae4e39..ceddb95e0 100644 --- a/burn-wgpu/src/fusion/codegen/function.rs +++ b/burn-wgpu/src/codegen/function.rs @@ -2,7 +2,7 @@ use super::Elem; use std::fmt::Display; /// Not all functions are native to WGSL, so this struct allows to support more functions. -#[derive(Hash, PartialEq, Eq, Clone)] +#[derive(PartialEq, Eq, Clone)] pub enum Function { Powf(Elem), Erf(Elem), diff --git a/burn-wgpu/src/codegen/kernel.rs b/burn-wgpu/src/codegen/kernel.rs new file mode 100644 index 000000000..3d9dea1ba --- /dev/null +++ b/burn-wgpu/src/codegen/kernel.rs @@ -0,0 +1,359 @@ +use crate::codegen::{ + Binding, Body, ComputeShader, Elem, Function, Location, Operator, Variable, Visibility, + WorkgroupSize, +}; +use crate::compute::{StaticKernel, WgpuComputeClient, WgpuHandle}; +use crate::element::WgpuElement; +use crate::kernel::{elemwise_workgroup, StaticKernelSource, WORKGROUP_DEFAULT}; +use std::marker::PhantomData; + +/// Kernel creation input phase, see [kernel codegen](ElemWiseKernelCodegen) for more details. +pub struct InputPhase; +/// Kernel creation body phase, see [kernel codegen](ElemWiseKernelCodegen) for more details. +pub struct BodyPhase; +/// Kernel creation output phase, see [kernel codegen](ElemWiseKernelCodegen) for more details. +pub struct OutputPhase; +/// Kernel compilation phase, see [kernel codegen](ElemWiseKernelCodegen) for more details. +pub struct CompilationPhase; + +/// Allows to create custom wgsl kernels based on configured inputs, body and outputs. +/// +/// This type has 4 phases that must be executed in order, but no worry the type system won't allow +/// you to make mistakes. +/// +/// 1. [Input Phase](InputPhase) +/// This phase focuses on registering the input arrays and scalars that are going to be used by +/// the kernel. +/// 2. [Body Phase](BodyPhase) +/// After the input phase is done, all the operations that happen in the body must be +/// registered. +/// 3. [Output Phase](OutputPhase) +/// This step focuses on registering all output arrays or inputs that the kernel needs to write to. +/// 4. [Compilation Phase](CompilationPhase) +/// Now that all other phases are completed, we can actually compile the kernel. +pub struct ElemWiseKernelCodegen { + operations: Vec, + input_bindings: Vec, + output_bindings: Vec, + named_bindings: Vec<(String, Binding)>, + functions: Vec, + _phase: PhantomData, +} + +pub enum Input { + Array { + elem: Elem, + visibility: Visibility, + strategy: ReadingStrategy, + }, + Scalar { + elem: Elem, + size: usize, + }, +} + +pub enum ReadingStrategy { + IntoContiguous, + Plain, +} + +pub enum Output { + Array { elem: Elem, local: u16 }, + Input { elem: Elem, input: u16, local: u16 }, +} + +impl ElemWiseKernelCodegen { + /// Create a new fusion kernel on the given device. + pub fn new() -> Self { + Self { + operations: Vec::new(), + input_bindings: Vec::new(), + output_bindings: Vec::new(), + named_bindings: Vec::new(), + functions: Vec::new(), + _phase: PhantomData, + } + } + + /// Register the inputs used by the kernel. + pub fn inputs(mut self, inputs: &[Input]) -> ElemWiseKernelCodegen { + let mut index: u16 = 0; + + let first_output_index = inputs + .iter() + .filter(|input| match input { + Input::Array { + elem: _, + visibility: _, + strategy: _, + } => true, + Input::Scalar { elem: _, size: _ } => false, + }) + .count(); + + for input in inputs { + match input { + Input::Array { + elem, + visibility, + strategy, + } => { + self.input_bindings.push(Binding { + elem: bool_elem(*elem), + visibility: *visibility, + location: Location::Storage, + size: None, + }); + + match strategy { + ReadingStrategy::IntoContiguous => { + self.operations.push(Operator::ReadGlobalIntoContiguous { + variable: Variable::Input(index, *elem), + position: index as usize, + position_out: first_output_index, // First output + }); + } + ReadingStrategy::Plain => { + self.operations.push(Operator::ReadGlobal { + variable: Variable::Input(index, *elem), + }); + } + } + + index += 1; + } + Input::Scalar { elem, size } => { + let elem = bool_elem(*elem); + + self.named_bindings.push(( + format!("scalars_{}", elem), + Binding { + elem, + visibility: Visibility::Read, + location: Location::Storage, + size: Some(*size), + }, + )); + } + } + } + + ElemWiseKernelCodegen { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + _phase: PhantomData, + } + } +} + +impl ElemWiseKernelCodegen { + /// Register the [operators](Operator) that the kernel must execute in the order provided. + pub fn body(mut self, operators: &[Operator]) -> ElemWiseKernelCodegen { + let mut register_function = |function: Function| { + if !self.functions.contains(&function) { + self.functions.push(function); + } + }; + + // Since not all operators are native to WGSL, we need to add the custom ones. + for ops in operators.iter() { + match ops { + Operator::Powf { + lhs: _, + rhs: _, + out: _, + } => { + register_function(Function::Powf(Elem::F32)); + } + Operator::Erf { input: _, out: _ } => { + register_function(Function::Erf(Elem::F32)); + } + _ => {} + } + self.operations.push(ops.clone()); + } + + ElemWiseKernelCodegen { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + _phase: PhantomData, + } + } +} + +impl ElemWiseKernelCodegen { + /// Register the outputs with their local variable index. + /// + /// Note that the index corresponds to the registered [operator](Operator) number at the + /// [body phase](BodyPhase). + /// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0). + pub fn outputs(mut self, outputs: &[Output]) -> ElemWiseKernelCodegen { + let mut index = 0; + + for array in outputs { + match array { + Output::Array { elem, local } => { + let elem_adapted = bool_elem(*elem); + + self.output_bindings.push(Binding { + elem: elem_adapted, + visibility: Visibility::ReadWrite, + location: Location::Storage, + size: None, + }); + self.operations.push(Operator::AssignGlobal { + input: Variable::Local(*local, *elem), + out: Variable::Output(index, elem_adapted), + }); + index += 1; + } + Output::Input { elem, input, local } => { + self.operations.push(Operator::AssignGlobal { + input: Variable::Local(*local, *elem), + out: Variable::Input(*input, bool_elem(*elem)), + }); + } + } + } + + ElemWiseKernelCodegen { + operations: self.operations, + input_bindings: self.input_bindings, + output_bindings: self.output_bindings, + named_bindings: self.named_bindings, + functions: self.functions, + _phase: PhantomData, + } + } +} + +impl ElemWiseKernelCodegen { + /// Compile the kernel into a [compute shader](ComputeShader). + pub fn compile(self) -> ComputeShader { + let inputs = self.input_bindings; + let outputs = self.output_bindings; + let mut named = Vec::with_capacity(2); + + named.push(( + "info".to_string(), + Binding { + elem: Elem::U32, + visibility: Visibility::Read, + location: Location::Storage, + size: None, // We avoid putting the length here since it will force a new kernel + // for each tensor rank. + }, + )); + + for (name, binding) in self.named_bindings.into_iter() { + named.push((name, binding)); + } + + ComputeShader { + inputs, + outputs, + named, + workgroup_size: WorkgroupSize::default(), + body: Body::new(self.operations), + num_workgroups: true, + global_invocation_id: true, + functions: self.functions, + } + } +} + +#[derive(new)] +pub struct StaticHandle<'a> { + handle: &'a WgpuHandle, + strides: &'a [usize], + shape: &'a [usize], +} + +/// Execute a static kernel. +/// +/// +/// The limitation from this method is that you can't launch a kernel with multiple types of +/// scalar. +pub fn execute_static( + inputs: &[StaticHandle], + outputs: &[StaticHandle], + scalar_elems: Option<&[E]>, + client: WgpuComputeClient, +) where + K: StaticKernelSource + 'static, +{ + let mut info = Vec::new(); + let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2); + + // Inner function to fill the info buffer. + let mut register_info_tensor = |strides: &[usize], shape: &[usize]| { + if info.is_empty() { + info.push(strides.len() as u32); + } + + for s in strides.iter() { + info.push(*s as u32); + } + for s in shape.iter() { + info.push(*s as u32); + } + }; + + // We start by registering the inputs. + for input in inputs.iter() { + register_info_tensor(input.strides, input.shape); + handles.push(input.handle); + } + + let mut num_elems_output = 0; + + // Then we follow with the outputs. + for output in outputs.iter() { + let num_elems = calculate_num_elems_dyn_rank(output.shape); + if num_elems > num_elems_output { + num_elems_output = num_elems; + } + register_info_tensor(output.strides, output.shape); + handles.push(output.handle); + } + + let info = &client.create(bytemuck::cast_slice(&info)); + handles.push(info); + + // Finally we finish with the named bindings. + let mut scalars = None; + if let Some(values) = &scalar_elems { + scalars = Some(client.create(bytemuck::cast_slice(values))); + } + + if let Some(scalars) = scalars.as_ref() { + handles.push(scalars); + } + + let workgroup = elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT); + let kernel = Box::new(StaticKernel::::new(workgroup)); + + client.execute(kernel, &handles); +} + +pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize { + let mut num_elems = 1; + for i in shape.iter() { + num_elems *= i; + } + num_elems +} + +fn bool_elem(elem: Elem) -> Elem { + match elem { + // I32 are used for bool tensors + Elem::Bool => Elem::I32, + _ => elem, + } +} diff --git a/burn-wgpu/src/codegen/mod.rs b/burn-wgpu/src/codegen/mod.rs new file mode 100644 index 000000000..a205841dc --- /dev/null +++ b/burn-wgpu/src/codegen/mod.rs @@ -0,0 +1,13 @@ +mod body; +mod function; +mod kernel; +mod operator; +mod shader; +mod variable; + +pub(crate) use body::*; +pub(crate) use function::*; +pub(crate) use kernel::*; +pub(crate) use operator::*; +pub(crate) use shader::*; +pub(crate) use variable::*; diff --git a/burn-wgpu/src/fusion/codegen/operator.rs b/burn-wgpu/src/codegen/operator.rs similarity index 78% rename from burn-wgpu/src/fusion/codegen/operator.rs rename to burn-wgpu/src/codegen/operator.rs index 2341b8ecf..e34f2e7e9 100644 --- a/burn-wgpu/src/fusion/codegen/operator.rs +++ b/burn-wgpu/src/codegen/operator.rs @@ -1,8 +1,9 @@ -use super::Variable; +use super::variable::Variable; use std::fmt::Display; /// All operators that can be fused in a WGSL compute shader. -#[derive(Debug, Hash, Clone)] +#[derive(Debug, Clone)] +#[allow(dead_code)] // Some variants might not be used with different flags pub enum Operator { Add { lhs: Variable, @@ -57,6 +58,10 @@ pub enum Operator { rhs: Variable, out: Variable, }, + Sqrt { + input: Variable, + out: Variable, + }, Erf { input: Variable, out: Variable, @@ -75,6 +80,12 @@ pub enum Operator { rhs: Variable, out: Variable, }, + Clamp { + input: Variable, + min_value: Variable, + max_value: Variable, + out: Variable, + }, Greater { lhs: Variable, rhs: Variable, @@ -100,8 +111,15 @@ pub enum Operator { input: Variable, out: Variable, }, + AssignLocal { + input: Variable, + out: Variable, + }, ReadGlobal { variable: Variable, + }, + ReadGlobalIntoContiguous { + variable: Variable, position: usize, position_out: usize, }, @@ -125,9 +143,20 @@ impl Display for Operator { Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")), Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")), Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")), + Operator::Clamp { + input, + min_value, + max_value, + out, + } => f.write_fmt(format_args!( + "let {out} = clamp({input}, {min_value}, {max_value});" + )), Operator::Powf { lhs, rhs, out } => { f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});")) } + Operator::Sqrt { input, out } => { + f.write_fmt(format_args!("let {out} = sqrt({input});")) + } Operator::Log1p { input, out } => { f.write_fmt(format_args!("let {out} = log({input} + 1.0);")) } @@ -159,7 +188,21 @@ impl Display for Operator { let elem = out.elem(); f.write_fmt(format_args!("{out}_global[id] = {elem}({input});")) } - Operator::ReadGlobal { + Operator::AssignLocal { input, out } => { + let elem = out.elem(); + f.write_fmt(format_args!("let {out} = {elem}({input});")) + } + Operator::ReadGlobal { variable } => match variable { + Variable::Input(number, _elem) => f.write_fmt(format_args!( + "let input_{number} = input_{number}_global[id];" + )), + Variable::Local(_, _) => panic!("can't read global local variable."), + Variable::Output(number, _elem) => f.write_fmt(format_args!( + "let output_{number} = output_{number}_global[id];" + )), + Variable::Scalar(_, _) => panic!("Can't read global scalar variable."), + }, + Operator::ReadGlobalIntoContiguous { variable, position, position_out, diff --git a/burn-wgpu/src/fusion/codegen/shader.rs b/burn-wgpu/src/codegen/shader.rs similarity index 86% rename from burn-wgpu/src/fusion/codegen/shader.rs rename to burn-wgpu/src/codegen/shader.rs index 62a80ea30..963317218 100644 --- a/burn-wgpu/src/fusion/codegen/shader.rs +++ b/burn-wgpu/src/codegen/shader.rs @@ -1,34 +1,29 @@ use super::{Body, Function}; -use crate::kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT}; -use std::{ - collections::hash_map::DefaultHasher, - fmt::Display, - hash::{Hash, Hasher}, -}; +use crate::kernel::WORKGROUP_DEFAULT; +use std::fmt::Display; -#[derive(Hash, PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone, Copy)] pub enum Location { Storage, #[allow(dead_code)] Workgroup, } -#[derive(Hash, PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone, Copy)] pub enum Visibility { Read, ReadWrite, } -#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)] +#[derive(Debug, Clone, PartialEq, Eq, Copy)] pub enum Elem { F32, - #[allow(dead_code)] I32, U32, Bool, } -#[derive(Hash, PartialEq, Eq)] +#[derive(PartialEq, Eq, Clone)] pub struct Binding { pub location: Location, pub visibility: Visibility, @@ -36,7 +31,7 @@ pub struct Binding { pub size: Option, } -#[derive(Hash, PartialEq, Eq)] +#[derive(PartialEq, Eq)] pub struct WorkgroupSize { pub x: usize, pub y: usize, @@ -53,7 +48,6 @@ impl Default for WorkgroupSize { } } -#[derive(Hash)] pub struct ComputeShader { pub inputs: Vec, pub outputs: Vec, @@ -65,19 +59,6 @@ pub struct ComputeShader { pub functions: Vec, } -impl DynamicKernelSource for ComputeShader { - fn source(&self) -> SourceTemplate { - SourceTemplate::new(self.to_string()) - } - - fn id(&self) -> String { - let mut s = DefaultHasher::new(); - self.hash(&mut s); - - s.finish().to_string() - } -} - impl Display for ComputeShader { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { Self::format_bindings(f, "input", &self.inputs, 0)?; diff --git a/burn-wgpu/src/fusion/codegen/variable.rs b/burn-wgpu/src/codegen/variable.rs similarity index 96% rename from burn-wgpu/src/fusion/codegen/variable.rs rename to burn-wgpu/src/codegen/variable.rs index 837d41646..7331c8b67 100644 --- a/burn-wgpu/src/fusion/codegen/variable.rs +++ b/burn-wgpu/src/codegen/variable.rs @@ -1,7 +1,7 @@ use super::Elem; use std::fmt::Display; -#[derive(Debug, Hash, Clone)] +#[derive(Debug, Clone)] pub enum Variable { Input(u16, Elem), Scalar(u16, Elem), diff --git a/burn-wgpu/src/element.rs b/burn-wgpu/src/element.rs index 61f9ed3f9..59609e842 100644 --- a/burn-wgpu/src/element.rs +++ b/burn-wgpu/src/element.rs @@ -9,8 +9,7 @@ where fn type_name() -> &'static str; fn as_bytes(slice: &[Self]) -> &[u8]; fn from_bytes(bytes: &[u8]) -> &[Self]; - #[cfg(any(feature = "fusion", test))] - fn elem_type() -> crate::fusion::codegen::Elem; + fn elem_type() -> crate::codegen::Elem; } /// The float element type for the wgpu backend. @@ -29,9 +28,8 @@ impl WgpuElement for u32 { fn from_bytes(bytes: &[u8]) -> &[Self] { bytemuck::cast_slice(bytes) } - #[cfg(any(feature = "fusion", test))] - fn elem_type() -> crate::fusion::codegen::Elem { - crate::fusion::codegen::Elem::U32 + fn elem_type() -> crate::codegen::Elem { + crate::codegen::Elem::U32 } } @@ -45,9 +43,8 @@ impl WgpuElement for i32 { fn from_bytes(bytes: &[u8]) -> &[Self] { bytemuck::cast_slice(bytes) } - #[cfg(any(feature = "fusion", test))] - fn elem_type() -> crate::fusion::codegen::Elem { - crate::fusion::codegen::Elem::I32 + fn elem_type() -> crate::codegen::Elem { + crate::codegen::Elem::I32 } } @@ -62,9 +59,8 @@ impl WgpuElement for f32 { bytemuck::cast_slice(bytes) } - #[cfg(any(feature = "fusion", test))] - fn elem_type() -> crate::fusion::codegen::Elem { - crate::fusion::codegen::Elem::F32 + fn elem_type() -> crate::codegen::Elem { + crate::codegen::Elem::F32 } } diff --git a/burn-wgpu/src/fusion/base.rs b/burn-wgpu/src/fusion/base.rs index 5be86ebc5..975e94d1e 100644 --- a/burn-wgpu/src/fusion/base.rs +++ b/burn-wgpu/src/fusion/base.rs @@ -81,14 +81,6 @@ pub fn strides_dyn_rank(shape: &[usize]) -> Vec { strides } -pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize { - let mut num_elems = 1; - for i in shape.iter() { - num_elems *= i; - } - num_elems -} - #[derive(new, Debug, Clone)] /// Handle to be used when fusing operations. pub struct WgpuFusionHandle { diff --git a/burn-wgpu/src/fusion/cache.rs b/burn-wgpu/src/fusion/cache.rs new file mode 100644 index 000000000..6c7bc63b1 --- /dev/null +++ b/burn-wgpu/src/fusion/cache.rs @@ -0,0 +1,57 @@ +use crate::{ + codegen::ComputeShader, + kernel::{DynamicKernelSource, SourceTemplate}, +}; +use hashbrown::HashSet; + +/// This cache ensures that the generation of the source code is only done once when the kernel is +/// executed for the first time. Following, we only include the ID in the dynamic kernel source, +/// since we rely on the compilation cache of the WGPU compute server. +/// +/// If it ever causes problems, we could cache the compute shader and put it into an Arc to avoid deep +/// cloning. +#[derive(Default, Debug)] +pub struct KernelCompilationCache { + already_compiled_ids: HashSet, +} + +#[derive(new)] +pub enum FusedKernelSource { + AlreadyCompiled { id: String }, + NewKernel { id: String, shader: ComputeShader }, +} + +impl DynamicKernelSource for FusedKernelSource { + fn source(&self) -> SourceTemplate { + match self { + FusedKernelSource::AlreadyCompiled { id: _ } => { + panic!("Can't get the source of an already compiled kernel.") + } + FusedKernelSource::NewKernel { + id: _, + shader: source, + } => SourceTemplate::new(source.to_string()), + } + } + + fn id(&self) -> String { + match self { + FusedKernelSource::AlreadyCompiled { id } => id.clone(), + FusedKernelSource::NewKernel { id, shader: _ } => id.clone(), + } + } +} + +impl KernelCompilationCache { + pub fn get(&self, id: &str) -> Option { + if self.already_compiled_ids.contains(id) { + return Some(FusedKernelSource::AlreadyCompiled { id: id.to_string() }); + } + + None + } + + pub fn insert(&mut self, id: String) { + self.already_compiled_ids.insert(id); + } +} diff --git a/burn-wgpu/src/fusion/codegen/mod.rs b/burn-wgpu/src/fusion/codegen/mod.rs deleted file mode 100644 index b9b568837..000000000 --- a/burn-wgpu/src/fusion/codegen/mod.rs +++ /dev/null @@ -1,11 +0,0 @@ -mod body; -mod function; -mod operator; -mod shader; -mod variable; - -pub use body::*; -pub use function::*; -pub use operator::*; -pub use shader::*; -pub use variable::*; diff --git a/burn-wgpu/src/fusion/elemwise/builder.rs b/burn-wgpu/src/fusion/elemwise/builder.rs index b8954cc79..5a0a813a5 100644 --- a/burn-wgpu/src/fusion/elemwise/builder.rs +++ b/burn-wgpu/src/fusion/elemwise/builder.rs @@ -1,8 +1,10 @@ use crate::{ + codegen::{Elem, Operator, Variable}, element::WgpuElement, - fusion::codegen::{Elem, Operator, Variable}, + fusion::cache::KernelCompilationCache, FloatElement, GraphicsApi, IntElement, Wgpu, }; +use burn_common::id::IdGenerator; use burn_fusion::{ graph::{ BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription, @@ -84,12 +86,14 @@ where .collect::>(); Box::new(FloatElementWise { + id: IdGenerator::generate(), inputs, outputs, locals, operators: self.operators.clone(), scalars_f32: self.scalars_f32, device: self.device.clone(), + cache: KernelCompilationCache::default(), }) } @@ -183,13 +187,19 @@ where Operator::AssignGlobal { input: _, out: _ } => { // Nothing to do here. } - Operator::ReadGlobal { + Operator::AssignLocal { input: _, out: _ } => { + // Nothing to do here. + } + Operator::ReadGlobalIntoContiguous { variable: _, position: _, position_out: _, } => { // Nothing to do here. } + Operator::ReadGlobal { variable: _ } => { + // Nothing to do here. + } Operator::Add { lhs, rhs, out } => { mark(lhs, &mut local_tensor_ids_input); mark(rhs, &mut local_tensor_ids_input); @@ -242,6 +252,15 @@ where mark(input, &mut local_tensor_ids_input); mark(out, &mut local_tensor_ids_output); } + Operator::Clamp { + input, + min_value: _, + max_value: _, + out, + } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } Operator::Powf { lhs, rhs, out } => { mark(lhs, &mut local_tensor_ids_input); mark(rhs, &mut local_tensor_ids_input); @@ -287,6 +306,10 @@ where mark(rhs, &mut local_tensor_ids_input); mark(out, &mut local_tensor_ids_output); } + Operator::Sqrt { input, out } => { + mark(input, &mut local_tensor_ids_input); + mark(out, &mut local_tensor_ids_output); + } } } diff --git a/burn-wgpu/src/fusion/elemwise/optimization.rs b/burn-wgpu/src/fusion/elemwise/optimization.rs index ff8f401c3..588187010 100644 --- a/burn-wgpu/src/fusion/elemwise/optimization.rs +++ b/burn-wgpu/src/fusion/elemwise/optimization.rs @@ -1,24 +1,73 @@ use crate::{ - fusion::codegen::{Elem, Operator}, - fusion::kernel::FusionKernel, + codegen::{ + ComputeShader, Elem, ElemWiseKernelCodegen, Input, Operator, Output, ReadingStrategy, + Visibility, + }, + fusion::{ + cache::{FusedKernelSource, KernelCompilationCache}, + kernel, + }, FloatElement, GraphicsApi, IntElement, Wgpu, }; use burn_fusion::{graph::Context, Optimization, TensorDescription}; use burn_tensor::Device; -#[derive(Clone)] pub(crate) struct FloatElementWise where G: GraphicsApi, F: FloatElement, I: IntElement, { + pub(crate) id: String, pub(crate) inputs: Vec<(TensorDescription, Elem)>, pub(crate) outputs: Vec<(TensorDescription, Elem)>, pub(crate) locals: Vec, pub(crate) operators: Vec, pub(crate) scalars_f32: usize, pub(crate) device: Device>, + pub(crate) cache: KernelCompilationCache, +} + +impl FloatElementWise +where + G: GraphicsApi, + F: FloatElement, + I: IntElement, +{ + pub fn compile(&mut self) -> ComputeShader { + let mut inputs = self + .inputs + .iter() + .map(|(_tensor, elem)| Input::Array { + elem: *elem, + visibility: Visibility::Read, + strategy: ReadingStrategy::IntoContiguous, + }) + .collect::>(); + + let outputs = self + .outputs + .iter() + .zip(self.locals.iter()) + .map(|((_tensor, elem), local)| Output::Array { + elem: *elem, + local: *local, + }) + .collect::>(); + + if self.scalars_f32 > 0 { + inputs.push(Input::Scalar { + elem: Elem::F32, + size: self.scalars_f32, + }) + } + + ElemWiseKernelCodegen::new() + .inputs(&inputs) + .body(&self.operators) + .outputs(&outputs) + .compile() + } } impl Optimization> for FloatElementWise @@ -27,27 +76,33 @@ where F: FloatElement, I: IntElement, { - fn execute(&self, context: &mut Context<'_, Wgpu>) { - let inputs = self - .inputs - .iter() - .map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem)) - .collect::>(); + fn execute(&mut self, context: &mut Context<'_, Wgpu>) { + if let Some(kernel) = self.cache.get(&self.id) { + kernel::execute_fusion( + &self.inputs.iter().map(|a| &a.0).collect::>(), + &self.outputs.iter().map(|a| &a.0).collect::>(), + self.scalars_f32, + kernel, + context, + self.device.clone(), + ); + } else { + let shader = self.compile(); - let outputs = self - .outputs - .iter() - .map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem)) - .collect::>(); + kernel::execute_fusion( + &self.inputs.iter().map(|a| &a.0).collect::>(), + &self.outputs.iter().map(|a| &a.0).collect::>(), + self.scalars_f32, + FusedKernelSource::NewKernel { + id: self.id.to_string(), + shader, + }, + context, + self.device.clone(), + ); - // The context may contain scalars for the end condition, which may vary. - let scalars_f32 = &context.scalar_floats[0..self.scalars_f32]; - - FusionKernel::new(&self.device) - .inputs(&inputs, scalars_f32) - .body(&self.operators) - .outputs(&outputs, &self.locals) - .execute(context.handles); + self.cache.insert(self.id.clone()); + } } fn len(&self) -> usize { @@ -144,6 +199,7 @@ mod tests { Variant1, Variant2, } + fn execute( data_1: Data, data_2: Data, diff --git a/burn-wgpu/src/fusion/kernel.rs b/burn-wgpu/src/fusion/kernel.rs index 4e9059cd4..8075c0824 100644 --- a/burn-wgpu/src/fusion/kernel.rs +++ b/burn-wgpu/src/fusion/kernel.rs @@ -1,355 +1,82 @@ -use super::codegen::Body; -use crate::compute::{compute_client, DynamicKernel, WgpuComputeClient}; -use crate::fusion::codegen::Function; -use crate::fusion::{calculate_num_elems_dyn_rank, strides_dyn_rank}; -use crate::fusion::{ - codegen::{ - Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize, - }, - WgpuFusionHandle, -}; +use super::cache::FusedKernelSource; +use crate::codegen::calculate_num_elems_dyn_rank; +use crate::compute::{compute_client, DynamicKernel}; +use crate::fusion::strides_dyn_rank; +use crate::fusion::WgpuFusionHandle; use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT}; use crate::{FloatElement, GraphicsApi, IntElement, Wgpu}; -use burn_fusion::{HandleContainer, TensorDescription}; +use burn_fusion::graph::Context; +use burn_fusion::TensorDescription; use burn_tensor::Device; -use std::marker::PhantomData; -/// Kernel creation input phase, see [fusion kernel](FusionKernel) for more details. -pub struct InputPhase; -/// Kernel creation body phase, see [fusion kernel](FusionKernel) for more details. -pub struct BodyPhase; -/// Kernel creation output phase, see [fusion kernel](FusionKernel) for more details. -pub struct OutputPhase; -/// Kernel execution phase, see [fusion kernel](FusionKernel) for more details. -pub struct ExecutionPhase; - -/// Allows to create custom wgsl kernels based on configured inputs, body and outputs. -/// -/// This type has 4 phases that must be executed in order, but no worry the type system won't allow -/// you to make mistakes. -/// -/// 1. [Input Phase](InputPhase) -/// This phase focuses on registering the input tensor descriptions that are going to be used by -/// the fused kernel. -/// 2. [Body Phase](BodyPhase) -/// After the input phase is done, all the operations that happen in the body must be -/// registered. -/// 3. [Output Phase](OutputPhase) -/// This step focuses on registering all tensor descriptions that the kernel needs to write to. -/// 4. [Execution Phase](ExecutionPhase) -/// Now that all other phases are completed, we can actually run the kernel on the given -/// [handles](HandleContainer). Note that the actual chosen kernel may vary based on the -/// handles provided. -pub struct FusionKernel -where - G: GraphicsApi, - F: FloatElement, - I: IntElement, -{ - operations: Vec, - input_bindings: Vec<(Binding, TensorDescription)>, - output_bindings: Vec<(Binding, TensorDescription)>, - named_bindings: Vec<(String, Binding, DataBuffer)>, - functions: Vec, - num_elems_output: usize, +pub fn execute_fusion( + inputs: &[&TensorDescription], + outputs: &[&TensorDescription], + scalars_f32: usize, + kernel: FusedKernelSource, + context: &mut Context<'_, Wgpu>, device: Device>, - client: WgpuComputeClient, - _phase: PhantomData, -} +) { + let client = compute_client::(&device); + let mut info = Vec::new(); + let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2); -enum DataBuffer { - F32(Vec), - U32(Vec), -} + // Inner function to fill the info buffer. + let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| { + if info.is_empty() { + info.push(handle.strides.len() as u32); + } -impl FusionKernel { - /// Create a new fusion kernel on the given device. - pub fn new(device: &Device>) -> Self { - let client = compute_client::(device); + for s in handle.strides.iter() { + info.push(*s as u32); + } + for s in tensor.shape.iter() { + info.push(*s as u32); + } + }; - Self { - operations: Vec::new(), - input_bindings: Vec::new(), - output_bindings: Vec::new(), - named_bindings: Vec::new(), - functions: Vec::new(), - num_elems_output: 0, + // We start by registering the inputs. + for tensor in inputs.iter() { + let tensor = context.tensors.get(&tensor.id).unwrap(); + let handle = context.handles.get_handle(tensor); + + register_info_tensor(tensor, &handle); + handles.push(handle.handle); + } + + let mut num_elems_output = 0; + + // Then we follow with the outputs. + for tensor in outputs.iter() { + let tensor = context.tensors.get(&tensor.id).unwrap(); + + let num_elems = calculate_num_elems_dyn_rank(&tensor.shape); + if num_elems > num_elems_output { + num_elems_output = num_elems; + } + let handle_fusion = WgpuFusionHandle { + client: client.clone(), device: device.clone(), - client, - _phase: PhantomData, - } - } - - /// Register the inputs used by the kernel. - pub fn inputs( - mut self, - inputs_tensor: &[(&TensorDescription, Elem)], - inputs_scalar_f32: &[f32], - ) -> FusionKernel { - for (i, (input, elem)) in inputs_tensor.iter().enumerate() { - if elem != &Elem::Bool { - self.input_bindings.push(( - Binding { - elem: *elem, - visibility: Visibility::Read, - location: Location::Storage, - size: None, - }, - (*input).clone(), - )); - - self.operations.push(Operator::ReadGlobal { - variable: Variable::Input(i as u16, *elem), - position: i, - position_out: inputs_tensor.len(), // First output - }); - } else { - self.input_bindings.push(( - Binding { - elem: Elem::I32, - visibility: Visibility::Read, - location: Location::Storage, - size: None, - }, - (*input).clone(), - )); - - self.operations.push(Operator::ReadGlobal { - variable: Variable::Input(i as u16, *elem), - position: i, - position_out: inputs_tensor.len(), // First output - }); - } - } - - if !inputs_scalar_f32.is_empty() { - self.named_bindings.push(( - "scalars_f32".to_string(), - Binding { - elem: Elem::F32, - visibility: Visibility::Read, - location: Location::Storage, - size: Some(inputs_scalar_f32.len()), - }, - DataBuffer::F32(inputs_scalar_f32.to_vec()), - )); - } - - FusionKernel { - operations: self.operations, - input_bindings: self.input_bindings, - output_bindings: self.output_bindings, - named_bindings: self.named_bindings, - functions: self.functions, - num_elems_output: self.num_elems_output, - device: self.device, - client: self.client, - _phase: PhantomData, - } - } -} - -impl FusionKernel { - /// Register the [operators](Operator) that the kernel must execute in the order provided. - pub fn body(mut self, operators: &[Operator]) -> FusionKernel { - let mut register_function = |function: Function| { - if !self.functions.contains(&function) { - self.functions.push(function); - } + strides: strides_dyn_rank(&tensor.shape), + handle: client.empty(core::mem::size_of::() * num_elems), }; - // Since not all operators are native to WGSL, we need to add the custom ones. - for ops in operators.iter() { - match ops { - Operator::Powf { - lhs: _, - rhs: _, - out: _, - } => { - register_function(Function::Powf(Elem::F32)); - } - Operator::Erf { input: _, out: _ } => { - register_function(Function::Erf(Elem::F32)); - } - _ => {} - } - self.operations.push(ops.clone()); - } + register_info_tensor(tensor, &handle_fusion); - FusionKernel { - operations: self.operations, - input_bindings: self.input_bindings, - output_bindings: self.output_bindings, - named_bindings: self.named_bindings, - functions: self.functions, - num_elems_output: self.num_elems_output, - device: self.device, - client: self.client, - _phase: PhantomData, - } - } -} - -impl FusionKernel { - /// Register the outputs with their local variable index. - /// - /// Note that the index corresponds to the registered [operator](Operator) number at the - /// [body phase](BodyPhase). - /// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0). - pub fn outputs( - mut self, - outputs: &[(&TensorDescription, Elem)], - locals: &[u16], - ) -> FusionKernel { - let mut num_elems_launch_option = 0; - - for (i, ((output, elem), local)) in outputs.iter().zip(locals).enumerate() { - let num_elems_output = calculate_num_elems_dyn_rank(&output.shape); - if num_elems_output > num_elems_launch_option { - num_elems_launch_option = num_elems_output; - } - - if elem != &Elem::Bool { - self.output_bindings.push(( - Binding { - elem: *elem, - visibility: Visibility::ReadWrite, - location: Location::Storage, - size: None, - }, - (*output).clone(), - )); - - self.operations.push(Operator::AssignGlobal { - input: Variable::Local(*local, *elem), - out: Variable::Output(i as u16, *elem), - }); - } else { - self.output_bindings.push(( - Binding { - elem: Elem::I32, // I32 are used for bool tensors - visibility: Visibility::ReadWrite, - location: Location::Storage, - size: None, - }, - (*output).clone(), - )); - - self.operations.push(Operator::AssignGlobal { - input: Variable::Local(*local, *elem), - out: Variable::Output(i as u16, Elem::I32), - }); - } - } - - self.num_elems_output = num_elems_launch_option; - - FusionKernel { - operations: self.operations, - input_bindings: self.input_bindings, - output_bindings: self.output_bindings, - named_bindings: self.named_bindings, - functions: self.functions, - num_elems_output: self.num_elems_output, - device: self.device, - client: self.client, - _phase: PhantomData, - } - } -} - -impl FusionKernel { - /// Execute the kernel on the provided [handles](HandleContainer). - pub fn execute(mut self, handle_container: &mut HandleContainer>) { - let mut inputs = Vec::with_capacity(self.input_bindings.len()); - let mut outputs = Vec::with_capacity(self.output_bindings.len()); - let mut named = Vec::with_capacity(2); - let mut info = Vec::new(); - let mut handles = - Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity()); - - // Inner function to fill the info buffer. - let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| { - if info.is_empty() { - info.push(handle.strides.len() as u32); - } - - for s in handle.strides.iter() { - info.push(*s as u32); - } - for s in tensor.shape.iter() { - info.push(*s as u32); - } - }; - - // We start by registering the inputs. - for (binding, tensor) in self.input_bindings.into_iter() { - let handle = handle_container.get_handle(&tensor); - register_info_tensor(&tensor, &handle); - - inputs.push(binding); - handles.push(handle.handle); - } - - // Then we follow with the outputs. - for (binding, tensor) in self.output_bindings { - let num_elems = calculate_num_elems_dyn_rank(&tensor.shape); - let handle_fusion = WgpuFusionHandle { - client: self.client.clone(), - device: self.device.clone(), - strides: strides_dyn_rank(&tensor.shape), - handle: self.client.empty(core::mem::size_of::() * num_elems), - }; - register_info_tensor(&tensor, &handle_fusion); - - handles.push(handle_fusion.handle.clone()); - handle_container.register_handle(tensor.id, handle_fusion); - outputs.push(binding); - } - - // Now we can create the info handle. - Self::build_info_handle(&mut self.named_bindings, info); - - // Finally we finish with the named bindings. - for (name, binding, data) in self.named_bindings { - let handle = self.client.create(match &data { - DataBuffer::F32(values) => bytemuck::cast_slice(values), - DataBuffer::U32(values) => bytemuck::cast_slice(values), - }); - named.push((name, binding)); - handles.push(handle); - } - - // We create the shader codegen type and launch the kernel. - let kernel = ComputeShader { - inputs, - outputs, - named, - workgroup_size: WorkgroupSize::default(), - body: Body::new(self.operations), - num_workgroups: true, - global_invocation_id: true, - functions: self.functions, - }; - - let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT); - let kernel = Box::new(DynamicKernel::new(kernel, workgroup)); - - self.client - .execute(kernel, &handles.iter().collect::>()); - } - - fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec) { - named_bindings.push(( - "info".to_string(), - Binding { - elem: Elem::U32, - visibility: Visibility::Read, - location: Location::Storage, - size: None, // We avoid putting the length here since it will force a new kernel - // for each tensor rank. - }, - DataBuffer::U32(info), - )); + handles.push(handle_fusion.handle.clone()); + context + .handles + .register_handle(tensor.id.clone(), handle_fusion); } + + handles.push(client.create(bytemuck::cast_slice(&info))); + + // Finally we finish with the named bindings. + if scalars_f32 > 0 { + handles.push(client.create(bytemuck::cast_slice(&context.scalar_floats[0..scalars_f32]))); + } + + let workgroup = elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT); + let kernel = Box::new(DynamicKernel::new(kernel, workgroup)); + client.execute(kernel, &handles.iter().collect::>()); } diff --git a/burn-wgpu/src/fusion/mod.rs b/burn-wgpu/src/fusion/mod.rs index 5e1fdd978..62af32e97 100644 --- a/burn-wgpu/src/fusion/mod.rs +++ b/burn-wgpu/src/fusion/mod.rs @@ -1,7 +1,7 @@ mod base; mod elemwise; -pub(crate) mod codegen; +pub(crate) mod cache; pub(crate) mod kernel; pub use base::*; diff --git a/burn-wgpu/src/kernel/clamp.rs b/burn-wgpu/src/kernel/clamp.rs index dcd774d8e..45f1fc946 100644 --- a/burn-wgpu/src/kernel/clamp.rs +++ b/burn-wgpu/src/kernel/clamp.rs @@ -1,78 +1,27 @@ +use super::unary; use crate::{ - compute::StaticKernel, + codegen::{Operator, Variable}, element::WgpuElement, - kernel::{unary_scalar, unary_scalar_inplace_default, WORKGROUP_DEFAULT}, - kernel_wgsl, - ops::numeric::empty_device, tensor::WgpuTensor, - unary_scalar, unary_scalar_inplace, + unary, }; -use super::{elemwise_workgroup, KernelSettings}; - -kernel_wgsl!(Clamp, "../template/clamp/clamp.wgsl"); -kernel_wgsl!(ClampInplace, "../template/clamp/clamp_inplace.wgsl"); - -pub(crate) fn clamp_min( - input: WgpuTensor, - min_value: E, -) -> WgpuTensor { - unary_scalar!(ClampMin, func "max"); - unary_scalar_inplace!(ClampMinInplace, func "max"); - - if input.can_mut() { - return unary_scalar_inplace_default::(input, min_value); - } - - unary_scalar::(input, min_value) -} - -pub(crate) fn clamp_max( - input: WgpuTensor, - max_value: E, -) -> WgpuTensor { - unary_scalar!(ClampMax, func "min"); - unary_scalar_inplace!(ClampMaxInPlace, func "min"); - - if input.can_mut() { - return unary_scalar_inplace_default::(input, max_value); - } - - unary_scalar::(input, max_value) -} +unary!( + |elem| Operator::Clamp { + input: Variable::Input(0, elem), + min_value: Variable::Scalar(0, elem), + max_value: Variable::Scalar(1, elem), + out: Variable::Local(0, elem), + }, + scalar 2 +); pub(crate) fn clamp( input: WgpuTensor, min_value: E, max_value: E, ) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let min_handle = input.client.create(E::as_bytes(&[min_value])); - let max_handle = input.client.create(E::as_bytes(&[max_value])); - - if input.can_mut() { - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - - input - .client - .execute(Box::new(kernel), &[&input.handle, &min_handle, &max_handle]); - - return input; - } - - let output = empty_device(input.client.clone(), input.device.clone(), input.shape); - let kernel = StaticKernel::< - KernelSettings, - >::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT)); - - input.client.execute( - Box::new(kernel), - &[&input.handle, &output.handle, &min_handle, &max_handle], - ); - - output + unary::, OpsInplace, E, D>(input, Some(&[min_value, max_value])) } #[cfg(test)] @@ -80,30 +29,6 @@ mod tests { use crate::tests::{ReferenceBackend, TestBackend}; use burn_tensor::{Distribution, Tensor}; - #[test] - fn clamp_min_should_match_reference() { - let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); - - let output = input.clamp_min(0.5); - - output - .into_data() - .assert_approx_eq(&input_ref.clamp_min(0.5).into_data(), 3); - } - - #[test] - fn clamp_max_should_match_reference() { - let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); - let input_ref = Tensor::::from_data(input.to_data()); - - let output = input.clamp_max(0.5); - - output - .into_data() - .assert_approx_eq(&input_ref.clamp_max(0.5).into_data(), 3); - } - #[test] fn clamp_should_match_reference() { let input = Tensor::::random([1, 5, 32, 32], Distribution::Default); diff --git a/burn-wgpu/src/kernel/mod.rs b/burn-wgpu/src/kernel/mod.rs index e30f100ca..07e732872 100644 --- a/burn-wgpu/src/kernel/mod.rs +++ b/burn-wgpu/src/kernel/mod.rs @@ -8,14 +8,12 @@ mod index; mod mask; mod source; mod unary; -mod unary_scalar; pub use base::*; pub use binary_elemwise::*; pub use cast::*; pub use source::*; pub use unary::*; -pub use unary_scalar::*; /// Convolution kernels pub mod conv; diff --git a/burn-wgpu/src/kernel/unary.rs b/burn-wgpu/src/kernel/unary.rs index e1f28dac1..baf740ec0 100644 --- a/burn-wgpu/src/kernel/unary.rs +++ b/burn-wgpu/src/kernel/unary.rs @@ -1,169 +1,220 @@ -use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT}; -use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; - -kernel_wgsl!(UnaryRaw, "../template/unary.wgsl"); -kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl"); +use super::StaticKernelSource; +use crate::{ + codegen::{execute_static, StaticHandle}, + element::WgpuElement, + tensor::WgpuTensor, +}; /// Creates a unary kernel. #[macro_export] macro_rules! unary { ( - $struct:ident, - func $func:expr - ) => { - pub struct $struct; + operator: $ops:expr, + input: $input:expr, + elem: $elem:ty + ) => {{ + unary!($ops); - impl $crate::kernel::StaticKernelSource for $struct { + $crate::kernel::unary::, OpsInplace<$elem>, $elem, D>($input, None) + }}; + ( + operator: $ops:expr, + input: $input:expr; $scalar:expr, + elem: $elem:ty + ) => {{ + unary!($ops, scalar 1); + + $crate::kernel::unary::, OpsInplace<$elem>, $elem, D>($input, Some(&[$scalar])) + }}; + + ( + $ops:expr + ) => { + pub struct Ops { + _e: core::marker::PhantomData, + } + pub struct OpsInplace { + _e: core::marker::PhantomData, + } + + #[allow(clippy::redundant_closure_call)] + impl $crate::kernel::StaticKernelSource for Ops { fn source() -> $crate::kernel::SourceTemplate { - let source = $crate::kernel::UnaryRaw::source(); - source.register("body", format!("output[id] = {}(input[id]);", $func)) + let shader = $crate::codegen::ElemWiseKernelCodegen::new() + .inputs(&[$crate::codegen::Input::Array { + elem: E::elem_type(), + visibility: $crate::codegen::Visibility::Read, + strategy: $crate::codegen::ReadingStrategy::IntoContiguous, + }]) + .body(&[$ops(E::elem_type())]) + .outputs(&[$crate::codegen::Output::Array { + elem: E::elem_type(), + local: 0, + }]) + .compile(); + + $crate::kernel::SourceTemplate::new(shader.to_string()) + } + } + + #[allow(clippy::redundant_closure_call)] + impl $crate::kernel::StaticKernelSource for OpsInplace { + fn source() -> $crate::kernel::SourceTemplate { + let shader = $crate::codegen::ElemWiseKernelCodegen::new() + .inputs(&[$crate::codegen::Input::Array { + elem: E::elem_type(), + visibility: $crate::codegen::Visibility::ReadWrite, + strategy: $crate::codegen::ReadingStrategy::Plain, + }]) + .body(&[$ops(E::elem_type())]) + .outputs(&[$crate::codegen::Output::Input { + elem: E::elem_type(), + input: 0, + local: 0, + }]) + .compile(); + + $crate::kernel::SourceTemplate::new(shader.to_string()) } } }; ( - $struct:ident, - body $body:expr + $ops:expr, + scalar $num:expr ) => { - pub struct $struct; + pub struct Ops { + _e: core::marker::PhantomData, + } + pub struct OpsInplace { + _e: core::marker::PhantomData, + } - impl $crate::kernel::StaticKernelSource for $struct { + #[allow(clippy::redundant_closure_call)] + impl $crate::kernel::StaticKernelSource for Ops { fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryRaw::source().register("body", $body) + let shader = $crate::codegen::ElemWiseKernelCodegen::new() + .inputs(&[ + $crate::codegen::Input::Array { + elem: E::elem_type(), + visibility: $crate::codegen::Visibility::Read, + strategy: $crate::codegen::ReadingStrategy::IntoContiguous, + }, + $crate::codegen::Input::Scalar { + elem: E::elem_type(), + size: $num, + }, + ]) + .body(&[$ops(E::elem_type())]) + .outputs(&[$crate::codegen::Output::Array { + elem: E::elem_type(), + local: 0, + }]) + .compile(); + + $crate::kernel::SourceTemplate::new(shader.to_string()) } } - }; - ( - $struct:ident, - func $func:expr, - include $file:expr - ) => { - pub struct $struct; - impl $crate::kernel::StaticKernelSource for $struct { + #[allow(clippy::redundant_closure_call)] + impl $crate::kernel::StaticKernelSource for OpsInplace { fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryRaw::source() - .register("body", format!("output[id] = {}(input[id]);", $func)) - .add_template(include_str!($file)) + let shader = $crate::codegen::ElemWiseKernelCodegen::new() + .inputs(&[ + $crate::codegen::Input::Array { + elem: E::elem_type(), + visibility: $crate::codegen::Visibility::ReadWrite, + strategy: $crate::codegen::ReadingStrategy::Plain, + }, + $crate::codegen::Input::Scalar { + elem: E::elem_type(), + size: $num, + }, + ]) + .body(&[$ops(E::elem_type())]) + .outputs(&[$crate::codegen::Output::Input { + elem: E::elem_type(), + input: 0, + local: 0, + }]) + .compile(); + + $crate::kernel::SourceTemplate::new(shader.to_string()) } } }; } -/// Creates a unary inplace kernel. -#[macro_export] -macro_rules! unary_inplace { - ( - $struct:ident, - func $func:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source() - .register("body", format!("input[id] = {}(input[id]);", $func)) - } - } - }; - ( - $struct:ident, - body $body:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source().register("body", $body) - } - } - }; - ( - $struct:ident, - func $func:expr, - include $file:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryInplaceRaw::source() - .register("body", format!("input[id] = {}(input[id]);", $func)) - .add_template(include_str!($file)) - } - } - }; -} - -/// Execute a unary kernel using the default settings. -pub fn unary_default( - input: WgpuTensor, -) -> WgpuTensor { - unary::(input) -} - -/// Execute a unary inplace kernel using the default settings. -pub fn unary_inplace_default( - input: WgpuTensor, -) -> WgpuTensor { - unary_inplace::(input) -} - -/// Execute a unary inplace kernel using the provided WORKGROUP. -pub fn unary_inplace< +/// Launch an unary operation. +pub fn unary( + tensor: WgpuTensor, + scalars: Option<&[E]>, +) -> WgpuTensor +where K: StaticKernelSource, + KI: StaticKernelSource, E: WgpuElement, - const D: usize, - const WORKGROUP: usize, ->( - input: WgpuTensor, -) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); +{ + if !tensor.can_mut() { + let num_elems = tensor.shape.num_elements(); + let buffer = tensor.client.empty(num_elems * core::mem::size_of::()); + let output = WgpuTensor::new( + tensor.client.clone(), + tensor.device, + tensor.shape.clone(), + buffer, + ); - input.client.execute(Box::new(kernel), &[&input.handle]); + execute_static::( + &[StaticHandle::new( + &tensor.handle, + &tensor.strides, + &tensor.shape.dims, + )], + &[StaticHandle::new( + &output.handle, + &output.strides, + &output.shape.dims, + )], + scalars, + tensor.client, + ); - input -} + output + } else { + execute_static::( + &[], + &[StaticHandle::new( + &tensor.handle, + &tensor.strides, + &tensor.shape.dims, + )], + scalars, + tensor.client.clone(), + ); -/// Execute a unary kernel using the provided WORKGROUP. -pub fn unary( - input: WgpuTensor, -) -> WgpuTensor { - let num_elems = input.shape.num_elements(); - let buffer = input.client.empty(num_elems * core::mem::size_of::()); - let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer); - // Since we don't handle the stride inside the kernel, the output tensor have the same strides - // as the input tensor. It might not be in the default format. - output.strides = input.strides; - - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - input - .client - .execute(Box::new(kernel), &[&input.handle, &output.handle]); - - output + tensor + } } #[cfg(test)] mod tests { use super::*; + use crate::codegen::{Operator, Variable}; use crate::tests::{ReferenceBackend, TestBackend}; use burn_tensor::{Distribution, Tensor}; - unary!(TestKernel, func "log"); - unary_inplace!(TestKernelInplace, func "log"); + unary!(|elem| Operator::Tanh { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }); #[test] fn unary_should_work_with_multiple_invocations() { let tensor = Tensor::::random([6, 256], Distribution::Default); let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = unary::(tensor.into_primitive()); - let expected = tensor_ref.log(); + let actual = unary::, OpsInplace, f32, 2>(tensor.into_primitive(), None); + let expected = tensor_ref.tanh(); expected.into_data().assert_approx_eq( &Tensor::::from_primitive(actual).into_data(), @@ -176,8 +227,8 @@ mod tests { let tensor = Tensor::::random([6, 256], Distribution::Default); let tensor_ref = Tensor::::from_data(tensor.to_data()); - let actual = unary_inplace::(tensor.into_primitive()); - let expected = tensor_ref.log(); + let actual = unary::, OpsInplace, f32, 2>(tensor.into_primitive(), None); + let expected = tensor_ref.tanh(); expected.into_data().assert_approx_eq( &Tensor::::from_primitive(actual).into_data(), diff --git a/burn-wgpu/src/kernel/unary_scalar.rs b/burn-wgpu/src/kernel/unary_scalar.rs deleted file mode 100644 index dc68443df..000000000 --- a/burn-wgpu/src/kernel/unary_scalar.rs +++ /dev/null @@ -1,220 +0,0 @@ -use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT}; -use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor}; - -kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl"); -kernel_wgsl!( - UnaryScalarInplaceRaw, - "../template/unary_scalar_inplace.wgsl" -); - -/// Creates a unary scalar kernel. -#[macro_export] -macro_rules! unary_scalar { - ( - $struct:ident, - ops $ops:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source() - .register("body", format!("output[id] = lhs[id] {} rhs;", $ops)) - } - } - }; - - ( - $struct:ident, - func $func:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source() - .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) - } - } - }; - - ( - $struct:ident, - func $func:expr, - include $file:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarRaw::source() - .register("body", format!("output[id] = {}(lhs[id], rhs);", $func)) - .add_template(include_str!($file)) - } - } - }; -} - -/// Creates a unary scalar inplace kernel. -#[macro_export] -macro_rules! unary_scalar_inplace { - ( - $struct:ident, - ops $ops:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source() - .register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops)) - } - } - }; - - ( - $struct:ident, - body $body:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source().register("body", $body) - } - } - }; - - ( - $struct:ident, - func $func:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source() - .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) - } - } - }; - - ( - $struct:ident, - func $func:expr, - include $file:expr - ) => { - pub struct $struct; - - impl $crate::kernel::StaticKernelSource for $struct { - fn source() -> $crate::kernel::SourceTemplate { - $crate::kernel::UnaryScalarInplaceRaw::source() - .register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func)) - .add_template(include_str!($file)) - } - } - }; -} - -/// Execute a unary scalar kernel using the default settings. -pub fn unary_scalar_default( - lhs: WgpuTensor, - scalar: E, -) -> WgpuTensor { - unary_scalar::(lhs, scalar) -} - -/// Execute a unary scalar kernel using the provided WORKGROUP. -pub fn unary_scalar< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, ->( - lhs: WgpuTensor, - scalar: E, -) -> WgpuTensor { - let num_elems = lhs.shape.num_elements(); - let buffer = lhs.client.empty(num_elems * core::mem::size_of::()); - let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); - - lhs.client.execute( - Box::new(kernel), - &[&lhs.handle, &rhs_handle, &output.handle], - ); - - output -} - -/// Execute a unary scalar inplace kernel using the default settings. -pub fn unary_scalar_inplace_default( - lhs: WgpuTensor, - scalar: E, -) -> WgpuTensor { - unary_scalar_inplace::(lhs, scalar) -} - -/// Execute a unary scalar inplace kernel using the provided WORKGROUP. -pub fn unary_scalar_inplace< - K: StaticKernelSource, - E: WgpuElement, - const D: usize, - const WORKGROUP: usize, ->( - lhs: WgpuTensor, - scalar: E, -) -> WgpuTensor { - let num_elems = lhs.shape.num_elements(); - let kernel = StaticKernel::>::new( - elemwise_workgroup(num_elems, WORKGROUP), - ); - let rhs_handle = lhs.client.create(E::as_bytes(&[scalar])); - - lhs.client - .execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]); - - lhs -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::tests::{ReferenceBackend, TestBackend}; - use burn_tensor::{Distribution, Tensor}; - - unary_scalar!(TestKernel, ops "*"); - unary_scalar_inplace!(TestKernelInplace, ops "*"); - - #[test] - fn unary_scalar_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - - let actual = unary_scalar::(tensor.into_primitive(), 5.0); - let expected = tensor_ref.mul_scalar(5.0); - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } - - #[test] - fn unary_scalar_inplace_should_work_with_multiple_invocations() { - let tensor = Tensor::::random([6, 256], Distribution::Default); - let tensor_ref = Tensor::::from_data(tensor.to_data()); - - let actual = - unary_scalar_inplace::(tensor.into_primitive(), 5.0); - let expected = tensor_ref.mul_scalar(5.0); - - expected.into_data().assert_approx_eq( - &Tensor::::from_primitive(actual).into_data(), - 3, - ); - } -} diff --git a/burn-wgpu/src/lib.rs b/burn-wgpu/src/lib.rs index d04b282ed..72d022863 100644 --- a/burn-wgpu/src/lib.rs +++ b/burn-wgpu/src/lib.rs @@ -15,6 +15,8 @@ pub mod kernel; /// Tensor module. pub mod tensor; +pub(crate) mod codegen; + mod element; pub use element::{FloatElement, IntElement}; diff --git a/burn-wgpu/src/ops/activation_ops.rs b/burn-wgpu/src/ops/activation_ops.rs index 225662860..c318d3ea4 100644 --- a/burn-wgpu/src/ops/activation_ops.rs +++ b/burn-wgpu/src/ops/activation_ops.rs @@ -1,10 +1,8 @@ -use burn_tensor::ops::{ActivationOps, FloatTensor}; - use crate::{ element::{FloatElement, IntElement}, - kernel::{unary_default, unary_inplace_default}, - unary, unary_inplace, GraphicsApi, Wgpu, + GraphicsApi, Wgpu, }; +use burn_tensor::ops::ActivationOps; impl ActivationOps> for Wgpu where @@ -12,14 +10,4 @@ where F: FloatElement, I: IntElement, { - fn relu(tensor: FloatTensor) -> FloatTensor { - unary!(Relu, body "output[id] = max(input[id], 0.0);"); - unary_inplace!(ReluInplace, body "input[id] = max(input[id], 0.0);"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) - } } diff --git a/burn-wgpu/src/ops/float_ops.rs b/burn-wgpu/src/ops/float_ops.rs index 7e91b86ef..0c2d916ad 100644 --- a/burn-wgpu/src/ops/float_ops.rs +++ b/burn-wgpu/src/ops/float_ops.rs @@ -1,4 +1,5 @@ use super::numeric; +use crate::codegen::{Elem, Operator, Variable}; #[cfg(not(feature = "autotune"))] use crate::kernel::matmul::init_matmul_output; #[cfg(feature = "autotune")] @@ -8,18 +9,14 @@ use crate::kernel::matmul::vec4::matmul_tiling_2d_vec4; use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform}; #[cfg(not(feature = "autotune"))] use crate::kernel::reduce::init_reduce_output; -use crate::kernel::{ - self, reduce, unary_default, unary_inplace_default, unary_scalar_default, - unary_scalar_inplace_default, -}; -use crate::{unary, unary_inplace, unary_scalar, FloatElement, GraphicsApi, IntElement, Wgpu}; -use crate::{unary_scalar_inplace, WgpuDevice}; +use crate::kernel::{self, reduce}; +use crate::WgpuDevice; +use crate::{unary, FloatElement, GraphicsApi, IntElement, Wgpu}; use burn_tensor::ops::{ BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor, }; use burn_tensor::{ops::TensorOps, Data, Distribution, Shape}; use burn_tensor::{ElementConversion, Reader}; - use std::ops::Range; impl TensorOps> for Wgpu @@ -357,122 +354,115 @@ where kernel::cast(tensor) } - fn exp(lhs: FloatTensor) -> FloatTensor { - unary!(Exp, func "exp"); - unary_inplace!(ExpInplace, func "exp"); - - if lhs.can_mut() { - return unary_inplace_default::(lhs); - } - - unary_default::(lhs) + fn exp(tensor: FloatTensor) -> FloatTensor { + unary!( + operator: |elem: Elem| Operator::Exp { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn log(tensor: FloatTensor) -> FloatTensor { - unary!(Log, func "log"); - unary_inplace!(LogInplace, func "log"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Log { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn log1p(tensor: FloatTensor) -> FloatTensor { - unary!(Log1p, body "output[id] = log(1.0 + input[id]);"); - unary_inplace!(Log1pInplace, body "input[id] = log(1.0 + input[id]);"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Log1p { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn powf(lhs: FloatTensor, rhs: f32) -> FloatTensor { - unary_scalar!(Powf, func "powf", include "../template/powf.wgsl"); - unary_scalar_inplace!(PowfInplace, func "powf", include "../template/powf.wgsl"); - - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs.elem()); - } - - unary_scalar_default::(lhs, rhs.elem()) + unary!( + operator: |elem: Elem| Operator::Powf { + lhs: Variable::Input(0, elem), + rhs: Variable::Scalar(0, elem), + out: Variable::Local(0, elem), + }, + input: lhs; rhs.elem(), + elem: F + ) } fn sqrt(tensor: FloatTensor) -> FloatTensor { - unary!(Sqrt, func "sqrt"); - unary_inplace!(SqrtInplace, func "sqrt"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Sqrt { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn abs(tensor: FloatTensor) -> FloatTensor { - unary!(Abs, func "abs"); - unary_inplace!(AbsInplace, func "abs"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Abs { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn cos(tensor: FloatTensor) -> FloatTensor { - unary!(Cos, func "cos"); - unary_inplace!(CosInplace, func "cos"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Cos { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn sin(tensor: FloatTensor) -> FloatTensor { - unary!(Sin, func "sin"); - unary_inplace!(SinInplace, func "sin"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Sin { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn tanh(tensor: FloatTensor) -> FloatTensor { - // Metal has a weird numerical behaviour with tanh which require a new function - #[cfg(target_os = "macos")] - unary!(Tanh, func "safe_tanh", include "../template/safe_tanh.wgsl"); - #[cfg(target_os = "macos")] - unary_inplace!(TanhInplace, func "safe_tanh", include "../template/safe_tanh.wgsl"); - - #[cfg(not(target_os = "macos"))] - unary!(Tanh, func "tanh"); - #[cfg(not(target_os = "macos"))] - unary_inplace!(TanhInplace, func "tanh"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Tanh { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn erf(tensor: FloatTensor) -> FloatTensor { - unary!(Erf, func "erf", include "../template/erf.wgsl"); - unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Erf { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn cat(tensors: Vec>, dim: usize) -> FloatTensor { @@ -491,20 +481,6 @@ where kernel::cast(tensor) } - fn clamp_min( - tensor: FloatTensor, - min: FloatElem, - ) -> FloatTensor { - kernel::clamp_min(tensor, min) - } - - fn clamp_max( - tensor: FloatTensor, - max: FloatElem, - ) -> FloatTensor { - kernel::clamp_max(tensor, max) - } - fn clamp( tensor: FloatTensor, min: FloatElem, @@ -516,14 +492,14 @@ where fn recip( tensor: FloatTensor, D>, ) -> FloatTensor, D> { - unary!(Recip, func "1.0 /"); - unary_inplace!(RecipInplace, func "1.0 /"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Recip { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: F + ) } fn repeat( diff --git a/burn-wgpu/src/ops/int_ops.rs b/burn-wgpu/src/ops/int_ops.rs index 9aa748036..5172789ae 100644 --- a/burn-wgpu/src/ops/int_ops.rs +++ b/burn-wgpu/src/ops/int_ops.rs @@ -1,10 +1,10 @@ use super::numeric; +use crate::codegen::{Elem, Operator, Variable}; use crate::kernel::reduce::{self, init_reduce_output}; -use crate::kernel::{unary_default, unary_inplace_default}; use crate::{ element::{FloatElement, IntElement}, - kernel, unary, unary_inplace, GraphicsApi, Wgpu, + kernel, unary, GraphicsApi, Wgpu, }; use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor}; @@ -280,20 +280,6 @@ where kernel::reduce::argmin(tensor, dim) } - fn int_clamp_min( - tensor: IntTensor, - min: IntElem, - ) -> IntTensor { - kernel::clamp_min(tensor, min) - } - - fn int_clamp_max( - tensor: IntTensor, - max: IntElem, - ) -> IntTensor { - kernel::clamp_max(tensor, max) - } - fn int_clamp( tensor: IntTensor, min: IntElem, @@ -303,14 +289,14 @@ where } fn int_abs(tensor: IntTensor) -> IntTensor { - unary!(IntAbs, func "abs"); - unary_inplace!(IntAbsInplace, func "abs"); - - if tensor.can_mut() { - return unary_inplace_default::(tensor); - } - - unary_default::(tensor) + unary!( + operator: |elem: Elem| Operator::Abs { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: tensor, + elem: I + ) } fn int_into_float(tensor: IntTensor) -> FloatTensor { diff --git a/burn-wgpu/src/ops/numeric.rs b/burn-wgpu/src/ops/numeric.rs index f4b57e83d..95fb80709 100644 --- a/burn-wgpu/src/ops/numeric.rs +++ b/burn-wgpu/src/ops/numeric.rs @@ -1,11 +1,8 @@ +use crate::codegen::{Elem, Operator, Variable}; use crate::compute::{compute_client, WgpuComputeClient}; -use crate::kernel::{ - binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default, - unary_scalar_inplace_default, -}; +use crate::kernel::{binary_elemwise_default, binary_elemwise_inplace_default}; use crate::{ - binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, - unary_scalar, unary_scalar_inplace, + binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, unary, }; use crate::{GraphicsApi, WgpuDevice}; use burn_tensor::{Element, ElementConversion, Shape}; @@ -28,8 +25,14 @@ pub fn full_device( ) -> WgpuTensor { let empty = empty_device(client, device, shape); - unary_scalar_inplace!(Full, body "lhs[id] = rhs;"); - unary_scalar_inplace_default::(empty, value) + unary!( + operator: |elem: Elem| Operator::AssignLocal { + input: Variable::Scalar(0, elem), + out: Variable::Local(0, elem), + }, + input: empty; value, + elem: E + ) } pub fn zeros( @@ -98,14 +101,15 @@ pub fn add_scalar( lhs: WgpuTensor, rhs: E, ) -> WgpuTensor { - unary_scalar!(AddScalar, ops "+"); - unary_scalar_inplace!(AddScalarInplace, ops "+"); - - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } - - unary_scalar_default::(lhs, rhs) + unary!( + operator: |elem: Elem| Operator::Add { + lhs: Variable::Input(0, elem), + rhs: Variable::Scalar(0, elem), + out: Variable::Local(0, elem), + }, + input: lhs; rhs, + elem: E + ) } pub fn sub( @@ -126,14 +130,15 @@ pub fn sub_scalar( lhs: WgpuTensor, rhs: E, ) -> WgpuTensor { - unary_scalar!(SubScalar, ops "-"); - unary_scalar_inplace!(SubScalarInplace, ops "-"); - - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } - - unary_scalar_default::(lhs, rhs) + unary!( + operator: |elem: Elem| Operator::Sub { + lhs: Variable::Input(0, elem), + rhs: Variable::Scalar(0, elem), + out: Variable::Local(0, elem), + }, + input: lhs; rhs, + elem: E + ) } pub fn mul( @@ -158,14 +163,15 @@ pub fn mul_scalar( lhs: WgpuTensor, rhs: E, ) -> WgpuTensor { - unary_scalar!(MulScalar, ops "*"); - unary_scalar_inplace!(MulScalarInplace, ops "*"); - - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } - - unary_scalar_default::(lhs, rhs) + unary!( + operator: |elem: Elem| Operator::Mul { + lhs: Variable::Input(0, elem), + rhs: Variable::Scalar(0, elem), + out: Variable::Local(0, elem), + }, + input: lhs; rhs, + elem: E + ) } pub fn div( @@ -186,12 +192,13 @@ pub fn div_scalar( lhs: WgpuTensor, rhs: E, ) -> WgpuTensor { - unary_scalar!(DivScalar, ops "/"); - unary_scalar_inplace!(DivScalarInplace, ops "/"); - - if lhs.can_mut() { - return unary_scalar_inplace_default::(lhs, rhs); - } - - unary_scalar_default::(lhs, rhs) + unary!( + operator: |elem: Elem| Operator::Div { + lhs: Variable::Input(0, elem), + rhs: Variable::Scalar(0, elem), + out: Variable::Local(0, elem), + }, + input: lhs; rhs, + elem: E + ) } diff --git a/burn-wgpu/src/template/clamp/clamp.wgsl b/burn-wgpu/src/template/clamp/clamp.wgsl deleted file mode 100644 index 100278b5b..000000000 --- a/burn-wgpu/src/template/clamp/clamp.wgsl +++ /dev/null @@ -1,25 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var output: array<{{ elem }}>; - -@group(0) -@binding(2) -var min_value: {{ elem }}; - -@group(0) -@binding(3) -var max_value: {{ elem }}; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x; - output[id] = clamp(input[id], min_value, max_value); -} \ No newline at end of file diff --git a/burn-wgpu/src/template/clamp/clamp_inplace.wgsl b/burn-wgpu/src/template/clamp/clamp_inplace.wgsl deleted file mode 100644 index 04ac1409b..000000000 --- a/burn-wgpu/src/template/clamp/clamp_inplace.wgsl +++ /dev/null @@ -1,21 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var min_value: {{ elem }}; - -@group(0) -@binding(2) -var max_value: {{ elem }}; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x; - input[id] = clamp(input[id], min_value, max_value); -} diff --git a/burn-wgpu/src/template/erf.wgsl b/burn-wgpu/src/template/erf.wgsl deleted file mode 100644 index ef6c940e8..000000000 --- a/burn-wgpu/src/template/erf.wgsl +++ /dev/null @@ -1,25 +0,0 @@ -/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations -/// -/// > (maximum error: 1.5×10−7) -/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x). -fn erf_positive(x: {{ elem }}) -> {{ elem }} { - let p = 0.3275911; - let a1 = 0.254829592; - let a2 = -0.284496736; - let a3 = 1.421413741; - let a4 = -1.453152027; - let a5 = 1.061405429; - - let t = 1.0 / (1.0 + p * abs(x)); - let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1; - - return 1.0 - (tmp * t * exp(-x * x)); -} - -fn erf(x: {{ elem }}) -> {{ elem }} { - if (x < 0.0) { - return -1.0 * erf_positive(-1.0 * x); - } - - return erf_positive(x); -} diff --git a/burn-wgpu/src/template/powf.wgsl b/burn-wgpu/src/template/powf.wgsl deleted file mode 100644 index 82317a4c5..000000000 --- a/burn-wgpu/src/template/powf.wgsl +++ /dev/null @@ -1,14 +0,0 @@ -fn powf(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} { - let modulo = rhs % 2.0; - - if (modulo == 0.0) { - // Even number - return pow(abs(lhs), rhs); - } else if (modulo == 1.0 && lhs < 0.0) { - // Odd number - return -1.0 * pow(-1.0 * lhs, rhs); - } else { - // Float number - return pow(lhs, rhs); - } -} diff --git a/burn-wgpu/src/template/safe_tanh.wgsl b/burn-wgpu/src/template/safe_tanh.wgsl deleted file mode 100644 index b41e728c0..000000000 --- a/burn-wgpu/src/template/safe_tanh.wgsl +++ /dev/null @@ -1,8 +0,0 @@ -/// Metal has a weird numerical behaviour with tanh for inputs over 43.0 -fn safe_tanh(x: {{ elem }}) -> {{ elem }} { - if x > 43.0 { - return 1.0; - } else { - return tanh(x); - } -} diff --git a/burn-wgpu/src/template/unary.wgsl b/burn-wgpu/src/template/unary.wgsl deleted file mode 100644 index 06770cdd3..000000000 --- a/burn-wgpu/src/template/unary.wgsl +++ /dev/null @@ -1,17 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@group(0) -@binding(1) -var output: array<{{ elem }}>; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x; - {{ body }} -} diff --git a/burn-wgpu/src/template/unary_inplace.wgsl b/burn-wgpu/src/template/unary_inplace.wgsl deleted file mode 100644 index 862d3d135..000000000 --- a/burn-wgpu/src/template/unary_inplace.wgsl +++ /dev/null @@ -1,13 +0,0 @@ -@group(0) -@binding(0) -var input: array<{{ elem }}>; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x; - {{ body }} -} diff --git a/burn-wgpu/src/template/unary_scalar.wgsl b/burn-wgpu/src/template/unary_scalar.wgsl deleted file mode 100644 index 1e6719a70..000000000 --- a/burn-wgpu/src/template/unary_scalar.wgsl +++ /dev/null @@ -1,21 +0,0 @@ -@group(0) -@binding(0) -var lhs: array<{{ elem }}>; - -@group(0) -@binding(1) -var rhs: {{ elem }}; - -@group(0) -@binding(2) -var output: array<{{ elem }}>; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x; - {{ body }} -} diff --git a/burn-wgpu/src/template/unary_scalar_inplace.wgsl b/burn-wgpu/src/template/unary_scalar_inplace.wgsl deleted file mode 100644 index 46e87c954..000000000 --- a/burn-wgpu/src/template/unary_scalar_inplace.wgsl +++ /dev/null @@ -1,17 +0,0 @@ -@group(0) -@binding(0) -var lhs: array<{{ elem }}>; - -@group(0) -@binding(1) -var rhs: {{ elem }}; - -@compute -@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1) -fn main( - @builtin(global_invocation_id) global_id: vec3, - @builtin(num_workgroups) num_workgroups: vec3, -) { - let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x; - {{ body }} -} diff --git a/burn-wgpu/src/tensor/base.rs b/burn-wgpu/src/tensor/base.rs index f939a0fcd..b88fdf69e 100644 --- a/burn-wgpu/src/tensor/base.rs +++ b/burn-wgpu/src/tensor/base.rs @@ -1,8 +1,9 @@ +use crate::codegen::{Elem, Operator, Variable}; +use crate::element::WgpuElement; use crate::{ compute::{WgpuComputeClient, WgpuHandle}, unary, WgpuDevice, }; -use crate::{element::WgpuElement, kernel::unary_default}; use burn_tensor::Shape; use std::marker::PhantomData; @@ -96,8 +97,14 @@ impl WgpuTensor { // slowdowns. // // The solution is just to use a simple unary compute shader. - unary!(CopyBuffer, body "output[id] = input[id];"); - unary_default::(self.clone()) + unary!( + operator: |elem: Elem| Operator::AssignLocal { + input: Variable::Input(0, elem), + out: Variable::Local(0, elem), + }, + input: self.clone(), + elem: E + ) } /// Check if the tensor is safe to mutate.