[Refactor] Just-In-Time Compilation Pipeline (#1313)

This commit is contained in:
Nathaniel Simard 2024-02-16 14:45:59 -05:00 committed by GitHub
parent 24287237d1
commit 843dd492c2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
37 changed files with 2633 additions and 1710 deletions

View File

@ -0,0 +1,261 @@
use super::dialect::gpu;
use crate::codegen::dialect::gpu::{
Binding, ComputeShader, Elem, Item, Location, Variable, Vectorization, Visibility,
WorkgroupSize,
};
/// The compilation struct allows you to create a [compute shader](ComputeShader) based on
/// [compilation info](CompilationInfo) and [compilation settings](CompilationSettings).
#[derive(Clone)]
pub struct Compilation {
info: CompilationInfo,
input_bindings: Vec<Binding>,
output_bindings: Vec<Binding>,
named_bindings: Vec<(String, Binding)>,
}
/// The information necessary to compile a [compute shader](ComputeShader).
#[derive(Clone)]
pub struct CompilationInfo {
pub inputs: Vec<InputInfo>,
pub outputs: Vec<OutputInfo>,
pub scope: gpu::Scope,
pub mappings: Vec<InplaceMapping>,
}
/// Simply indicate the output that can be replaced by the input.
#[derive(new, Clone, Copy)]
pub struct InplaceMapping {
/// Input position.
pub pos_input: usize,
/// Output position.
pub pos_output: usize,
}
#[derive(Default)]
pub struct CompilationSettings {
vectorization: Vectorization,
inplace_available: bool,
workgroup_size: WorkgroupSize,
}
impl CompilationSettings {
/// Compile the shader with vectorization enabled.
#[allow(dead_code)]
pub fn vectorize(mut self, vectorization: Vectorization) -> Self {
self.vectorization = vectorization;
self
}
/// Compile the shader with inplace enabled.
///
/// Notes:
///
/// This won't guarantee that the shader will use input arrays as outputs, since it is only
/// possible when [inplace mappings](InplaceMapping) are provided as [compilation info](CompilationInfo)
pub fn inplace(mut self, available: bool) -> Self {
self.inplace_available = available;
self
}
/// Set the grid size.
#[allow(dead_code)] // Only used for fusion for now.
pub fn workgroup_size(mut self, workgroup_size: WorkgroupSize) -> Self {
self.workgroup_size = workgroup_size;
self
}
}
/// Information related to an input.
#[derive(Clone)]
pub enum InputInfo {
Array { item: Item, visibility: Visibility },
Scalar { elem: Elem, size: usize },
}
/// Information related to an output.
#[derive(Clone)]
pub enum OutputInfo {
/// Write the local variable to a new array.
///
/// This will create a new binding in the [compute shader](ComputeShader).
Array { item: Item, local: u16 },
/// Write the local variable to an existing input binding.
Input { item: Item, input: u16, local: u16 },
}
impl Compilation {
/// Starts a new compilation.
pub fn new(info: CompilationInfo) -> Self {
Self {
info,
input_bindings: Default::default(),
output_bindings: Default::default(),
named_bindings: Default::default(),
}
}
/// Performs the compilation with the provided [settings](CompilationSettings).
pub fn compile(mut self, settings: CompilationSettings) -> ComputeShader {
self.info.scope.vectorize(settings.vectorization);
self.register_inputs(&settings);
self.register_outputs(&settings);
let inputs = self.input_bindings;
let outputs = self.output_bindings;
let mut named = Vec::with_capacity(2);
named.push((
"info".to_string(),
Binding {
item: Item::Scalar(Elem::UInt),
visibility: Visibility::Read,
location: Location::Storage,
size: None, // We avoid putting the length here since it will force a new kernel
// for each tensor rank.
},
));
for (name, binding) in self.named_bindings.into_iter() {
named.push((name, binding));
}
ComputeShader {
inputs,
outputs,
named,
workgroup_size: settings.workgroup_size,
body: self.info.scope,
num_workgroups: true,
global_invocation_id: true,
}
}
fn register_inputs(&mut self, settings: &CompilationSettings) {
for input in self.info.inputs.drain(..) {
match input {
InputInfo::Array { item, visibility } => {
let item = item.vectorize(settings.vectorization);
self.input_bindings.push(Binding {
item: bool_item(item),
visibility,
location: Location::Storage,
size: None,
});
}
InputInfo::Scalar { elem, size } => {
let elem = bool_elem(elem);
self.named_bindings.push((
format!("scalars_{}", elem),
Binding {
item: Item::Scalar(elem),
visibility: Visibility::Read,
location: Location::Storage,
size: Some(size),
},
));
}
}
}
}
fn register_outputs(&mut self, settings: &CompilationSettings) {
let mut index = 0;
if settings.inplace_available {
let mut mappings = Vec::new();
core::mem::swap(&mut self.info.mappings, &mut mappings);
for mapping in mappings {
self.register_inplace_mapping(mapping);
}
}
for array in self.info.outputs.drain(..) {
match array {
OutputInfo::Array { item, local } => {
let item = item.vectorize(settings.vectorization);
let elem_adapted = bool_item(item);
self.output_bindings.push(Binding {
item: elem_adapted,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
});
self.info.scope.write_global(
Variable::Local(local, item, self.info.scope.depth),
Variable::GlobalOutputArray(index, elem_adapted),
);
index += 1;
}
OutputInfo::Input { item, input, local } => {
let item = item.vectorize(settings.vectorization);
self.info.scope.write_global(
Variable::Local(local, item, self.info.scope.depth),
Variable::GlobalInputArray(input, bool_item(item)),
);
}
}
}
}
fn register_inplace_mapping(&mut self, mapping: InplaceMapping) {
let output = match self.info.outputs.get_mut(mapping.pos_output) {
Some(output) => output,
None => return, // No output to update.
};
let (item, local) = match output {
OutputInfo::Array { item, local } => (item, local),
OutputInfo::Input {
item: _,
input: _,
local: _,
} => return, // Output already updated.
};
let item = match self.input_bindings.get_mut(mapping.pos_input) {
Some(binding) => {
// Update input visibility.
binding.visibility = Visibility::ReadWrite;
// Inputs modified inplace should be read without any specified layout.
self.info
.scope
.update_read(mapping.pos_input as u16, gpu::ReadingStrategy::Plain);
// Use the same item as the input.
//
// The output can be different (i.e inplace boolean operations on float bindings).
binding.item
}
None => *item,
};
// Update the output.
*output = OutputInfo::Input {
item,
input: mapping.pos_input as u16,
local: *local,
};
}
}
fn bool_item(ty: Item) -> Item {
match ty {
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
}
}
fn bool_elem(elem: Elem) -> Elem {
match elem {
// U32 are used for bool tensors
Elem::Bool => Elem::UInt,
_ => elem,
}
}

View File

@ -0,0 +1,57 @@
use super::{
gpu, Elem, Item, Metadata, Operator, ReadGlobalAlgo, ReadGlobalWithLayoutAlgo, Scope, Variable,
};
use crate::codegen::dialect::gpu::BinaryOperator;
impl ReadGlobalAlgo {
pub fn expand(self, scope: &mut Scope) {
scope.register(Operator::Index(BinaryOperator {
lhs: self.global,
rhs: Variable::Id,
out: self.out,
}));
}
}
impl ReadGlobalWithLayoutAlgo {
pub fn expand(self, scope: &mut Scope) {
let out = self.out;
let tensor = self.global;
let layout = self.layout;
let index_item_ty = Item::Scalar(Elem::UInt);
let index_local = scope.create_local(index_item_ty);
let zero: Variable = 0u32.into();
let id = Variable::Id;
let offset: Variable = match self.global.item() {
Item::Vec4(_) => 4u32,
Item::Vec3(_) => 3u32,
Item::Vec2(_) => 2u32,
Item::Scalar(_) => 1u32,
}
.into();
gpu!(scope, index_local = zero);
gpu!(
scope,
range(zero, Variable::Rank).for_each(|i, scope| {
let stride = scope.create_local(index_item_ty);
let stride_layout = scope.create_local(index_item_ty);
let shape = scope.create_local(index_item_ty);
let tmp = scope.create_local(index_item_ty);
gpu!(scope, stride = stride(tensor, i));
gpu!(scope, shape = shape(tensor, i));
gpu!(scope, stride_layout = stride(layout, i));
gpu!(scope, tmp = id * offset);
gpu!(scope, tmp = tmp / stride_layout);
gpu!(scope, tmp = tmp % shape);
gpu!(scope, tmp = tmp * stride);
gpu!(scope, index_local = index_local + tmp);
})
);
gpu!(scope, index_local = index_local / offset);
gpu!(scope, out = tensor[index_local]);
}
}

View File

@ -1,7 +0,0 @@
use super::Operation;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize, new)]
pub struct Body {
pub operators: Vec<Operation>,
}

View File

@ -0,0 +1,260 @@
use super::Variable;
macro_rules! gpu {
// out = lhs + rhs
($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => {
gpu!($scope, $out = add($lhs, $rhs))
};
// out = add(lhs, rhs)
($scope:expr, $out:ident = add($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Add(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs - rhs
($scope:expr, $out:ident = $lhs:ident - $rhs:expr) => {
gpu!($scope, $out = sub($lhs, $rhs))
};
// out = sub(lhs, rhs)
($scope:expr, $out:ident = sub($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Sub(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs * rhs
($scope:expr, $out:ident = $lhs:ident * $rhs:expr) => {
gpu!($scope, $out = mul($lhs, $rhs))
};
// out = mul(lhs, rhs)
($scope:expr, $out:ident = mul($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Mul(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs / rhs
($scope:expr, $out:ident = $lhs:ident / $rhs:expr) => {
gpu!($scope, $out = div($lhs, $rhs))
};
// out = div(lhs, rhs)
($scope:expr, $out:ident = div($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Div(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs % rhs
($scope:expr, $out:ident = $lhs:ident % $rhs:expr) => {
gpu!($scope, $out = modulo($lhs, $rhs))
};
// out = modulo(lhs, rhs)
($scope:expr, $out:ident = modulo($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Modulo(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs ^ rhs
($scope:expr, $out:ident = $lhs:ident ^ $rhs:expr) => {
gpu!($scope, $out = powf($lhs, $rhs))
};
// out = powf(lhs, rhs)
($scope:expr, $out:ident = powf($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Powf(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs == rhs
($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => {
gpu!($scope, $out = equal($lhs, $rhs))
};
// out = equal(lhs, rhs)
($scope:expr, $out:ident = equal($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Equal(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs > rhs
($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => {
gpu!($scope, $out = greater($lhs, $rhs))
};
// out = greater(lhs, rhs)
($scope:expr, $out:ident = greater($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Greater(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs >= rhs
($scope:expr, $out:ident = $lhs:ident >= $rhs:expr) => {
gpu!($scope, $out = greater_equal($lhs, $rhs))
};
// out = greater_equal(lhs, rhs)
($scope:expr, $out:ident = greater_equal($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::GreaterEqual(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs < rhs
($scope:expr, $out:ident = $lhs:ident < $rhs:expr) => {
gpu!($scope, $out = lower($lhs, $rhs))
};
// out = lower(lhs, rhs)
($scope:expr, $out:ident = lower($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Lower(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs <= rhs
($scope:expr, $out:ident = $lhs:ident <= $rhs:expr) => {
gpu!($scope, $out = lower_equal($lhs, $rhs))
};
// out = lower_equal(lhs, rhs)
($scope:expr, $out:ident = lower_equal($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::LowerEqual(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = lhs[rhs]
($scope:expr, $out:ident = $lhs:ident[$rhs:expr]) => {
gpu!($scope, $out = index($lhs, $rhs))
};
// out = index(lhs, rhs)
($scope:expr, $out:ident = index($lhs:expr, $rhs:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Index(
gpu!(binary $lhs, $rhs, $out)
));
};
// out = |input|
($scope:expr, $out:ident = |$input:ident|) => {
gpu!($scope, $out = abs($input))
};
// out = abs(input)
($scope:expr, $out:ident = abs($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Abs(
gpu!(unary $input, $out)
));
};
// out = exp(input)
($scope:expr, $out:ident = exp($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Exp(
gpu!(unary $input, $out)
));
};
// out = log(input)
($scope:expr, $out:ident = log($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Log(
gpu!(unary $input, $out)
));
};
// out = log1p(input)
($scope:expr, $out:ident = log1p($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Log1p(
gpu!(unary $input, $out)
));
};
// out = cos(input)
($scope:expr, $out:ident = cos($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Cos(
gpu!(unary $input, $out)
));
};
// out = sin(input)
($scope:expr, $out:ident = sin($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Sin(
gpu!(unary $input, $out)
));
};
// out = tanh(input)
($scope:expr, $out:ident = tanh($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Tanh(
gpu!(unary $input, $out)
));
};
// out = sqrt(input)
($scope:expr, $out:ident = sqrt($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Sqrt(
gpu!(unary $input, $out)
));
};
// out = erf(input)
($scope:expr, $out:ident = erf($input:expr)) => {
$scope.register($crate::codegen::dialect::gpu::Operator::Erf(
gpu!(unary $input, $out)
));
};
// out = input
($scope:expr, eval $arg:expr) => {
gpu!($scope, $arg);
};
// out = input
($scope:expr, $out:ident = $input:ident) => {
$scope.register($crate::codegen::dialect::gpu::Operator::AssignLocal(
gpu!(unary $input, $out)
));
};
// out = input
($scope:expr, $out:ident = $input:ident) => {
$scope.register($crate::codegen::dialect::gpu::Operator::AssignLocal(
gpu!(unary $input, $out)
));
};
($scope:expr, $out:ident = shape($input:expr, $dim:expr)) => {
$scope.register(Metadata::Shape {
dim: $dim,
var: $input,
out: $out,
});
};
($scope:expr, $out:ident = stride($input:expr, $dim:expr)) => {
$scope.register(Metadata::Stride {
dim: $dim,
var: $input,
out: $out,
});
};
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
};
(binary $lhs:expr, $rhs:expr, $out:expr) => {
$crate::codegen::dialect::gpu::BinaryOperator {
lhs: $lhs.into(),
rhs: $rhs.into(),
out: $out.into(),
}
};
(unary $input:expr, $out:expr) => {
$crate::codegen::dialect::gpu::UnaryOperator {
input: $input.into(),
out: $out.into(),
}
};
}
impl From<bool> for Variable {
fn from(value: bool) -> Self {
Self::ConstantScalar(if value { 1.0 } else { 0.0 }, super::Elem::Bool)
}
}
impl From<i32> for Variable {
fn from(value: i32) -> Self {
Self::ConstantScalar(value as f64, super::Elem::Int)
}
}
impl From<f32> for Variable {
fn from(value: f32) -> Self {
Self::ConstantScalar(value as f64, super::Elem::Float)
}
}
impl From<u32> for Variable {
fn from(value: u32) -> Self {
Self::ConstantScalar(value as f64, super::Elem::UInt)
}
}
impl From<usize> for Variable {
fn from(value: usize) -> Self {
Self::ConstantScalar(value as f64, super::Elem::UInt)
}
}
pub(crate) use gpu;

View File

@ -1,11 +1,15 @@
mod body;
pub(crate) mod algorithm;
mod macros;
mod operation;
mod scope;
mod shader;
mod variable;
mod vectorization;
pub(crate) use body::*;
pub(crate) use macros::gpu;
pub(crate) use operation::*;
pub(crate) use scope::*;
pub(crate) use shader::*;
pub(crate) use variable::*;
pub(crate) use vectorization::*;

View File

@ -1,56 +1,169 @@
use super::Variable;
use super::{Elem, Item, Scope, Variable};
use serde::{Deserialize, Serialize};
/// All operations that can be used in a GPU compute shader.
///
/// Notes:
///
/// [Operator] and [Algorithm] can be vectorized, but [Metadata] and [Loop] can't.
/// Therefore, during tracing, only operators and algorithms can be registered, and during the
/// compilation phase, the algorithm will be expanded.
///
/// Algorithm expansions can safely use [Metadata] and [Loop] operations.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub enum Operation {
Add(BinaryOperation),
Sub(BinaryOperation),
Mul(BinaryOperation),
Div(BinaryOperation),
Abs(UnaryOperation),
Exp(UnaryOperation),
Log(UnaryOperation),
Log1p(UnaryOperation),
Cos(UnaryOperation),
Sin(UnaryOperation),
Tanh(UnaryOperation),
Powf(BinaryOperation),
Sqrt(UnaryOperation),
Erf(UnaryOperation),
Recip(UnaryOperation),
Equal(BinaryOperation),
Lower(BinaryOperation),
Clamp(ClampOperation),
Greater(BinaryOperation),
LowerEqual(BinaryOperation),
GreaterEqual(BinaryOperation),
ConditionalAssign(ConditionalAssignOperation),
AssignGlobal(UnaryOperation),
AssignLocal(UnaryOperation),
ReadGlobal(ReadGlobalOperation),
ReadGlobalWithLayout(ReadGlobalWithLayoutOperation),
Operator(Operator),
Metadata(Metadata),
Algorithm(Algorithm),
Loop(Loop),
}
/// All operator that can be used in a GPU compute shader.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub enum Operator {
Add(BinaryOperator),
Sub(BinaryOperator),
Mul(BinaryOperator),
Div(BinaryOperator),
Abs(UnaryOperator),
Exp(UnaryOperator),
Log(UnaryOperator),
Log1p(UnaryOperator),
Cos(UnaryOperator),
Sin(UnaryOperator),
Tanh(UnaryOperator),
Powf(BinaryOperator),
Sqrt(UnaryOperator),
Erf(UnaryOperator),
Recip(UnaryOperator),
Equal(BinaryOperator),
Lower(BinaryOperator),
Clamp(ClampOperator),
Greater(BinaryOperator),
LowerEqual(BinaryOperator),
GreaterEqual(BinaryOperator),
ConditionalAssign(ConditionalAssignOperator),
AssignGlobal(UnaryOperator),
AssignLocal(UnaryOperator),
Modulo(BinaryOperator),
Index(BinaryOperator),
}
/// Tensor operations that can't be executed with a simple [operator](Operator) should use an
/// algorithm.
///
/// Algorithms can be expanded to basic [operator](Operator) during compilation, but after
/// vectorization, since for loops and other construct can't simply be vectorized. This also gives
/// the vectorization state to the expansion function, which may create a different set of
/// [operator](Operator) depending on the vectorization state.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Algorithm {
/// Read an input array with the given layout.
///
/// Crucial to read arrays that aren't contiguous and to perform correct broadcasting.
ReadGlobalWithLayout(ReadGlobalWithLayoutAlgo),
/// Read an input array.
ReadGlobal(ReadGlobalAlgo),
}
/// Settings for the [Algorithm::ReadGlobalWithLayout] variant.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReadGlobalWithLayoutAlgo {
/// The array to be read.
pub global: Variable,
/// The layout to be used.
pub layout: Variable,
/// The output variable to write the result.
pub out: Variable,
}
/// Settings for the [Algorithm::ReadGlobal] variant.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReadGlobalAlgo {
/// The array to be read.
pub global: Variable,
/// The output variable to write the result.
pub out: Variable,
}
/// All loop variants.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Loop {
/// A basic range loop.
Range(RangeLoop),
}
/// All metadata that can be access in a shader.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum Metadata {
/// The stride of an array at the given dimension.
Stride {
dim: Variable,
var: Variable,
out: Variable,
},
/// The shape of an array at the given dimension.
Shape {
dim: Variable,
var: Variable,
out: Variable,
},
}
/// Settings for the [Loop::Range] variant.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RangeLoop {
/// The loop index variable.
pub i: Variable,
/// The start value.
pub start: Variable,
/// The end value.
pub end: Variable,
/// The scope that contains all operations and variables declared in the loop body.
pub scope: Scope,
}
impl RangeLoop {
/// Registers a range loop to the given scope.
pub fn register<F: Fn(Variable, &mut Scope)>(
parent_scope: &mut Scope,
start: Variable,
end: Variable,
func: F,
) {
let mut scope = parent_scope.child();
let index_ty = Item::Scalar(Elem::UInt);
let i = scope.create_local_undeclare(index_ty);
func(i, &mut scope);
let op = Self {
i,
start,
end,
scope,
};
parent_scope.register(Loop::Range(op));
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct BinaryOperation {
pub struct BinaryOperator {
pub lhs: Variable,
pub rhs: Variable,
pub out: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct UnaryOperation {
pub struct UnaryOperator {
pub input: Variable,
pub out: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct ClampOperation {
pub struct ClampOperator {
pub input: Variable,
pub min_value: Variable,
pub max_value: Variable,
@ -58,8 +171,7 @@ pub struct ClampOperation {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct ConditionalAssignOperation {
pub struct ConditionalAssignOperator {
pub cond: Variable,
pub lhs: Variable,
pub rhs: Variable,
@ -67,15 +179,37 @@ pub struct ConditionalAssignOperation {
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct ReadGlobalOperation {
pub struct ReadGlobalOperator {
pub variable: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub struct ReadGlobalWithLayoutOperation {
pub struct ReadGlobalWithLayoutOperator {
pub variable: Variable,
pub tensor_read_pos: usize,
pub tensor_layout_pos: usize,
}
impl From<Operator> for Operation {
fn from(val: Operator) -> Self {
Operation::Operator(val)
}
}
impl From<Metadata> for Operation {
fn from(val: Metadata) -> Self {
Operation::Metadata(val)
}
}
impl From<Algorithm> for Operation {
fn from(val: Algorithm) -> Self {
Operation::Algorithm(val)
}
}
impl From<Loop> for Operation {
fn from(val: Loop) -> Self {
Operation::Loop(val)
}
}

View File

@ -0,0 +1,257 @@
use super::{
Algorithm, Elem, Item, Operation, Operator, ReadGlobalAlgo, ReadGlobalWithLayoutAlgo,
UnaryOperator, Variable, Vectorization,
};
use serde::{Deserialize, Serialize};
/// The scope is the main [operation](Operation) and [variable](Variable) container that simplify
/// the process of reading inputs, creating local variables and adding new operations.
///
/// Notes:
///
/// This type isn't responsible for creating [shader bindings](super::Binding) and figuring out which
/// variable can be written to.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Scope {
pub depth: u8,
pub operations: Vec<Operation>,
locals: Vec<Variable>,
reads_global: Vec<(Variable, ReadingStrategy, Variable)>,
writes_global: Vec<(Variable, Variable)>,
reads_scalar: Vec<(Variable, Variable)>,
output_ref: Option<Variable>,
undeclared: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ReadingStrategy {
/// Each element will be read in a way to be compatible with the output layout.
OutputLayout,
/// Keep the current layout.
Plain,
}
/// Information necessary when compiling a scope.
pub struct ScopeProcessing {
/// The variable declarations.
pub variables: Vec<Variable>,
/// The operations.
pub operations: Vec<Operation>,
}
impl Scope {
/// Create a scope that is at the root of a
/// [compute shader](crate::codegen::dialect::gpu::ComputeShader).
///
/// A local scope can be created with the [child](Self::child) method.
pub fn root() -> Self {
Self {
depth: 0,
operations: Vec::new(),
locals: Vec::new(),
reads_global: Vec::new(),
writes_global: Vec::new(),
reads_scalar: Vec::new(),
output_ref: None,
undeclared: 0,
}
}
/// Create a local variable of the given [item type](Item).
pub fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
let item = item.into();
let index = self.new_local_index();
let local = Variable::Local(index, item, self.depth);
self.locals.push(local);
local
}
/// Create a new local variable, but doesn't perform the declaration.
///
/// Useful for _for loops_ and other algorithms that require the control over initialization.
pub fn create_local_undeclare(&mut self, item: Item) -> Variable {
let index = self.new_local_index();
let local = Variable::Local(index, item, self.depth);
self.undeclared += 1;
local
}
/// Reads an input array to a local variable.
///
/// The index refers to the argument position of the array in the compute shader.
pub fn read_array<I: Into<Item>>(&mut self, index: u16, item: I) -> Variable {
self.read_input_strategy(index, item.into(), ReadingStrategy::OutputLayout)
}
/// Reads an input scalar to a local variable.
///
/// The index refers to the scalar position for the same [element](Elem) type.
pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable {
let local = Variable::LocalScalar(self.new_local_scalar_index(), elem, self.depth);
let scalar = Variable::GlobalScalar(index, elem);
self.reads_scalar.push((local, scalar));
local
}
/// Retrieve the last local variable that was created.
pub fn last_local_index(&self) -> Option<&Variable> {
self.locals.last()
}
/// Vectorize the scope using the [vectorization](Vectorization) type.
///
/// Notes:
///
/// Scopes created _during_ compilation (after the tracing is done) should not be vectorized.
pub fn vectorize(&mut self, vectorization: Vectorization) {
self.operations
.iter_mut()
.for_each(|op| *op = op.vectorize(vectorization));
self.locals
.iter_mut()
.for_each(|var| *var = var.vectorize(vectorization));
self.reads_global.iter_mut().for_each(|(input, _, output)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
self.writes_global.iter_mut().for_each(|(input, output)| {
*input = input.vectorize(vectorization);
*output = output.vectorize(vectorization);
});
}
/// Writes a variable to given output.
///
/// Notes:
///
/// This should only be used when doing compilation.
pub(crate) fn write_global(&mut self, input: Variable, output: Variable) {
if self.output_ref.is_none() {
self.output_ref = Some(output);
}
self.writes_global.push((input, output));
}
/// Update the [reading strategy](ReadingStrategy) for an input array.
///
/// Notes:
///
/// This should only be used when doing compilation.
pub(crate) fn update_read(&mut self, index: u16, strategy: ReadingStrategy) {
if let Some((_, strategy_old, _)) = self
.reads_global
.iter_mut()
.find(|(var, _, _)| var.index() == Some(index))
{
*strategy_old = strategy;
}
}
/// Register an [operation](Operation) into the scope.
pub fn register<T: Into<Operation>>(&mut self, operation: T) {
self.operations.push(operation.into())
}
/// Create an empty child scope.
pub fn child(&mut self) -> Self {
Self {
depth: self.depth + 1,
operations: Vec::new(),
locals: Vec::new(),
reads_global: Vec::new(),
writes_global: Vec::new(),
reads_scalar: Vec::new(),
output_ref: None,
undeclared: 0,
}
}
/// Returns the variables and operations to be declared and executed.
///
/// Notes:
///
/// New operations and variables can be created within the same scope without having name
/// conflicts.
pub fn process(&mut self) -> ScopeProcessing {
self.undeclared += self.locals.len() as u16;
let mut variables = Vec::new();
core::mem::swap(&mut self.locals, &mut variables);
let mut operations = Vec::new();
for (input, strategy, local) in self.reads_global.drain(..) {
match strategy {
ReadingStrategy::OutputLayout => {
let output = self.output_ref.expect(
"Output should be set when processing an input with output layout.",
);
operations.push(Operation::Algorithm(Algorithm::ReadGlobalWithLayout(
ReadGlobalWithLayoutAlgo {
global: input,
layout: output,
out: local,
},
)));
}
ReadingStrategy::Plain => operations.push(Operation::Algorithm(
Algorithm::ReadGlobal(ReadGlobalAlgo {
global: input,
out: local,
}),
)),
}
}
for (local, scalar) in self.reads_scalar.drain(..) {
operations.push(
Operator::AssignLocal(UnaryOperator {
input: scalar,
out: local,
})
.into(),
);
variables.push(local);
}
for op in self.operations.drain(..) {
operations.push(op);
}
for (input, out) in self.writes_global.drain(..) {
operations.push(Operation::Operator(Operator::AssignGlobal(UnaryOperator {
input,
out,
})))
}
ScopeProcessing {
variables,
operations,
}
}
fn new_local_index(&self) -> u16 {
self.locals.len() as u16 + self.undeclared
}
fn new_local_scalar_index(&self) -> u16 {
self.reads_scalar.len() as u16
}
fn read_input_strategy(
&mut self,
index: u16,
item: Item,
strategy: ReadingStrategy,
) -> Variable {
let input = Variable::GlobalInputArray(index, item);
let index = self.new_local_index();
let local = Variable::Local(index, item, self.depth);
self.reads_global.push((input, strategy, local));
self.locals.push(local);
local
}
}

View File

@ -1,8 +1,9 @@
use super::Body;
use crate::kernel::WORKGROUP_DEFAULT;
use serde::{Deserialize, Serialize};
use std::fmt::Display;
use super::Scope;
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum Location {
Storage,
@ -16,7 +17,7 @@ pub enum Visibility {
ReadWrite,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)]
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)]
pub enum Elem {
Float,
Int,
@ -24,6 +25,12 @@ pub enum Elem {
Bool,
}
impl From<Elem> for Item {
fn from(val: Elem) -> Self {
Item::Scalar(val)
}
}
impl Display for Elem {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
@ -87,5 +94,5 @@ pub struct ComputeShader {
pub workgroup_size: WorkgroupSize,
pub global_invocation_id: bool,
pub num_workgroups: bool,
pub body: Body,
pub body: Scope,
}

View File

@ -1,11 +1,47 @@
use super::Item;
use super::{Elem, Item};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum Variable {
Input(u16, Item),
Scalar(u16, Item),
Local(u16, Item),
Output(u16, Item),
Constant(f64, Item),
GlobalInputArray(u16, Item),
GlobalScalar(u16, Elem),
GlobalOutputArray(u16, Item),
Local(u16, Item, u8),
LocalScalar(u16, Elem, u8),
ConstantScalar(f64, Elem),
Id,
Rank,
}
impl Variable {
pub fn const_value(&self) -> Option<f64> {
match self {
Variable::ConstantScalar(value, _) => Some(*value),
_ => None,
}
}
pub fn index(&self) -> Option<u16> {
match self {
Variable::GlobalInputArray(idx, _) => Some(*idx),
Variable::GlobalScalar(idx, _) => Some(*idx),
Variable::Local(idx, _, _) => Some(*idx),
Variable::LocalScalar(idx, _, _) => Some(*idx),
Variable::GlobalOutputArray(idx, _) => Some(*idx),
Variable::ConstantScalar(_, _) => None,
Variable::Id => None,
Variable::Rank => None,
}
}
pub fn item(&self) -> Item {
match self {
Variable::GlobalInputArray(_, item) => *item,
Variable::GlobalOutputArray(_, item) => *item,
Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
Variable::Local(_, item, _) => *item,
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
Variable::Id => Item::Scalar(Elem::UInt),
Variable::Rank => Item::Scalar(Elem::UInt),
}
}
}

View File

@ -1,11 +1,11 @@
use super::{
BinaryOperation, ClampOperation, ConditionalAssignOperation, Item, Operation,
ReadGlobalOperation, ReadGlobalWithLayoutOperation, UnaryOperation, Variable,
Algorithm, BinaryOperator, ClampOperator, ConditionalAssignOperator, Item, Operation, Operator,
ReadGlobalAlgo, ReadGlobalWithLayoutAlgo, UnaryOperator, Variable,
};
/// Define a vectorization scheme.
#[allow(dead_code)]
#[derive(Copy, Clone, Debug)]
#[derive(Copy, Clone, Debug, Default)]
pub enum Vectorization {
/// Use vec4 for vectorization.
Vec4,
@ -14,47 +14,99 @@ pub enum Vectorization {
/// Use vec2 for vectorization.
Vec2,
/// Don't vectorize.
#[default]
Scalar,
}
impl Operation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
match self {
Operation::Add(op) => Operation::Add(op.vectorize(vectorization)),
Operation::Sub(op) => Operation::Sub(op.vectorize(vectorization)),
Operation::Mul(op) => Operation::Mul(op.vectorize(vectorization)),
Operation::Div(op) => Operation::Div(op.vectorize(vectorization)),
Operation::Abs(op) => Operation::Abs(op.vectorize(vectorization)),
Operation::Exp(op) => Operation::Exp(op.vectorize(vectorization)),
Operation::Log(op) => Operation::Log(op.vectorize(vectorization)),
Operation::Log1p(op) => Operation::Log1p(op.vectorize(vectorization)),
Operation::Cos(op) => Operation::Cos(op.vectorize(vectorization)),
Operation::Sin(op) => Operation::Sin(op.vectorize(vectorization)),
Operation::Tanh(op) => Operation::Tanh(op.vectorize(vectorization)),
Operation::Powf(op) => Operation::Powf(op.vectorize(vectorization)),
Operation::Sqrt(op) => Operation::Sqrt(op.vectorize(vectorization)),
Operation::Erf(op) => Operation::Erf(op.vectorize(vectorization)),
Operation::Recip(op) => Operation::Recip(op.vectorize(vectorization)),
Operation::Equal(op) => Operation::Equal(op.vectorize(vectorization)),
Operation::Lower(op) => Operation::Lower(op.vectorize(vectorization)),
Operation::Clamp(op) => Operation::Clamp(op.vectorize(vectorization)),
Operation::Greater(op) => Operation::Greater(op.vectorize(vectorization)),
Operation::LowerEqual(op) => Operation::LowerEqual(op.vectorize(vectorization)),
Operation::GreaterEqual(op) => Operation::GreaterEqual(op.vectorize(vectorization)),
Operation::ConditionalAssign(op) => {
Operation::ConditionalAssign(op.vectorize(vectorization))
}
Operation::AssignGlobal(op) => Operation::AssignGlobal(op.vectorize(vectorization)),
Operation::AssignLocal(op) => Operation::AssignLocal(op.vectorize(vectorization)),
Operation::ReadGlobal(op) => Operation::ReadGlobal(op.vectorize(vectorization)),
Operation::ReadGlobalWithLayout(op) => {
Operation::ReadGlobalWithLayout(op.vectorize(vectorization))
}
Operation::Operator(op) => Operation::Operator(op.vectorize(vectorization)),
Operation::Algorithm(op) => Operation::Algorithm(op.vectorize(vectorization)),
Operation::Metadata(_) => panic!(
"Metadata can't be vectorized, they should only be generated after vectorization."
),
Operation::Loop(_) => panic!(
"Loops can't be vectorized, they should only be generated after vectorization."
),
}
}
}
impl BinaryOperation {
impl Algorithm {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
match self {
Algorithm::ReadGlobalWithLayout(op) => {
Algorithm::ReadGlobalWithLayout(op.vectorize(vectorization))
}
Algorithm::ReadGlobal(op) => Algorithm::ReadGlobal(op.vectorize(vectorization)),
}
}
}
impl ReadGlobalWithLayoutAlgo {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
global: self.global.vectorize(vectorization),
layout: self.layout.vectorize(vectorization),
out: self.out.vectorize(vectorization),
}
}
}
impl ReadGlobalAlgo {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
Self {
global: self.global.vectorize(vectorization),
out: self.out.vectorize(vectorization),
}
}
}
impl Operator {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
match self {
Operator::Add(op) => Operator::Add(op.vectorize(vectorization)),
Operator::Index(op) => Operator::Index(op.vectorize(vectorization)),
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)),
Operator::Div(op) => Operator::Div(op.vectorize(vectorization)),
Operator::Abs(op) => Operator::Abs(op.vectorize(vectorization)),
Operator::Exp(op) => Operator::Exp(op.vectorize(vectorization)),
Operator::Log(op) => Operator::Log(op.vectorize(vectorization)),
Operator::Log1p(op) => Operator::Log1p(op.vectorize(vectorization)),
Operator::Cos(op) => Operator::Cos(op.vectorize(vectorization)),
Operator::Sin(op) => Operator::Sin(op.vectorize(vectorization)),
Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)),
Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)),
Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)),
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),
Operator::Lower(op) => Operator::Lower(op.vectorize(vectorization)),
Operator::Clamp(op) => Operator::Clamp(op.vectorize(vectorization)),
Operator::Greater(op) => Operator::Greater(op.vectorize(vectorization)),
Operator::LowerEqual(op) => Operator::LowerEqual(op.vectorize(vectorization)),
Operator::GreaterEqual(op) => Operator::GreaterEqual(op.vectorize(vectorization)),
Operator::ConditionalAssign(op) => {
Operator::ConditionalAssign(op.vectorize(vectorization))
}
Operator::AssignGlobal(op) => Operator::AssignGlobal(op.vectorize(vectorization)),
Operator::AssignLocal(op) => {
if let Variable::GlobalScalar(_, _) = op.input {
// Assign will not change the type of the output if the input can't be
// vectorized.
return Operator::AssignLocal(op.clone());
}
Operator::AssignLocal(op.vectorize(vectorization))
}
Operator::Modulo(op) => Operator::Modulo(op.vectorize(vectorization)),
}
}
}
impl BinaryOperator {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let lhs = self.lhs.vectorize(vectorization);
let rhs = self.rhs.vectorize(vectorization);
@ -64,7 +116,7 @@ impl BinaryOperation {
}
}
impl UnaryOperation {
impl UnaryOperator {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let input = self.input.vectorize(vectorization);
let out = self.out.vectorize(vectorization);
@ -73,7 +125,7 @@ impl UnaryOperation {
}
}
impl ClampOperation {
impl ClampOperator {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let input = self.input.vectorize(vectorization);
let out = self.out.vectorize(vectorization);
@ -89,7 +141,7 @@ impl ClampOperation {
}
}
impl ConditionalAssignOperation {
impl ConditionalAssignOperator {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let cond = self.cond.vectorize(vectorization);
let lhs = self.lhs.vectorize(vectorization);
@ -105,39 +157,23 @@ impl ConditionalAssignOperation {
}
}
impl ReadGlobalOperation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let variable = self.variable.vectorize(vectorization);
Self { variable }
}
}
impl ReadGlobalWithLayoutOperation {
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
let variable = self.variable.vectorize(vectorization);
let tensor_read_pos = self.tensor_read_pos;
let tensor_layout_pos = self.tensor_layout_pos;
Self {
variable,
tensor_read_pos,
tensor_layout_pos,
}
}
}
impl Variable {
pub fn vectorize(&self, vectorize: Vectorization) -> Self {
match self {
Variable::Input(index, item) => Variable::Input(*index, item.vectorize(vectorize)),
Variable::Local(index, item) => Variable::Local(*index, item.vectorize(vectorize)),
Variable::Output(index, item) => Variable::Output(*index, item.vectorize(vectorize)),
Variable::Constant(index, item) => {
Variable::Constant(*index, item.vectorize(vectorize))
Variable::GlobalInputArray(index, item) => {
Variable::GlobalInputArray(*index, item.vectorize(vectorize))
}
Variable::Scalar(index, item) => Variable::Scalar(*index, *item), // Don't vectorize
// scalar variables.
Variable::Local(index, item, name) => {
Variable::Local(*index, item.vectorize(vectorize), *name)
}
Variable::GlobalOutputArray(index, item) => {
Variable::GlobalOutputArray(*index, item.vectorize(vectorize))
}
Variable::ConstantScalar(_, _) => *self,
Variable::GlobalScalar(_, _) => *self,
Variable::Id => *self,
Variable::Rank => *self,
Variable::LocalScalar(_, _, _) => *self,
}
}
}

View File

@ -3,11 +3,22 @@ use std::fmt::Display;
#[derive(Debug, Clone)]
pub enum Variable {
Input(u16, Item),
Scalar(u16, Item, gpu::Elem),
Local(u16, Item),
Output(u16, Item),
Constant(f64, Item),
GlobalInputArray(u16, Item),
GlobalOutputArray(u16, Item),
GlobalScalar(u16, Elem, gpu::Elem),
ConstantScalar(f64, Elem),
Local {
index: u16,
item: Item,
scope_depth: u8,
},
LocalScalar {
index: u16,
elem: Elem,
scope_depth: u8,
},
Id,
Rank,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
@ -33,6 +44,26 @@ pub struct IndexedVariable {
}
impl Variable {
pub fn is_always_scalar(&self) -> bool {
match self {
Variable::GlobalScalar(_, _, _) => true,
Variable::ConstantScalar(_, _) => true,
Variable::LocalScalar {
index: _,
elem: _,
scope_depth: _,
} => true,
Variable::Id => true,
Variable::Rank => true,
Variable::GlobalInputArray(_, _) => false,
Variable::GlobalOutputArray(_, _) => false,
Variable::Local {
index: _,
item: _,
scope_depth: _,
} => false,
}
}
pub fn index(&self, index: usize) -> IndexedVariable {
IndexedVariable {
var: self.clone(),
@ -40,15 +71,29 @@ impl Variable {
}
}
pub fn item(&self) -> &Item {
pub fn item(&self) -> Item {
match self {
Self::Input(_, e) => e,
Self::Scalar(_, e, _) => e,
Self::Local(_, e) => e,
Self::Output(_, e) => e,
Self::Constant(_, e) => e,
Self::GlobalInputArray(_, e) => *e,
Self::GlobalOutputArray(_, e) => *e,
Self::Local {
index: _,
item,
scope_depth: _,
} => *item,
Self::ConstantScalar(_, e) => Item::Scalar(*e),
Self::GlobalScalar(_, e, _) => Item::Scalar(*e),
Self::Id => Item::Scalar(Elem::U32),
Self::Rank => Item::Scalar(Elem::U32),
Self::LocalScalar {
index: _,
elem,
scope_depth: _,
} => Item::Scalar(*elem),
}
}
pub fn elem(&self) -> Elem {
*self.item().elem()
}
}
impl Item {
@ -98,39 +143,24 @@ impl Display for Item {
impl Display for Variable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Variable::Input(number, _) => f.write_fmt(format_args!("input_{number}")),
Variable::Local(number, _) => f.write_fmt(format_args!("local_{number}")),
Variable::Output(number, _) => f.write_fmt(format_args!("output_{number}")),
Variable::Scalar(number, _, elem) => {
Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")),
Variable::LocalScalar {
index,
elem: _,
scope_depth,
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
Variable::Local {
index,
item: _,
scope_depth,
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")),
Variable::GlobalScalar(number, _, elem) => {
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
}
Variable::Constant(number, item) => match item {
Item::Vec4(elem) => f.write_fmt(format_args!(
"
vec4(
{elem}({number}),
{elem}({number}),
{elem}({number}),
{elem}({number}),
)"
)),
Item::Vec3(elem) => f.write_fmt(format_args!(
"
vec3(
{elem}({number}),
{elem}({number}),
{elem}({number}),
)"
)),
Item::Vec2(elem) => f.write_fmt(format_args!(
"
vec2(
{elem}({number}),
{elem}({number}),
)"
)),
Item::Scalar(elem) => f.write_fmt(format_args!("{elem}({number})")),
},
Variable::ConstantScalar(number, elem) => f.write_fmt(format_args!("{elem}({number})")),
Variable::Id => f.write_str("id"),
Variable::Rank => f.write_str("rank"),
}
}
}
@ -149,8 +179,8 @@ impl Display for IndexedVariable {
let index = self.index;
match self.var {
Variable::Scalar(_, _, _) => f.write_fmt(format_args!("{var}")),
_ => match should_index(item) {
Variable::GlobalScalar(_, _, _) => f.write_fmt(format_args!("{var}")),
_ => match should_index(&item) {
true => f.write_fmt(format_args!("{var}[{index}]")),
false => f.write_fmt(format_args!("{var}")),
},

View File

@ -1,4 +1,4 @@
use super::Operation;
use super::Instruction;
use std::fmt::Display;
/// A body is composed of a list of [operations](Operation).
@ -6,11 +6,11 @@ use std::fmt::Display;
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
/// X and Y, but with Z=1.
#[derive(Debug, Clone)]
pub struct Body {
pub operators: Vec<Operation>,
pub struct Scope {
pub operators: Vec<Instruction>,
}
impl Display for Body {
impl Display for Scope {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(
"let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n",
@ -19,7 +19,6 @@ impl Display for Body {
for ops in self.operators.iter() {
f.write_fmt(format_args!("{ops}"))?;
f.write_str("\n")?;
}
Ok(())

View File

@ -11,6 +11,8 @@ use std::marker::PhantomData;
/// Wgsl Compiler.
#[derive(Clone)]
pub struct Compiler<F: FloatElement, I: IntElement> {
num_inputs: usize,
num_outputs: usize,
_float: PhantomData<F>,
_int: PhantomData<I>,
}
@ -24,6 +26,8 @@ impl<F: FloatElement, I: IntElement> core::fmt::Debug for Compiler<F, I> {
impl<F: FloatElement, I: IntElement> Default for Compiler<F, I> {
fn default() -> Self {
Self {
num_inputs: 0,
num_outputs: 0,
_float: PhantomData,
_int: PhantomData,
}
@ -37,7 +41,8 @@ impl<F: FloatElement, I: IntElement> compiler::Compiler for Compiler<F, I> {
type FullPrecisionCompiler = Compiler<f32, i32>;
fn compile(shader: gpu::ComputeShader) -> Self::Representation {
Self::compile_shader(shader)
let mut compiler = Self::default();
compiler.compile_shader(shader)
}
fn elem_size(elem: gpu::Elem) -> usize {
@ -46,6 +51,37 @@ impl<F: FloatElement, I: IntElement> compiler::Compiler for Compiler<F, I> {
}
impl<F: FloatElement, I: IntElement> Compiler<F, I> {
fn compile_shader(&mut self, mut value: gpu::ComputeShader) -> wgsl::ComputeShader {
self.num_inputs = value.inputs.len();
self.num_outputs = value.outputs.len();
let body = self.compile_scope(&mut value.body);
let extensions = register_extensions(&body);
wgsl::ComputeShader {
inputs: value
.inputs
.into_iter()
.map(Self::compile_binding)
.collect(),
outputs: value
.outputs
.into_iter()
.map(Self::compile_binding)
.collect(),
named: value
.named
.into_iter()
.map(|(name, binding)| (name, Self::compile_binding(binding)))
.collect(),
workgroup_size: value.workgroup_size,
global_invocation_id: value.global_invocation_id,
num_workgroups: value.num_workgroups,
body,
extensions,
}
}
fn compile_item(item: gpu::Item) -> Item {
match item {
gpu::Item::Vec4(elem) => wgsl::Item::Vec4(Self::compile_elem(elem)),
@ -66,155 +102,252 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
fn compile_variable(value: gpu::Variable) -> wgsl::Variable {
match value {
gpu::Variable::Input(index, item) => {
wgsl::Variable::Input(index, Self::compile_item(item))
gpu::Variable::GlobalInputArray(index, item) => {
wgsl::Variable::GlobalInputArray(index, Self::compile_item(item))
}
gpu::Variable::Scalar(index, item) => {
let elem = item.elem();
wgsl::Variable::Scalar(index, Self::compile_item(item), elem)
gpu::Variable::GlobalScalar(index, elem) => {
wgsl::Variable::GlobalScalar(index, Self::compile_elem(elem), elem)
}
gpu::Variable::Local(index, item) => {
wgsl::Variable::Local(index, Self::compile_item(item))
gpu::Variable::Local(index, item, scope_depth) => wgsl::Variable::Local {
index,
item: Self::compile_item(item),
scope_depth,
},
gpu::Variable::LocalScalar(index, elem, scope_depth) => wgsl::Variable::LocalScalar {
index,
elem: Self::compile_elem(elem),
scope_depth,
},
gpu::Variable::GlobalOutputArray(index, item) => {
wgsl::Variable::GlobalOutputArray(index, Self::compile_item(item))
}
gpu::Variable::Output(index, item) => {
wgsl::Variable::Output(index, Self::compile_item(item))
gpu::Variable::ConstantScalar(index, elem) => {
wgsl::Variable::ConstantScalar(index, Self::compile_elem(elem))
}
gpu::Variable::Constant(index, item) => {
wgsl::Variable::Constant(index, Self::compile_item(item))
gpu::Variable::Id => wgsl::Variable::Id,
gpu::Variable::Rank => wgsl::Variable::Rank,
}
}
fn compile_scope(&self, value: &mut gpu::Scope) -> wgsl::Scope {
let mut operations = Vec::new();
let processing = value.process();
for var in processing.variables {
operations.push(wgsl::Instruction::DeclareVariable {
var: Self::compile_variable(var),
});
}
processing
.operations
.into_iter()
.for_each(|op| self.compile_operation(&mut operations, op, value));
wgsl::Scope {
operators: operations,
}
}
fn compile_operation(
&self,
instructions: &mut Vec<wgsl::Instruction>,
operation: gpu::Operation,
scope: &mut gpu::Scope,
) {
match operation {
gpu::Operation::Operator(op) => instructions.push(self.compile_instruction(op)),
gpu::Operation::Algorithm(algo) => self.compile_algorithm(instructions, algo, scope),
gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)),
gpu::Operation::Loop(val) => instructions.push(self.compile_loop(val)),
}
}
fn compile_algorithm(
&self,
instructions: &mut Vec<wgsl::Instruction>,
algo: gpu::Algorithm,
scope: &mut gpu::Scope,
) {
let mut compile = |scope: &mut gpu::Scope| {
let compiled = self.compile_scope(scope).operators;
instructions.extend(compiled);
};
match algo {
gpu::Algorithm::ReadGlobalWithLayout(algo) => {
algo.expand(scope);
compile(scope);
}
gpu::Algorithm::ReadGlobal(algo) => {
algo.expand(scope);
compile(scope);
}
}
}
fn compile_body(value: gpu::Body) -> wgsl::Body {
wgsl::Body {
operators: value
.operators
.into_iter()
.map(Self::compile_operation)
.collect(),
fn compile_loop(&self, loop_val: gpu::Loop) -> wgsl::Instruction {
match loop_val {
gpu::Loop::Range(mut range_loop) => wgsl::Instruction::RangeLoop {
i: Self::compile_variable(range_loop.i),
start: Self::compile_variable(range_loop.start),
end: Self::compile_variable(range_loop.end),
instructions: self.compile_scope(&mut range_loop.scope).operators,
},
}
}
fn compile_operation(value: gpu::Operation) -> wgsl::Operation {
fn compile_metadata(&self, metadata: gpu::Metadata) -> wgsl::Instruction {
match metadata {
gpu::Metadata::Stride { dim, var, out } => {
let position = match var {
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
};
wgsl::Instruction::Stride {
dim: Self::compile_variable(dim),
position,
out: Self::compile_variable(out),
}
}
gpu::Metadata::Shape { dim, var, out } => {
let position = match var {
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
_ => panic!("Only Input and Output have a shape, got {:?}", var),
};
wgsl::Instruction::Shape {
dim: Self::compile_variable(dim),
position,
out: Self::compile_variable(out),
}
}
}
}
fn compile_instruction(&self, value: gpu::Operator) -> wgsl::Instruction {
match value {
gpu::Operation::Add(op) => wgsl::Operation::Add {
gpu::Operator::Add(op) => wgsl::Instruction::Add {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::Sub(op) => wgsl::Operation::Sub {
gpu::Operator::Index(op) => wgsl::Instruction::Index {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::Mul(op) => wgsl::Operation::Mul {
gpu::Operator::Modulo(op) => wgsl::Instruction::Modulo {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::Div(op) => wgsl::Operation::Div {
gpu::Operator::Sub(op) => wgsl::Instruction::Sub {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::Abs(op) => wgsl::Operation::Abs {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Exp(op) => wgsl::Operation::Exp {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Log(op) => wgsl::Operation::Log {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Log1p(op) => wgsl::Operation::Log1p {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Cos(op) => wgsl::Operation::Cos {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Sin(op) => wgsl::Operation::Sin {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Tanh(op) => wgsl::Operation::Tanh {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Powf(op) => wgsl::Operation::Powf {
gpu::Operator::Mul(op) => wgsl::Instruction::Mul {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::Sqrt(op) => wgsl::Operation::Sqrt {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Erf(op) => wgsl::Operation::Erf {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Recip(op) => wgsl::Operation::Recip {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::Equal(op) => wgsl::Operation::Equal {
gpu::Operator::Div(op) => wgsl::Instruction::Div {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::Lower(op) => wgsl::Operation::Lower {
gpu::Operator::Abs(op) => wgsl::Instruction::Abs {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Exp(op) => wgsl::Instruction::Exp {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Log(op) => wgsl::Instruction::Log {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Log1p(op) => wgsl::Instruction::Log1p {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Cos(op) => wgsl::Instruction::Cos {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Sin(op) => wgsl::Instruction::Sin {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Tanh(op) => wgsl::Instruction::Tanh {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Powf(op) => wgsl::Instruction::Powf {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::Clamp(op) => wgsl::Operation::Clamp {
gpu::Operator::Sqrt(op) => wgsl::Instruction::Sqrt {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Erf(op) => wgsl::Instruction::Erf {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Recip(op) => wgsl::Instruction::Recip {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operator::Equal(op) => wgsl::Instruction::Equal {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operator::Lower(op) => wgsl::Instruction::Lower {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operator::Clamp(op) => wgsl::Instruction::Clamp {
input: Self::compile_variable(op.input),
min_value: Self::compile_variable(op.min_value),
max_value: Self::compile_variable(op.max_value),
out: Self::compile_variable(op.out),
},
gpu::Operation::Greater(op) => wgsl::Operation::Greater {
gpu::Operator::Greater(op) => wgsl::Instruction::Greater {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::LowerEqual(op) => wgsl::Operation::LowerEqual {
gpu::Operator::LowerEqual(op) => wgsl::Instruction::LowerEqual {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::GreaterEqual(op) => wgsl::Operation::GreaterEqual {
gpu::Operator::GreaterEqual(op) => wgsl::Instruction::GreaterEqual {
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::ConditionalAssign(op) => wgsl::Operation::ConditionalAssign {
gpu::Operator::ConditionalAssign(op) => wgsl::Instruction::ConditionalAssign {
cond: Self::compile_variable(op.cond),
lhs: Self::compile_variable(op.lhs),
rhs: Self::compile_variable(op.rhs),
out: Self::compile_variable(op.out),
},
gpu::Operation::AssignGlobal(op) => wgsl::Operation::AssignGlobal {
gpu::Operator::AssignGlobal(op) => wgsl::Instruction::AssignGlobal {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::AssignLocal(op) => wgsl::Operation::AssignLocal {
gpu::Operator::AssignLocal(op) => wgsl::Instruction::AssignLocal {
input: Self::compile_variable(op.input),
out: Self::compile_variable(op.out),
},
gpu::Operation::ReadGlobal(op) => wgsl::Operation::ReadGlobal {
variable: Self::compile_variable(op.variable),
},
gpu::Operation::ReadGlobalWithLayout(op) => wgsl::Operation::ReadGlobalWithLayout {
variable: Self::compile_variable(op.variable),
tensor_read_pos: op.tensor_read_pos,
tensor_layout_pos: op.tensor_layout_pos,
},
}
}
@ -240,37 +373,9 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
size: value.size,
}
}
fn compile_shader(value: gpu::ComputeShader) -> wgsl::ComputeShader {
let body = Self::compile_body(value.body);
let extensions = register_extensions(&body);
wgsl::ComputeShader {
inputs: value
.inputs
.into_iter()
.map(Self::compile_binding)
.collect(),
outputs: value
.outputs
.into_iter()
.map(Self::compile_binding)
.collect(),
named: value
.named
.into_iter()
.map(|(name, binding)| (name, Self::compile_binding(binding)))
.collect(),
workgroup_size: value.workgroup_size,
global_invocation_id: value.global_invocation_id,
num_workgroups: value.num_workgroups,
body,
extensions,
}
}
}
fn register_extensions(body: &wgsl::Body) -> Vec<wgsl::Extension> {
fn register_extensions(body: &wgsl::Scope) -> Vec<wgsl::Extension> {
let mut extensions = Vec::new();
let mut register_extension = |extension: wgsl::Extension| {
@ -282,20 +387,21 @@ fn register_extensions(body: &wgsl::Body) -> Vec<wgsl::Extension> {
// Since not all operators are native to WGSL, we need to add the custom ones.
for op in body.operators.iter() {
match op {
wgsl::Operation::Powf { lhs: _, rhs, out } => match rhs {
wgsl::Variable::Scalar(_, _, _) => {
register_extension(wgsl::Extension::PowfScalar(*out.item()));
wgsl::Instruction::Powf { lhs: _, rhs, out } => {
register_extension(wgsl::Extension::PowfPrimitive(out.item()));
if rhs.is_always_scalar() {
register_extension(wgsl::Extension::PowfScalar(out.item()));
} else {
register_extension(wgsl::Extension::Powf(out.item()));
}
_ => {
register_extension(wgsl::Extension::Powf(*out.item()));
}
},
wgsl::Operation::Erf { input, out: _ } => {
register_extension(wgsl::Extension::Erf(*input.item()));
}
wgsl::Instruction::Erf { input, out: _ } => {
register_extension(wgsl::Extension::Erf(input.item()));
}
#[cfg(target_os = "macos")]
wgsl::Operation::Tanh { input, out: _ } => {
register_extension(wgsl::Extension::SafeTanh(*input.item()))
wgsl::Instruction::Tanh { input, out: _ } => {
register_extension(wgsl::Extension::SafeTanh(input.item()))
}
_ => {}
}

View File

@ -5,6 +5,7 @@ use std::fmt::Display;
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum Extension {
PowfScalar(Item),
PowfPrimitive(Item),
Powf(Item),
Erf(Item),
#[cfg(target_os = "macos")]
@ -15,6 +16,7 @@ impl Display for Extension {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Extension::PowfScalar(elem) => format_powf_scalar(f, elem),
Extension::PowfPrimitive(elem) => format_powf_primitive(f, elem),
Extension::Powf(elem) => format_powf(f, elem),
Extension::Erf(elem) => format_erf(f, elem),
#[cfg(target_os = "macos")]
@ -24,12 +26,10 @@ impl Display for Extension {
}
fn format_powf_scalar(f: &mut core::fmt::Formatter<'_>, item: &Item) -> core::fmt::Result {
base_powf_fmt(f, item)?;
match item {
Item::Vec4(elem) => f.write_fmt(format_args!(
"
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
fn powf_scalar(lhs: {item}, rhs: {elem}) -> {item} {{
return vec4(
powf_primitive(lhs[0], rhs),
powf_primitive(lhs[1], rhs),
@ -41,7 +41,7 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
)),
Item::Vec3(elem) => f.write_fmt(format_args!(
"
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
fn powf_scalar(lhs: {item}, rhs: {elem}) -> {item} {{
return vec3(
powf_primitive(lhs[0], rhs),
powf_primitive(lhs[1], rhs),
@ -52,7 +52,7 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
)),
Item::Vec2(elem) => f.write_fmt(format_args!(
"
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
fn powf_scalar(lhs: {item}, rhs: {elem}) -> {item} {{
return vec2(
powf_primitive(lhs[0], rhs),
powf_primitive(lhs[1], rhs),
@ -62,7 +62,7 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
)),
Item::Scalar(elem) => f.write_fmt(format_args!(
"
fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{
fn powf_scalar(lhs: {elem}, rhs: {elem}) -> {elem} {{
return powf_primitive(lhs, rhs);
}}
"
@ -70,7 +70,10 @@ fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{
}
}
fn base_powf_fmt(f: &mut std::fmt::Formatter<'_>, item: &Item) -> Result<(), std::fmt::Error> {
fn format_powf_primitive(
f: &mut std::fmt::Formatter<'_>,
item: &Item,
) -> Result<(), std::fmt::Error> {
let elem = item.elem();
f.write_fmt(format_args!(
"
@ -96,12 +99,10 @@ fn powf_primitive(lhs: {elem}, rhs: {elem}) -> {elem} {{
}
fn format_powf(f: &mut core::fmt::Formatter<'_>, item: &Item) -> core::fmt::Result {
base_powf_fmt(f, item)?;
match item {
Item::Vec4(elem) => f.write_fmt(format_args!(
Item::Vec4(_) => f.write_fmt(format_args!(
"
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
fn powf(lhs: {item}, rhs: {item}) -> {item} {{
return vec4(
powf_primitive(lhs[0], rhs[0]),
powf_primitive(lhs[1], rhs[1]),
@ -111,9 +112,9 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
}}
"
)),
Item::Vec3(elem) => f.write_fmt(format_args!(
Item::Vec3(_) => f.write_fmt(format_args!(
"
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
fn powf(lhs: {item}, rhs: {item}) -> {item} {{
return vec3(
powf_primitive(lhs[0], rhs[0]),
powf_primitive(lhs[1], rhs[1]),
@ -122,9 +123,9 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
}}
"
)),
Item::Vec2(elem) => f.write_fmt(format_args!(
Item::Vec2(_) => f.write_fmt(format_args!(
"
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
fn powf(lhs: {item}, rhs: {item}) -> {item} {{
return vec2(
powf_primitive(lhs[0], rhs[0]),
powf_primitive(lhs[1], rhs[1]),

View File

@ -1,15 +1,28 @@
use super::base::{Item, Variable};
use std::fmt::Display;
/// All operations that can be used in a WGSL compute shader.
/// All instructions that can be used in a WGSL compute shader.
#[derive(Debug, Clone)]
#[allow(dead_code)] // Some variants might not be used with different flags
pub enum Operation {
pub enum Instruction {
DeclareVariable {
var: Variable,
},
Add {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Index {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Modulo {
lhs: Variable,
rhs: Variable,
out: Variable,
},
Sub {
lhs: Variable,
rhs: Variable,
@ -115,72 +128,99 @@ pub enum Operation {
input: Variable,
out: Variable,
},
ReadGlobal {
variable: Variable,
Stride {
dim: Variable,
position: usize,
out: Variable,
},
/// Read the tensor in a way to be compatible with another tensor layout.
ReadGlobalWithLayout {
variable: Variable,
tensor_read_pos: usize,
tensor_layout_pos: usize,
Shape {
dim: Variable,
position: usize,
out: Variable,
},
RangeLoop {
i: Variable,
start: Variable,
end: Variable,
instructions: Vec<Instruction>,
},
}
impl Display for Operation {
impl Display for Instruction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Operation::Add { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} + {rhs};"))
Instruction::DeclareVariable { var } => {
let item = var.item();
f.write_fmt(format_args!("var {var}: {item};\n"))
}
Operation::Sub { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} - {rhs};"))
Instruction::Add { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n"))
}
Operation::Mul { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} * {rhs};"))
Instruction::Index { lhs, rhs, out } => {
let item = out.item();
let lhs = match lhs {
Variable::GlobalInputArray(index, _) => format!("input_{index}_global"),
Variable::GlobalOutputArray(index, _) => format!("output_{index}_global"),
_ => format!("{lhs}"),
};
f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n"))
}
Operation::Div { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = {lhs} / {rhs};"))
Instruction::Modulo { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
}
Operation::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")),
Operation::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")),
Operation::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")),
Operation::Clamp {
Instruction::Sub { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} - {rhs};\n"))
}
Instruction::Mul { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} * {rhs};\n"))
}
Instruction::Div { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} / {rhs};\n"))
}
Instruction::Abs { input, out } => f.write_fmt(format_args!("{out} = abs({input});\n")),
Instruction::Exp { input, out } => f.write_fmt(format_args!("{out} = exp({input});\n")),
Instruction::Log { input, out } => f.write_fmt(format_args!("{out} = log({input});\n")),
Instruction::Clamp {
input,
min_value,
max_value,
out,
} => f.write_fmt(format_args!(
"let {out} = clamp({input}, {min_value}, {max_value});"
"{out} = clamp({input}, {min_value}, {max_value});\n"
)),
Operation::Powf { lhs, rhs, out } => {
f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});"))
Instruction::Powf { lhs, rhs, out } => {
if rhs.is_always_scalar() {
f.write_fmt(format_args!("{out} = powf_scalar({lhs}, {rhs});\n"))
} else {
f.write_fmt(format_args!("{out} = powf({lhs}, {rhs});\n"))
}
}
Operation::Sqrt { input, out } => {
f.write_fmt(format_args!("let {out} = sqrt({input});"))
Instruction::Sqrt { input, out } => {
f.write_fmt(format_args!("{out} = sqrt({input});\n"))
}
Operation::Log1p { input, out } => {
f.write_fmt(format_args!("let {out} = log({input} + 1.0);"))
Instruction::Log1p { input, out } => {
f.write_fmt(format_args!("{out} = log({input} + 1.0);\n"))
}
Operation::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")),
Operation::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")),
Operation::Tanh { input, out } => {
Instruction::Cos { input, out } => f.write_fmt(format_args!("{out} = cos({input});\n")),
Instruction::Sin { input, out } => f.write_fmt(format_args!("{out} = sin({input});\n")),
Instruction::Tanh { input, out } => {
#[cfg(target_os = "macos")]
let result = f.write_fmt(format_args!("let {out} = safe_tanh({input});"));
let result = f.write_fmt(format_args!("{out} = safe_tanh({input});\n"));
#[cfg(not(target_os = "macos"))]
let result = f.write_fmt(format_args!("let {out} = tanh({input});"));
let result = f.write_fmt(format_args!("{out} = tanh({input});\n"));
result
}
Operation::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
Operation::Recip { input, out } => {
f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
Instruction::Erf { input, out } => f.write_fmt(format_args!("{out} = erf({input});\n")),
Instruction::Recip { input, out } => {
f.write_fmt(format_args!("{out} = 1.0 / {input};"))
}
Operation::Equal { lhs, rhs, out } => comparison(lhs, rhs, out, "==", f),
Operation::Lower { lhs, rhs, out } => comparison(lhs, rhs, out, "<", f),
Operation::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f),
Operation::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f),
Operation::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f),
Operation::AssignGlobal { input, out } => {
Instruction::Equal { lhs, rhs, out } => comparison(lhs, rhs, out, "==", f),
Instruction::Lower { lhs, rhs, out } => comparison(lhs, rhs, out, "<", f),
Instruction::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f),
Instruction::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f),
Instruction::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f),
Instruction::AssignGlobal { input, out } => {
let elem_out = out.item();
let elem_in = input.item();
@ -211,76 +251,18 @@ impl Display for Operation {
);"
)),
Item::Scalar(elem) => {
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});\n"))
}
}
} else {
f.write_fmt(format_args!("{out}_global[id] = {elem_out}({input});"))
f.write_fmt(format_args!("{out}_global[id] = {elem_out}({input});\n"))
}
}
Operation::AssignLocal { input, out } => {
let elem = out.item();
f.write_fmt(format_args!("let {out} = {elem}({input});"))
Instruction::AssignLocal { input, out } => {
let item = out.item();
f.write_fmt(format_args!("{out} = {item}({input});\n"))
}
Operation::ReadGlobal { variable } => match variable {
Variable::Input(number, _elem) => f.write_fmt(format_args!(
"let input_{number} = input_{number}_global[id];"
)),
Variable::Local(_, _) => panic!("can't read global local variable."),
Variable::Output(number, _elem) => f.write_fmt(format_args!(
"let output_{number} = output_{number}_global[id];"
)),
Variable::Scalar(_, _, _) => panic!("Can't read global scalar variable."),
Variable::Constant(_, _) => panic!("Can't read global constant variable."),
},
Operation::ReadGlobalWithLayout {
variable,
tensor_read_pos: position,
tensor_layout_pos: position_out,
} => {
let (global, local, elem) = match variable {
Variable::Input(number, elem) => (
format!("input_{number}_global"),
format!("input_{number}"),
elem,
),
Variable::Local(_, _) => panic!("can't read global local variable."),
Variable::Output(number, elem) => (
format!("output_{number}_global"),
format!("output_{number}"),
elem,
),
Variable::Scalar(_, _, _) => panic!("Can't read global scalar variable."),
Variable::Constant(_, _) => panic!("Can't read global constant variable."),
};
let offset = match elem {
Item::Vec4(_) => 4,
Item::Vec3(_) => 3,
Item::Vec2(_) => 2,
Item::Scalar(_) => 1,
};
f.write_fmt(format_args!(
"
var index_{local}: u32 = 0u;
for (var i: u32 = 1u; i <= rank; i++) {{
let position = {position}u * (2u * rank);
let position_out = {position_out}u * (2u * rank);
let stride = info[position + i];
let stride_out = info[position_out + i];
let shape = info[position + rank + i];
index_{local} += (id * {offset}u) / stride_out % shape * stride;
}}
let {local} = {elem}({global}[index_{local} / {offset}u]);
"
))
}
Operation::ConditionalAssign {
Instruction::ConditionalAssign {
cond,
lhs,
rhs,
@ -301,7 +283,6 @@ let {local} = {elem}({global}[index_{local} / {offset}u]);
f.write_fmt(format_args!(
"
var {out}: {elem};
if {cond}[0] {{
{out}[0] = {lhs0};
}} else {{
@ -335,7 +316,6 @@ if {cond}[3] {{
f.write_fmt(format_args!(
"
var {out}: {elem};
if {cond}[0] {{
{out}[0] = {lhs0};
}} else {{
@ -362,7 +342,6 @@ if {cond}[2] {{
f.write_fmt(format_args!(
"
var {out}: {elem};
if {cond}[0] {{
{out}[0] = {lhs0};
}} else {{
@ -378,7 +357,6 @@ if {cond}[1] {{
}
Item::Scalar(_) => f.write_fmt(format_args!(
"
var {out}: {elem};
if {cond} {{
{out} = {lhs};
}} else {{
@ -388,6 +366,29 @@ if {cond} {{
)),
}
}
Instruction::Stride { dim, position, out } => f.write_fmt(format_args!(
"{out} = info[({position}u * (2u * rank)) + {dim} + 1u];\n"
)),
Instruction::Shape { dim, position, out } => f.write_fmt(format_args!(
"{out} = info[({position}u * (2u * rank)) + rank + {dim} + 1u];\n"
)),
Instruction::RangeLoop {
i,
start,
end,
instructions,
} => {
f.write_fmt(format_args!(
"
for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
"
))?;
for instruction in instructions {
f.write_fmt(format_args!("{instruction}"))?;
}
f.write_str("}\n")
}
}
}
}
@ -412,7 +413,7 @@ fn comparison(
f.write_fmt(format_args!(
"
let {out} = vec4({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2}, {lhs3} {op} {rhs3});
{out} = vec4({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2}, {lhs3} {op} {rhs3});
"
))
}
@ -426,7 +427,7 @@ let {out} = vec4({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2}, {lh
f.write_fmt(format_args!(
"
let {out} = vec3({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2});
{out} = vec3({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2});
"
))
}
@ -438,12 +439,12 @@ let {out} = vec3({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2});
f.write_fmt(format_args!(
"
let {out} = vec2({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1});
{out} = vec2({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1});
"
))
}
Item::Scalar(_) => match rhs.item() {
Item::Scalar(_) => f.write_fmt(format_args!("let {out} = {lhs} {op} {rhs};")),
Item::Scalar(_) => f.write_fmt(format_args!("{out} = {lhs} {op} {rhs};\n")),
_ => panic!("Can only compare a scalar when the output is a scalar"),
},
}

View File

@ -1,4 +1,4 @@
use super::{Body, Extension, Item};
use super::{Extension, Item, Scope};
use crate::codegen::dialect::gpu::WorkgroupSize;
use std::fmt::Display;
@ -31,7 +31,7 @@ pub struct ComputeShader {
pub workgroup_size: WorkgroupSize,
pub global_invocation_id: bool,
pub num_workgroups: bool,
pub body: Body,
pub body: Scope,
pub extensions: Vec<Extension>,
}

View File

@ -1,343 +1,8 @@
use burn_compute::client::ComputeClient;
use crate::codegen::dialect::gpu::{
Binding, Body, ComputeShader, Elem, Item, Location, Operation, ReadGlobalOperation,
ReadGlobalWithLayoutOperation, UnaryOperation, Variable, Vectorization, Visibility,
WorkgroupSize,
};
use crate::compute::StaticKernel;
use crate::element::JitElement;
use crate::kernel::{elemwise_workgroup, StaticKernelSource, WORKGROUP_DEFAULT};
use crate::Runtime;
use std::marker::PhantomData;
/// Kernel creation input phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
pub struct InputPhase;
/// Kernel creation body phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
pub struct BodyPhase;
/// Kernel creation output phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
pub struct OutputPhase;
/// Kernel compilation phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
pub struct CompilationPhase;
#[derive(new, Clone, Copy)]
pub struct InplaceMapping {
pub position_input: usize,
pub position_output: usize,
}
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
///
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
/// you to make mistakes.
///
/// 1. [Input Phase](InputPhase)
/// This phase focuses on registering the input arrays and scalars that are going to be used by
/// the kernel.
/// 2. [Body Phase](BodyPhase)
/// After the input phase is done, all the operations that happen in the body must be
/// registered.
/// 3. [Output Phase](OutputPhase)
/// This step focuses on registering all output arrays or inputs that the kernel needs to write to.
/// 4. [Compilation Phase](CompilationPhase)
/// Now that all other phases are completed, we can actually compile the kernel.
pub struct ElemWiseKernelCodegen<Phase = InputPhase> {
operations: Vec<Operation>,
input_bindings: Vec<Binding>,
output_bindings: Vec<Binding>,
named_bindings: Vec<(String, Binding)>,
vectorization: Vectorization,
mappings_inplace: Vec<InplaceMapping>,
workgroup_size: WorkgroupSize,
_phase: PhantomData<Phase>,
}
pub enum Input {
Array {
item: Item,
visibility: Visibility,
strategy: ReadingStrategy,
},
Scalar {
elem: Elem,
size: usize,
},
}
pub enum ReadingStrategy {
/// Each element will be read in a way to be compatible with the output layout.
OutputLayout,
/// Keep the current layout.
Plain,
}
#[derive(Clone)]
pub enum Output {
Array { item: Item, local: u16 },
Input { item: Item, input: u16, local: u16 },
}
impl Default for ElemWiseKernelCodegen<InputPhase> {
fn default() -> Self {
Self {
operations: Vec::new(),
input_bindings: Vec::new(),
output_bindings: Vec::new(),
named_bindings: Vec::new(),
vectorization: Vectorization::Scalar,
mappings_inplace: Vec::new(),
workgroup_size: WorkgroupSize::default(),
_phase: PhantomData,
}
}
}
impl ElemWiseKernelCodegen<InputPhase> {
pub fn new() -> Self {
Self::default()
}
#[allow(dead_code)]
pub fn vectorize(mut self, vectorization: Vectorization) -> Self {
self.vectorization = vectorization;
self
}
#[allow(dead_code)]
pub fn inplace(mut self, mappings: &[InplaceMapping]) -> Self {
self.mappings_inplace = mappings.to_vec();
self
}
/// Register the inputs used by the kernel.
pub fn inputs(mut self, inputs: &[Input]) -> ElemWiseKernelCodegen<BodyPhase> {
let mut index: u16 = 0;
for input in inputs {
match input {
Input::Array {
item,
visibility,
strategy,
} => {
let item = item.vectorize(self.vectorization);
self.input_bindings.push(Binding {
item: bool_item(item),
visibility: *visibility,
location: Location::Storage,
size: None,
});
match strategy {
ReadingStrategy::OutputLayout => {
self.operations.push(Operation::ReadGlobalWithLayout(
ReadGlobalWithLayoutOperation {
variable: Variable::Input(index, item),
tensor_read_pos: index as usize,
tensor_layout_pos: 0, // Will set the right value during the output
// phase.
},
));
}
ReadingStrategy::Plain => {
self.operations
.push(Operation::ReadGlobal(ReadGlobalOperation {
variable: Variable::Input(index, item),
}));
}
}
index += 1;
}
Input::Scalar { elem, size } => {
let elem = bool_elem(*elem);
self.named_bindings.push((
format!("scalars_{}", elem),
Binding {
item: Item::Scalar(elem),
visibility: Visibility::Read,
location: Location::Storage,
size: Some(*size),
},
));
}
}
}
ElemWiseKernelCodegen {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
vectorization: self.vectorization,
mappings_inplace: self.mappings_inplace,
workgroup_size: self.workgroup_size,
_phase: PhantomData,
}
}
}
impl ElemWiseKernelCodegen<BodyPhase> {
/// Register the [operators](Operator) that the kernel must execute in the order provided.
pub fn body(mut self, operators: &[Operation]) -> ElemWiseKernelCodegen<OutputPhase> {
for ops in operators.iter() {
self.operations.push(ops.vectorize(self.vectorization));
}
ElemWiseKernelCodegen {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
vectorization: self.vectorization,
mappings_inplace: self.mappings_inplace,
workgroup_size: self.workgroup_size,
_phase: PhantomData,
}
}
}
impl ElemWiseKernelCodegen<OutputPhase> {
/// Register the outputs with their local variable index.
///
/// Note that the index corresponds to the registered [operator](Operator) number at the
/// [body phase](BodyPhase).
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
pub fn outputs(mut self, outputs: &[Output]) -> ElemWiseKernelCodegen<CompilationPhase> {
let mut index = 0;
let mut position_out = 0;
let mut outputs = outputs.to_vec();
for mapping in self.mappings_inplace.iter() {
match outputs.get_mut(mapping.position_output) {
Some(output) => match output {
Output::Array { item, local } => {
*output = Output::Input {
item: *item,
input: mapping.position_input as u16,
local: *local,
};
}
Output::Input {
item: _,
input: _,
local: _,
} => continue,
},
None => continue,
}
if let Some(binding) = self.input_bindings.get_mut(mapping.position_input) {
binding.visibility = Visibility::ReadWrite
}
}
for array in &outputs {
match array {
Output::Array { item, local } => {
let item = item.vectorize(self.vectorization);
let elem_adapted = bool_item(item);
self.output_bindings.push(Binding {
item: elem_adapted,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
});
self.operations
.push(Operation::AssignGlobal(UnaryOperation {
input: Variable::Local(*local, item),
out: Variable::Output(index, elem_adapted),
}));
index += 1;
if index == 1 {
position_out = self.input_bindings.len(); // First output when we have a
// new array for the output.
}
}
Output::Input { item, input, local } => {
let item = item.vectorize(self.vectorization);
self.operations
.push(Operation::AssignGlobal(UnaryOperation {
input: Variable::Local(*local, item),
out: Variable::Input(*input, bool_item(item)),
}));
position_out = *input as usize; // Input number when we use inplace operation.
}
}
}
// We set the output number that will be used for the stride definition.
for i in 0..self.input_bindings.len() {
if let Some(Operation::ReadGlobalWithLayout(ReadGlobalWithLayoutOperation {
variable: _,
tensor_read_pos: _,
tensor_layout_pos,
})) = self.operations.get_mut(i)
{
{
*tensor_layout_pos = position_out;
}
};
}
ElemWiseKernelCodegen {
operations: self.operations,
input_bindings: self.input_bindings,
output_bindings: self.output_bindings,
named_bindings: self.named_bindings,
vectorization: self.vectorization,
mappings_inplace: self.mappings_inplace,
workgroup_size: self.workgroup_size,
_phase: PhantomData,
}
}
}
impl ElemWiseKernelCodegen<CompilationPhase> {
#[allow(dead_code)] // Only used for fusion for now.
pub fn workgroup_size(mut self, workgroup_size: WorkgroupSize) -> Self {
self.workgroup_size = workgroup_size;
self
}
/// Compile the kernel into a [compute shader](ComputeShader).
pub fn compile(self) -> ComputeShader {
let inputs = self.input_bindings;
let outputs = self.output_bindings;
let mut named = Vec::with_capacity(2);
named.push((
"info".to_string(),
Binding {
item: Item::Scalar(Elem::UInt),
visibility: Visibility::Read,
location: Location::Storage,
size: None, // We avoid putting the length here since it will force a new kernel
// for each tensor rank.
},
));
for (name, binding) in self.named_bindings.into_iter() {
named.push((name, binding));
}
ComputeShader {
inputs,
outputs,
named,
workgroup_size: self.workgroup_size,
body: Body::new(self.operations),
num_workgroups: true,
global_invocation_id: true,
}
}
}
use burn_compute::client::ComputeClient;
#[derive(new)]
pub struct StaticHandle<'a, R: Runtime> {
@ -433,20 +98,3 @@ pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
}
num_elems
}
fn bool_item(ty: Item) -> Item {
match ty {
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
}
}
fn bool_elem(elem: Elem) -> Elem {
match elem {
// U32 are used for bool tensors
Elem::Bool => Elem::UInt,
_ => elem,
}
}

View File

@ -1,7 +1,9 @@
mod compilation;
pub(crate) mod compiler;
pub(crate) mod dialect;
mod kernel;
pub(crate) use compilation::*;
pub(crate) use compiler::*;
pub(crate) use kernel::*;

View File

@ -75,7 +75,7 @@ mod tests {
use super::*;
use crate::{
binary,
codegen::dialect::gpu::{BinaryOperation, Elem, Item, Operation, Variable},
codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope},
kernel::{KernelSettings, WORKGROUP_DEFAULT},
tests::{TestCompiler, TestRuntime},
Runtime, WgpuDevice,
@ -84,10 +84,10 @@ mod tests {
#[test]
fn can_run_kernel() {
binary!(
operation: |elem: Elem| Operation::Add(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Add(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(elem),
}),
compiler: TestCompiler,
elem_in: f32,

View File

@ -70,7 +70,7 @@ impl<R: Runtime> FusionBackend for JitBackend<R> {
type OptimizationState = WgpuOptimizationState;
type Optimization = WgpuOptimization<R>;
type FusionDevice = R::Device;
type Handle = WgpuFusionHandle<R>;
type Handle = JitFusionHandle<R>;
type FusionClient = MutexFusionClient<Self>;
fn optimizations(
@ -126,7 +126,7 @@ pub fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
}
/// Handle to be used when fusing operations.
pub struct WgpuFusionHandle<R: Runtime> {
pub struct JitFusionHandle<R: Runtime> {
/// Compute client for wgpu.
pub client: ComputeClient<R::Server, R::Channel>,
/// The buffer where the data are stored.
@ -136,13 +136,17 @@ pub struct WgpuFusionHandle<R: Runtime> {
pub(crate) strides: Vec<usize>,
}
impl<R: Runtime> core::fmt::Debug for WgpuFusionHandle<R> {
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
todo!()
impl<R: Runtime> core::fmt::Debug for JitFusionHandle<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!(
"JitFusionHandle {{ device: {:?}, runtime: {}}}",
self.device,
R::name(),
))
}
}
impl<R: Runtime> Clone for WgpuFusionHandle<R> {
impl<R: Runtime> Clone for JitFusionHandle<R> {
fn clone(&self) -> Self {
Self {
client: self.client.clone(),
@ -153,10 +157,10 @@ impl<R: Runtime> Clone for WgpuFusionHandle<R> {
}
}
unsafe impl<R: Runtime> Send for WgpuFusionHandle<R> {}
unsafe impl<R: Runtime> Sync for WgpuFusionHandle<R> {}
unsafe impl<R: Runtime> Send for JitFusionHandle<R> {}
unsafe impl<R: Runtime> Sync for JitFusionHandle<R> {}
impl<R: Runtime> WgpuFusionHandle<R> {
impl<R: Runtime> JitFusionHandle<R> {
pub(crate) fn into_tensor<const D: usize, E: JitElement>(
self,
shape: Shape<D>,
@ -172,7 +176,7 @@ impl<R: Runtime> WgpuFusionHandle<R> {
}
}
impl<R: Runtime, E: JitElement, const D: usize> From<JitTensor<R, E, D>> for WgpuFusionHandle<R> {
impl<R: Runtime, E: JitElement, const D: usize> From<JitTensor<R, E, D>> for JitFusionHandle<R> {
fn from(value: JitTensor<R, E, D>) -> Self {
Self {
client: value.client,

View File

@ -1,11 +1,10 @@
use super::{optimization::ElementWise, CompilationPhase, Scalars};
use super::{optimization::ElementWise, CompilationPhase};
use crate::{
codegen::dialect::gpu::{
BinaryOperation, ConditionalAssignOperation, Elem, Item, Operation, UnaryOperation,
Variable,
BinaryOperator, ConditionalAssignOperator, Elem, Operator, UnaryOperator, Variable,
},
element::JitElement,
fusion::WgpuOptimization,
fusion::{tracing::TraceBuilder, WgpuOptimization},
JitBackend, Runtime,
};
use burn_fusion::{
@ -14,27 +13,20 @@ use burn_fusion::{
NumericOperationDescription, OperationDescription, ScalarOperationDescription,
UnaryOperationDescription,
},
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, TensorId,
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription,
};
use burn_tensor::{
ops::{FloatElem, IntElem},
Device, Element,
};
use hashbrown::HashMap;
/// Fused element wise operations that are normally memory bound.
pub(crate) struct ElementWiseBuilder<R: Runtime> {
pub(crate) inputs: Vec<TensorDescription>,
pub(crate) locals: HashMap<TensorId, u16>,
pub(crate) tensors: HashMap<TensorId, (TensorDescription, Elem)>,
pub(crate) scalars_float: usize,
pub(crate) scalars_int: usize,
pub(crate) scalars_uint: usize,
pub(crate) booleans: usize,
pub(crate) operators: Vec<Operation>,
pub(crate) current_output_shape: Vec<usize>,
pub(crate) status: OptimizationStatus,
pub(crate) device: R::Device,
builder: TraceBuilder,
current_output_shape: Vec<usize>,
status: OptimizationStatus,
num_added: usize,
device: R::Device,
}
impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder<R> {
@ -81,22 +73,13 @@ impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder
};
self.status = OptimizationStatus::Open;
self.num_added += 1;
}
fn build(&self) -> WgpuOptimization<R> {
let inputs = self.input_descriptions();
let outputs = self.output_descriptions();
let locals = outputs
.iter()
.map(|out| *self.locals.get(&out.0.id).unwrap())
.collect::<Vec<_>>();
let op = ElementWise::new(
inputs,
outputs,
locals,
Scalars::new(self.scalars_float, self.scalars_uint, self.scalars_int),
self.operators.clone(),
self.builder.clone().build(),
self.num_added,
self.device.clone(),
CompilationPhase,
);
@ -105,18 +88,12 @@ impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder
}
fn len(&self) -> usize {
self.operators.len()
self.num_added
}
fn reset(&mut self) {
self.inputs.clear();
self.locals.drain();
self.tensors.clear();
self.scalars_float = 0;
self.scalars_int = 0;
self.scalars_uint = 0;
self.booleans = 0;
self.operators.clear();
self.builder = TraceBuilder::new();
self.num_added = 0;
self.status = OptimizationStatus::Open;
self.current_output_shape.clear();
}
@ -126,11 +103,11 @@ impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder
}
fn properties(&self) -> OptimizationProperties {
let ready = !self.operators.is_empty();
let ready = self.num_added > 0;
OptimizationProperties {
ready,
score: self.operators.len() as u64,
score: self.num_added as u64,
}
}
}
@ -138,269 +115,20 @@ impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder
impl<R: Runtime> ElementWiseBuilder<R> {
pub fn new(device: Device<JitBackend<R>>) -> Self {
Self {
inputs: Vec::new(),
locals: HashMap::new(),
tensors: HashMap::new(),
scalars_float: 0,
scalars_int: 0,
scalars_uint: 0,
booleans: 0,
operators: Vec::new(),
builder: TraceBuilder::new(),
num_added: 0,
current_output_shape: Vec::new(),
status: OptimizationStatus::Open,
device,
}
}
fn input_descriptions(&self) -> Vec<(TensorDescription, Elem)> {
self.inputs
.iter()
.map(|input| {
let updated_tensor = self.tensors.get(&input.id).unwrap();
updated_tensor.clone()
})
.collect::<Vec<_>>()
}
fn output_descriptions(&self) -> Vec<(TensorDescription, Elem)> {
let mut outputs = Vec::new();
let mut local_tensor_ids_input = Vec::new();
let mut local_tensor_ids_output = Vec::new();
// Mark a variable to the provided list of tensor ids using the variable list.
//
// Only local variables can become outputs.
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
if let Variable::Local(index, _) = var {
if let Some((id, _)) = self
.locals
.iter()
.find(|(_id, position)| *position == index)
{
if !list.contains(id) {
list.push(*id);
}
}
}
};
let mark_binary =
|op: &BinaryOperation, inputs: &mut Vec<TensorId>, outputs: &mut Vec<TensorId>| {
mark(&op.lhs, inputs);
mark(&op.rhs, inputs);
mark(&op.out, outputs);
};
let mark_unary =
|op: &UnaryOperation, inputs: &mut Vec<TensorId>, outputs: &mut Vec<TensorId>| {
mark(&op.input, inputs);
mark(&op.out, outputs);
};
// For all operators, mark their local tensor id in the proper set.
for ops in self.operators.iter() {
match ops {
Operation::AssignGlobal(_) => {
// Nothing to do here.
}
Operation::AssignLocal(op) => {
mark(&op.out, &mut local_tensor_ids_output);
}
Operation::ReadGlobalWithLayout(_) => {
// Nothing to do here.
}
Operation::ReadGlobal(_) => {
// Nothing to do here.
}
Operation::Add(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Sub(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Mul(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Div(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Exp(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Abs(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Erf(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Log(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Log1p(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Cos(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Sin(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Tanh(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Clamp(op) => {
mark(&op.input, &mut local_tensor_ids_input);
mark(&op.out, &mut local_tensor_ids_output);
}
Operation::Powf(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Recip(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Lower(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Greater(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::LowerEqual(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::GreaterEqual(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::Equal(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
Operation::ConditionalAssign(op) => {
mark(&op.cond, &mut local_tensor_ids_input);
mark(&op.lhs, &mut local_tensor_ids_input);
mark(&op.rhs, &mut local_tensor_ids_input);
mark(&op.out, &mut local_tensor_ids_output);
}
Operation::Sqrt(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
}
}
// All output tensors that are never read by a following operation should be written to
// since they are essentially the "logical" output of the shader.
for out in local_tensor_ids_output {
let is_read = local_tensor_ids_input.contains(&out);
if !is_read {
outputs.push(self.tensors.get(&out).unwrap().clone());
}
}
// All tensors where their latest description is read only should be written to since they
// are going to be used after the fused kernel by other operations.
for entry in self.tensors.values() {
let (tensor, _) = &entry;
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
if self.locals.contains_key(&tensor.id) {
outputs.push(entry.clone());
}
}
}
outputs
}
fn input_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
let already_exists = self.tensors.contains_key(&tensor.id);
let variable = match already_exists {
false => {
// New input
let var = Variable::Input(self.inputs.len() as u16, Item::Scalar(elem));
self.inputs.push(tensor.clone());
var
}
true => match self.locals.get(&tensor.id) {
// Is a local variable.
Some(local_index) => Variable::Local(*local_index, Item::Scalar(elem)),
// Isn't a local variable, so must be an existing input.
None => {
let input = self
.inputs
.iter()
.enumerate()
.find(|(_, input)| input.id == tensor.id)
.unwrap();
let input_index = input.0;
Variable::Input(input_index as u16, Item::Scalar(elem))
}
},
};
// Update the tensor description with the new version.
self.tensors.insert(tensor.id, (tensor.clone(), elem));
variable
}
fn output_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
// Update the tensor description to the new version.
self.tensors.insert(tensor.id, (tensor.clone(), elem));
// Output already registered as a local variable.
if let Some(index) = self.locals.get(&tensor.id) {
return Variable::Local(*index, Item::Scalar(elem));
}
// New local variable.
let local_index = self.locals.len() as u16;
self.locals.insert(tensor.id, local_index);
Variable::Local(local_index, Item::Scalar(elem))
}
fn register_base<E: JitElement>(&mut self, ops: &BaseOperationDescription) -> bool {
match ops {
BaseOperationDescription::Equal(desc) => self.register_binary_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::Equal(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Equal(BinaryOperator { lhs, rhs, out }),
),
_ => false,
}
@ -410,47 +138,47 @@ impl<R: Runtime> ElementWiseBuilder<R> {
match ops {
FloatOperationDescription::Exp(desc) => {
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
Operation::Exp(UnaryOperation { input, out })
Operator::Exp(UnaryOperator { input, out })
})
}
FloatOperationDescription::Log(desc) => {
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
Operation::Log(UnaryOperation { input, out })
Operator::Log(UnaryOperator { input, out })
})
}
FloatOperationDescription::Log1p(desc) => {
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
Operation::Log1p(UnaryOperation { input, out })
Operator::Log1p(UnaryOperator { input, out })
})
}
FloatOperationDescription::Cos(desc) => {
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
Operation::Cos(UnaryOperation { input, out })
Operator::Cos(UnaryOperator { input, out })
})
}
FloatOperationDescription::Sin(desc) => {
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
Operation::Sin(UnaryOperation { input, out })
Operator::Sin(UnaryOperator { input, out })
})
}
FloatOperationDescription::PowfScalar(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|lhs, rhs, out| Operation::Powf(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Powf(BinaryOperator { lhs, rhs, out }),
),
FloatOperationDescription::Tanh(desc) => {
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
Operation::Tanh(UnaryOperation { input, out })
Operator::Tanh(UnaryOperator { input, out })
})
}
FloatOperationDescription::Erf(desc) => {
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
Operation::Erf(UnaryOperation { input, out })
Operator::Erf(UnaryOperator { input, out })
})
}
FloatOperationDescription::Recip(desc) => {
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
Operation::Recip(UnaryOperation { input, out })
Operator::Recip(UnaryOperator { input, out })
})
}
_ => false,
@ -465,110 +193,111 @@ impl<R: Runtime> ElementWiseBuilder<R> {
NumericOperationDescription::Add(desc) => self.register_binary_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|lhs, rhs, out| Operation::Add(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Add(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::AddScalar(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|lhs, rhs, out| Operation::Add(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Add(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::Sub(desc) => self.register_binary_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|lhs, rhs, out| Operation::Sub(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Sub(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::SubScalar(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|lhs, rhs, out| Operation::Sub(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Sub(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::Mul(desc) => self.register_binary_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|lhs, rhs, out| Operation::Mul(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Mul(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::MulScalar(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|lhs, rhs, out| Operation::Mul(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Mul(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::Div(desc) => self.register_binary_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|lhs, rhs, out| Operation::Div(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Div(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::DivScalar(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|lhs, rhs, out| Operation::Div(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Div(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::Abs(desc) => {
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
Operation::Abs(UnaryOperation { input, out })
Operator::Abs(UnaryOperator { input, out })
})
}
NumericOperationDescription::Lower(desc) => self.register_binary_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::Lower(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Lower(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::LowerElem(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::Lower(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Lower(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::Greater(desc) => self.register_binary_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::Greater(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Greater(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::GreaterElem(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::Greater(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Greater(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::LowerEqual(desc) => self.register_binary_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::LowerEqual(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::LowerEqual(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::LowerEqualElem(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::LowerEqual(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::LowerEqual(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::GreaterEqual(desc) => self.register_binary_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::GreaterEqual(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::GreaterEqual(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::EqualElem(desc) => self.register_scalar_ops(
desc,
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|lhs, rhs, out| Operation::Equal(BinaryOperation { lhs, rhs, out }),
|lhs, rhs, out| Operator::Equal(BinaryOperator { lhs, rhs, out }),
),
NumericOperationDescription::MaskWhere(desc) => {
if !self.output_is_compatible(&desc.out) {
return false;
}
let cond = self.input_to_var(&desc.mask, Elem::Bool);
let lhs = self.input_to_var(&desc.value, E::gpu_elem());
let rhs = self.input_to_var(&desc.tensor, E::gpu_elem());
let out = self.output_to_var(&desc.out, E::gpu_elem());
let cond = self.builder.input(&desc.mask, Elem::Bool);
let lhs = self.builder.input(&desc.value, E::gpu_elem());
let rhs = self.builder.input(&desc.tensor, E::gpu_elem());
let out = self.builder.output(&desc.out, E::gpu_elem());
let ops = Operation::ConditionalAssign(ConditionalAssignOperation {
cond,
lhs,
rhs,
out,
});
self.operators.push(ops);
self.builder.register_operation(Operator::ConditionalAssign(
ConditionalAssignOperator {
cond,
lhs,
rhs,
out,
},
));
true
}
@ -577,18 +306,19 @@ impl<R: Runtime> ElementWiseBuilder<R> {
return false;
}
let cond = self.input_to_var(&desc.mask, Elem::Bool);
let lhs = self.scalar_to_var(&desc.value, E::gpu_elem());
let rhs = self.input_to_var(&desc.tensor, E::gpu_elem());
let out = self.output_to_var(&desc.out, E::gpu_elem());
let cond = self.builder.input(&desc.mask, Elem::Bool);
let lhs = self.builder.scalar(&desc.value, E::gpu_elem());
let rhs = self.builder.input(&desc.tensor, E::gpu_elem());
let out = self.builder.output(&desc.out, E::gpu_elem());
self.operators
.push(Operation::ConditionalAssign(ConditionalAssignOperation {
self.builder.register_operation(Operator::ConditionalAssign(
ConditionalAssignOperator {
cond,
lhs,
rhs,
out,
}));
},
));
true
}
@ -597,11 +327,11 @@ impl<R: Runtime> ElementWiseBuilder<R> {
return false;
}
let input = Variable::Constant(1.0, Item::Scalar(E::gpu_elem()));
let out = self.output_to_var(desc, E::gpu_elem());
let input = Variable::ConstantScalar(1.0, E::gpu_elem());
let out = self.builder.output(desc, E::gpu_elem());
self.operators
.push(Operation::AssignLocal(UnaryOperation { input, out }));
self.builder
.register_operation(Operator::AssignLocal(UnaryOperator { input, out }));
true
}
@ -610,11 +340,11 @@ impl<R: Runtime> ElementWiseBuilder<R> {
return false;
}
let input = Variable::Constant(0.0, Item::Scalar(E::gpu_elem()));
let out = self.output_to_var(desc, E::gpu_elem());
let input = Variable::ConstantScalar(0.0, E::gpu_elem());
let out = self.builder.output(desc, E::gpu_elem());
self.operators
.push(Operation::AssignLocal(UnaryOperation { input, out }));
self.builder
.register_operation(Operator::AssignLocal(UnaryOperator { input, out }));
true
}
@ -623,11 +353,11 @@ impl<R: Runtime> ElementWiseBuilder<R> {
return false;
}
let input = self.scalar_to_var(elem, E::gpu_elem());
let out = self.output_to_var(desc, E::gpu_elem());
let input = self.builder.scalar(elem, E::gpu_elem());
let out = self.builder.output(desc, E::gpu_elem());
self.operators
.push(Operation::AssignLocal(UnaryOperation { input, out }));
self.builder
.register_operation(Operator::AssignLocal(UnaryOperator { input, out }));
true
}
@ -642,17 +372,17 @@ impl<R: Runtime> ElementWiseBuilder<R> {
func: Func,
) -> bool
where
Func: Fn(Variable, Variable, Variable) -> Operation,
Func: Fn(Variable, Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
let rhs = self.input_to_var(&desc.rhs, elem_rhs);
let out = self.output_to_var(&desc.out, elem_out);
let lhs = self.builder.input(&desc.lhs, elem_lhs);
let rhs = self.builder.input(&desc.rhs, elem_rhs);
let out = self.builder.output(&desc.out, elem_out);
self.operators.push(func(lhs, rhs, out));
self.builder.register_operation(func(lhs, rhs, out));
true
}
@ -664,16 +394,16 @@ impl<R: Runtime> ElementWiseBuilder<R> {
func: Func,
) -> bool
where
Func: Fn(Variable, Variable) -> Operation,
Func: Fn(Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let input = self.input_to_var(&desc.input, elem_input);
let out = self.output_to_var(&desc.out, elem_out);
let input = self.builder.input(&desc.input, elem_input);
let out = self.builder.output(&desc.out, elem_out);
self.operators.push(func(input, out));
self.builder.register_operation(func(input, out));
true
}
@ -685,41 +415,21 @@ impl<R: Runtime> ElementWiseBuilder<R> {
func: Func,
) -> bool
where
Func: Fn(Variable, Variable, Variable) -> Operation,
Func: Fn(Variable, Variable, Variable) -> Operator,
{
if !self.output_is_compatible(&desc.out) {
return false;
}
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
let rhs = self.scalar_to_var(&desc.rhs, elem_rhs);
let out = self.output_to_var(&desc.out, elem_out);
let lhs = self.builder.input(&desc.lhs, elem_lhs);
let rhs = self.builder.scalar(&desc.rhs, elem_rhs);
let out = self.builder.output(&desc.out, elem_out);
self.operators.push(func(lhs, rhs, out));
self.builder.register_operation(func(lhs, rhs, out));
true
}
fn scalar_to_var<E: Element>(&mut self, _value: &E, elem_type: Elem) -> Variable {
match elem_type {
Elem::Float => {
self.scalars_float += 1;
Variable::Scalar(self.scalars_float as u16 - 1, Item::Scalar(Elem::Float))
}
Elem::Int => {
self.scalars_int += 1;
Variable::Scalar(self.scalars_int as u16 - 1, Item::Scalar(Elem::Int))
}
Elem::UInt => {
self.scalars_uint += 1;
Variable::Scalar(self.scalars_uint as u16 - 1, Item::Scalar(Elem::UInt))
}
Elem::Bool => {
panic!("Bool scalars not supported")
}
}
}
fn output_is_compatible(&mut self, out: &TensorDescription) -> bool {
if self.current_output_shape.is_empty() {
self.current_output_shape = out.shape.clone();

View File

@ -4,7 +4,7 @@ use crate::{
fusion::{
kernel::{FusionKernel, OutputInfo, Priority, SelectedKernel},
source::GpuKernelSource,
WgpuFusionHandle,
JitFusionHandle,
},
kernel::elemwise_workgroup,
Runtime,
@ -23,7 +23,7 @@ pub struct VecElementWise<R: Runtime> {
impl<R: Runtime> FusionKernel<R> for ScalarElementWise<R> {
fn kernel(
&self,
handles_inputs: &[WgpuFusionHandle<R>],
handles_inputs: &[JitFusionHandle<R>],
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
) -> SelectedKernel {
@ -32,7 +32,7 @@ impl<R: Runtime> FusionKernel<R> for ScalarElementWise<R> {
fn priority(
&self,
_handles_inputs: &[WgpuFusionHandle<R>],
_handles_inputs: &[JitFusionHandle<R>],
_inputs: &[&TensorDescription],
_outputs: &[&TensorDescription],
) -> Priority {
@ -43,7 +43,7 @@ impl<R: Runtime> FusionKernel<R> for ScalarElementWise<R> {
impl<R: Runtime> FusionKernel<R> for VecElementWise<R> {
fn kernel(
&self,
handles_inputs: &[WgpuFusionHandle<R>],
handles_inputs: &[JitFusionHandle<R>],
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
) -> SelectedKernel {
@ -52,11 +52,11 @@ impl<R: Runtime> FusionKernel<R> for VecElementWise<R> {
fn priority(
&self,
handles_inputs: &[WgpuFusionHandle<R>],
handles_inputs: &[JitFusionHandle<R>],
inputs: &[&TensorDescription],
_outputs: &[&TensorDescription],
) -> Priority {
let is_unavailable_input = |handle: &WgpuFusionHandle<R>, desc: &TensorDescription| {
let is_unavailable_input = |handle: &JitFusionHandle<R>, desc: &TensorDescription| {
let rank = handle.strides.len();
// Last dimension strides should be 1, otherwise vecX won't be contiguous.
@ -96,7 +96,7 @@ impl<R: Runtime> FusionKernel<R> for VecElementWise<R> {
impl<R: Runtime> ElementWiseSource<R> {
fn kernel(
&self,
handles_inputs: &[WgpuFusionHandle<R>],
handles_inputs: &[JitFusionHandle<R>],
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
) -> SelectedKernel {
@ -110,7 +110,7 @@ impl<R: Runtime> ElementWiseSource<R> {
match inplace_available(&self.mappings, handles_inputs) {
true => {
let reference_tensor = inputs[self.mappings[0].position_input];
let reference_tensor = inputs[self.mappings[0].pos_input];
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
let workgroup = elemwise_workgroup(num_elems / self.factor, workgroup_size);
let kernel = Box::new(DynamicKernel::new(self.source_inplace.clone(), workgroup));
@ -172,7 +172,7 @@ impl<R: Runtime> ElementWiseSource<R> {
let mut inplace_output2input = vec![None; num_output];
for mapping in mappings.iter() {
inplace_output2input[mapping.position_output] = Some(mapping.position_input);
inplace_output2input[mapping.pos_output] = Some(mapping.pos_input);
}
Self {
@ -214,14 +214,14 @@ impl<R: Runtime> VecElementWise<R> {
fn inplace_available<R: Runtime>(
mappings: &[InplaceMapping],
handles_inputs: &[WgpuFusionHandle<R>],
handles_inputs: &[JitFusionHandle<R>],
) -> bool {
if mappings.is_empty() {
return false;
}
for mapping in mappings.iter() {
let handle = &handles_inputs[mapping.position_input];
let handle = &handles_inputs[mapping.pos_input];
if !handle.handle.can_mut() {
return false;

View File

@ -4,163 +4,57 @@ use super::{
FusionElemWiseAutotuneKey,
};
use crate::{
codegen::dialect::gpu::{Elem, Item, Operation, Vectorization, Visibility, WorkgroupSize},
codegen::{ElemWiseKernelCodegen, InplaceMapping, Input, Output, ReadingStrategy},
codegen::{
dialect::gpu::{Vectorization, WorkgroupSize},
Compilation, CompilationInfo, CompilationSettings,
},
compute::JitAutotuneKey,
fusion::{kernel::FusionKernelSet, source::GpuKernelSource},
fusion::{kernel::FusionKernelSet, source::GpuKernelSource, tracing::Trace},
JitBackend, Runtime,
};
use burn_common::id::IdGenerator;
use burn_compute::client::ComputeClient;
use burn_fusion::{stream::Context, TensorDescription};
use burn_fusion::stream::Context;
use serde::{Deserialize, Serialize};
#[derive(new)]
pub struct ElementWise<R: Runtime, Phase = ExecutionPhase<R>> {
pub(super) inputs: Vec<(TensorDescription, Elem)>,
pub(super) outputs: Vec<(TensorDescription, Elem)>,
pub(super) locals: Vec<u16>,
pub(super) scalars: Scalars,
pub(super) operators: Vec<Operation>,
pub(super) trace: Trace,
pub(super) num_operations: usize,
pub(super) device: R::Device,
pub(super) phase: Phase,
}
#[derive(new, Clone, Serialize, Deserialize)]
pub struct Scalars {
pub(super) num_f32: usize,
pub(super) num_u32: usize,
pub(super) num_i32: usize,
}
/// Phase where the kernel should be compiled.
pub struct CompilationPhase;
/// Phase where the kernel should be executed.
#[derive(new)]
pub struct ExecutionPhase<R: Runtime> {
/// Kernel set with default workgroup size.
pub(super) kernel_set_1: FusionKernelSet<R>,
/// Kernel set with custom workgroup size.
pub(super) kernel_set_2: FusionKernelSet<R>,
}
#[derive(Serialize, Deserialize)]
#[derive(new, Serialize, Deserialize)]
pub struct ElementWiseState {
inputs: Vec<(TensorDescription, Elem)>,
outputs: Vec<(TensorDescription, Elem)>,
scalars: Scalars,
operators: Vec<Operation>,
locals: Vec<u16>,
trace: Trace,
num_operations: usize,
}
impl<R: Runtime> ElementWise<R, CompilationPhase> {
pub(crate) fn compile(self) -> ElementWise<R, ExecutionPhase<R>> {
let mut inputs = self
.inputs
.iter()
.map(|(_tensor, elem)| Input::Array {
item: Item::Scalar(*elem),
visibility: Visibility::Read,
strategy: ReadingStrategy::OutputLayout,
})
.collect::<Vec<_>>();
let info = self.trace.compiling();
let outputs = self
.outputs
.iter()
.zip(self.locals.iter())
.map(|((_tensor, elem), local)| Output::Array {
item: Item::Scalar(*elem),
local: *local,
})
.collect::<Vec<_>>();
if self.scalars.num_f32 > 0 {
inputs.push(Input::Scalar {
elem: Elem::Float,
size: self.scalars.num_f32,
})
}
if self.scalars.num_u32 > 0 {
inputs.push(Input::Scalar {
elem: Elem::UInt,
size: self.scalars.num_u32,
})
}
if self.scalars.num_i32 > 0 {
inputs.push(Input::Scalar {
elem: Elem::Int,
size: self.scalars.num_i32,
})
}
let mut potential_inplace = self
.inputs
.iter()
.zip(inputs.iter())
.enumerate()
.filter(|(_pos, ((desc, _elem), _input))| match desc.status {
burn_fusion::TensorStatus::ReadOnly => false,
burn_fusion::TensorStatus::ReadWrite => true,
burn_fusion::TensorStatus::NotInit => false,
})
.map(|(pos, ((desc, elem), input))| (pos, desc, elem, input))
.collect::<Vec<_>>();
let mappings = self
.outputs
.iter()
.zip(outputs.iter())
.enumerate()
.filter_map(|(pos, ((desc, elem), _output))| {
if potential_inplace.is_empty() {
return None;
}
let mut chosen = None;
for (index, (_pos_input, desc_input, elem_input, _input)) in
potential_inplace.iter().enumerate()
{
if chosen.is_some() {
break;
}
if desc.shape == desc_input.shape && *elem_input == elem {
chosen = Some(index);
}
}
match chosen {
Some(index) => {
let input = potential_inplace.remove(index);
Some(InplaceMapping::new(input.0, pos))
}
None => None,
}
})
.collect::<Vec<_>>();
let kernel_set_1 = build_kernel_set::<R>(
&inputs,
&outputs,
&self.operators,
&mappings,
WorkgroupSize::default(),
);
let kernel_set_2 = build_kernel_set::<R>(
&inputs,
&outputs,
&self.operators,
&mappings,
WorkgroupSize::new(16, 16, 1),
);
let kernel_set_1 = build_kernel_set::<R>(&info, WorkgroupSize::default());
let kernel_set_2 = build_kernel_set::<R>(&info, WorkgroupSize::new(16, 16, 1));
ElementWise {
inputs: self.inputs,
outputs: self.outputs,
scalars: self.scalars,
trace: self.trace,
device: self.device,
operators: self.operators,
locals: self.locals,
phase: ExecutionPhase::new(kernel_set_1, kernel_set_2),
num_operations: self.num_operations,
}
}
}
@ -170,7 +64,7 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
let client = R::client(&self.device);
let key = JitAutotuneKey::FusionElemWise(FusionElemWiseAutotuneKey::new(
self.operators.len(),
self.num_operations,
self.autotune_shape(context),
));
@ -187,22 +81,14 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
client: ComputeClient<R::Server, R::Channel>,
fastest_set_index: usize,
) {
let info = self.trace.running();
let kernel_set = match fastest_set_index {
0 => &self.phase.kernel_set_1,
1 => &self.phase.kernel_set_2,
_ => panic!("Should be 0 or 1, got {fastest_set_index}"),
};
let kernel = kernel_set.select(
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
self.scalars.num_f32,
self.scalars.num_i32,
context,
self.device.clone(),
client,
true,
);
let kernel = kernel_set.select(&info, context, self.device.clone(), client, true);
kernel.execute();
}
@ -213,31 +99,24 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
client: ComputeClient<R::Server, R::Channel>,
key: JitAutotuneKey,
) {
let info = self.trace.running();
let kernel_1 = self.phase.kernel_set_1.select(
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
self.scalars.num_f32,
self.scalars.num_i32,
&info,
context,
self.device.clone(),
client.clone(),
false, // Should not mutate the context.
);
let kernel_2 = self.phase.kernel_set_1.select(
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
self.scalars.num_f32,
self.scalars.num_i32,
&info,
context,
self.device.clone(),
client.clone(),
false, // Should not mutate the context.
);
let kernel_default = self.phase.kernel_set_1.select(
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
self.scalars.num_f32,
self.scalars.num_i32,
&info,
context,
self.device.clone(),
client.clone(),
@ -253,7 +132,7 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
}
pub(crate) fn len(&self) -> usize {
self.operators.len()
self.num_operations
}
/// The first output is chosen when possible, otherwise the first input is chosen.
@ -261,13 +140,15 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
&self,
context: &mut Context<'a, JitBackend<R>>,
) -> &'a [usize] {
if let Some(tensor) = self.outputs.first() {
let tensor = context.tensors.get(&tensor.0.id).unwrap();
let info = self.trace.running();
if let Some(tensor) = info.outputs.first() {
let tensor = context.tensors.get(&tensor.id).unwrap();
return &tensor.shape;
}
if let Some(tensor) = self.inputs.first() {
let tensor = context.tensors.get(&tensor.0.id).unwrap();
if let Some(tensor) = info.inputs.first() {
let tensor = context.tensors.get(&tensor.id).unwrap();
return &tensor.shape;
}
@ -281,109 +162,86 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
// It is still unclear if the deserialization would be that much faster than
// simply recompiling it.
ElementWise {
inputs: state.inputs,
outputs: state.outputs,
scalars: state.scalars,
trace: state.trace,
device: device.clone(),
locals: state.locals,
operators: state.operators,
phase: CompilationPhase,
num_operations: state.num_operations,
}
.compile()
}
pub(crate) fn to_state(&self) -> ElementWiseState {
ElementWiseState {
inputs: self.inputs.clone(),
outputs: self.outputs.clone(),
scalars: self.scalars.clone(),
operators: self.operators.clone(),
locals: self.locals.clone(),
trace: self.trace.clone(),
num_operations: self.num_operations,
}
}
}
fn build_kernel_set<R: Runtime>(
inputs: &[Input],
outputs: &[Output],
operators: &[Operation],
mappings: &[InplaceMapping],
info: &CompilationInfo,
workgroup_size: WorkgroupSize,
) -> FusionKernelSet<R> {
let scalar = ScalarElementWise::<R>::new(
GpuKernelSource::new(
IdGenerator::generate(),
ElemWiseKernelCodegen::new()
.inputs(inputs)
.body(operators)
.outputs(outputs)
.workgroup_size(workgroup_size)
.compile(),
Compilation::new(info.clone())
.compile(CompilationSettings::default().workgroup_size(workgroup_size)),
),
GpuKernelSource::new(
IdGenerator::generate(),
ElemWiseKernelCodegen::new()
.inplace(mappings)
.inputs(inputs)
.body(operators)
.outputs(outputs)
.workgroup_size(workgroup_size)
.compile(),
Compilation::new(info.clone()).compile(
CompilationSettings::default()
.inplace(true)
.workgroup_size(workgroup_size),
),
),
mappings.to_vec(),
outputs.len(),
info.mappings.to_vec(),
info.outputs.len(),
);
let vec2 = VecElementWise::<R>::new(
GpuKernelSource::new(
IdGenerator::generate(),
ElemWiseKernelCodegen::new()
.vectorize(Vectorization::Vec2)
.inputs(inputs)
.body(operators)
.outputs(outputs)
.workgroup_size(workgroup_size)
.compile(),
Compilation::new(info.clone()).compile(
CompilationSettings::default()
.vectorize(Vectorization::Vec2)
.workgroup_size(workgroup_size),
),
),
GpuKernelSource::new(
IdGenerator::generate(),
ElemWiseKernelCodegen::new()
.vectorize(Vectorization::Vec2)
.inplace(mappings)
.inputs(inputs)
.body(operators)
.outputs(outputs)
.workgroup_size(workgroup_size)
.compile(),
Compilation::new(info.clone()).compile(
CompilationSettings::default()
.inplace(true)
.vectorize(Vectorization::Vec2)
.workgroup_size(workgroup_size),
),
),
mappings.to_vec(),
outputs.len(),
info.mappings.to_vec(),
info.outputs.len(),
2,
);
let vec4 = VecElementWise::<R>::new(
GpuKernelSource::new(
IdGenerator::generate(),
ElemWiseKernelCodegen::new()
.vectorize(Vectorization::Vec4)
.inputs(inputs)
.body(operators)
.outputs(outputs)
.workgroup_size(workgroup_size)
.compile(),
Compilation::new(info.clone()).compile(
CompilationSettings::default()
.vectorize(Vectorization::Vec4)
.workgroup_size(workgroup_size),
),
),
GpuKernelSource::new(
IdGenerator::generate(),
ElemWiseKernelCodegen::new()
.vectorize(Vectorization::Vec4)
.inplace(mappings)
.inputs(inputs)
.body(operators)
.outputs(outputs)
.workgroup_size(workgroup_size)
.compile(),
Compilation::new(info.clone()).compile(
CompilationSettings::default()
.inplace(true)
.vectorize(Vectorization::Vec4)
.workgroup_size(workgroup_size),
),
),
mappings.to_vec(),
outputs.len(),
info.mappings.to_vec(),
info.outputs.len(),
4,
);

View File

@ -1,6 +1,6 @@
use crate::compute::Kernel;
use crate::fusion::strides_dyn_rank;
use crate::fusion::WgpuFusionHandle;
use crate::fusion::JitFusionHandle;
use crate::JitBackend;
use crate::Runtime;
use burn_compute::client::ComputeClient;
@ -11,6 +11,8 @@ use burn_fusion::{TensorDescription, TensorStatus};
use burn_tensor::Device;
use std::sync::Arc;
use super::tracing::ExecutionInfo;
/// Many kernels can be used for the same set of tensor operations fused into one.
///
/// This type makes it easy to group those potential kernels and execute the best one depending on the context.
@ -104,14 +106,14 @@ pub trait FusionKernel<R: Runtime>: Send + Sync {
/// Returns the priority of this kernel based on the input and output information.
fn priority(
&self,
handles_inputs: &[WgpuFusionHandle<R>],
handles_inputs: &[JitFusionHandle<R>],
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
) -> Priority;
/// Returns a [selected kernel](SelectedKernel) that can be executed by the compute server.
fn kernel(
&self,
handles_inputs: &[WgpuFusionHandle<R>],
handles_inputs: &[JitFusionHandle<R>],
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
) -> SelectedKernel;
@ -119,20 +121,21 @@ pub trait FusionKernel<R: Runtime>: Send + Sync {
impl<R: Runtime> FusionKernelSet<R> {
/// Select the best kernel based on the given information.
#[allow(clippy::too_many_arguments)]
pub fn select(
&self,
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
scalars_f32: usize,
scalars_i32: usize,
running_info: &ExecutionInfo<'_>,
context: &mut Context<'_, JitBackend<R>>,
device: Device<JitBackend<R>>,
client: ComputeClient<R::Server, R::Channel>,
stateful: bool,
) -> ExecutableKernel<R> {
let (handles_input, inputs_description_updated, outputs_description_updated) =
process_inputs_outputs(inputs, outputs, context, stateful);
process_inputs_outputs(
&running_info.inputs,
&running_info.outputs,
context,
stateful,
);
let selected = self.select_kernel(
&handles_input,
@ -140,19 +143,27 @@ impl<R: Runtime> FusionKernelSet<R> {
&outputs_description_updated,
);
let rank_input = inputs.first().map(|desc| desc.shape.len()).unwrap_or(1);
let rank_output = outputs.first().map(|desc| desc.shape.len()).unwrap_or(1);
let rank_input = running_info
.inputs
.first()
.map(|desc| desc.shape.len())
.unwrap_or(1);
let rank_output = running_info
.outputs
.first()
.map(|desc| desc.shape.len())
.unwrap_or(1);
let rank = usize::max(rank_input, rank_output);
let num_tensors = inputs.len() + outputs.len();
let num_tensors = running_info.inputs.len() + running_info.outputs.len();
// The buffer starts with the rank, then each tensor shape and stride.
let info_size = (num_tensors * rank * 2) + 1;
let mut num_handles = num_tensors + 1;
if scalars_f32 > 0 {
if running_info.scalars.num_float > 0 {
num_handles += 1;
}
if scalars_i32 > 0 {
if running_info.scalars.num_int > 0 {
num_handles += 1;
}
@ -175,7 +186,7 @@ impl<R: Runtime> FusionKernelSet<R> {
// Use the input inplace for this output.
OutputInfo::Inplace { input_index } => {
let handle = handles.get(*input_index).unwrap().clone();
let handle_fusion = WgpuFusionHandle {
let handle_fusion = JitFusionHandle {
client: client.clone(),
device: device.clone(),
strides: strides_dyn_rank(&tensor.shape),
@ -185,7 +196,7 @@ impl<R: Runtime> FusionKernelSet<R> {
}
// Create a new buffer for this output.
OutputInfo::Array { size } => {
let handle_fusion = WgpuFusionHandle {
let handle_fusion = JitFusionHandle {
client: client.clone(),
device: device.clone(),
strides: strides_dyn_rank(&tensor.shape),
@ -203,13 +214,16 @@ impl<R: Runtime> FusionKernelSet<R> {
handles.push(client.create(bytemuck::cast_slice(&info)));
// Finally we finish with the named bindings.
if scalars_f32 > 0 {
handles
.push(client.create(bytemuck::cast_slice(&context.scalar_floats[0..scalars_f32])));
if running_info.scalars.num_float > 0 {
handles.push(client.create(bytemuck::cast_slice(
&context.scalar_floats[0..running_info.scalars.num_float],
)));
}
if scalars_i32 > 0 {
handles.push(client.create(bytemuck::cast_slice(&context.scalar_ints[0..scalars_i32])));
if running_info.scalars.num_int > 0 {
handles.push(client.create(bytemuck::cast_slice(
&context.scalar_ints[0..running_info.scalars.num_int],
)));
}
// We have to register the output handles to the context.
@ -222,7 +236,7 @@ impl<R: Runtime> FusionKernelSet<R> {
fn select_kernel(
&self,
handles_input: &[WgpuFusionHandle<R>],
handles_input: &[JitFusionHandle<R>],
inputs: &[&TensorDescription],
outputs: &[&TensorDescription],
) -> SelectedKernel {
@ -249,7 +263,7 @@ impl<R: Runtime> FusionKernelSet<R> {
fn register_info_tensor<R: Runtime>(
info: &mut Vec<u32>,
tensor: &TensorDescription,
handle: &WgpuFusionHandle<R>,
handle: &JitFusionHandle<R>,
) {
if info.is_empty() {
info.push(handle.strides.len() as u32);
@ -269,7 +283,7 @@ fn process_inputs_outputs<'a, R: Runtime>(
context: &'a mut Context<'_, JitBackend<R>>,
stateful: bool,
) -> (
Vec<WgpuFusionHandle<R>>,
Vec<JitFusionHandle<R>>,
Vec<&'a TensorDescription>,
Vec<&'a TensorDescription>,
) {

View File

@ -3,6 +3,7 @@ mod elemwise;
pub(crate) mod kernel;
pub(crate) mod source;
pub(crate) mod tracing;
pub use base::*;
pub(crate) use elemwise::*;

View File

@ -0,0 +1,8 @@
use serde::{Deserialize, Serialize};
#[derive(Default, Clone, Serialize, Deserialize)]
pub struct Scalars {
pub(crate) num_float: usize,
pub(crate) num_int: usize,
pub(crate) num_uint: usize,
pub(crate) num_bool: usize,
}

View File

@ -0,0 +1,352 @@
use super::{trace::Trace, Scalars};
use crate::codegen::dialect::gpu::{self, Operation, Variable};
use burn_fusion::{TensorDescription, TensorId};
use burn_tensor::Element;
use hashbrown::HashMap;
/// Type facilitating building a [trace](Trace) by doing most of the conversions between the
/// operations provided in [burn_fusion] and the [gpu dialect](gpu).
#[derive(Clone)]
pub struct TraceBuilder {
// Input tensor descriptions with the variables created after reading from global memory.
inputs: Vec<(TensorDescription, Variable)>,
// Each output tensor id with the output variable index created by the operation.
output_to_local: HashMap<TensorId, u16>,
tensors: HashMap<TensorId, (TensorDescription, gpu::Elem)>,
scalars: Scalars,
scope: gpu::Scope,
}
impl TraceBuilder {
/// Create a new builder.
pub fn new() -> Self {
Self {
inputs: Vec::new(),
output_to_local: HashMap::new(),
tensors: HashMap::new(),
scalars: Scalars::default(),
scope: gpu::Scope::root(),
}
}
/// Register a [gpu operation](gpu::Operation).
pub fn register_operation<T: Into<gpu::Operation>>(&mut self, value: T) {
self.scope.register(value)
}
/// Create a variable from an input [tensor description](TensorDescription).
pub fn input(&mut self, tensor: &TensorDescription, elem: gpu::Elem) -> gpu::Variable {
let already_exists = self.tensors.contains_key(&tensor.id);
let variable = match already_exists {
false => {
// New input
let index = self.inputs.len() as u16;
let item = gpu::Item::Scalar(elem);
let local = self.scope.read_array(index, item);
self.inputs.push((tensor.clone(), local));
local
}
true => match self.output_to_local.get(&tensor.id) {
// Is a local variable.
Some(local_index) => {
gpu::Variable::Local(*local_index, gpu::Item::Scalar(elem), self.scope.depth)
}
// Isn't an operation output variable, so must be an existing input.
None => self
.inputs
.iter()
.find(|(input, _local)| input.id == tensor.id)
.map(|(_, local)| *local)
.unwrap(),
},
};
// Update the tensor description with the new version.
self.tensors.insert(tensor.id, (tensor.clone(), elem));
variable
}
/// Create a variable from an output [tensor description](TensorDescription).
pub fn output(&mut self, tensor: &TensorDescription, elem: gpu::Elem) -> gpu::Variable {
// Update the tensor description to the new version.
self.tensors.insert(tensor.id, (tensor.clone(), elem));
// Output already registered as a local variable.
if let Some(index) = self.output_to_local.get(&tensor.id) {
return gpu::Variable::Local(*index, gpu::Item::Scalar(elem), self.scope.depth);
}
let variable = self.scope.create_local(gpu::Item::Scalar(elem));
let local_index = variable.index().unwrap();
self.output_to_local.insert(tensor.id, local_index);
variable
}
/// Create a variable from an input [scalar](Element).
pub fn scalar<E: Element>(&mut self, _value: &E, elem_type: gpu::Elem) -> gpu::Variable {
match elem_type {
gpu::Elem::Float => {
let var = self
.scope
.read_scalar(self.scalars.num_float as u16, elem_type);
self.scalars.num_float += 1;
var
}
gpu::Elem::Int => {
let var = self
.scope
.read_scalar(self.scalars.num_int as u16, elem_type);
self.scalars.num_int += 1;
var
}
gpu::Elem::UInt => {
let var = self
.scope
.read_scalar(self.scalars.num_uint as u16, elem_type);
self.scalars.num_uint += 1;
var
}
gpu::Elem::Bool => {
let var = self
.scope
.read_scalar(self.scalars.num_bool as u16, elem_type);
self.scalars.num_bool += 1;
var
}
}
}
/// Build the [trace](Trace).
pub fn build(self) -> Trace {
let inputs = self.input_descriptions();
let outputs = self.output_descriptions();
let locals = outputs
.iter()
.map(|out| *self.output_to_local.get(&out.0.id).unwrap())
.collect::<Vec<_>>();
Trace::new(inputs, outputs, locals, self.scalars, self.scope)
}
fn input_descriptions(&self) -> Vec<(TensorDescription, gpu::Elem)> {
self.inputs
.iter()
.map(|(input, _local)| {
let updated_tensor = self.tensors.get(&input.id).unwrap();
updated_tensor.clone()
})
.collect::<Vec<_>>()
}
fn output_descriptions(&self) -> Vec<(TensorDescription, gpu::Elem)> {
let mut outputs = Vec::new();
let mut local_tensor_ids_input = Vec::new();
let mut local_tensor_ids_output = Vec::new();
// Mark a variable to the provided list of tensor ids using the variable list.
//
// Only local variables can become outputs.
let mark = |var: &gpu::Variable, list: &mut Vec<TensorId>| {
if let gpu::Variable::Local(index, _, _) = var {
if let Some((id, _)) = self
.output_to_local
.iter()
.find(|(_id, position)| *position == index)
{
if !list.contains(id) {
list.push(*id);
}
}
}
};
let mark_binary =
|op: &gpu::BinaryOperator, inputs: &mut Vec<TensorId>, outputs: &mut Vec<TensorId>| {
mark(&op.lhs, inputs);
mark(&op.rhs, inputs);
mark(&op.out, outputs);
};
let mark_unary =
|op: &gpu::UnaryOperator, inputs: &mut Vec<TensorId>, outputs: &mut Vec<TensorId>| {
mark(&op.input, inputs);
mark(&op.out, outputs);
};
// For all operators, mark their local tensor id in the proper set.
for op in self.scope.operations.iter() {
match op {
Operation::Operator(op) => {
match op {
gpu::Operator::AssignGlobal(_) => {
// Nothing to do here.
}
gpu::Operator::AssignLocal(op) => {
mark(&op.out, &mut local_tensor_ids_output);
}
gpu::Operator::Add(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Index(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Sub(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Mul(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Div(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Exp(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Abs(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Erf(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Log(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Log1p(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Cos(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Sin(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Tanh(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Clamp(op) => {
mark(&op.input, &mut local_tensor_ids_input);
mark(&op.out, &mut local_tensor_ids_output);
}
gpu::Operator::Powf(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Recip(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Lower(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Greater(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::LowerEqual(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::GreaterEqual(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Equal(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::ConditionalAssign(op) => {
mark(&op.cond, &mut local_tensor_ids_input);
mark(&op.lhs, &mut local_tensor_ids_input);
mark(&op.rhs, &mut local_tensor_ids_input);
mark(&op.out, &mut local_tensor_ids_output);
}
gpu::Operator::Sqrt(op) => mark_unary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
gpu::Operator::Modulo(op) => mark_binary(
op,
&mut local_tensor_ids_input,
&mut local_tensor_ids_output,
),
}
}
Operation::Algorithm(algo) => {
match algo {
gpu::Algorithm::ReadGlobalWithLayout(_) => {
// Nothing to do here.
}
gpu::Algorithm::ReadGlobal(_) => {
// Nothing to do here.
}
}
}
Operation::Metadata(_) => {
// Nothing to do, should never impact read-write access to bindings.
}
Operation::Loop(_) => {
// Nothing to do, should never impact read-write access to bindings.
}
}
}
// All output tensors that are never read by a following operation should be written to
// since they are essentially the "logical" output of the shader.
for out in local_tensor_ids_output {
let is_read = local_tensor_ids_input.contains(&out);
if !is_read {
outputs.push(self.tensors.get(&out).unwrap().clone());
}
}
// All tensors where their latest description is read only should be written to since they
// are going to be used after the fused kernel by other operations.
for entry in self.tensors.values() {
let (tensor, _) = &entry;
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
if self.output_to_local.contains_key(&tensor.id) {
outputs.push(entry.clone());
}
}
}
outputs
}
}

View File

@ -0,0 +1,7 @@
mod base;
mod builder;
mod trace;
pub use base::*;
pub use builder::*;
pub use trace::*;

View File

@ -0,0 +1,133 @@
use super::Scalars;
use crate::codegen::{dialect::gpu, CompilationInfo, InplaceMapping, InputInfo, OutputInfo};
use burn_fusion::TensorDescription;
use serde::{Deserialize, Serialize};
/// A trace encaptulates all information necessary to perform the compilation and execution of
/// captured [tensor operations](burn_fusion::stream::OperationDescription).
///
/// A trace should be built using a [builder](super::TraceBuilder).
#[derive(new, Clone, Serialize, Deserialize)]
pub struct Trace {
inputs: Vec<(TensorDescription, gpu::Elem)>,
outputs: Vec<(TensorDescription, gpu::Elem)>,
locals: Vec<u16>,
scalars: Scalars,
scope: gpu::Scope,
}
/// Information necessary to execute a kernel.
pub struct ExecutionInfo<'a> {
/// Tensor inputs.
pub inputs: Vec<&'a TensorDescription>,
/// Tensor outputs.
pub outputs: Vec<&'a TensorDescription>,
/// Scalar inputs.
pub scalars: &'a Scalars,
}
impl Trace {
/// Collect information related to running the trace.
pub fn running(&self) -> ExecutionInfo<'_> {
ExecutionInfo {
inputs: self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
outputs: self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
scalars: &self.scalars,
}
}
/// Collect information related to compiling the trace.
pub fn compiling(&self) -> CompilationInfo {
let mut inputs = self
.inputs
.iter()
.map(|(_tensor, elem)| InputInfo::Array {
item: gpu::Item::Scalar(*elem),
visibility: gpu::Visibility::Read,
})
.collect::<Vec<_>>();
let outputs = self
.outputs
.iter()
.zip(self.locals.iter())
.map(|((_tensor, elem), local)| OutputInfo::Array {
item: gpu::Item::Scalar(*elem),
local: *local,
})
.collect::<Vec<_>>();
if self.scalars.num_float > 0 {
inputs.push(InputInfo::Scalar {
elem: gpu::Elem::Float,
size: self.scalars.num_float,
})
}
if self.scalars.num_uint > 0 {
inputs.push(InputInfo::Scalar {
elem: gpu::Elem::UInt,
size: self.scalars.num_uint,
})
}
if self.scalars.num_int > 0 {
inputs.push(InputInfo::Scalar {
elem: gpu::Elem::Int,
size: self.scalars.num_int,
})
}
let mut potential_inplace = self
.inputs
.iter()
.zip(inputs.iter())
.enumerate()
.filter(|(_pos, ((desc, _elem), _input))| match desc.status {
burn_fusion::TensorStatus::ReadOnly => false,
burn_fusion::TensorStatus::ReadWrite => true,
burn_fusion::TensorStatus::NotInit => false,
})
.map(|(pos, ((desc, elem), input))| (pos, desc, elem, input))
.collect::<Vec<_>>();
let mappings = self
.outputs
.iter()
.zip(outputs.iter())
.enumerate()
.filter_map(|(pos, ((desc, elem), _output))| {
if potential_inplace.is_empty() {
return None;
}
let mut chosen = None;
for (index, (_pos_input, desc_input, elem_input, _input)) in
potential_inplace.iter().enumerate()
{
if chosen.is_some() {
break;
}
if desc.shape == desc_input.shape && *elem_input == elem {
chosen = Some(index);
}
}
match chosen {
Some(index) => {
let input = potential_inplace.remove(index);
Some(InplaceMapping::new(input.0, pos))
}
None => None,
}
})
.collect::<Vec<_>>();
CompilationInfo {
inputs,
outputs,
scope: self.scope.clone(),
mappings,
}
}
}

View File

@ -49,6 +49,46 @@ macro_rules! binary {
_o: core::marker::PhantomData<O>,
}
#[allow(clippy::redundant_closure_call)]
fn compile<C, I, O>(
settings: $crate::codegen::CompilationSettings,
mappings: Vec<$crate::codegen::InplaceMapping>,
) -> $crate::kernel::SourceTemplate
where
C: $crate::codegen::Compiler,
I: $crate::element::JitElement,
O: $crate::element::JitElement
{
let mut scope = $crate::codegen::dialect::gpu::Scope::root();
let op = $ops(&mut scope, I::gpu_elem());
scope.register(op);
let local = scope.last_local_index().unwrap().index().unwrap();
let lhs = $crate::codegen::InputInfo::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
};
let rhs = $crate::codegen::InputInfo::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
};
let out = $crate::codegen::OutputInfo::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(O::gpu_elem()),
local,
};
let info = $crate::codegen::CompilationInfo {
inputs: vec![lhs, rhs],
outputs: vec![out],
scope,
mappings,
};
let shader = $crate::codegen::Compilation::new(info).compile(settings);
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
}
#[allow(clippy::redundant_closure_call)]
impl<C, I, O> $crate::kernel::StaticKernelSource for Ops<C, I, O>
where
@ -57,28 +97,8 @@ macro_rules! binary {
O: $crate::element::JitElement
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
])
.body(&[$ops(I::gpu_elem())])
.outputs(&[$crate::codegen::Output::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(O::gpu_elem()),
local: 0,
}])
.compile();
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
let settings = $crate::codegen::CompilationSettings::default();
compile::<C, I, O>(settings, Vec::new())
}
}
@ -91,29 +111,13 @@ macro_rules! binary {
O: $crate::element::JitElement
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::ReadWrite,
strategy: $crate::codegen::ReadingStrategy::Plain,
},
$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
])
.body(&[$ops(I::gpu_elem())])
.outputs(&[$crate::codegen::Output::Input {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
input: 0,
local: 0,
}])
.compile();
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
let settings = $crate::codegen::CompilationSettings::default()
.inplace(true);
let mapping = $crate::codegen::InplaceMapping {
pos_input: 0,
pos_output: 0,
};
compile::<C, I, O>(settings, vec![mapping])
}
}
@ -126,29 +130,13 @@ macro_rules! binary {
O: $crate::element::JitElement
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::ReadWrite,
strategy: $crate::codegen::ReadingStrategy::Plain,
},
])
.body(&[$ops(I::gpu_elem())])
.outputs(&[$crate::codegen::Output::Input {
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
input: 1,
local: 0,
}])
.compile();
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
let settings = $crate::codegen::CompilationSettings::default()
.inplace(true);
let mapping = $crate::codegen::InplaceMapping {
pos_input: 1,
pos_output: 0,
};
compile::<C, I, O>(settings, vec![mapping])
}
}
};

View File

@ -1,6 +1,6 @@
use super::unary;
use crate::{
codegen::dialect::gpu::{ClampOperation, Item, Operation, Variable},
codegen::dialect::gpu::{ClampOperator, Operator, Scope},
element::JitElement,
tensor::JitTensor,
unary, Runtime,
@ -12,11 +12,11 @@ pub(crate) fn clamp<R: Runtime, E: JitElement, const D: usize>(
max_value: E,
) -> JitTensor<R, E, D> {
unary!(
operation: |elem| Operation::Clamp(ClampOperation {
input: Variable::Input(0, Item::Scalar(elem)),
min_value: Variable::Scalar(0, Item::Scalar(elem)),
max_value: Variable::Scalar(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem| Operator::Clamp(ClampOperator {
input: scope.read_array(0, elem),
min_value: scope.read_scalar(0, elem),
max_value: scope.read_scalar(1, elem),
out: scope.create_local(elem),
}),
compiler: R::Compiler,
scalar 2

View File

@ -1,6 +1,6 @@
use crate::{
binary,
codegen::dialect::gpu::{BinaryOperation, Elem, Item, Operation, Variable},
codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope},
element::JitElement,
kernel::StaticKernelSource,
kernel::{binary::binary, unary::unary},
@ -51,10 +51,10 @@ pub fn equal<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, u32, D> {
comparison!(
binary: |elem: Elem| Operation::Equal(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
binary: |scope: &mut Scope, elem: Elem| Operator::Equal(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,
@ -67,10 +67,10 @@ pub fn greater<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, u32, D> {
comparison!(
binary: |elem: Elem| Operation::Greater(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
binary: |scope: &mut Scope, elem: Elem| Operator::Greater(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,
@ -83,10 +83,10 @@ pub fn greater_equal<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, u32, D> {
comparison!(
binary: |elem: Elem| Operation::GreaterEqual(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
binary: |scope: &mut Scope, elem: Elem| Operator::GreaterEqual(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,
@ -99,10 +99,10 @@ pub fn lower<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, u32, D> {
comparison!(
binary: |elem: Elem| Operation::Lower(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
binary: |scope: &mut Scope, elem: Elem| Operator::Lower(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,
@ -115,10 +115,10 @@ pub fn lower_equal<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, u32, D> {
comparison!(
binary: |elem: Elem| Operation::LowerEqual(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
binary: |scope: &mut Scope, elem: Elem| Operator::LowerEqual(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,
@ -131,10 +131,10 @@ pub fn equal_elem<R: Runtime, E: JitElement, const D: usize>(
rhs: E,
) -> JitTensor<R, u32, D> {
comparison!(
unary: |elem: Elem| Operation::Equal(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
unary: |scope: &mut Scope, elem: Elem| Operator::Equal(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,
@ -147,10 +147,10 @@ pub fn greater_elem<R: Runtime, E: JitElement, const D: usize>(
rhs: E,
) -> JitTensor<R, u32, D> {
comparison!(
unary: |elem: Elem| Operation::Greater(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
unary: |scope: &mut Scope, elem: Elem| Operator::Greater(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,
@ -163,10 +163,10 @@ pub fn lower_elem<R: Runtime, E: JitElement, const D: usize>(
rhs: E,
) -> JitTensor<R, u32, D> {
comparison!(
unary: |elem: Elem| Operation::Lower(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
unary: |scope: &mut Scope, elem: Elem| Operator::Lower(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,
@ -179,10 +179,10 @@ pub fn greater_equal_elem<R: Runtime, E: JitElement, const D: usize>(
rhs: E,
) -> JitTensor<R, u32, D> {
comparison!(
unary: |elem: Elem| Operation::GreaterEqual(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
unary: |scope: &mut Scope, elem: Elem| Operator::GreaterEqual(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,
@ -195,10 +195,10 @@ pub fn lower_equal_elem<R: Runtime, E: JitElement, const D: usize>(
rhs: E,
) -> JitTensor<R, u32, D> {
comparison!(
unary: |elem: Elem| Operation::LowerEqual(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
unary: |scope: &mut Scope, elem: Elem| Operator::LowerEqual(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(Elem::Bool),
}),
runtime: R,
input: lhs; rhs,

View File

@ -55,6 +55,42 @@ macro_rules! unary {
_e: core::marker::PhantomData<E>,
}
#[allow(clippy::redundant_closure_call)]
fn compile<C, E>(
settings: $crate::codegen::CompilationSettings,
mappings: Vec<$crate::codegen::InplaceMapping>,
) -> $crate::kernel::SourceTemplate
where
C: $crate::codegen::Compiler,
E: $crate::element::JitElement
{
let mut scope = $crate::codegen::dialect::gpu::Scope::root();
let op = $ops(&mut scope, E::gpu_elem());
scope.register(op);
let local = scope.last_local_index().unwrap().index().unwrap();
let input = $crate::codegen::InputInfo::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
};
let out = $crate::codegen::OutputInfo::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
local,
};
let info = $crate::codegen::CompilationInfo {
inputs: vec![input],
outputs: vec![out],
scope,
mappings,
};
let shader = $crate::codegen::Compilation::new(info).compile(settings);
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
}
#[allow(clippy::redundant_closure_call)]
impl<C, E> $crate::kernel::StaticKernelSource for Ops<C, E>
where
@ -62,21 +98,8 @@ macro_rules! unary {
E: $crate::element::JitElement,
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
}])
.body(&[$ops(E::gpu_elem())])
.outputs(&[$crate::codegen::Output::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
local: 0,
}])
.compile();
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
let settings = $crate::codegen::CompilationSettings::default();
compile::<C, E>(settings, Vec::new())
}
}
@ -87,22 +110,13 @@ macro_rules! unary {
E: $crate::element::JitElement,
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::ReadWrite,
strategy: $crate::codegen::ReadingStrategy::Plain,
}])
.body(&[$ops(E::gpu_elem())])
.outputs(&[$crate::codegen::Output::Input {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
input: 0,
local: 0,
}])
.compile();
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
let settings = $crate::codegen::CompilationSettings::default()
.inplace(true);
let mapping = $crate::codegen::InplaceMapping {
pos_input: 0,
pos_output: 0,
};
compile::<C, E>(settings, vec![mapping])
}
}
};
@ -120,6 +134,46 @@ macro_rules! unary {
_e: core::marker::PhantomData<E>,
}
#[allow(clippy::redundant_closure_call)]
fn compile<C, E>(
settings: $crate::codegen::CompilationSettings,
mappings: Vec<$crate::codegen::InplaceMapping>,
) -> $crate::kernel::SourceTemplate
where
C: $crate::codegen::Compiler,
E: $crate::element::JitElement
{
let mut scope = $crate::codegen::dialect::gpu::Scope::root();
let op = $ops(&mut scope, E::gpu_elem());
scope.register(op);
let local = scope.last_local_index().unwrap().index().unwrap();
let input = $crate::codegen::InputInfo::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
};
let scalars = $crate::codegen::InputInfo::Scalar {
elem: E::gpu_elem(),
size: $num,
};
let out = $crate::codegen::OutputInfo::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
local,
};
let info = $crate::codegen::CompilationInfo {
inputs: vec![input, scalars],
outputs: vec![out],
scope,
mappings,
};
let shader = $crate::codegen::Compilation::new(info).compile(settings);
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
}
#[allow(clippy::redundant_closure_call)]
impl<C, E> $crate::kernel::StaticKernelSource for Ops<C, E>
where
@ -127,27 +181,8 @@ macro_rules! unary {
E: $crate::element::JitElement,
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
},
$crate::codegen::Input::Scalar {
elem: E::gpu_elem(),
size: $num,
},
])
.body(&[$ops(E::gpu_elem())])
.outputs(&[$crate::codegen::Output::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
local: 0,
}])
.compile();
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
let settings = $crate::codegen::CompilationSettings::default();
compile::<C, E>(settings, Vec::new())
}
}
@ -158,28 +193,13 @@ macro_rules! unary {
E: $crate::element::JitElement,
{
fn source() -> $crate::kernel::SourceTemplate {
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
.inputs(&[
$crate::codegen::Input::Array {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
visibility: $crate::codegen::dialect::gpu::Visibility::ReadWrite,
strategy: $crate::codegen::ReadingStrategy::Plain,
},
$crate::codegen::Input::Scalar {
elem: E::gpu_elem(),
size: $num,
},
])
.body(&[$ops(E::gpu_elem())])
.outputs(&[$crate::codegen::Output::Input {
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
input: 0,
local: 0,
}])
.compile();
let compiled = C::compile(shader);
$crate::kernel::SourceTemplate::new(compiled.to_string())
let settings = $crate::codegen::CompilationSettings::default()
.inplace(true);
let mapping = $crate::codegen::InplaceMapping {
pos_input: 0,
pos_output: 0,
};
compile::<C, E>(settings, vec![mapping])
}
}
};
@ -243,14 +263,14 @@ where
#[cfg(test)]
mod tests {
use super::*;
use crate::codegen::dialect::gpu::{Item, Operation, UnaryOperation, Variable};
use crate::codegen::dialect::gpu::{Operator, Scope, UnaryOperator};
use crate::tests::{ReferenceBackend, TestBackend, TestCompiler, TestRuntime};
use burn_tensor::{Distribution, Tensor};
unary!(
operation: |elem| Operation::Tanh(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem| Operator::Tanh(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
compiler: TestCompiler
);

View File

@ -1,7 +1,5 @@
use super::numeric;
use crate::codegen::dialect::gpu::{
BinaryOperation, Elem, Item, Operation, UnaryOperation, Variable,
};
use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryOperator};
#[cfg(not(feature = "autotune"))]
use crate::kernel::matmul::init_matmul_output;
#[cfg(feature = "autotune")]
@ -364,9 +362,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_exp<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Exp(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Exp(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,
@ -376,9 +374,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Log(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Log(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,
@ -388,9 +386,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Log1p(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Log1p(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,
@ -403,10 +401,10 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
rhs: f32,
) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Powf(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Powf(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs.elem(),
@ -416,9 +414,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Sqrt(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Sqrt(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,
@ -428,9 +426,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Abs(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Abs(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,
@ -440,9 +438,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Cos(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Cos(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,
@ -452,9 +450,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Sin(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Sin(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,
@ -464,9 +462,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Tanh(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Tanh(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,
@ -476,9 +474,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Erf(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Erf(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,
@ -521,9 +519,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
fn float_recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Recip(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Recip(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,

View File

@ -1,5 +1,5 @@
use super::numeric;
use crate::codegen::dialect::gpu::{Elem, Item, Operation, UnaryOperation, Variable};
use crate::codegen::dialect::gpu::{Elem, Item, Operator, Scope, UnaryOperator};
use crate::kernel::reduce::{self, init_reduce_output};
use crate::{kernel, unary, JitBackend, Runtime};
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
@ -281,9 +281,9 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
unary!(
operation: |elem: Elem| Operation::Abs(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Abs(UnaryOperator {
input: scope.read_array(0, Item::Scalar(elem)),
out: scope.create_local(elem),
}),
runtime: R,
input: tensor,

View File

@ -1,6 +1,4 @@
use crate::codegen::dialect::gpu::{
BinaryOperation, Elem, Item, Operation, UnaryOperation, Variable,
};
use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryOperator};
use crate::{binary, Runtime};
use crate::{element::JitElement, tensor::JitTensor, unary};
use burn_compute::client::ComputeClient;
@ -25,9 +23,9 @@ pub fn full_device<R: Runtime, E: JitElement, const D: usize>(
let empty = empty_device(client, device, shape);
unary!(
operation: |elem: Elem| Operation::AssignLocal(UnaryOperation {
input: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::AssignLocal(UnaryOperator {
input: scope.read_scalar(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: empty; value,
@ -84,10 +82,10 @@ pub fn add<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |elem: Elem| Operation::Add(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Add(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
@ -100,10 +98,10 @@ pub fn add_scalar<R: Runtime, E: JitElement, const D: usize>(
rhs: E,
) -> JitTensor<R, E, D> {
unary!(
operation: |elem: Elem| Operation::Add(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Add(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
@ -116,10 +114,10 @@ pub fn sub<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |elem: Elem| Operation::Sub(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Sub(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
@ -132,10 +130,10 @@ pub fn sub_scalar<R: Runtime, E: JitElement, const D: usize>(
rhs: E,
) -> JitTensor<R, E, D> {
unary!(
operation: |elem: Elem| Operation::Sub(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Sub(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
@ -148,10 +146,10 @@ pub fn mul<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |elem: Elem| Operation::Mul(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Mul(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
@ -164,10 +162,10 @@ pub fn mul_scalar<R: Runtime, E: JitElement, const D: usize>(
rhs: E,
) -> JitTensor<R, E, D> {
unary!(
operation: |elem: Elem| Operation::Mul(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Mul(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
@ -180,10 +178,10 @@ pub fn div<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |elem: Elem| Operation::Div(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Div(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
@ -196,10 +194,10 @@ pub fn div_scalar<R: Runtime, E: JitElement, const D: usize>(
rhs: E,
) -> JitTensor<R, E, D> {
unary!(
operation: |elem: Elem| Operation::Div(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Scalar(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Div(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_scalar(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,
@ -212,10 +210,10 @@ pub fn pow<R: Runtime, E: JitElement, const D: usize>(
rhs: JitTensor<R, E, D>,
) -> JitTensor<R, E, D> {
binary!(
operation: |elem: Elem| Operation::Powf(BinaryOperation {
lhs: Variable::Input(0, Item::Scalar(elem)),
rhs: Variable::Input(1, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::Powf(BinaryOperator {
lhs: scope.read_array(0, elem),
rhs: scope.read_array(1, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: lhs; rhs,

View File

@ -1,4 +1,4 @@
use crate::codegen::dialect::gpu::{Elem, Item, Operation, UnaryOperation, Variable};
use crate::codegen::dialect::gpu::{Elem, Operator, Scope, UnaryOperator};
use crate::element::JitElement;
use crate::{unary, Runtime};
use burn_compute::client::ComputeClient;
@ -145,9 +145,9 @@ where
//
// The solution is just to use a simple unary compute shader.
unary!(
operation: |elem: Elem| Operation::AssignLocal(UnaryOperation {
input: Variable::Input(0, Item::Scalar(elem)),
out: Variable::Local(0, Item::Scalar(elem)),
operation: |scope: &mut Scope, elem: Elem| Operator::AssignLocal(UnaryOperator {
input: scope.read_array(0, elem),
out: scope.create_local(elem),
}),
runtime: R,
input: self.clone(),