mirror of https://github.com/tracel-ai/burn.git
Fusion wgpu compilation cache (#1069)
* Refactor fusion in the wgpu backend * WIP * Refactor * WIP * Fix inplace ops * Works ish * Cleanup Output * Refactoring * Refactor Clamp * Cleanup * Cleanup * Updates * Fix CI * Code review
This commit is contained in:
parent
042454a9db
commit
b5c49c5bf7
|
@ -96,7 +96,7 @@ pub trait OptimizationBuilder<B: FusionBackend>: Send {
|
|||
/// The operation created from the [builder](OptimizationBuilder).
|
||||
pub trait Optimization<B: FusionBackend>: Send {
|
||||
/// Execute the operation.
|
||||
fn execute(&self, context: &mut Context<'_, B>);
|
||||
fn execute(&mut self, context: &mut Context<'_, B>);
|
||||
/// The number of registered operations in this optimization.
|
||||
fn len(&self) -> usize;
|
||||
/// If the current optimization is empty.
|
||||
|
|
|
@ -53,7 +53,7 @@ impl<B: FusionBackend> Graph<B> {
|
|||
pub(crate) fn execute_optimization(
|
||||
&mut self,
|
||||
handles: &mut HandleContainer<B>,
|
||||
optimization: &dyn Optimization<B>,
|
||||
optimization: &mut dyn Optimization<B>,
|
||||
) {
|
||||
let num_keep = optimization.len();
|
||||
let mut context = self.converter.context(handles);
|
||||
|
|
|
@ -682,20 +682,6 @@ impl<E: Element> NumericOpsDescription<E> {
|
|||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::ClampMax(desc) => {
|
||||
NumericOpsDescription::ClampMax(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
NumericOpsDescription::ClampMin(desc) => {
|
||||
NumericOpsDescription::ClampMin(ScalarOpsDescription {
|
||||
lhs: desc.lhs.to_relative(converter),
|
||||
rhs: local_elem(converter, &desc.rhs),
|
||||
out: desc.out.to_relative(converter),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -71,7 +71,7 @@ impl<B: FusionBackend> GraphExecution<B> {
|
|||
};
|
||||
}
|
||||
CacheResult::Found(ops) => {
|
||||
graph.execute_optimization(handles, ops.as_ref());
|
||||
graph.execute_optimization(handles, ops.as_mut());
|
||||
self.reset(graph);
|
||||
}
|
||||
};
|
||||
|
@ -107,14 +107,14 @@ impl<B: FusionBackend> GraphExecution<B> {
|
|||
}
|
||||
}
|
||||
|
||||
match find_best_optimization_index(&self.optimizations) {
|
||||
match find_best_optimization_index(&mut self.optimizations) {
|
||||
Some(index) => {
|
||||
let (relative, next_ops) = Self::split_relative_graph_owned(graph, mode);
|
||||
let optimization = &self.optimizations[index];
|
||||
let ops = self
|
||||
.optimization_cache
|
||||
.complete(optimization, relative, next_ops);
|
||||
BuildAction::ExecuteOptimization(ops.as_ref())
|
||||
BuildAction::ExecuteOptimization(ops.as_mut())
|
||||
}
|
||||
None => {
|
||||
// TODO: Cache this result too.
|
||||
|
@ -184,7 +184,7 @@ impl<B: FusionBackend> GraphExecution<B> {
|
|||
}
|
||||
|
||||
enum BuildAction<'a, B: FusionBackend> {
|
||||
ExecuteOptimization(&'a dyn Optimization<B>),
|
||||
ExecuteOptimization(&'a mut dyn Optimization<B>),
|
||||
ExecuteOperations,
|
||||
ContinueBuilding,
|
||||
}
|
||||
|
@ -202,7 +202,7 @@ fn still_optimizing<B: FusionBackend>(optimizations: &[Box<dyn OptimizationBuild
|
|||
}
|
||||
|
||||
fn find_best_optimization_index<B: FusionBackend>(
|
||||
optimizations: &[Box<dyn OptimizationBuilder<B>>],
|
||||
optimizations: &mut [Box<dyn OptimizationBuilder<B>>],
|
||||
) -> Option<usize> {
|
||||
let mut best_index = None;
|
||||
let mut best_score = 0;
|
||||
|
|
|
@ -379,16 +379,6 @@ pub enum NumericOpsDescription<E> {
|
|||
/// Float => [clamp](burn_tensor::ops::TensorOps::clamp).
|
||||
/// Int => [clamp](burn_tensor::ops::IntTensorOps::int_clamp).
|
||||
Clamp(ClampOpsDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [clamp max](burn_tensor::ops::TensorOps::clamp_max).
|
||||
/// Int => [clamp max](burn_tensor::ops::IntTensorOps::int_clamp_max).
|
||||
ClampMax(ScalarOpsDescription<E>),
|
||||
/// Operation corresponding to:
|
||||
///
|
||||
/// Float => [clamp min](burn_tensor::ops::TensorOps::clamp_min).
|
||||
/// Int => [cleamp min](burn_tensor::ops::IntTensorOps::int_clamp_min).
|
||||
ClampMin(ScalarOpsDescription<E>),
|
||||
}
|
||||
|
||||
/// Operation description specific to an int tensor.
|
||||
|
@ -900,12 +890,6 @@ impl<E: Element> NumericOpsDescription<E> {
|
|||
NumericOpsDescription::Clamp(desc) => {
|
||||
vec![&desc.tensor, &desc.out]
|
||||
}
|
||||
NumericOpsDescription::ClampMin(desc) => {
|
||||
vec![&desc.lhs, &desc.out]
|
||||
}
|
||||
NumericOpsDescription::ClampMax(desc) => {
|
||||
vec![&desc.lhs, &desc.out]
|
||||
}
|
||||
NumericOpsDescription::Abs(desc) => {
|
||||
vec![&desc.input, &desc.out]
|
||||
}
|
||||
|
@ -1144,8 +1128,6 @@ impl<E> core::hash::Hash for NumericOpsDescription<E> {
|
|||
NumericOpsDescription::MaxDim(desc) => desc.hash(state),
|
||||
NumericOpsDescription::MinDim(desc) => desc.hash(state),
|
||||
NumericOpsDescription::Clamp(desc) => desc.hash(state),
|
||||
NumericOpsDescription::ClampMax(desc) => desc.hash(state),
|
||||
NumericOpsDescription::ClampMin(desc) => desc.hash(state),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -60,16 +60,13 @@ impl<O> OptimizationCache<O> {
|
|||
}
|
||||
|
||||
if let Some(candidate) = self.found {
|
||||
return CacheResult::Found(&self.optimizations.get(candidate).unwrap().value);
|
||||
return CacheResult::Found(&mut self.optimizations.get_mut(candidate).unwrap().value);
|
||||
}
|
||||
|
||||
// Invalidate candidates.
|
||||
let mut invalidated_candidate = Vec::new();
|
||||
for id in self.candidates.iter() {
|
||||
let item = match self.optimizations.get(*id) {
|
||||
Some(item) => item,
|
||||
None => panic!("Should have an optimization"),
|
||||
};
|
||||
let item = &self.optimizations[*id];
|
||||
let next_ops = graph.last().expect("Validated earlier");
|
||||
let next_ops_index = graph.len() - 1;
|
||||
let next_ops_candidate = match item.graph.get(next_ops_index) {
|
||||
|
@ -93,13 +90,13 @@ impl<O> OptimizationCache<O> {
|
|||
Condition::NextOps(ops) => ops,
|
||||
Condition::Sync => {
|
||||
self.found = Some(*id);
|
||||
return CacheResult::Found(&item.value);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
if item.end_conditions.contains(ops) {
|
||||
self.found = Some(*id);
|
||||
return CacheResult::Found(&item.value);
|
||||
break;
|
||||
} else {
|
||||
self.availables.push((*id, graph.len()));
|
||||
invalidated_candidate.push(*id);
|
||||
|
@ -107,6 +104,10 @@ impl<O> OptimizationCache<O> {
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(id) = self.found {
|
||||
return CacheResult::Found(&mut self.optimizations[id].value);
|
||||
}
|
||||
|
||||
let mut updated_candidates = Vec::new();
|
||||
core::mem::swap(&mut updated_candidates, &mut self.candidates);
|
||||
|
||||
|
@ -136,7 +137,7 @@ impl<O> OptimizationCache<O> {
|
|||
factory: &Factory,
|
||||
graph: Vec<TensorOpsDescription>,
|
||||
next_ops: Option<TensorOpsDescription>,
|
||||
) -> &'a O {
|
||||
) -> &'a mut O {
|
||||
let existing_optim = self
|
||||
.availables
|
||||
.iter()
|
||||
|
@ -149,7 +150,7 @@ impl<O> OptimizationCache<O> {
|
|||
optimization.end_conditions.push(ops)
|
||||
};
|
||||
|
||||
return &optimization.value;
|
||||
return &mut optimization.value;
|
||||
};
|
||||
|
||||
self.starters
|
||||
|
@ -164,7 +165,9 @@ impl<O> OptimizationCache<O> {
|
|||
};
|
||||
|
||||
self.optimizations.push(optimization);
|
||||
&self.optimizations.last().unwrap().value
|
||||
|
||||
let last_index = self.optimizations.len() - 1;
|
||||
&mut self.optimizations[last_index].value
|
||||
}
|
||||
|
||||
// Signal that a new path will begin.
|
||||
|
@ -188,7 +191,7 @@ pub enum CacheResult<'a, T> {
|
|||
/// happens.
|
||||
OnPath,
|
||||
/// An optimization has been found, and the best action is to execute it!
|
||||
Found(&'a T),
|
||||
Found(&'a mut T),
|
||||
}
|
||||
|
||||
/// When checking if an optimization is possible, a start or an end condition ensures that this optimization is
|
||||
|
|
|
@ -265,48 +265,6 @@ impl<B: FusionBackend> TensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn clamp_min<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
min: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
scalar_float_ops!(ClampMinOps, B::clamp_min);
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: min.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMin(desc.clone())),
|
||||
ClampMinOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn clamp_max<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
max: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
scalar_float_ops!(ClampMaxOps, B::clamp_max);
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: max.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsFloat(NumericOpsDescription::ClampMax(desc.clone())),
|
||||
ClampMaxOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn clamp<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
min: FloatElem<Self>,
|
||||
|
|
|
@ -1034,48 +1034,6 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
|
|||
out
|
||||
}
|
||||
|
||||
fn int_clamp_min<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
min: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
scalar_int_ops!(ClampMinOps, B::int_clamp_min);
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: min.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMin(desc.clone())),
|
||||
ClampMinOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn int_clamp_max<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
max: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
scalar_int_ops!(ClampMaxOps, B::int_clamp_max);
|
||||
|
||||
let out = tensor.client.tensor_uninitialized(tensor.shape.clone());
|
||||
|
||||
let desc = ScalarOpsDescription {
|
||||
lhs: tensor.into_description(),
|
||||
rhs: max.elem(),
|
||||
out: out.to_description_out(),
|
||||
};
|
||||
out.client.register(
|
||||
TensorOpsDescription::NumericOpsInt(NumericOpsDescription::ClampMax(desc.clone())),
|
||||
ClampMaxOps::<D>::new(desc),
|
||||
);
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
fn int_clamp<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
min: IntElem<Self>,
|
||||
|
|
|
@ -5,7 +5,7 @@ use std::fmt::Display;
|
|||
///
|
||||
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
|
||||
/// X and Y, but with Z=1.
|
||||
#[derive(Hash, new)]
|
||||
#[derive(new)]
|
||||
pub struct Body {
|
||||
operators: Vec<Operator>,
|
||||
}
|
|
@ -2,7 +2,7 @@ use super::Elem;
|
|||
use std::fmt::Display;
|
||||
|
||||
/// Not all functions are native to WGSL, so this struct allows to support more functions.
|
||||
#[derive(Hash, PartialEq, Eq, Clone)]
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
pub enum Function {
|
||||
Powf(Elem),
|
||||
Erf(Elem),
|
|
@ -0,0 +1,359 @@
|
|||
use crate::codegen::{
|
||||
Binding, Body, ComputeShader, Elem, Function, Location, Operator, Variable, Visibility,
|
||||
WorkgroupSize,
|
||||
};
|
||||
use crate::compute::{StaticKernel, WgpuComputeClient, WgpuHandle};
|
||||
use crate::element::WgpuElement;
|
||||
use crate::kernel::{elemwise_workgroup, StaticKernelSource, WORKGROUP_DEFAULT};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Kernel creation input phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||
pub struct InputPhase;
|
||||
/// Kernel creation body phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||
pub struct BodyPhase;
|
||||
/// Kernel creation output phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||
pub struct OutputPhase;
|
||||
/// Kernel compilation phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||
pub struct CompilationPhase;
|
||||
|
||||
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
|
||||
///
|
||||
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
|
||||
/// you to make mistakes.
|
||||
///
|
||||
/// 1. [Input Phase](InputPhase)
|
||||
/// This phase focuses on registering the input arrays and scalars that are going to be used by
|
||||
/// the kernel.
|
||||
/// 2. [Body Phase](BodyPhase)
|
||||
/// After the input phase is done, all the operations that happen in the body must be
|
||||
/// registered.
|
||||
/// 3. [Output Phase](OutputPhase)
|
||||
/// This step focuses on registering all output arrays or inputs that the kernel needs to write to.
|
||||
/// 4. [Compilation Phase](CompilationPhase)
|
||||
/// Now that all other phases are completed, we can actually compile the kernel.
|
||||
pub struct ElemWiseKernelCodegen<Phase = InputPhase> {
|
||||
operations: Vec<Operator>,
|
||||
input_bindings: Vec<Binding>,
|
||||
output_bindings: Vec<Binding>,
|
||||
named_bindings: Vec<(String, Binding)>,
|
||||
functions: Vec<Function>,
|
||||
_phase: PhantomData<Phase>,
|
||||
}
|
||||
|
||||
pub enum Input {
|
||||
Array {
|
||||
elem: Elem,
|
||||
visibility: Visibility,
|
||||
strategy: ReadingStrategy,
|
||||
},
|
||||
Scalar {
|
||||
elem: Elem,
|
||||
size: usize,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum ReadingStrategy {
|
||||
IntoContiguous,
|
||||
Plain,
|
||||
}
|
||||
|
||||
pub enum Output {
|
||||
Array { elem: Elem, local: u16 },
|
||||
Input { elem: Elem, input: u16, local: u16 },
|
||||
}
|
||||
|
||||
impl ElemWiseKernelCodegen<InputPhase> {
|
||||
/// Create a new fusion kernel on the given device.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
operations: Vec::new(),
|
||||
input_bindings: Vec::new(),
|
||||
output_bindings: Vec::new(),
|
||||
named_bindings: Vec::new(),
|
||||
functions: Vec::new(),
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the inputs used by the kernel.
|
||||
pub fn inputs(mut self, inputs: &[Input]) -> ElemWiseKernelCodegen<BodyPhase> {
|
||||
let mut index: u16 = 0;
|
||||
|
||||
let first_output_index = inputs
|
||||
.iter()
|
||||
.filter(|input| match input {
|
||||
Input::Array {
|
||||
elem: _,
|
||||
visibility: _,
|
||||
strategy: _,
|
||||
} => true,
|
||||
Input::Scalar { elem: _, size: _ } => false,
|
||||
})
|
||||
.count();
|
||||
|
||||
for input in inputs {
|
||||
match input {
|
||||
Input::Array {
|
||||
elem,
|
||||
visibility,
|
||||
strategy,
|
||||
} => {
|
||||
self.input_bindings.push(Binding {
|
||||
elem: bool_elem(*elem),
|
||||
visibility: *visibility,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
|
||||
match strategy {
|
||||
ReadingStrategy::IntoContiguous => {
|
||||
self.operations.push(Operator::ReadGlobalIntoContiguous {
|
||||
variable: Variable::Input(index, *elem),
|
||||
position: index as usize,
|
||||
position_out: first_output_index, // First output
|
||||
});
|
||||
}
|
||||
ReadingStrategy::Plain => {
|
||||
self.operations.push(Operator::ReadGlobal {
|
||||
variable: Variable::Input(index, *elem),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
index += 1;
|
||||
}
|
||||
Input::Scalar { elem, size } => {
|
||||
let elem = bool_elem(*elem);
|
||||
|
||||
self.named_bindings.push((
|
||||
format!("scalars_{}", elem),
|
||||
Binding {
|
||||
elem,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: Some(*size),
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ElemWiseKernelCodegen {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
functions: self.functions,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemWiseKernelCodegen<BodyPhase> {
|
||||
/// Register the [operators](Operator) that the kernel must execute in the order provided.
|
||||
pub fn body(mut self, operators: &[Operator]) -> ElemWiseKernelCodegen<OutputPhase> {
|
||||
let mut register_function = |function: Function| {
|
||||
if !self.functions.contains(&function) {
|
||||
self.functions.push(function);
|
||||
}
|
||||
};
|
||||
|
||||
// Since not all operators are native to WGSL, we need to add the custom ones.
|
||||
for ops in operators.iter() {
|
||||
match ops {
|
||||
Operator::Powf {
|
||||
lhs: _,
|
||||
rhs: _,
|
||||
out: _,
|
||||
} => {
|
||||
register_function(Function::Powf(Elem::F32));
|
||||
}
|
||||
Operator::Erf { input: _, out: _ } => {
|
||||
register_function(Function::Erf(Elem::F32));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
self.operations.push(ops.clone());
|
||||
}
|
||||
|
||||
ElemWiseKernelCodegen {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
functions: self.functions,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemWiseKernelCodegen<OutputPhase> {
|
||||
/// Register the outputs with their local variable index.
|
||||
///
|
||||
/// Note that the index corresponds to the registered [operator](Operator) number at the
|
||||
/// [body phase](BodyPhase).
|
||||
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
||||
pub fn outputs(mut self, outputs: &[Output]) -> ElemWiseKernelCodegen<CompilationPhase> {
|
||||
let mut index = 0;
|
||||
|
||||
for array in outputs {
|
||||
match array {
|
||||
Output::Array { elem, local } => {
|
||||
let elem_adapted = bool_elem(*elem);
|
||||
|
||||
self.output_bindings.push(Binding {
|
||||
elem: elem_adapted,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
self.operations.push(Operator::AssignGlobal {
|
||||
input: Variable::Local(*local, *elem),
|
||||
out: Variable::Output(index, elem_adapted),
|
||||
});
|
||||
index += 1;
|
||||
}
|
||||
Output::Input { elem, input, local } => {
|
||||
self.operations.push(Operator::AssignGlobal {
|
||||
input: Variable::Local(*local, *elem),
|
||||
out: Variable::Input(*input, bool_elem(*elem)),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ElemWiseKernelCodegen {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
functions: self.functions,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemWiseKernelCodegen<CompilationPhase> {
|
||||
/// Compile the kernel into a [compute shader](ComputeShader).
|
||||
pub fn compile(self) -> ComputeShader {
|
||||
let inputs = self.input_bindings;
|
||||
let outputs = self.output_bindings;
|
||||
let mut named = Vec::with_capacity(2);
|
||||
|
||||
named.push((
|
||||
"info".to_string(),
|
||||
Binding {
|
||||
elem: Elem::U32,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None, // We avoid putting the length here since it will force a new kernel
|
||||
// for each tensor rank.
|
||||
},
|
||||
));
|
||||
|
||||
for (name, binding) in self.named_bindings.into_iter() {
|
||||
named.push((name, binding));
|
||||
}
|
||||
|
||||
ComputeShader {
|
||||
inputs,
|
||||
outputs,
|
||||
named,
|
||||
workgroup_size: WorkgroupSize::default(),
|
||||
body: Body::new(self.operations),
|
||||
num_workgroups: true,
|
||||
global_invocation_id: true,
|
||||
functions: self.functions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub struct StaticHandle<'a> {
|
||||
handle: &'a WgpuHandle,
|
||||
strides: &'a [usize],
|
||||
shape: &'a [usize],
|
||||
}
|
||||
|
||||
/// Execute a static kernel.
|
||||
///
|
||||
///
|
||||
/// The limitation from this method is that you can't launch a kernel with multiple types of
|
||||
/// scalar.
|
||||
pub fn execute_static<K, E: WgpuElement>(
|
||||
inputs: &[StaticHandle],
|
||||
outputs: &[StaticHandle],
|
||||
scalar_elems: Option<&[E]>,
|
||||
client: WgpuComputeClient,
|
||||
) where
|
||||
K: StaticKernelSource + 'static,
|
||||
{
|
||||
let mut info = Vec::new();
|
||||
let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
|
||||
|
||||
// Inner function to fill the info buffer.
|
||||
let mut register_info_tensor = |strides: &[usize], shape: &[usize]| {
|
||||
if info.is_empty() {
|
||||
info.push(strides.len() as u32);
|
||||
}
|
||||
|
||||
for s in strides.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
for s in shape.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
};
|
||||
|
||||
// We start by registering the inputs.
|
||||
for input in inputs.iter() {
|
||||
register_info_tensor(input.strides, input.shape);
|
||||
handles.push(input.handle);
|
||||
}
|
||||
|
||||
let mut num_elems_output = 0;
|
||||
|
||||
// Then we follow with the outputs.
|
||||
for output in outputs.iter() {
|
||||
let num_elems = calculate_num_elems_dyn_rank(output.shape);
|
||||
if num_elems > num_elems_output {
|
||||
num_elems_output = num_elems;
|
||||
}
|
||||
register_info_tensor(output.strides, output.shape);
|
||||
handles.push(output.handle);
|
||||
}
|
||||
|
||||
let info = &client.create(bytemuck::cast_slice(&info));
|
||||
handles.push(info);
|
||||
|
||||
// Finally we finish with the named bindings.
|
||||
let mut scalars = None;
|
||||
if let Some(values) = &scalar_elems {
|
||||
scalars = Some(client.create(bytemuck::cast_slice(values)));
|
||||
}
|
||||
|
||||
if let Some(scalars) = scalars.as_ref() {
|
||||
handles.push(scalars);
|
||||
}
|
||||
|
||||
let workgroup = elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT);
|
||||
let kernel = Box::new(StaticKernel::<K>::new(workgroup));
|
||||
|
||||
client.execute(kernel, &handles);
|
||||
}
|
||||
|
||||
pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
|
||||
let mut num_elems = 1;
|
||||
for i in shape.iter() {
|
||||
num_elems *= i;
|
||||
}
|
||||
num_elems
|
||||
}
|
||||
|
||||
fn bool_elem(elem: Elem) -> Elem {
|
||||
match elem {
|
||||
// I32 are used for bool tensors
|
||||
Elem::Bool => Elem::I32,
|
||||
_ => elem,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,13 @@
|
|||
mod body;
|
||||
mod function;
|
||||
mod kernel;
|
||||
mod operator;
|
||||
mod shader;
|
||||
mod variable;
|
||||
|
||||
pub(crate) use body::*;
|
||||
pub(crate) use function::*;
|
||||
pub(crate) use kernel::*;
|
||||
pub(crate) use operator::*;
|
||||
pub(crate) use shader::*;
|
||||
pub(crate) use variable::*;
|
|
@ -1,8 +1,9 @@
|
|||
use super::Variable;
|
||||
use super::variable::Variable;
|
||||
use std::fmt::Display;
|
||||
|
||||
/// All operators that can be fused in a WGSL compute shader.
|
||||
#[derive(Debug, Hash, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub enum Operator {
|
||||
Add {
|
||||
lhs: Variable,
|
||||
|
@ -57,6 +58,10 @@ pub enum Operator {
|
|||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Sqrt {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Erf {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
|
@ -75,6 +80,12 @@ pub enum Operator {
|
|||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Clamp {
|
||||
input: Variable,
|
||||
min_value: Variable,
|
||||
max_value: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Greater {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
|
@ -100,8 +111,15 @@ pub enum Operator {
|
|||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
AssignLocal {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
ReadGlobal {
|
||||
variable: Variable,
|
||||
},
|
||||
ReadGlobalIntoContiguous {
|
||||
variable: Variable,
|
||||
position: usize,
|
||||
position_out: usize,
|
||||
},
|
||||
|
@ -125,9 +143,20 @@ impl Display for Operator {
|
|||
Operator::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")),
|
||||
Operator::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")),
|
||||
Operator::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")),
|
||||
Operator::Clamp {
|
||||
input,
|
||||
min_value,
|
||||
max_value,
|
||||
out,
|
||||
} => f.write_fmt(format_args!(
|
||||
"let {out} = clamp({input}, {min_value}, {max_value});"
|
||||
)),
|
||||
Operator::Powf { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});"))
|
||||
}
|
||||
Operator::Sqrt { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = sqrt({input});"))
|
||||
}
|
||||
Operator::Log1p { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = log({input} + 1.0);"))
|
||||
}
|
||||
|
@ -159,7 +188,21 @@ impl Display for Operator {
|
|||
let elem = out.elem();
|
||||
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
|
||||
}
|
||||
Operator::ReadGlobal {
|
||||
Operator::AssignLocal { input, out } => {
|
||||
let elem = out.elem();
|
||||
f.write_fmt(format_args!("let {out} = {elem}({input});"))
|
||||
}
|
||||
Operator::ReadGlobal { variable } => match variable {
|
||||
Variable::Input(number, _elem) => f.write_fmt(format_args!(
|
||||
"let input_{number} = input_{number}_global[id];"
|
||||
)),
|
||||
Variable::Local(_, _) => panic!("can't read global local variable."),
|
||||
Variable::Output(number, _elem) => f.write_fmt(format_args!(
|
||||
"let output_{number} = output_{number}_global[id];"
|
||||
)),
|
||||
Variable::Scalar(_, _) => panic!("Can't read global scalar variable."),
|
||||
},
|
||||
Operator::ReadGlobalIntoContiguous {
|
||||
variable,
|
||||
position,
|
||||
position_out,
|
|
@ -1,34 +1,29 @@
|
|||
use super::{Body, Function};
|
||||
use crate::kernel::{DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT};
|
||||
use std::{
|
||||
collections::hash_map::DefaultHasher,
|
||||
fmt::Display,
|
||||
hash::{Hash, Hasher},
|
||||
};
|
||||
use crate::kernel::WORKGROUP_DEFAULT;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub enum Location {
|
||||
Storage,
|
||||
#[allow(dead_code)]
|
||||
Workgroup,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
#[derive(PartialEq, Eq, Clone, Copy)]
|
||||
pub enum Visibility {
|
||||
Read,
|
||||
ReadWrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq, Copy)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
pub enum Elem {
|
||||
F32,
|
||||
#[allow(dead_code)]
|
||||
I32,
|
||||
U32,
|
||||
Bool,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
#[derive(PartialEq, Eq, Clone)]
|
||||
pub struct Binding {
|
||||
pub location: Location,
|
||||
pub visibility: Visibility,
|
||||
|
@ -36,7 +31,7 @@ pub struct Binding {
|
|||
pub size: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Hash, PartialEq, Eq)]
|
||||
#[derive(PartialEq, Eq)]
|
||||
pub struct WorkgroupSize {
|
||||
pub x: usize,
|
||||
pub y: usize,
|
||||
|
@ -53,7 +48,6 @@ impl Default for WorkgroupSize {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Hash)]
|
||||
pub struct ComputeShader {
|
||||
pub inputs: Vec<Binding>,
|
||||
pub outputs: Vec<Binding>,
|
||||
|
@ -65,19 +59,6 @@ pub struct ComputeShader {
|
|||
pub functions: Vec<Function>,
|
||||
}
|
||||
|
||||
impl DynamicKernelSource for ComputeShader {
|
||||
fn source(&self) -> SourceTemplate {
|
||||
SourceTemplate::new(self.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
let mut s = DefaultHasher::new();
|
||||
self.hash(&mut s);
|
||||
|
||||
s.finish().to_string()
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for ComputeShader {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
Self::format_bindings(f, "input", &self.inputs, 0)?;
|
|
@ -1,7 +1,7 @@
|
|||
use super::Elem;
|
||||
use std::fmt::Display;
|
||||
|
||||
#[derive(Debug, Hash, Clone)]
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Variable {
|
||||
Input(u16, Elem),
|
||||
Scalar(u16, Elem),
|
|
@ -9,8 +9,7 @@ where
|
|||
fn type_name() -> &'static str;
|
||||
fn as_bytes(slice: &[Self]) -> &[u8];
|
||||
fn from_bytes(bytes: &[u8]) -> &[Self];
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
fn elem_type() -> crate::fusion::codegen::Elem;
|
||||
fn elem_type() -> crate::codegen::Elem;
|
||||
}
|
||||
|
||||
/// The float element type for the wgpu backend.
|
||||
|
@ -29,9 +28,8 @@ impl WgpuElement for u32 {
|
|||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
fn elem_type() -> crate::fusion::codegen::Elem {
|
||||
crate::fusion::codegen::Elem::U32
|
||||
fn elem_type() -> crate::codegen::Elem {
|
||||
crate::codegen::Elem::U32
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,9 +43,8 @@ impl WgpuElement for i32 {
|
|||
fn from_bytes(bytes: &[u8]) -> &[Self] {
|
||||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
fn elem_type() -> crate::fusion::codegen::Elem {
|
||||
crate::fusion::codegen::Elem::I32
|
||||
fn elem_type() -> crate::codegen::Elem {
|
||||
crate::codegen::Elem::I32
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -62,9 +59,8 @@ impl WgpuElement for f32 {
|
|||
bytemuck::cast_slice(bytes)
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "fusion", test))]
|
||||
fn elem_type() -> crate::fusion::codegen::Elem {
|
||||
crate::fusion::codegen::Elem::F32
|
||||
fn elem_type() -> crate::codegen::Elem {
|
||||
crate::codegen::Elem::F32
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -81,14 +81,6 @@ pub fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
|
|||
strides
|
||||
}
|
||||
|
||||
pub fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
|
||||
let mut num_elems = 1;
|
||||
for i in shape.iter() {
|
||||
num_elems *= i;
|
||||
}
|
||||
num_elems
|
||||
}
|
||||
|
||||
#[derive(new, Debug, Clone)]
|
||||
/// Handle to be used when fusing operations.
|
||||
pub struct WgpuFusionHandle {
|
||||
|
|
|
@ -0,0 +1,57 @@
|
|||
use crate::{
|
||||
codegen::ComputeShader,
|
||||
kernel::{DynamicKernelSource, SourceTemplate},
|
||||
};
|
||||
use hashbrown::HashSet;
|
||||
|
||||
/// This cache ensures that the generation of the source code is only done once when the kernel is
|
||||
/// executed for the first time. Following, we only include the ID in the dynamic kernel source,
|
||||
/// since we rely on the compilation cache of the WGPU compute server.
|
||||
///
|
||||
/// If it ever causes problems, we could cache the compute shader and put it into an Arc to avoid deep
|
||||
/// cloning.
|
||||
#[derive(Default, Debug)]
|
||||
pub struct KernelCompilationCache {
|
||||
already_compiled_ids: HashSet<String>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub enum FusedKernelSource {
|
||||
AlreadyCompiled { id: String },
|
||||
NewKernel { id: String, shader: ComputeShader },
|
||||
}
|
||||
|
||||
impl DynamicKernelSource for FusedKernelSource {
|
||||
fn source(&self) -> SourceTemplate {
|
||||
match self {
|
||||
FusedKernelSource::AlreadyCompiled { id: _ } => {
|
||||
panic!("Can't get the source of an already compiled kernel.")
|
||||
}
|
||||
FusedKernelSource::NewKernel {
|
||||
id: _,
|
||||
shader: source,
|
||||
} => SourceTemplate::new(source.to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
match self {
|
||||
FusedKernelSource::AlreadyCompiled { id } => id.clone(),
|
||||
FusedKernelSource::NewKernel { id, shader: _ } => id.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl KernelCompilationCache {
|
||||
pub fn get(&self, id: &str) -> Option<FusedKernelSource> {
|
||||
if self.already_compiled_ids.contains(id) {
|
||||
return Some(FusedKernelSource::AlreadyCompiled { id: id.to_string() });
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
pub fn insert(&mut self, id: String) {
|
||||
self.already_compiled_ids.insert(id);
|
||||
}
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
mod body;
|
||||
mod function;
|
||||
mod operator;
|
||||
mod shader;
|
||||
mod variable;
|
||||
|
||||
pub use body::*;
|
||||
pub use function::*;
|
||||
pub use operator::*;
|
||||
pub use shader::*;
|
||||
pub use variable::*;
|
|
@ -1,8 +1,10 @@
|
|||
use crate::{
|
||||
codegen::{Elem, Operator, Variable},
|
||||
element::WgpuElement,
|
||||
fusion::codegen::{Elem, Operator, Variable},
|
||||
fusion::cache::KernelCompilationCache,
|
||||
FloatElement, GraphicsApi, IntElement, Wgpu,
|
||||
};
|
||||
use burn_common::id::IdGenerator;
|
||||
use burn_fusion::{
|
||||
graph::{
|
||||
BaseOpsDescription, BinaryOpsDescription, FloatOpsDescription, NumericOpsDescription,
|
||||
|
@ -84,12 +86,14 @@ where
|
|||
.collect::<Vec<_>>();
|
||||
|
||||
Box::new(FloatElementWise {
|
||||
id: IdGenerator::generate(),
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
operators: self.operators.clone(),
|
||||
scalars_f32: self.scalars_f32,
|
||||
device: self.device.clone(),
|
||||
cache: KernelCompilationCache::default(),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -183,13 +187,19 @@ where
|
|||
Operator::AssignGlobal { input: _, out: _ } => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operator::ReadGlobal {
|
||||
Operator::AssignLocal { input: _, out: _ } => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operator::ReadGlobalIntoContiguous {
|
||||
variable: _,
|
||||
position: _,
|
||||
position_out: _,
|
||||
} => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operator::ReadGlobal { variable: _ } => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operator::Add { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
|
@ -242,6 +252,15 @@ where
|
|||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Clamp {
|
||||
input,
|
||||
min_value: _,
|
||||
max_value: _,
|
||||
out,
|
||||
} => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Powf { lhs, rhs, out } => {
|
||||
mark(lhs, &mut local_tensor_ids_input);
|
||||
mark(rhs, &mut local_tensor_ids_input);
|
||||
|
@ -287,6 +306,10 @@ where
|
|||
mark(rhs, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Sqrt { input, out } => {
|
||||
mark(input, &mut local_tensor_ids_input);
|
||||
mark(out, &mut local_tensor_ids_output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,24 +1,73 @@
|
|||
use crate::{
|
||||
fusion::codegen::{Elem, Operator},
|
||||
fusion::kernel::FusionKernel,
|
||||
codegen::{
|
||||
ComputeShader, Elem, ElemWiseKernelCodegen, Input, Operator, Output, ReadingStrategy,
|
||||
Visibility,
|
||||
},
|
||||
fusion::{
|
||||
cache::{FusedKernelSource, KernelCompilationCache},
|
||||
kernel,
|
||||
},
|
||||
FloatElement, GraphicsApi, IntElement, Wgpu,
|
||||
};
|
||||
use burn_fusion::{graph::Context, Optimization, TensorDescription};
|
||||
use burn_tensor::Device;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct FloatElementWise<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
pub(crate) id: String,
|
||||
pub(crate) inputs: Vec<(TensorDescription, Elem)>,
|
||||
pub(crate) outputs: Vec<(TensorDescription, Elem)>,
|
||||
pub(crate) locals: Vec<u16>,
|
||||
pub(crate) operators: Vec<Operator>,
|
||||
pub(crate) scalars_f32: usize,
|
||||
pub(crate) device: Device<Wgpu<G, F, I>>,
|
||||
pub(crate) cache: KernelCompilationCache,
|
||||
}
|
||||
|
||||
impl<G, F, I> FloatElementWise<G, F, I>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
pub fn compile(&mut self) -> ComputeShader {
|
||||
let mut inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|(_tensor, elem)| Input::Array {
|
||||
elem: *elem,
|
||||
visibility: Visibility::Read,
|
||||
strategy: ReadingStrategy::IntoContiguous,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.zip(self.locals.iter())
|
||||
.map(|((_tensor, elem), local)| Output::Array {
|
||||
elem: *elem,
|
||||
local: *local,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if self.scalars_f32 > 0 {
|
||||
inputs.push(Input::Scalar {
|
||||
elem: Elem::F32,
|
||||
size: self.scalars_f32,
|
||||
})
|
||||
}
|
||||
|
||||
ElemWiseKernelCodegen::new()
|
||||
.inputs(&inputs)
|
||||
.body(&self.operators)
|
||||
.outputs(&outputs)
|
||||
.compile()
|
||||
}
|
||||
}
|
||||
|
||||
impl<G, F, I> Optimization<Wgpu<G, F, I>> for FloatElementWise<G, F, I>
|
||||
|
@ -27,27 +76,33 @@ where
|
|||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn execute(&self, context: &mut Context<'_, Wgpu<G, F, I>>) {
|
||||
let inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
|
||||
.collect::<Vec<_>>();
|
||||
fn execute(&mut self, context: &mut Context<'_, Wgpu<G, F, I>>) {
|
||||
if let Some(kernel) = self.cache.get(&self.id) {
|
||||
kernel::execute_fusion(
|
||||
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
self.scalars_f32,
|
||||
kernel,
|
||||
context,
|
||||
self.device.clone(),
|
||||
);
|
||||
} else {
|
||||
let shader = self.compile();
|
||||
|
||||
let outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.map(|(tensor, elem)| (context.tensors.get(&tensor.id).unwrap(), *elem))
|
||||
.collect::<Vec<_>>();
|
||||
kernel::execute_fusion(
|
||||
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
self.scalars_f32,
|
||||
FusedKernelSource::NewKernel {
|
||||
id: self.id.to_string(),
|
||||
shader,
|
||||
},
|
||||
context,
|
||||
self.device.clone(),
|
||||
);
|
||||
|
||||
// The context may contain scalars for the end condition, which may vary.
|
||||
let scalars_f32 = &context.scalar_floats[0..self.scalars_f32];
|
||||
|
||||
FusionKernel::new(&self.device)
|
||||
.inputs(&inputs, scalars_f32)
|
||||
.body(&self.operators)
|
||||
.outputs(&outputs, &self.locals)
|
||||
.execute(context.handles);
|
||||
self.cache.insert(self.id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
|
@ -144,6 +199,7 @@ mod tests {
|
|||
Variant1,
|
||||
Variant2,
|
||||
}
|
||||
|
||||
fn execute<B: Backend>(
|
||||
data_1: Data<f32, 2>,
|
||||
data_2: Data<f32, 2>,
|
||||
|
|
|
@ -1,355 +1,82 @@
|
|||
use super::codegen::Body;
|
||||
use crate::compute::{compute_client, DynamicKernel, WgpuComputeClient};
|
||||
use crate::fusion::codegen::Function;
|
||||
use crate::fusion::{calculate_num_elems_dyn_rank, strides_dyn_rank};
|
||||
use crate::fusion::{
|
||||
codegen::{
|
||||
Binding, ComputeShader, Elem, Location, Operator, Variable, Visibility, WorkgroupSize,
|
||||
},
|
||||
WgpuFusionHandle,
|
||||
};
|
||||
use super::cache::FusedKernelSource;
|
||||
use crate::codegen::calculate_num_elems_dyn_rank;
|
||||
use crate::compute::{compute_client, DynamicKernel};
|
||||
use crate::fusion::strides_dyn_rank;
|
||||
use crate::fusion::WgpuFusionHandle;
|
||||
use crate::kernel::{elemwise_workgroup, WORKGROUP_DEFAULT};
|
||||
use crate::{FloatElement, GraphicsApi, IntElement, Wgpu};
|
||||
use burn_fusion::{HandleContainer, TensorDescription};
|
||||
use burn_fusion::graph::Context;
|
||||
use burn_fusion::TensorDescription;
|
||||
use burn_tensor::Device;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Kernel creation input phase, see [fusion kernel](FusionKernel) for more details.
|
||||
pub struct InputPhase;
|
||||
/// Kernel creation body phase, see [fusion kernel](FusionKernel) for more details.
|
||||
pub struct BodyPhase;
|
||||
/// Kernel creation output phase, see [fusion kernel](FusionKernel) for more details.
|
||||
pub struct OutputPhase;
|
||||
/// Kernel execution phase, see [fusion kernel](FusionKernel) for more details.
|
||||
pub struct ExecutionPhase;
|
||||
|
||||
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
|
||||
///
|
||||
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
|
||||
/// you to make mistakes.
|
||||
///
|
||||
/// 1. [Input Phase](InputPhase)
|
||||
/// This phase focuses on registering the input tensor descriptions that are going to be used by
|
||||
/// the fused kernel.
|
||||
/// 2. [Body Phase](BodyPhase)
|
||||
/// After the input phase is done, all the operations that happen in the body must be
|
||||
/// registered.
|
||||
/// 3. [Output Phase](OutputPhase)
|
||||
/// This step focuses on registering all tensor descriptions that the kernel needs to write to.
|
||||
/// 4. [Execution Phase](ExecutionPhase)
|
||||
/// Now that all other phases are completed, we can actually run the kernel on the given
|
||||
/// [handles](HandleContainer). Note that the actual chosen kernel may vary based on the
|
||||
/// handles provided.
|
||||
pub struct FusionKernel<G, F, I, Phase = InputPhase>
|
||||
where
|
||||
G: GraphicsApi,
|
||||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
operations: Vec<Operator>,
|
||||
input_bindings: Vec<(Binding, TensorDescription)>,
|
||||
output_bindings: Vec<(Binding, TensorDescription)>,
|
||||
named_bindings: Vec<(String, Binding, DataBuffer)>,
|
||||
functions: Vec<Function>,
|
||||
num_elems_output: usize,
|
||||
pub fn execute_fusion<G: GraphicsApi, F: FloatElement, I: IntElement>(
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
scalars_f32: usize,
|
||||
kernel: FusedKernelSource,
|
||||
context: &mut Context<'_, Wgpu<G, F, I>>,
|
||||
device: Device<Wgpu<G, F, I>>,
|
||||
client: WgpuComputeClient,
|
||||
_phase: PhantomData<Phase>,
|
||||
}
|
||||
) {
|
||||
let client = compute_client::<G>(&device);
|
||||
let mut info = Vec::new();
|
||||
let mut handles = Vec::with_capacity(inputs.len() + outputs.len() + 2);
|
||||
|
||||
enum DataBuffer {
|
||||
F32(Vec<f32>),
|
||||
U32(Vec<u32>),
|
||||
}
|
||||
// Inner function to fill the info buffer.
|
||||
let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| {
|
||||
if info.is_empty() {
|
||||
info.push(handle.strides.len() as u32);
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, InputPhase> {
|
||||
/// Create a new fusion kernel on the given device.
|
||||
pub fn new(device: &Device<Wgpu<G, F, I>>) -> Self {
|
||||
let client = compute_client::<G>(device);
|
||||
for s in handle.strides.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
for s in tensor.shape.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
};
|
||||
|
||||
Self {
|
||||
operations: Vec::new(),
|
||||
input_bindings: Vec::new(),
|
||||
output_bindings: Vec::new(),
|
||||
named_bindings: Vec::new(),
|
||||
functions: Vec::new(),
|
||||
num_elems_output: 0,
|
||||
// We start by registering the inputs.
|
||||
for tensor in inputs.iter() {
|
||||
let tensor = context.tensors.get(&tensor.id).unwrap();
|
||||
let handle = context.handles.get_handle(tensor);
|
||||
|
||||
register_info_tensor(tensor, &handle);
|
||||
handles.push(handle.handle);
|
||||
}
|
||||
|
||||
let mut num_elems_output = 0;
|
||||
|
||||
// Then we follow with the outputs.
|
||||
for tensor in outputs.iter() {
|
||||
let tensor = context.tensors.get(&tensor.id).unwrap();
|
||||
|
||||
let num_elems = calculate_num_elems_dyn_rank(&tensor.shape);
|
||||
if num_elems > num_elems_output {
|
||||
num_elems_output = num_elems;
|
||||
}
|
||||
let handle_fusion = WgpuFusionHandle {
|
||||
client: client.clone(),
|
||||
device: device.clone(),
|
||||
client,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register the inputs used by the kernel.
|
||||
pub fn inputs(
|
||||
mut self,
|
||||
inputs_tensor: &[(&TensorDescription, Elem)],
|
||||
inputs_scalar_f32: &[f32],
|
||||
) -> FusionKernel<G, F, I, BodyPhase> {
|
||||
for (i, (input, elem)) in inputs_tensor.iter().enumerate() {
|
||||
if elem != &Elem::Bool {
|
||||
self.input_bindings.push((
|
||||
Binding {
|
||||
elem: *elem,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*input).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::ReadGlobal {
|
||||
variable: Variable::Input(i as u16, *elem),
|
||||
position: i,
|
||||
position_out: inputs_tensor.len(), // First output
|
||||
});
|
||||
} else {
|
||||
self.input_bindings.push((
|
||||
Binding {
|
||||
elem: Elem::I32,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*input).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::ReadGlobal {
|
||||
variable: Variable::Input(i as u16, *elem),
|
||||
position: i,
|
||||
position_out: inputs_tensor.len(), // First output
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
if !inputs_scalar_f32.is_empty() {
|
||||
self.named_bindings.push((
|
||||
"scalars_f32".to_string(),
|
||||
Binding {
|
||||
elem: Elem::F32,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: Some(inputs_scalar_f32.len()),
|
||||
},
|
||||
DataBuffer::F32(inputs_scalar_f32.to_vec()),
|
||||
));
|
||||
}
|
||||
|
||||
FusionKernel {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
functions: self.functions,
|
||||
num_elems_output: self.num_elems_output,
|
||||
device: self.device,
|
||||
client: self.client,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, BodyPhase> {
|
||||
/// Register the [operators](Operator) that the kernel must execute in the order provided.
|
||||
pub fn body(mut self, operators: &[Operator]) -> FusionKernel<G, F, I, OutputPhase> {
|
||||
let mut register_function = |function: Function| {
|
||||
if !self.functions.contains(&function) {
|
||||
self.functions.push(function);
|
||||
}
|
||||
strides: strides_dyn_rank(&tensor.shape),
|
||||
handle: client.empty(core::mem::size_of::<F>() * num_elems),
|
||||
};
|
||||
|
||||
// Since not all operators are native to WGSL, we need to add the custom ones.
|
||||
for ops in operators.iter() {
|
||||
match ops {
|
||||
Operator::Powf {
|
||||
lhs: _,
|
||||
rhs: _,
|
||||
out: _,
|
||||
} => {
|
||||
register_function(Function::Powf(Elem::F32));
|
||||
}
|
||||
Operator::Erf { input: _, out: _ } => {
|
||||
register_function(Function::Erf(Elem::F32));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
self.operations.push(ops.clone());
|
||||
}
|
||||
register_info_tensor(tensor, &handle_fusion);
|
||||
|
||||
FusionKernel {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
functions: self.functions,
|
||||
num_elems_output: self.num_elems_output,
|
||||
device: self.device,
|
||||
client: self.client,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, OutputPhase> {
|
||||
/// Register the outputs with their local variable index.
|
||||
///
|
||||
/// Note that the index corresponds to the registered [operator](Operator) number at the
|
||||
/// [body phase](BodyPhase).
|
||||
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
||||
pub fn outputs(
|
||||
mut self,
|
||||
outputs: &[(&TensorDescription, Elem)],
|
||||
locals: &[u16],
|
||||
) -> FusionKernel<G, F, I, ExecutionPhase> {
|
||||
let mut num_elems_launch_option = 0;
|
||||
|
||||
for (i, ((output, elem), local)) in outputs.iter().zip(locals).enumerate() {
|
||||
let num_elems_output = calculate_num_elems_dyn_rank(&output.shape);
|
||||
if num_elems_output > num_elems_launch_option {
|
||||
num_elems_launch_option = num_elems_output;
|
||||
}
|
||||
|
||||
if elem != &Elem::Bool {
|
||||
self.output_bindings.push((
|
||||
Binding {
|
||||
elem: *elem,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*output).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::AssignGlobal {
|
||||
input: Variable::Local(*local, *elem),
|
||||
out: Variable::Output(i as u16, *elem),
|
||||
});
|
||||
} else {
|
||||
self.output_bindings.push((
|
||||
Binding {
|
||||
elem: Elem::I32, // I32 are used for bool tensors
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
},
|
||||
(*output).clone(),
|
||||
));
|
||||
|
||||
self.operations.push(Operator::AssignGlobal {
|
||||
input: Variable::Local(*local, *elem),
|
||||
out: Variable::Output(i as u16, Elem::I32),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
self.num_elems_output = num_elems_launch_option;
|
||||
|
||||
FusionKernel {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
functions: self.functions,
|
||||
num_elems_output: self.num_elems_output,
|
||||
device: self.device,
|
||||
client: self.client,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<G: GraphicsApi, F: FloatElement, I: IntElement> FusionKernel<G, F, I, ExecutionPhase> {
|
||||
/// Execute the kernel on the provided [handles](HandleContainer).
|
||||
pub fn execute(mut self, handle_container: &mut HandleContainer<Wgpu<G, F, I>>) {
|
||||
let mut inputs = Vec::with_capacity(self.input_bindings.len());
|
||||
let mut outputs = Vec::with_capacity(self.output_bindings.len());
|
||||
let mut named = Vec::with_capacity(2);
|
||||
let mut info = Vec::new();
|
||||
let mut handles =
|
||||
Vec::with_capacity(inputs.capacity() + outputs.capacity() + named.capacity());
|
||||
|
||||
// Inner function to fill the info buffer.
|
||||
let mut register_info_tensor = |tensor: &TensorDescription, handle: &WgpuFusionHandle| {
|
||||
if info.is_empty() {
|
||||
info.push(handle.strides.len() as u32);
|
||||
}
|
||||
|
||||
for s in handle.strides.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
for s in tensor.shape.iter() {
|
||||
info.push(*s as u32);
|
||||
}
|
||||
};
|
||||
|
||||
// We start by registering the inputs.
|
||||
for (binding, tensor) in self.input_bindings.into_iter() {
|
||||
let handle = handle_container.get_handle(&tensor);
|
||||
register_info_tensor(&tensor, &handle);
|
||||
|
||||
inputs.push(binding);
|
||||
handles.push(handle.handle);
|
||||
}
|
||||
|
||||
// Then we follow with the outputs.
|
||||
for (binding, tensor) in self.output_bindings {
|
||||
let num_elems = calculate_num_elems_dyn_rank(&tensor.shape);
|
||||
let handle_fusion = WgpuFusionHandle {
|
||||
client: self.client.clone(),
|
||||
device: self.device.clone(),
|
||||
strides: strides_dyn_rank(&tensor.shape),
|
||||
handle: self.client.empty(core::mem::size_of::<F>() * num_elems),
|
||||
};
|
||||
register_info_tensor(&tensor, &handle_fusion);
|
||||
|
||||
handles.push(handle_fusion.handle.clone());
|
||||
handle_container.register_handle(tensor.id, handle_fusion);
|
||||
outputs.push(binding);
|
||||
}
|
||||
|
||||
// Now we can create the info handle.
|
||||
Self::build_info_handle(&mut self.named_bindings, info);
|
||||
|
||||
// Finally we finish with the named bindings.
|
||||
for (name, binding, data) in self.named_bindings {
|
||||
let handle = self.client.create(match &data {
|
||||
DataBuffer::F32(values) => bytemuck::cast_slice(values),
|
||||
DataBuffer::U32(values) => bytemuck::cast_slice(values),
|
||||
});
|
||||
named.push((name, binding));
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
// We create the shader codegen type and launch the kernel.
|
||||
let kernel = ComputeShader {
|
||||
inputs,
|
||||
outputs,
|
||||
named,
|
||||
workgroup_size: WorkgroupSize::default(),
|
||||
body: Body::new(self.operations),
|
||||
num_workgroups: true,
|
||||
global_invocation_id: true,
|
||||
functions: self.functions,
|
||||
};
|
||||
|
||||
let workgroup = elemwise_workgroup(self.num_elems_output, WORKGROUP_DEFAULT);
|
||||
let kernel = Box::new(DynamicKernel::new(kernel, workgroup));
|
||||
|
||||
self.client
|
||||
.execute(kernel, &handles.iter().collect::<Vec<_>>());
|
||||
}
|
||||
|
||||
fn build_info_handle(named_bindings: &mut Vec<(String, Binding, DataBuffer)>, info: Vec<u32>) {
|
||||
named_bindings.push((
|
||||
"info".to_string(),
|
||||
Binding {
|
||||
elem: Elem::U32,
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None, // We avoid putting the length here since it will force a new kernel
|
||||
// for each tensor rank.
|
||||
},
|
||||
DataBuffer::U32(info),
|
||||
));
|
||||
handles.push(handle_fusion.handle.clone());
|
||||
context
|
||||
.handles
|
||||
.register_handle(tensor.id.clone(), handle_fusion);
|
||||
}
|
||||
|
||||
handles.push(client.create(bytemuck::cast_slice(&info)));
|
||||
|
||||
// Finally we finish with the named bindings.
|
||||
if scalars_f32 > 0 {
|
||||
handles.push(client.create(bytemuck::cast_slice(&context.scalar_floats[0..scalars_f32])));
|
||||
}
|
||||
|
||||
let workgroup = elemwise_workgroup(num_elems_output, WORKGROUP_DEFAULT);
|
||||
let kernel = Box::new(DynamicKernel::new(kernel, workgroup));
|
||||
client.execute(kernel, &handles.iter().collect::<Vec<_>>());
|
||||
}
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
mod base;
|
||||
mod elemwise;
|
||||
|
||||
pub(crate) mod codegen;
|
||||
pub(crate) mod cache;
|
||||
pub(crate) mod kernel;
|
||||
|
||||
pub use base::*;
|
||||
|
|
|
@ -1,78 +1,27 @@
|
|||
use super::unary;
|
||||
use crate::{
|
||||
compute::StaticKernel,
|
||||
codegen::{Operator, Variable},
|
||||
element::WgpuElement,
|
||||
kernel::{unary_scalar, unary_scalar_inplace_default, WORKGROUP_DEFAULT},
|
||||
kernel_wgsl,
|
||||
ops::numeric::empty_device,
|
||||
tensor::WgpuTensor,
|
||||
unary_scalar, unary_scalar_inplace,
|
||||
unary,
|
||||
};
|
||||
|
||||
use super::{elemwise_workgroup, KernelSettings};
|
||||
|
||||
kernel_wgsl!(Clamp, "../template/clamp/clamp.wgsl");
|
||||
kernel_wgsl!(ClampInplace, "../template/clamp/clamp_inplace.wgsl");
|
||||
|
||||
pub(crate) fn clamp_min<E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
min_value: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar!(ClampMin, func "max");
|
||||
unary_scalar_inplace!(ClampMinInplace, func "max");
|
||||
|
||||
if input.can_mut() {
|
||||
return unary_scalar_inplace_default::<ClampMinInplace, E, D>(input, min_value);
|
||||
}
|
||||
|
||||
unary_scalar::<ClampMin, E, D, WORKGROUP_DEFAULT>(input, min_value)
|
||||
}
|
||||
|
||||
pub(crate) fn clamp_max<E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
max_value: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar!(ClampMax, func "min");
|
||||
unary_scalar_inplace!(ClampMaxInPlace, func "min");
|
||||
|
||||
if input.can_mut() {
|
||||
return unary_scalar_inplace_default::<ClampMaxInPlace, E, D>(input, max_value);
|
||||
}
|
||||
|
||||
unary_scalar::<ClampMax, E, D, WORKGROUP_DEFAULT>(input, max_value)
|
||||
}
|
||||
unary!(
|
||||
|elem| Operator::Clamp {
|
||||
input: Variable::Input(0, elem),
|
||||
min_value: Variable::Scalar(0, elem),
|
||||
max_value: Variable::Scalar(1, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
scalar 2
|
||||
);
|
||||
|
||||
pub(crate) fn clamp<E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
min_value: E,
|
||||
max_value: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let num_elems = input.shape.num_elements();
|
||||
let min_handle = input.client.create(E::as_bytes(&[min_value]));
|
||||
let max_handle = input.client.create(E::as_bytes(&[max_value]));
|
||||
|
||||
if input.can_mut() {
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<ClampInplace, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
|
||||
input
|
||||
.client
|
||||
.execute(Box::new(kernel), &[&input.handle, &min_handle, &max_handle]);
|
||||
|
||||
return input;
|
||||
}
|
||||
|
||||
let output = empty_device(input.client.clone(), input.device.clone(), input.shape);
|
||||
let kernel = StaticKernel::<
|
||||
KernelSettings<Clamp, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
|
||||
>::new(elemwise_workgroup(num_elems, WORKGROUP_DEFAULT));
|
||||
|
||||
input.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&input.handle, &output.handle, &min_handle, &max_handle],
|
||||
);
|
||||
|
||||
output
|
||||
unary::<Ops<E>, OpsInplace<E>, E, D>(input, Some(&[min_value, max_value]))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -80,30 +29,6 @@ mod tests {
|
|||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{Distribution, Tensor};
|
||||
|
||||
#[test]
|
||||
fn clamp_min_should_match_reference() {
|
||||
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);
|
||||
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data());
|
||||
|
||||
let output = input.clamp_min(0.5);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq(&input_ref.clamp_min(0.5).into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_max_should_match_reference() {
|
||||
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);
|
||||
let input_ref = Tensor::<ReferenceBackend, 4>::from_data(input.to_data());
|
||||
|
||||
let output = input.clamp_max(0.5);
|
||||
|
||||
output
|
||||
.into_data()
|
||||
.assert_approx_eq(&input_ref.clamp_max(0.5).into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn clamp_should_match_reference() {
|
||||
let input = Tensor::<TestBackend, 4>::random([1, 5, 32, 32], Distribution::Default);
|
||||
|
|
|
@ -8,14 +8,12 @@ mod index;
|
|||
mod mask;
|
||||
mod source;
|
||||
mod unary;
|
||||
mod unary_scalar;
|
||||
|
||||
pub use base::*;
|
||||
pub use binary_elemwise::*;
|
||||
pub use cast::*;
|
||||
pub use source::*;
|
||||
pub use unary::*;
|
||||
pub use unary_scalar::*;
|
||||
|
||||
/// Convolution kernels
|
||||
pub mod conv;
|
||||
|
|
|
@ -1,169 +1,220 @@
|
|||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
|
||||
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
|
||||
kernel_wgsl!(UnaryRaw, "../template/unary.wgsl");
|
||||
kernel_wgsl!(UnaryInplaceRaw, "../template/unary_inplace.wgsl");
|
||||
use super::StaticKernelSource;
|
||||
use crate::{
|
||||
codegen::{execute_static, StaticHandle},
|
||||
element::WgpuElement,
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
/// Creates a unary kernel.
|
||||
#[macro_export]
|
||||
macro_rules! unary {
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
operator: $ops:expr,
|
||||
input: $input:expr,
|
||||
elem: $elem:ty
|
||||
) => {{
|
||||
unary!($ops);
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, None)
|
||||
}};
|
||||
(
|
||||
operator: $ops:expr,
|
||||
input: $input:expr; $scalar:expr,
|
||||
elem: $elem:ty
|
||||
) => {{
|
||||
unary!($ops, scalar 1);
|
||||
|
||||
$crate::kernel::unary::<Ops<$elem>, OpsInplace<$elem>, $elem, D>($input, Some(&[$scalar]))
|
||||
}};
|
||||
|
||||
(
|
||||
$ops:expr
|
||||
) => {
|
||||
pub struct Ops<E> {
|
||||
_e: core::marker::PhantomData<E>,
|
||||
}
|
||||
pub struct OpsInplace<E> {
|
||||
_e: core::marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for Ops<E> {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let source = $crate::kernel::UnaryRaw::source();
|
||||
source.register("body", format!("output[id] = {}(input[id]);", $func))
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[$crate::codegen::Input::Array {
|
||||
elem: E::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
|
||||
}])
|
||||
.body(&[$ops(E::elem_type())])
|
||||
.outputs(&[$crate::codegen::Output::Array {
|
||||
elem: E::elem_type(),
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for OpsInplace<E> {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[$crate::codegen::Input::Array {
|
||||
elem: E::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::ReadWrite,
|
||||
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||
}])
|
||||
.body(&[$ops(E::elem_type())])
|
||||
.outputs(&[$crate::codegen::Output::Input {
|
||||
elem: E::elem_type(),
|
||||
input: 0,
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$struct:ident,
|
||||
body $body:expr
|
||||
$ops:expr,
|
||||
scalar $num:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
pub struct Ops<E> {
|
||||
_e: core::marker::PhantomData<E>,
|
||||
}
|
||||
pub struct OpsInplace<E> {
|
||||
_e: core::marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for Ops<E> {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryRaw::source().register("body", $body)
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
elem: E::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::IntoContiguous,
|
||||
},
|
||||
$crate::codegen::Input::Scalar {
|
||||
elem: E::elem_type(),
|
||||
size: $num,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(E::elem_type())])
|
||||
.outputs(&[$crate::codegen::Output::Array {
|
||||
elem: E::elem_type(),
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr,
|
||||
include $file:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<E: $crate::element::WgpuElement> $crate::kernel::StaticKernelSource for OpsInplace<E> {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryRaw::source()
|
||||
.register("body", format!("output[id] = {}(input[id]);", $func))
|
||||
.add_template(include_str!($file))
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
elem: E::elem_type(),
|
||||
visibility: $crate::codegen::Visibility::ReadWrite,
|
||||
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||
},
|
||||
$crate::codegen::Input::Scalar {
|
||||
elem: E::elem_type(),
|
||||
size: $num,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(E::elem_type())])
|
||||
.outputs(&[$crate::codegen::Output::Input {
|
||||
elem: E::elem_type(),
|
||||
input: 0,
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
$crate::kernel::SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates a unary inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! unary_inplace {
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryInplaceRaw::source()
|
||||
.register("body", format!("input[id] = {}(input[id]);", $func))
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$struct:ident,
|
||||
body $body:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryInplaceRaw::source().register("body", $body)
|
||||
}
|
||||
}
|
||||
};
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr,
|
||||
include $file:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryInplaceRaw::source()
|
||||
.register("body", format!("input[id] = {}(input[id]);", $func))
|
||||
.add_template(include_str!($file))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Execute a unary kernel using the default settings.
|
||||
pub fn unary_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary::<K, E, D, WORKGROUP_DEFAULT>(input)
|
||||
}
|
||||
|
||||
/// Execute a unary inplace kernel using the default settings.
|
||||
pub fn unary_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_inplace::<K, E, D, WORKGROUP_DEFAULT>(input)
|
||||
}
|
||||
|
||||
/// Execute a unary inplace kernel using the provided WORKGROUP.
|
||||
pub fn unary_inplace<
|
||||
/// Launch an unary operation.
|
||||
pub fn unary<K, KI, E, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
scalars: Option<&[E]>,
|
||||
) -> WgpuTensor<E, D>
|
||||
where
|
||||
K: StaticKernelSource,
|
||||
KI: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let num_elems = input.shape.num_elements();
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
{
|
||||
if !tensor.can_mut() {
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
let buffer = tensor.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(
|
||||
tensor.client.clone(),
|
||||
tensor.device,
|
||||
tensor.shape.clone(),
|
||||
buffer,
|
||||
);
|
||||
|
||||
input.client.execute(Box::new(kernel), &[&input.handle]);
|
||||
execute_static::<K, E>(
|
||||
&[StaticHandle::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
&[StaticHandle::new(
|
||||
&output.handle,
|
||||
&output.strides,
|
||||
&output.shape.dims,
|
||||
)],
|
||||
scalars,
|
||||
tensor.client,
|
||||
);
|
||||
|
||||
input
|
||||
}
|
||||
output
|
||||
} else {
|
||||
execute_static::<KI, E>(
|
||||
&[],
|
||||
&[StaticHandle::new(
|
||||
&tensor.handle,
|
||||
&tensor.strides,
|
||||
&tensor.shape.dims,
|
||||
)],
|
||||
scalars,
|
||||
tensor.client.clone(),
|
||||
);
|
||||
|
||||
/// Execute a unary kernel using the provided WORKGROUP.
|
||||
pub fn unary<K: StaticKernelSource, E: WgpuElement, const D: usize, const WORKGROUP: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let num_elems = input.shape.num_elements();
|
||||
let buffer = input.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let mut output = WgpuTensor::new(input.client.clone(), input.device, input.shape, buffer);
|
||||
// Since we don't handle the stride inside the kernel, the output tensor have the same strides
|
||||
// as the input tensor. It might not be in the default format.
|
||||
output.strides = input.strides;
|
||||
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
input
|
||||
.client
|
||||
.execute(Box::new(kernel), &[&input.handle, &output.handle]);
|
||||
|
||||
output
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codegen::{Operator, Variable};
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{Distribution, Tensor};
|
||||
|
||||
unary!(TestKernel, func "log");
|
||||
unary_inplace!(TestKernelInplace, func "log");
|
||||
unary!(|elem| Operator::Tanh {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
});
|
||||
|
||||
#[test]
|
||||
fn unary_should_work_with_multiple_invocations() {
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let actual = unary::<TestKernel, _, 2, 16>(tensor.into_primitive());
|
||||
let expected = tensor_ref.log();
|
||||
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
|
||||
let expected = tensor_ref.tanh();
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||
|
@ -176,8 +227,8 @@ mod tests {
|
|||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let actual = unary_inplace::<TestKernelInplace, _, 2, 16>(tensor.into_primitive());
|
||||
let expected = tensor_ref.log();
|
||||
let actual = unary::<Ops<f32>, OpsInplace<f32>, f32, 2>(tensor.into_primitive(), None);
|
||||
let expected = tensor_ref.tanh();
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||
|
|
|
@ -1,220 +0,0 @@
|
|||
use super::{elemwise_workgroup, KernelSettings, StaticKernelSource, WORKGROUP_DEFAULT};
|
||||
use crate::{compute::StaticKernel, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
|
||||
kernel_wgsl!(UnaryScalarRaw, "../template/unary_scalar.wgsl");
|
||||
kernel_wgsl!(
|
||||
UnaryScalarInplaceRaw,
|
||||
"../template/unary_scalar_inplace.wgsl"
|
||||
);
|
||||
|
||||
/// Creates a unary scalar kernel.
|
||||
#[macro_export]
|
||||
macro_rules! unary_scalar {
|
||||
(
|
||||
$struct:ident,
|
||||
ops $ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarRaw::source()
|
||||
.register("body", format!("output[id] = lhs[id] {} rhs;", $ops))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarRaw::source()
|
||||
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr,
|
||||
include $file:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarRaw::source()
|
||||
.register("body", format!("output[id] = {}(lhs[id], rhs);", $func))
|
||||
.add_template(include_str!($file))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Creates a unary scalar inplace kernel.
|
||||
#[macro_export]
|
||||
macro_rules! unary_scalar_inplace {
|
||||
(
|
||||
$struct:ident,
|
||||
ops $ops:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source()
|
||||
.register("body", format!("lhs[id] = lhs[id] {} rhs;", $ops))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
$struct:ident,
|
||||
body $body:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source().register("body", $body)
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source()
|
||||
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
(
|
||||
$struct:ident,
|
||||
func $func:expr,
|
||||
include $file:expr
|
||||
) => {
|
||||
pub struct $struct;
|
||||
|
||||
impl $crate::kernel::StaticKernelSource for $struct {
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
$crate::kernel::UnaryScalarInplaceRaw::source()
|
||||
.register("body", format!("lhs[id] = {}(lhs[id], rhs);", $func))
|
||||
.add_template(include_str!($file))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Execute a unary scalar kernel using the default settings.
|
||||
pub fn unary_scalar_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
|
||||
}
|
||||
|
||||
/// Execute a unary scalar kernel using the provided WORKGROUP.
|
||||
pub fn unary_scalar<
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let num_elems = lhs.shape.num_elements();
|
||||
let buffer = lhs.client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(lhs.client.clone(), lhs.device, lhs.shape, buffer);
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(kernel),
|
||||
&[&lhs.handle, &rhs_handle, &output.handle],
|
||||
);
|
||||
|
||||
output
|
||||
}
|
||||
|
||||
/// Execute a unary scalar inplace kernel using the default settings.
|
||||
pub fn unary_scalar_inplace_default<K: StaticKernelSource, E: WgpuElement, const D: usize>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar_inplace::<K, E, D, WORKGROUP_DEFAULT>(lhs, scalar)
|
||||
}
|
||||
|
||||
/// Execute a unary scalar inplace kernel using the provided WORKGROUP.
|
||||
pub fn unary_scalar_inplace<
|
||||
K: StaticKernelSource,
|
||||
E: WgpuElement,
|
||||
const D: usize,
|
||||
const WORKGROUP: usize,
|
||||
>(
|
||||
lhs: WgpuTensor<E, D>,
|
||||
scalar: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
let num_elems = lhs.shape.num_elements();
|
||||
let kernel = StaticKernel::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>::new(
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
);
|
||||
let rhs_handle = lhs.client.create(E::as_bytes(&[scalar]));
|
||||
|
||||
lhs.client
|
||||
.execute(Box::new(kernel), &[&lhs.handle, &rhs_handle]);
|
||||
|
||||
lhs
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{Distribution, Tensor};
|
||||
|
||||
unary_scalar!(TestKernel, ops "*");
|
||||
unary_scalar_inplace!(TestKernelInplace, ops "*");
|
||||
|
||||
#[test]
|
||||
fn unary_scalar_should_work_with_multiple_invocations() {
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let actual = unary_scalar::<TestKernel, _, 2, 16>(tensor.into_primitive(), 5.0);
|
||||
let expected = tensor_ref.mul_scalar(5.0);
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||
3,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unary_scalar_inplace_should_work_with_multiple_invocations() {
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let actual =
|
||||
unary_scalar_inplace::<TestKernelInplace, _, 2, 16>(tensor.into_primitive(), 5.0);
|
||||
let expected = tensor_ref.mul_scalar(5.0);
|
||||
|
||||
expected.into_data().assert_approx_eq(
|
||||
&Tensor::<TestBackend, 2>::from_primitive(actual).into_data(),
|
||||
3,
|
||||
);
|
||||
}
|
||||
}
|
|
@ -15,6 +15,8 @@ pub mod kernel;
|
|||
/// Tensor module.
|
||||
pub mod tensor;
|
||||
|
||||
pub(crate) mod codegen;
|
||||
|
||||
mod element;
|
||||
pub use element::{FloatElement, IntElement};
|
||||
|
||||
|
|
|
@ -1,10 +1,8 @@
|
|||
use burn_tensor::ops::{ActivationOps, FloatTensor};
|
||||
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
kernel::{unary_default, unary_inplace_default},
|
||||
unary, unary_inplace, GraphicsApi, Wgpu,
|
||||
GraphicsApi, Wgpu,
|
||||
};
|
||||
use burn_tensor::ops::ActivationOps;
|
||||
|
||||
impl<G, F, I> ActivationOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
|
||||
where
|
||||
|
@ -12,14 +10,4 @@ where
|
|||
F: FloatElement,
|
||||
I: IntElement,
|
||||
{
|
||||
fn relu<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Relu, body "output[id] = max(input[id], 0.0);");
|
||||
unary_inplace!(ReluInplace, body "input[id] = max(input[id], 0.0);");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<ReluInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Relu, F, D>(tensor)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use super::numeric;
|
||||
use crate::codegen::{Elem, Operator, Variable};
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
use crate::kernel::matmul::init_matmul_output;
|
||||
#[cfg(feature = "autotune")]
|
||||
|
@ -8,18 +9,14 @@ use crate::kernel::matmul::vec4::matmul_tiling_2d_vec4;
|
|||
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
use crate::kernel::reduce::init_reduce_output;
|
||||
use crate::kernel::{
|
||||
self, reduce, unary_default, unary_inplace_default, unary_scalar_default,
|
||||
unary_scalar_inplace_default,
|
||||
};
|
||||
use crate::{unary, unary_inplace, unary_scalar, FloatElement, GraphicsApi, IntElement, Wgpu};
|
||||
use crate::{unary_scalar_inplace, WgpuDevice};
|
||||
use crate::kernel::{self, reduce};
|
||||
use crate::WgpuDevice;
|
||||
use crate::{unary, FloatElement, GraphicsApi, IntElement, Wgpu};
|
||||
use burn_tensor::ops::{
|
||||
BoolTensor, Device, FloatElem, FloatTensor, FullPrecisionBackend, IntTensor,
|
||||
};
|
||||
use burn_tensor::{ops::TensorOps, Data, Distribution, Shape};
|
||||
use burn_tensor::{ElementConversion, Reader};
|
||||
|
||||
use std::ops::Range;
|
||||
|
||||
impl<G, F, I> TensorOps<Wgpu<G, F, I>> for Wgpu<G, F, I>
|
||||
|
@ -357,122 +354,115 @@ where
|
|||
kernel::cast(tensor)
|
||||
}
|
||||
|
||||
fn exp<const D: usize>(lhs: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Exp, func "exp");
|
||||
unary_inplace!(ExpInplace, func "exp");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_inplace_default::<ExpInplace, F, D>(lhs);
|
||||
}
|
||||
|
||||
unary_default::<Exp, F, D>(lhs)
|
||||
fn exp<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Exp {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Log, func "log");
|
||||
unary_inplace!(LogInplace, func "log");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<LogInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Log, F, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Log {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Log1p, body "output[id] = log(1.0 + input[id]);");
|
||||
unary_inplace!(Log1pInplace, body "input[id] = log(1.0 + input[id]);");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<Log1pInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Log1p, F, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Log1p {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn powf<const D: usize>(lhs: FloatTensor<Self, D>, rhs: f32) -> FloatTensor<Self, D> {
|
||||
unary_scalar!(Powf, func "powf", include "../template/powf.wgsl");
|
||||
unary_scalar_inplace!(PowfInplace, func "powf", include "../template/powf.wgsl");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace_default::<PowfInplace, F, D>(lhs, rhs.elem());
|
||||
}
|
||||
|
||||
unary_scalar_default::<Powf, F, D>(lhs, rhs.elem())
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Powf {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: lhs; rhs.elem(),
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Sqrt, func "sqrt");
|
||||
unary_inplace!(SqrtInplace, func "sqrt");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<SqrtInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Sqrt, F, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Sqrt {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Abs, func "abs");
|
||||
unary_inplace!(AbsInplace, func "abs");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<AbsInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Abs, F, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Abs {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Cos, func "cos");
|
||||
unary_inplace!(CosInplace, func "cos");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<CosInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Cos, F, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Cos {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Sin, func "sin");
|
||||
unary_inplace!(SinInplace, func "sin");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<SinInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Sin, F, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Sin {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
// Metal has a weird numerical behaviour with tanh which require a new function
|
||||
#[cfg(target_os = "macos")]
|
||||
unary!(Tanh, func "safe_tanh", include "../template/safe_tanh.wgsl");
|
||||
#[cfg(target_os = "macos")]
|
||||
unary_inplace!(TanhInplace, func "safe_tanh", include "../template/safe_tanh.wgsl");
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
unary!(Tanh, func "tanh");
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
unary_inplace!(TanhInplace, func "tanh");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<TanhInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Tanh, F, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Tanh {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(Erf, func "erf", include "../template/erf.wgsl");
|
||||
unary_inplace!(ErfInplace, func "erf", include "../template/erf.wgsl");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<ErfInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Erf, F, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Erf {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn cat<const D: usize>(tensors: Vec<FloatTensor<Self, D>>, dim: usize) -> FloatTensor<Self, D> {
|
||||
|
@ -491,20 +481,6 @@ where
|
|||
kernel::cast(tensor)
|
||||
}
|
||||
|
||||
fn clamp_min<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
min: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
kernel::clamp_min(tensor, min)
|
||||
}
|
||||
|
||||
fn clamp_max<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
max: FloatElem<Self>,
|
||||
) -> FloatTensor<Self, D> {
|
||||
kernel::clamp_max(tensor, max)
|
||||
}
|
||||
|
||||
fn clamp<const D: usize>(
|
||||
tensor: FloatTensor<Self, D>,
|
||||
min: FloatElem<Self>,
|
||||
|
@ -516,14 +492,14 @@ where
|
|||
fn recip<const D: usize>(
|
||||
tensor: FloatTensor<Wgpu<G, F, I>, D>,
|
||||
) -> FloatTensor<Wgpu<G, F, I>, D> {
|
||||
unary!(Recip, func "1.0 /");
|
||||
unary_inplace!(RecipInplace, func "1.0 /");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<RecipInplace, F, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<Recip, F, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Recip {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: F
|
||||
)
|
||||
}
|
||||
|
||||
fn repeat<const D: usize>(
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
use super::numeric;
|
||||
|
||||
use crate::codegen::{Elem, Operator, Variable};
|
||||
use crate::kernel::reduce::{self, init_reduce_output};
|
||||
use crate::kernel::{unary_default, unary_inplace_default};
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
kernel, unary, unary_inplace, GraphicsApi, Wgpu,
|
||||
kernel, unary, GraphicsApi, Wgpu,
|
||||
};
|
||||
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
|
||||
|
@ -280,20 +280,6 @@ where
|
|||
kernel::reduce::argmin(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_clamp_min<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
min: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
kernel::clamp_min(tensor, min)
|
||||
}
|
||||
|
||||
fn int_clamp_max<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
max: IntElem<Self>,
|
||||
) -> IntTensor<Self, D> {
|
||||
kernel::clamp_max(tensor, max)
|
||||
}
|
||||
|
||||
fn int_clamp<const D: usize>(
|
||||
tensor: IntTensor<Self, D>,
|
||||
min: IntElem<Self>,
|
||||
|
@ -303,14 +289,14 @@ where
|
|||
}
|
||||
|
||||
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
|
||||
unary!(IntAbs, func "abs");
|
||||
unary_inplace!(IntAbsInplace, func "abs");
|
||||
|
||||
if tensor.can_mut() {
|
||||
return unary_inplace_default::<IntAbsInplace, I, D>(tensor);
|
||||
}
|
||||
|
||||
unary_default::<IntAbs, I, D>(tensor)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Abs {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: tensor,
|
||||
elem: I
|
||||
)
|
||||
}
|
||||
|
||||
fn int_into_float<const D: usize>(tensor: IntTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
use crate::codegen::{Elem, Operator, Variable};
|
||||
use crate::compute::{compute_client, WgpuComputeClient};
|
||||
use crate::kernel::{
|
||||
binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default,
|
||||
unary_scalar_inplace_default,
|
||||
};
|
||||
use crate::kernel::{binary_elemwise_default, binary_elemwise_inplace_default};
|
||||
use crate::{
|
||||
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor,
|
||||
unary_scalar, unary_scalar_inplace,
|
||||
binary_elemwise, binary_elemwise_inplace, element::WgpuElement, tensor::WgpuTensor, unary,
|
||||
};
|
||||
use crate::{GraphicsApi, WgpuDevice};
|
||||
use burn_tensor::{Element, ElementConversion, Shape};
|
||||
|
@ -28,8 +25,14 @@ pub fn full_device<E: WgpuElement + Element, const D: usize>(
|
|||
) -> WgpuTensor<E, D> {
|
||||
let empty = empty_device(client, device, shape);
|
||||
|
||||
unary_scalar_inplace!(Full, body "lhs[id] = rhs;");
|
||||
unary_scalar_inplace_default::<Full, E, D>(empty, value)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::AssignLocal {
|
||||
input: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: empty; value,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn zeros<G: GraphicsApi, E: WgpuElement + Element, const D: usize>(
|
||||
|
@ -98,14 +101,15 @@ pub fn add_scalar<E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar!(AddScalar, ops "+");
|
||||
unary_scalar_inplace!(AddScalarInplace, ops "+");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace_default::<AddScalarInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
unary_scalar_default::<AddScalar, E, D>(lhs, rhs)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Add {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn sub<E: WgpuElement, const D: usize>(
|
||||
|
@ -126,14 +130,15 @@ pub fn sub_scalar<E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar!(SubScalar, ops "-");
|
||||
unary_scalar_inplace!(SubScalarInplace, ops "-");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace_default::<SubScalarInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
unary_scalar_default::<SubScalar, E, D>(lhs, rhs)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Sub {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn mul<E: WgpuElement, const D: usize>(
|
||||
|
@ -158,14 +163,15 @@ pub fn mul_scalar<E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar!(MulScalar, ops "*");
|
||||
unary_scalar_inplace!(MulScalarInplace, ops "*");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace_default::<MulScalarInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
unary_scalar_default::<MulScalar, E, D>(lhs, rhs)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Mul {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
pub fn div<E: WgpuElement, const D: usize>(
|
||||
|
@ -186,12 +192,13 @@ pub fn div_scalar<E: WgpuElement, const D: usize>(
|
|||
lhs: WgpuTensor<E, D>,
|
||||
rhs: E,
|
||||
) -> WgpuTensor<E, D> {
|
||||
unary_scalar!(DivScalar, ops "/");
|
||||
unary_scalar_inplace!(DivScalarInplace, ops "/");
|
||||
|
||||
if lhs.can_mut() {
|
||||
return unary_scalar_inplace_default::<DivScalarInplace, E, D>(lhs, rhs);
|
||||
}
|
||||
|
||||
unary_scalar_default::<DivScalar, E, D>(lhs, rhs)
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::Div {
|
||||
lhs: Variable::Input(0, elem),
|
||||
rhs: Variable::Scalar(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: lhs; rhs,
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> min_value: {{ elem }};
|
||||
|
||||
@group(0)
|
||||
@binding(3)
|
||||
var<storage, read> max_value: {{ elem }};
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
||||
output[id] = clamp(input[id], min_value, max_value);
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> min_value: {{ elem }};
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> max_value: {{ elem }};
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
||||
input[id] = clamp(input[id], min_value, max_value);
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
/// An approximation of the error function: https://en.wikipedia.org/wiki/Error_function#Numerical_approximations
|
||||
///
|
||||
/// > (maximum error: 1.5×10−7)
|
||||
/// > All of these approximations are valid for x ≥ 0. To use these approximations for negative x, use the fact that erf x is an odd function, so erf x = −erf(−x).
|
||||
fn erf_positive(x: {{ elem }}) -> {{ elem }} {
|
||||
let p = 0.3275911;
|
||||
let a1 = 0.254829592;
|
||||
let a2 = -0.284496736;
|
||||
let a3 = 1.421413741;
|
||||
let a4 = -1.453152027;
|
||||
let a5 = 1.061405429;
|
||||
|
||||
let t = 1.0 / (1.0 + p * abs(x));
|
||||
let tmp = ((((a5 * t + a4) * t) + a3) * t + a2) * t + a1;
|
||||
|
||||
return 1.0 - (tmp * t * exp(-x * x));
|
||||
}
|
||||
|
||||
fn erf(x: {{ elem }}) -> {{ elem }} {
|
||||
if (x < 0.0) {
|
||||
return -1.0 * erf_positive(-1.0 * x);
|
||||
}
|
||||
|
||||
return erf_positive(x);
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
fn powf(lhs: {{ elem }}, rhs: {{ elem }}) -> {{ elem }} {
|
||||
let modulo = rhs % 2.0;
|
||||
|
||||
if (modulo == 0.0) {
|
||||
// Even number
|
||||
return pow(abs(lhs), rhs);
|
||||
} else if (modulo == 1.0 && lhs < 0.0) {
|
||||
// Odd number
|
||||
return -1.0 * pow(-1.0 * lhs, rhs);
|
||||
} else {
|
||||
// Float number
|
||||
return pow(lhs, rhs);
|
||||
}
|
||||
}
|
|
@ -1,8 +0,0 @@
|
|||
/// Metal has a weird numerical behaviour with tanh for inputs over 43.0
|
||||
fn safe_tanh(x: {{ elem }}) -> {{ elem }} {
|
||||
if x > 43.0 {
|
||||
return 1.0;
|
||||
} else {
|
||||
return tanh(x);
|
||||
}
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
||||
{{ body }}
|
||||
}
|
|
@ -1,13 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> input: array<{{ elem }}>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
||||
{{ body }}
|
||||
}
|
|
@ -1,21 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: {{ elem }};
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
||||
{{ body }}
|
||||
}
|
|
@ -1,17 +0,0 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read_write> lhs: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read> rhs: {{ elem }};
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * {{ workgroup_size_x }}u) + global_id.x;
|
||||
{{ body }}
|
||||
}
|
|
@ -1,8 +1,9 @@
|
|||
use crate::codegen::{Elem, Operator, Variable};
|
||||
use crate::element::WgpuElement;
|
||||
use crate::{
|
||||
compute::{WgpuComputeClient, WgpuHandle},
|
||||
unary, WgpuDevice,
|
||||
};
|
||||
use crate::{element::WgpuElement, kernel::unary_default};
|
||||
use burn_tensor::Shape;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
|
@ -96,8 +97,14 @@ impl<E: WgpuElement, const D: usize> WgpuTensor<E, D> {
|
|||
// slowdowns.
|
||||
//
|
||||
// The solution is just to use a simple unary compute shader.
|
||||
unary!(CopyBuffer, body "output[id] = input[id];");
|
||||
unary_default::<CopyBuffer, E, D>(self.clone())
|
||||
unary!(
|
||||
operator: |elem: Elem| Operator::AssignLocal {
|
||||
input: Variable::Input(0, elem),
|
||||
out: Variable::Local(0, elem),
|
||||
},
|
||||
input: self.clone(),
|
||||
elem: E
|
||||
)
|
||||
}
|
||||
|
||||
/// Check if the tensor is safe to mutate.
|
||||
|
|
Loading…
Reference in New Issue