mirror of https://github.com/tracel-ai/burn.git
[Refactor] Just-In-Time Compilation Pipeline (#1313)
This commit is contained in:
parent
24287237d1
commit
843dd492c2
|
@ -0,0 +1,261 @@
|
|||
use super::dialect::gpu;
|
||||
use crate::codegen::dialect::gpu::{
|
||||
Binding, ComputeShader, Elem, Item, Location, Variable, Vectorization, Visibility,
|
||||
WorkgroupSize,
|
||||
};
|
||||
|
||||
/// The compilation struct allows you to create a [compute shader](ComputeShader) based on
|
||||
/// [compilation info](CompilationInfo) and [compilation settings](CompilationSettings).
|
||||
#[derive(Clone)]
|
||||
pub struct Compilation {
|
||||
info: CompilationInfo,
|
||||
input_bindings: Vec<Binding>,
|
||||
output_bindings: Vec<Binding>,
|
||||
named_bindings: Vec<(String, Binding)>,
|
||||
}
|
||||
|
||||
/// The information necessary to compile a [compute shader](ComputeShader).
|
||||
#[derive(Clone)]
|
||||
pub struct CompilationInfo {
|
||||
pub inputs: Vec<InputInfo>,
|
||||
pub outputs: Vec<OutputInfo>,
|
||||
pub scope: gpu::Scope,
|
||||
pub mappings: Vec<InplaceMapping>,
|
||||
}
|
||||
|
||||
/// Simply indicate the output that can be replaced by the input.
|
||||
#[derive(new, Clone, Copy)]
|
||||
pub struct InplaceMapping {
|
||||
/// Input position.
|
||||
pub pos_input: usize,
|
||||
/// Output position.
|
||||
pub pos_output: usize,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct CompilationSettings {
|
||||
vectorization: Vectorization,
|
||||
inplace_available: bool,
|
||||
workgroup_size: WorkgroupSize,
|
||||
}
|
||||
|
||||
impl CompilationSettings {
|
||||
/// Compile the shader with vectorization enabled.
|
||||
#[allow(dead_code)]
|
||||
pub fn vectorize(mut self, vectorization: Vectorization) -> Self {
|
||||
self.vectorization = vectorization;
|
||||
self
|
||||
}
|
||||
/// Compile the shader with inplace enabled.
|
||||
///
|
||||
/// Notes:
|
||||
///
|
||||
/// This won't guarantee that the shader will use input arrays as outputs, since it is only
|
||||
/// possible when [inplace mappings](InplaceMapping) are provided as [compilation info](CompilationInfo)
|
||||
pub fn inplace(mut self, available: bool) -> Self {
|
||||
self.inplace_available = available;
|
||||
self
|
||||
}
|
||||
/// Set the grid size.
|
||||
#[allow(dead_code)] // Only used for fusion for now.
|
||||
pub fn workgroup_size(mut self, workgroup_size: WorkgroupSize) -> Self {
|
||||
self.workgroup_size = workgroup_size;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Information related to an input.
|
||||
#[derive(Clone)]
|
||||
pub enum InputInfo {
|
||||
Array { item: Item, visibility: Visibility },
|
||||
Scalar { elem: Elem, size: usize },
|
||||
}
|
||||
|
||||
/// Information related to an output.
|
||||
#[derive(Clone)]
|
||||
pub enum OutputInfo {
|
||||
/// Write the local variable to a new array.
|
||||
///
|
||||
/// This will create a new binding in the [compute shader](ComputeShader).
|
||||
Array { item: Item, local: u16 },
|
||||
/// Write the local variable to an existing input binding.
|
||||
Input { item: Item, input: u16, local: u16 },
|
||||
}
|
||||
|
||||
impl Compilation {
|
||||
/// Starts a new compilation.
|
||||
pub fn new(info: CompilationInfo) -> Self {
|
||||
Self {
|
||||
info,
|
||||
input_bindings: Default::default(),
|
||||
output_bindings: Default::default(),
|
||||
named_bindings: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Performs the compilation with the provided [settings](CompilationSettings).
|
||||
pub fn compile(mut self, settings: CompilationSettings) -> ComputeShader {
|
||||
self.info.scope.vectorize(settings.vectorization);
|
||||
|
||||
self.register_inputs(&settings);
|
||||
self.register_outputs(&settings);
|
||||
|
||||
let inputs = self.input_bindings;
|
||||
let outputs = self.output_bindings;
|
||||
let mut named = Vec::with_capacity(2);
|
||||
|
||||
named.push((
|
||||
"info".to_string(),
|
||||
Binding {
|
||||
item: Item::Scalar(Elem::UInt),
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None, // We avoid putting the length here since it will force a new kernel
|
||||
// for each tensor rank.
|
||||
},
|
||||
));
|
||||
|
||||
for (name, binding) in self.named_bindings.into_iter() {
|
||||
named.push((name, binding));
|
||||
}
|
||||
|
||||
ComputeShader {
|
||||
inputs,
|
||||
outputs,
|
||||
named,
|
||||
workgroup_size: settings.workgroup_size,
|
||||
body: self.info.scope,
|
||||
num_workgroups: true,
|
||||
global_invocation_id: true,
|
||||
}
|
||||
}
|
||||
|
||||
fn register_inputs(&mut self, settings: &CompilationSettings) {
|
||||
for input in self.info.inputs.drain(..) {
|
||||
match input {
|
||||
InputInfo::Array { item, visibility } => {
|
||||
let item = item.vectorize(settings.vectorization);
|
||||
|
||||
self.input_bindings.push(Binding {
|
||||
item: bool_item(item),
|
||||
visibility,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
}
|
||||
InputInfo::Scalar { elem, size } => {
|
||||
let elem = bool_elem(elem);
|
||||
|
||||
self.named_bindings.push((
|
||||
format!("scalars_{}", elem),
|
||||
Binding {
|
||||
item: Item::Scalar(elem),
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: Some(size),
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_outputs(&mut self, settings: &CompilationSettings) {
|
||||
let mut index = 0;
|
||||
|
||||
if settings.inplace_available {
|
||||
let mut mappings = Vec::new();
|
||||
core::mem::swap(&mut self.info.mappings, &mut mappings);
|
||||
|
||||
for mapping in mappings {
|
||||
self.register_inplace_mapping(mapping);
|
||||
}
|
||||
}
|
||||
|
||||
for array in self.info.outputs.drain(..) {
|
||||
match array {
|
||||
OutputInfo::Array { item, local } => {
|
||||
let item = item.vectorize(settings.vectorization);
|
||||
let elem_adapted = bool_item(item);
|
||||
|
||||
self.output_bindings.push(Binding {
|
||||
item: elem_adapted,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
self.info.scope.write_global(
|
||||
Variable::Local(local, item, self.info.scope.depth),
|
||||
Variable::GlobalOutputArray(index, elem_adapted),
|
||||
);
|
||||
index += 1;
|
||||
}
|
||||
OutputInfo::Input { item, input, local } => {
|
||||
let item = item.vectorize(settings.vectorization);
|
||||
|
||||
self.info.scope.write_global(
|
||||
Variable::Local(local, item, self.info.scope.depth),
|
||||
Variable::GlobalInputArray(input, bool_item(item)),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_inplace_mapping(&mut self, mapping: InplaceMapping) {
|
||||
let output = match self.info.outputs.get_mut(mapping.pos_output) {
|
||||
Some(output) => output,
|
||||
None => return, // No output to update.
|
||||
};
|
||||
|
||||
let (item, local) = match output {
|
||||
OutputInfo::Array { item, local } => (item, local),
|
||||
OutputInfo::Input {
|
||||
item: _,
|
||||
input: _,
|
||||
local: _,
|
||||
} => return, // Output already updated.
|
||||
};
|
||||
|
||||
let item = match self.input_bindings.get_mut(mapping.pos_input) {
|
||||
Some(binding) => {
|
||||
// Update input visibility.
|
||||
binding.visibility = Visibility::ReadWrite;
|
||||
// Inputs modified inplace should be read without any specified layout.
|
||||
self.info
|
||||
.scope
|
||||
.update_read(mapping.pos_input as u16, gpu::ReadingStrategy::Plain);
|
||||
|
||||
// Use the same item as the input.
|
||||
//
|
||||
// The output can be different (i.e inplace boolean operations on float bindings).
|
||||
binding.item
|
||||
}
|
||||
None => *item,
|
||||
};
|
||||
|
||||
// Update the output.
|
||||
*output = OutputInfo::Input {
|
||||
item,
|
||||
input: mapping.pos_input as u16,
|
||||
local: *local,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_item(ty: Item) -> Item {
|
||||
match ty {
|
||||
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
|
||||
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
|
||||
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
|
||||
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_elem(elem: Elem) -> Elem {
|
||||
match elem {
|
||||
// U32 are used for bool tensors
|
||||
Elem::Bool => Elem::UInt,
|
||||
_ => elem,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,57 @@
|
|||
use super::{
|
||||
gpu, Elem, Item, Metadata, Operator, ReadGlobalAlgo, ReadGlobalWithLayoutAlgo, Scope, Variable,
|
||||
};
|
||||
use crate::codegen::dialect::gpu::BinaryOperator;
|
||||
|
||||
impl ReadGlobalAlgo {
|
||||
pub fn expand(self, scope: &mut Scope) {
|
||||
scope.register(Operator::Index(BinaryOperator {
|
||||
lhs: self.global,
|
||||
rhs: Variable::Id,
|
||||
out: self.out,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadGlobalWithLayoutAlgo {
|
||||
pub fn expand(self, scope: &mut Scope) {
|
||||
let out = self.out;
|
||||
let tensor = self.global;
|
||||
let layout = self.layout;
|
||||
let index_item_ty = Item::Scalar(Elem::UInt);
|
||||
let index_local = scope.create_local(index_item_ty);
|
||||
let zero: Variable = 0u32.into();
|
||||
let id = Variable::Id;
|
||||
let offset: Variable = match self.global.item() {
|
||||
Item::Vec4(_) => 4u32,
|
||||
Item::Vec3(_) => 3u32,
|
||||
Item::Vec2(_) => 2u32,
|
||||
Item::Scalar(_) => 1u32,
|
||||
}
|
||||
.into();
|
||||
|
||||
gpu!(scope, index_local = zero);
|
||||
gpu!(
|
||||
scope,
|
||||
range(zero, Variable::Rank).for_each(|i, scope| {
|
||||
let stride = scope.create_local(index_item_ty);
|
||||
let stride_layout = scope.create_local(index_item_ty);
|
||||
let shape = scope.create_local(index_item_ty);
|
||||
let tmp = scope.create_local(index_item_ty);
|
||||
|
||||
gpu!(scope, stride = stride(tensor, i));
|
||||
gpu!(scope, shape = shape(tensor, i));
|
||||
gpu!(scope, stride_layout = stride(layout, i));
|
||||
|
||||
gpu!(scope, tmp = id * offset);
|
||||
gpu!(scope, tmp = tmp / stride_layout);
|
||||
gpu!(scope, tmp = tmp % shape);
|
||||
gpu!(scope, tmp = tmp * stride);
|
||||
gpu!(scope, index_local = index_local + tmp);
|
||||
})
|
||||
);
|
||||
|
||||
gpu!(scope, index_local = index_local / offset);
|
||||
gpu!(scope, out = tensor[index_local]);
|
||||
}
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
use super::Operation;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, new)]
|
||||
pub struct Body {
|
||||
pub operators: Vec<Operation>,
|
||||
}
|
|
@ -0,0 +1,260 @@
|
|||
use super::Variable;
|
||||
|
||||
macro_rules! gpu {
|
||||
// out = lhs + rhs
|
||||
($scope:expr, $out:ident = $lhs:ident + $rhs:expr) => {
|
||||
gpu!($scope, $out = add($lhs, $rhs))
|
||||
};
|
||||
// out = add(lhs, rhs)
|
||||
($scope:expr, $out:ident = add($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Add(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs - rhs
|
||||
($scope:expr, $out:ident = $lhs:ident - $rhs:expr) => {
|
||||
gpu!($scope, $out = sub($lhs, $rhs))
|
||||
};
|
||||
// out = sub(lhs, rhs)
|
||||
($scope:expr, $out:ident = sub($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Sub(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs * rhs
|
||||
($scope:expr, $out:ident = $lhs:ident * $rhs:expr) => {
|
||||
gpu!($scope, $out = mul($lhs, $rhs))
|
||||
};
|
||||
// out = mul(lhs, rhs)
|
||||
($scope:expr, $out:ident = mul($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Mul(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs / rhs
|
||||
($scope:expr, $out:ident = $lhs:ident / $rhs:expr) => {
|
||||
gpu!($scope, $out = div($lhs, $rhs))
|
||||
};
|
||||
// out = div(lhs, rhs)
|
||||
($scope:expr, $out:ident = div($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Div(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs % rhs
|
||||
($scope:expr, $out:ident = $lhs:ident % $rhs:expr) => {
|
||||
gpu!($scope, $out = modulo($lhs, $rhs))
|
||||
};
|
||||
// out = modulo(lhs, rhs)
|
||||
($scope:expr, $out:ident = modulo($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Modulo(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs ^ rhs
|
||||
($scope:expr, $out:ident = $lhs:ident ^ $rhs:expr) => {
|
||||
gpu!($scope, $out = powf($lhs, $rhs))
|
||||
};
|
||||
// out = powf(lhs, rhs)
|
||||
($scope:expr, $out:ident = powf($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Powf(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs == rhs
|
||||
($scope:expr, $out:ident = $lhs:ident == $rhs:expr) => {
|
||||
gpu!($scope, $out = equal($lhs, $rhs))
|
||||
};
|
||||
// out = equal(lhs, rhs)
|
||||
($scope:expr, $out:ident = equal($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Equal(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs > rhs
|
||||
($scope:expr, $out:ident = $lhs:ident > $rhs:expr) => {
|
||||
gpu!($scope, $out = greater($lhs, $rhs))
|
||||
};
|
||||
// out = greater(lhs, rhs)
|
||||
($scope:expr, $out:ident = greater($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Greater(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs >= rhs
|
||||
($scope:expr, $out:ident = $lhs:ident >= $rhs:expr) => {
|
||||
gpu!($scope, $out = greater_equal($lhs, $rhs))
|
||||
};
|
||||
// out = greater_equal(lhs, rhs)
|
||||
($scope:expr, $out:ident = greater_equal($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::GreaterEqual(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs < rhs
|
||||
($scope:expr, $out:ident = $lhs:ident < $rhs:expr) => {
|
||||
gpu!($scope, $out = lower($lhs, $rhs))
|
||||
};
|
||||
// out = lower(lhs, rhs)
|
||||
($scope:expr, $out:ident = lower($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Lower(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs <= rhs
|
||||
($scope:expr, $out:ident = $lhs:ident <= $rhs:expr) => {
|
||||
gpu!($scope, $out = lower_equal($lhs, $rhs))
|
||||
};
|
||||
// out = lower_equal(lhs, rhs)
|
||||
($scope:expr, $out:ident = lower_equal($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::LowerEqual(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = lhs[rhs]
|
||||
($scope:expr, $out:ident = $lhs:ident[$rhs:expr]) => {
|
||||
gpu!($scope, $out = index($lhs, $rhs))
|
||||
};
|
||||
// out = index(lhs, rhs)
|
||||
($scope:expr, $out:ident = index($lhs:expr, $rhs:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Index(
|
||||
gpu!(binary $lhs, $rhs, $out)
|
||||
));
|
||||
};
|
||||
// out = |input|
|
||||
($scope:expr, $out:ident = |$input:ident|) => {
|
||||
gpu!($scope, $out = abs($input))
|
||||
};
|
||||
// out = abs(input)
|
||||
($scope:expr, $out:ident = abs($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Abs(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = exp(input)
|
||||
($scope:expr, $out:ident = exp($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Exp(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = log(input)
|
||||
($scope:expr, $out:ident = log($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Log(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = log1p(input)
|
||||
($scope:expr, $out:ident = log1p($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Log1p(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = cos(input)
|
||||
($scope:expr, $out:ident = cos($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Cos(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = sin(input)
|
||||
($scope:expr, $out:ident = sin($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Sin(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = tanh(input)
|
||||
($scope:expr, $out:ident = tanh($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Tanh(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = sqrt(input)
|
||||
($scope:expr, $out:ident = sqrt($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Sqrt(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = erf(input)
|
||||
($scope:expr, $out:ident = erf($input:expr)) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::Erf(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = input
|
||||
($scope:expr, eval $arg:expr) => {
|
||||
gpu!($scope, $arg);
|
||||
};
|
||||
// out = input
|
||||
($scope:expr, $out:ident = $input:ident) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::AssignLocal(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = input
|
||||
($scope:expr, $out:ident = $input:ident) => {
|
||||
$scope.register($crate::codegen::dialect::gpu::Operator::AssignLocal(
|
||||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
($scope:expr, $out:ident = shape($input:expr, $dim:expr)) => {
|
||||
$scope.register(Metadata::Shape {
|
||||
dim: $dim,
|
||||
var: $input,
|
||||
out: $out,
|
||||
});
|
||||
};
|
||||
($scope:expr, $out:ident = stride($input:expr, $dim:expr)) => {
|
||||
$scope.register(Metadata::Stride {
|
||||
dim: $dim,
|
||||
var: $input,
|
||||
out: $out,
|
||||
});
|
||||
};
|
||||
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
|
||||
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
|
||||
};
|
||||
(binary $lhs:expr, $rhs:expr, $out:expr) => {
|
||||
$crate::codegen::dialect::gpu::BinaryOperator {
|
||||
lhs: $lhs.into(),
|
||||
rhs: $rhs.into(),
|
||||
out: $out.into(),
|
||||
}
|
||||
};
|
||||
(unary $input:expr, $out:expr) => {
|
||||
$crate::codegen::dialect::gpu::UnaryOperator {
|
||||
input: $input.into(),
|
||||
out: $out.into(),
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl From<bool> for Variable {
|
||||
fn from(value: bool) -> Self {
|
||||
Self::ConstantScalar(if value { 1.0 } else { 0.0 }, super::Elem::Bool)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for Variable {
|
||||
fn from(value: i32) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::Int)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for Variable {
|
||||
fn from(value: f32) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::Float)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for Variable {
|
||||
fn from(value: u32) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::UInt)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for Variable {
|
||||
fn from(value: usize) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::UInt)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) use gpu;
|
|
@ -1,11 +1,15 @@
|
|||
mod body;
|
||||
pub(crate) mod algorithm;
|
||||
|
||||
mod macros;
|
||||
mod operation;
|
||||
mod scope;
|
||||
mod shader;
|
||||
mod variable;
|
||||
mod vectorization;
|
||||
|
||||
pub(crate) use body::*;
|
||||
pub(crate) use macros::gpu;
|
||||
pub(crate) use operation::*;
|
||||
pub(crate) use scope::*;
|
||||
pub(crate) use shader::*;
|
||||
pub(crate) use variable::*;
|
||||
pub(crate) use vectorization::*;
|
||||
|
|
|
@ -1,56 +1,169 @@
|
|||
use super::Variable;
|
||||
use super::{Elem, Item, Scope, Variable};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// All operations that can be used in a GPU compute shader.
|
||||
///
|
||||
/// Notes:
|
||||
///
|
||||
/// [Operator] and [Algorithm] can be vectorized, but [Metadata] and [Loop] can't.
|
||||
/// Therefore, during tracing, only operators and algorithms can be registered, and during the
|
||||
/// compilation phase, the algorithm will be expanded.
|
||||
///
|
||||
/// Algorithm expansions can safely use [Metadata] and [Loop] operations.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub enum Operation {
|
||||
Add(BinaryOperation),
|
||||
Sub(BinaryOperation),
|
||||
Mul(BinaryOperation),
|
||||
Div(BinaryOperation),
|
||||
Abs(UnaryOperation),
|
||||
Exp(UnaryOperation),
|
||||
Log(UnaryOperation),
|
||||
Log1p(UnaryOperation),
|
||||
Cos(UnaryOperation),
|
||||
Sin(UnaryOperation),
|
||||
Tanh(UnaryOperation),
|
||||
Powf(BinaryOperation),
|
||||
Sqrt(UnaryOperation),
|
||||
Erf(UnaryOperation),
|
||||
Recip(UnaryOperation),
|
||||
Equal(BinaryOperation),
|
||||
Lower(BinaryOperation),
|
||||
Clamp(ClampOperation),
|
||||
Greater(BinaryOperation),
|
||||
LowerEqual(BinaryOperation),
|
||||
GreaterEqual(BinaryOperation),
|
||||
ConditionalAssign(ConditionalAssignOperation),
|
||||
AssignGlobal(UnaryOperation),
|
||||
AssignLocal(UnaryOperation),
|
||||
ReadGlobal(ReadGlobalOperation),
|
||||
ReadGlobalWithLayout(ReadGlobalWithLayoutOperation),
|
||||
Operator(Operator),
|
||||
Metadata(Metadata),
|
||||
Algorithm(Algorithm),
|
||||
Loop(Loop),
|
||||
}
|
||||
|
||||
/// All operator that can be used in a GPU compute shader.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub enum Operator {
|
||||
Add(BinaryOperator),
|
||||
Sub(BinaryOperator),
|
||||
Mul(BinaryOperator),
|
||||
Div(BinaryOperator),
|
||||
Abs(UnaryOperator),
|
||||
Exp(UnaryOperator),
|
||||
Log(UnaryOperator),
|
||||
Log1p(UnaryOperator),
|
||||
Cos(UnaryOperator),
|
||||
Sin(UnaryOperator),
|
||||
Tanh(UnaryOperator),
|
||||
Powf(BinaryOperator),
|
||||
Sqrt(UnaryOperator),
|
||||
Erf(UnaryOperator),
|
||||
Recip(UnaryOperator),
|
||||
Equal(BinaryOperator),
|
||||
Lower(BinaryOperator),
|
||||
Clamp(ClampOperator),
|
||||
Greater(BinaryOperator),
|
||||
LowerEqual(BinaryOperator),
|
||||
GreaterEqual(BinaryOperator),
|
||||
ConditionalAssign(ConditionalAssignOperator),
|
||||
AssignGlobal(UnaryOperator),
|
||||
AssignLocal(UnaryOperator),
|
||||
Modulo(BinaryOperator),
|
||||
Index(BinaryOperator),
|
||||
}
|
||||
|
||||
/// Tensor operations that can't be executed with a simple [operator](Operator) should use an
|
||||
/// algorithm.
|
||||
///
|
||||
/// Algorithms can be expanded to basic [operator](Operator) during compilation, but after
|
||||
/// vectorization, since for loops and other construct can't simply be vectorized. This also gives
|
||||
/// the vectorization state to the expansion function, which may create a different set of
|
||||
/// [operator](Operator) depending on the vectorization state.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum Algorithm {
|
||||
/// Read an input array with the given layout.
|
||||
///
|
||||
/// Crucial to read arrays that aren't contiguous and to perform correct broadcasting.
|
||||
ReadGlobalWithLayout(ReadGlobalWithLayoutAlgo),
|
||||
/// Read an input array.
|
||||
ReadGlobal(ReadGlobalAlgo),
|
||||
}
|
||||
|
||||
/// Settings for the [Algorithm::ReadGlobalWithLayout] variant.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReadGlobalWithLayoutAlgo {
|
||||
/// The array to be read.
|
||||
pub global: Variable,
|
||||
/// The layout to be used.
|
||||
pub layout: Variable,
|
||||
/// The output variable to write the result.
|
||||
pub out: Variable,
|
||||
}
|
||||
|
||||
/// Settings for the [Algorithm::ReadGlobal] variant.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReadGlobalAlgo {
|
||||
/// The array to be read.
|
||||
pub global: Variable,
|
||||
/// The output variable to write the result.
|
||||
pub out: Variable,
|
||||
}
|
||||
|
||||
/// All loop variants.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum Loop {
|
||||
/// A basic range loop.
|
||||
Range(RangeLoop),
|
||||
}
|
||||
|
||||
/// All metadata that can be access in a shader.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum Metadata {
|
||||
/// The stride of an array at the given dimension.
|
||||
Stride {
|
||||
dim: Variable,
|
||||
var: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
/// The shape of an array at the given dimension.
|
||||
Shape {
|
||||
dim: Variable,
|
||||
var: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
}
|
||||
|
||||
/// Settings for the [Loop::Range] variant.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RangeLoop {
|
||||
/// The loop index variable.
|
||||
pub i: Variable,
|
||||
/// The start value.
|
||||
pub start: Variable,
|
||||
/// The end value.
|
||||
pub end: Variable,
|
||||
/// The scope that contains all operations and variables declared in the loop body.
|
||||
pub scope: Scope,
|
||||
}
|
||||
|
||||
impl RangeLoop {
|
||||
/// Registers a range loop to the given scope.
|
||||
pub fn register<F: Fn(Variable, &mut Scope)>(
|
||||
parent_scope: &mut Scope,
|
||||
start: Variable,
|
||||
end: Variable,
|
||||
func: F,
|
||||
) {
|
||||
let mut scope = parent_scope.child();
|
||||
let index_ty = Item::Scalar(Elem::UInt);
|
||||
let i = scope.create_local_undeclare(index_ty);
|
||||
|
||||
func(i, &mut scope);
|
||||
|
||||
let op = Self {
|
||||
i,
|
||||
start,
|
||||
end,
|
||||
scope,
|
||||
};
|
||||
parent_scope.register(Loop::Range(op));
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub struct BinaryOperation {
|
||||
pub struct BinaryOperator {
|
||||
pub lhs: Variable,
|
||||
pub rhs: Variable,
|
||||
pub out: Variable,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub struct UnaryOperation {
|
||||
pub struct UnaryOperator {
|
||||
pub input: Variable,
|
||||
pub out: Variable,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub struct ClampOperation {
|
||||
pub struct ClampOperator {
|
||||
pub input: Variable,
|
||||
pub min_value: Variable,
|
||||
pub max_value: Variable,
|
||||
|
@ -58,8 +171,7 @@ pub struct ClampOperation {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub struct ConditionalAssignOperation {
|
||||
pub struct ConditionalAssignOperator {
|
||||
pub cond: Variable,
|
||||
pub lhs: Variable,
|
||||
pub rhs: Variable,
|
||||
|
@ -67,15 +179,37 @@ pub struct ConditionalAssignOperation {
|
|||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub struct ReadGlobalOperation {
|
||||
pub struct ReadGlobalOperator {
|
||||
pub variable: Variable,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub struct ReadGlobalWithLayoutOperation {
|
||||
pub struct ReadGlobalWithLayoutOperator {
|
||||
pub variable: Variable,
|
||||
pub tensor_read_pos: usize,
|
||||
pub tensor_layout_pos: usize,
|
||||
}
|
||||
|
||||
impl From<Operator> for Operation {
|
||||
fn from(val: Operator) -> Self {
|
||||
Operation::Operator(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Metadata> for Operation {
|
||||
fn from(val: Metadata) -> Self {
|
||||
Operation::Metadata(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Algorithm> for Operation {
|
||||
fn from(val: Algorithm) -> Self {
|
||||
Operation::Algorithm(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Loop> for Operation {
|
||||
fn from(val: Loop) -> Self {
|
||||
Operation::Loop(val)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,257 @@
|
|||
use super::{
|
||||
Algorithm, Elem, Item, Operation, Operator, ReadGlobalAlgo, ReadGlobalWithLayoutAlgo,
|
||||
UnaryOperator, Variable, Vectorization,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The scope is the main [operation](Operation) and [variable](Variable) container that simplify
|
||||
/// the process of reading inputs, creating local variables and adding new operations.
|
||||
///
|
||||
/// Notes:
|
||||
///
|
||||
/// This type isn't responsible for creating [shader bindings](super::Binding) and figuring out which
|
||||
/// variable can be written to.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Scope {
|
||||
pub depth: u8,
|
||||
pub operations: Vec<Operation>,
|
||||
locals: Vec<Variable>,
|
||||
reads_global: Vec<(Variable, ReadingStrategy, Variable)>,
|
||||
writes_global: Vec<(Variable, Variable)>,
|
||||
reads_scalar: Vec<(Variable, Variable)>,
|
||||
output_ref: Option<Variable>,
|
||||
undeclared: u16,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ReadingStrategy {
|
||||
/// Each element will be read in a way to be compatible with the output layout.
|
||||
OutputLayout,
|
||||
/// Keep the current layout.
|
||||
Plain,
|
||||
}
|
||||
|
||||
/// Information necessary when compiling a scope.
|
||||
pub struct ScopeProcessing {
|
||||
/// The variable declarations.
|
||||
pub variables: Vec<Variable>,
|
||||
/// The operations.
|
||||
pub operations: Vec<Operation>,
|
||||
}
|
||||
|
||||
impl Scope {
|
||||
/// Create a scope that is at the root of a
|
||||
/// [compute shader](crate::codegen::dialect::gpu::ComputeShader).
|
||||
///
|
||||
/// A local scope can be created with the [child](Self::child) method.
|
||||
pub fn root() -> Self {
|
||||
Self {
|
||||
depth: 0,
|
||||
operations: Vec::new(),
|
||||
locals: Vec::new(),
|
||||
reads_global: Vec::new(),
|
||||
writes_global: Vec::new(),
|
||||
reads_scalar: Vec::new(),
|
||||
output_ref: None,
|
||||
undeclared: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a local variable of the given [item type](Item).
|
||||
pub fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
|
||||
let item = item.into();
|
||||
let index = self.new_local_index();
|
||||
let local = Variable::Local(index, item, self.depth);
|
||||
self.locals.push(local);
|
||||
local
|
||||
}
|
||||
|
||||
/// Create a new local variable, but doesn't perform the declaration.
|
||||
///
|
||||
/// Useful for _for loops_ and other algorithms that require the control over initialization.
|
||||
pub fn create_local_undeclare(&mut self, item: Item) -> Variable {
|
||||
let index = self.new_local_index();
|
||||
let local = Variable::Local(index, item, self.depth);
|
||||
self.undeclared += 1;
|
||||
local
|
||||
}
|
||||
|
||||
/// Reads an input array to a local variable.
|
||||
///
|
||||
/// The index refers to the argument position of the array in the compute shader.
|
||||
pub fn read_array<I: Into<Item>>(&mut self, index: u16, item: I) -> Variable {
|
||||
self.read_input_strategy(index, item.into(), ReadingStrategy::OutputLayout)
|
||||
}
|
||||
|
||||
/// Reads an input scalar to a local variable.
|
||||
///
|
||||
/// The index refers to the scalar position for the same [element](Elem) type.
|
||||
pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable {
|
||||
let local = Variable::LocalScalar(self.new_local_scalar_index(), elem, self.depth);
|
||||
let scalar = Variable::GlobalScalar(index, elem);
|
||||
|
||||
self.reads_scalar.push((local, scalar));
|
||||
|
||||
local
|
||||
}
|
||||
|
||||
/// Retrieve the last local variable that was created.
|
||||
pub fn last_local_index(&self) -> Option<&Variable> {
|
||||
self.locals.last()
|
||||
}
|
||||
|
||||
/// Vectorize the scope using the [vectorization](Vectorization) type.
|
||||
///
|
||||
/// Notes:
|
||||
///
|
||||
/// Scopes created _during_ compilation (after the tracing is done) should not be vectorized.
|
||||
pub fn vectorize(&mut self, vectorization: Vectorization) {
|
||||
self.operations
|
||||
.iter_mut()
|
||||
.for_each(|op| *op = op.vectorize(vectorization));
|
||||
self.locals
|
||||
.iter_mut()
|
||||
.for_each(|var| *var = var.vectorize(vectorization));
|
||||
self.reads_global.iter_mut().for_each(|(input, _, output)| {
|
||||
*input = input.vectorize(vectorization);
|
||||
*output = output.vectorize(vectorization);
|
||||
});
|
||||
self.writes_global.iter_mut().for_each(|(input, output)| {
|
||||
*input = input.vectorize(vectorization);
|
||||
*output = output.vectorize(vectorization);
|
||||
});
|
||||
}
|
||||
|
||||
/// Writes a variable to given output.
|
||||
///
|
||||
/// Notes:
|
||||
///
|
||||
/// This should only be used when doing compilation.
|
||||
pub(crate) fn write_global(&mut self, input: Variable, output: Variable) {
|
||||
if self.output_ref.is_none() {
|
||||
self.output_ref = Some(output);
|
||||
}
|
||||
self.writes_global.push((input, output));
|
||||
}
|
||||
|
||||
/// Update the [reading strategy](ReadingStrategy) for an input array.
|
||||
///
|
||||
/// Notes:
|
||||
///
|
||||
/// This should only be used when doing compilation.
|
||||
pub(crate) fn update_read(&mut self, index: u16, strategy: ReadingStrategy) {
|
||||
if let Some((_, strategy_old, _)) = self
|
||||
.reads_global
|
||||
.iter_mut()
|
||||
.find(|(var, _, _)| var.index() == Some(index))
|
||||
{
|
||||
*strategy_old = strategy;
|
||||
}
|
||||
}
|
||||
|
||||
/// Register an [operation](Operation) into the scope.
|
||||
pub fn register<T: Into<Operation>>(&mut self, operation: T) {
|
||||
self.operations.push(operation.into())
|
||||
}
|
||||
|
||||
/// Create an empty child scope.
|
||||
pub fn child(&mut self) -> Self {
|
||||
Self {
|
||||
depth: self.depth + 1,
|
||||
operations: Vec::new(),
|
||||
locals: Vec::new(),
|
||||
reads_global: Vec::new(),
|
||||
writes_global: Vec::new(),
|
||||
reads_scalar: Vec::new(),
|
||||
output_ref: None,
|
||||
undeclared: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the variables and operations to be declared and executed.
|
||||
///
|
||||
/// Notes:
|
||||
///
|
||||
/// New operations and variables can be created within the same scope without having name
|
||||
/// conflicts.
|
||||
pub fn process(&mut self) -> ScopeProcessing {
|
||||
self.undeclared += self.locals.len() as u16;
|
||||
|
||||
let mut variables = Vec::new();
|
||||
core::mem::swap(&mut self.locals, &mut variables);
|
||||
|
||||
let mut operations = Vec::new();
|
||||
|
||||
for (input, strategy, local) in self.reads_global.drain(..) {
|
||||
match strategy {
|
||||
ReadingStrategy::OutputLayout => {
|
||||
let output = self.output_ref.expect(
|
||||
"Output should be set when processing an input with output layout.",
|
||||
);
|
||||
operations.push(Operation::Algorithm(Algorithm::ReadGlobalWithLayout(
|
||||
ReadGlobalWithLayoutAlgo {
|
||||
global: input,
|
||||
layout: output,
|
||||
out: local,
|
||||
},
|
||||
)));
|
||||
}
|
||||
ReadingStrategy::Plain => operations.push(Operation::Algorithm(
|
||||
Algorithm::ReadGlobal(ReadGlobalAlgo {
|
||||
global: input,
|
||||
out: local,
|
||||
}),
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
for (local, scalar) in self.reads_scalar.drain(..) {
|
||||
operations.push(
|
||||
Operator::AssignLocal(UnaryOperator {
|
||||
input: scalar,
|
||||
out: local,
|
||||
})
|
||||
.into(),
|
||||
);
|
||||
variables.push(local);
|
||||
}
|
||||
|
||||
for op in self.operations.drain(..) {
|
||||
operations.push(op);
|
||||
}
|
||||
|
||||
for (input, out) in self.writes_global.drain(..) {
|
||||
operations.push(Operation::Operator(Operator::AssignGlobal(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})))
|
||||
}
|
||||
|
||||
ScopeProcessing {
|
||||
variables,
|
||||
operations,
|
||||
}
|
||||
}
|
||||
|
||||
fn new_local_index(&self) -> u16 {
|
||||
self.locals.len() as u16 + self.undeclared
|
||||
}
|
||||
|
||||
fn new_local_scalar_index(&self) -> u16 {
|
||||
self.reads_scalar.len() as u16
|
||||
}
|
||||
|
||||
fn read_input_strategy(
|
||||
&mut self,
|
||||
index: u16,
|
||||
item: Item,
|
||||
strategy: ReadingStrategy,
|
||||
) -> Variable {
|
||||
let input = Variable::GlobalInputArray(index, item);
|
||||
let index = self.new_local_index();
|
||||
let local = Variable::Local(index, item, self.depth);
|
||||
self.reads_global.push((input, strategy, local));
|
||||
self.locals.push(local);
|
||||
local
|
||||
}
|
||||
}
|
|
@ -1,8 +1,9 @@
|
|||
use super::Body;
|
||||
use crate::kernel::WORKGROUP_DEFAULT;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Display;
|
||||
|
||||
use super::Scope;
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum Location {
|
||||
Storage,
|
||||
|
@ -16,7 +17,7 @@ pub enum Visibility {
|
|||
ReadWrite,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy, Hash, Serialize, Deserialize)]
|
||||
pub enum Elem {
|
||||
Float,
|
||||
Int,
|
||||
|
@ -24,6 +25,12 @@ pub enum Elem {
|
|||
Bool,
|
||||
}
|
||||
|
||||
impl From<Elem> for Item {
|
||||
fn from(val: Elem) -> Self {
|
||||
Item::Scalar(val)
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Elem {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
|
@ -87,5 +94,5 @@ pub struct ComputeShader {
|
|||
pub workgroup_size: WorkgroupSize,
|
||||
pub global_invocation_id: bool,
|
||||
pub num_workgroups: bool,
|
||||
pub body: Body,
|
||||
pub body: Scope,
|
||||
}
|
||||
|
|
|
@ -1,11 +1,47 @@
|
|||
use super::Item;
|
||||
use super::{Elem, Item};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum Variable {
|
||||
Input(u16, Item),
|
||||
Scalar(u16, Item),
|
||||
Local(u16, Item),
|
||||
Output(u16, Item),
|
||||
Constant(f64, Item),
|
||||
GlobalInputArray(u16, Item),
|
||||
GlobalScalar(u16, Elem),
|
||||
GlobalOutputArray(u16, Item),
|
||||
Local(u16, Item, u8),
|
||||
LocalScalar(u16, Elem, u8),
|
||||
ConstantScalar(f64, Elem),
|
||||
Id,
|
||||
Rank,
|
||||
}
|
||||
|
||||
impl Variable {
|
||||
pub fn const_value(&self) -> Option<f64> {
|
||||
match self {
|
||||
Variable::ConstantScalar(value, _) => Some(*value),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
pub fn index(&self) -> Option<u16> {
|
||||
match self {
|
||||
Variable::GlobalInputArray(idx, _) => Some(*idx),
|
||||
Variable::GlobalScalar(idx, _) => Some(*idx),
|
||||
Variable::Local(idx, _, _) => Some(*idx),
|
||||
Variable::LocalScalar(idx, _, _) => Some(*idx),
|
||||
Variable::GlobalOutputArray(idx, _) => Some(*idx),
|
||||
Variable::ConstantScalar(_, _) => None,
|
||||
Variable::Id => None,
|
||||
Variable::Rank => None,
|
||||
}
|
||||
}
|
||||
pub fn item(&self) -> Item {
|
||||
match self {
|
||||
Variable::GlobalInputArray(_, item) => *item,
|
||||
Variable::GlobalOutputArray(_, item) => *item,
|
||||
Variable::GlobalScalar(_, elem) => Item::Scalar(*elem),
|
||||
Variable::Local(_, item, _) => *item,
|
||||
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
|
||||
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
|
||||
Variable::Id => Item::Scalar(Elem::UInt),
|
||||
Variable::Rank => Item::Scalar(Elem::UInt),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,11 @@
|
|||
use super::{
|
||||
BinaryOperation, ClampOperation, ConditionalAssignOperation, Item, Operation,
|
||||
ReadGlobalOperation, ReadGlobalWithLayoutOperation, UnaryOperation, Variable,
|
||||
Algorithm, BinaryOperator, ClampOperator, ConditionalAssignOperator, Item, Operation, Operator,
|
||||
ReadGlobalAlgo, ReadGlobalWithLayoutAlgo, UnaryOperator, Variable,
|
||||
};
|
||||
|
||||
/// Define a vectorization scheme.
|
||||
#[allow(dead_code)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
#[derive(Copy, Clone, Debug, Default)]
|
||||
pub enum Vectorization {
|
||||
/// Use vec4 for vectorization.
|
||||
Vec4,
|
||||
|
@ -14,47 +14,99 @@ pub enum Vectorization {
|
|||
/// Use vec2 for vectorization.
|
||||
Vec2,
|
||||
/// Don't vectorize.
|
||||
#[default]
|
||||
Scalar,
|
||||
}
|
||||
|
||||
impl Operation {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
match self {
|
||||
Operation::Add(op) => Operation::Add(op.vectorize(vectorization)),
|
||||
Operation::Sub(op) => Operation::Sub(op.vectorize(vectorization)),
|
||||
Operation::Mul(op) => Operation::Mul(op.vectorize(vectorization)),
|
||||
Operation::Div(op) => Operation::Div(op.vectorize(vectorization)),
|
||||
Operation::Abs(op) => Operation::Abs(op.vectorize(vectorization)),
|
||||
Operation::Exp(op) => Operation::Exp(op.vectorize(vectorization)),
|
||||
Operation::Log(op) => Operation::Log(op.vectorize(vectorization)),
|
||||
Operation::Log1p(op) => Operation::Log1p(op.vectorize(vectorization)),
|
||||
Operation::Cos(op) => Operation::Cos(op.vectorize(vectorization)),
|
||||
Operation::Sin(op) => Operation::Sin(op.vectorize(vectorization)),
|
||||
Operation::Tanh(op) => Operation::Tanh(op.vectorize(vectorization)),
|
||||
Operation::Powf(op) => Operation::Powf(op.vectorize(vectorization)),
|
||||
Operation::Sqrt(op) => Operation::Sqrt(op.vectorize(vectorization)),
|
||||
Operation::Erf(op) => Operation::Erf(op.vectorize(vectorization)),
|
||||
Operation::Recip(op) => Operation::Recip(op.vectorize(vectorization)),
|
||||
Operation::Equal(op) => Operation::Equal(op.vectorize(vectorization)),
|
||||
Operation::Lower(op) => Operation::Lower(op.vectorize(vectorization)),
|
||||
Operation::Clamp(op) => Operation::Clamp(op.vectorize(vectorization)),
|
||||
Operation::Greater(op) => Operation::Greater(op.vectorize(vectorization)),
|
||||
Operation::LowerEqual(op) => Operation::LowerEqual(op.vectorize(vectorization)),
|
||||
Operation::GreaterEqual(op) => Operation::GreaterEqual(op.vectorize(vectorization)),
|
||||
Operation::ConditionalAssign(op) => {
|
||||
Operation::ConditionalAssign(op.vectorize(vectorization))
|
||||
}
|
||||
Operation::AssignGlobal(op) => Operation::AssignGlobal(op.vectorize(vectorization)),
|
||||
Operation::AssignLocal(op) => Operation::AssignLocal(op.vectorize(vectorization)),
|
||||
Operation::ReadGlobal(op) => Operation::ReadGlobal(op.vectorize(vectorization)),
|
||||
Operation::ReadGlobalWithLayout(op) => {
|
||||
Operation::ReadGlobalWithLayout(op.vectorize(vectorization))
|
||||
}
|
||||
Operation::Operator(op) => Operation::Operator(op.vectorize(vectorization)),
|
||||
Operation::Algorithm(op) => Operation::Algorithm(op.vectorize(vectorization)),
|
||||
Operation::Metadata(_) => panic!(
|
||||
"Metadata can't be vectorized, they should only be generated after vectorization."
|
||||
),
|
||||
Operation::Loop(_) => panic!(
|
||||
"Loops can't be vectorized, they should only be generated after vectorization."
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOperation {
|
||||
impl Algorithm {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
match self {
|
||||
Algorithm::ReadGlobalWithLayout(op) => {
|
||||
Algorithm::ReadGlobalWithLayout(op.vectorize(vectorization))
|
||||
}
|
||||
Algorithm::ReadGlobal(op) => Algorithm::ReadGlobal(op.vectorize(vectorization)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadGlobalWithLayoutAlgo {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
Self {
|
||||
global: self.global.vectorize(vectorization),
|
||||
layout: self.layout.vectorize(vectorization),
|
||||
out: self.out.vectorize(vectorization),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadGlobalAlgo {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
Self {
|
||||
global: self.global.vectorize(vectorization),
|
||||
out: self.out.vectorize(vectorization),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Operator {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
match self {
|
||||
Operator::Add(op) => Operator::Add(op.vectorize(vectorization)),
|
||||
Operator::Index(op) => Operator::Index(op.vectorize(vectorization)),
|
||||
Operator::Sub(op) => Operator::Sub(op.vectorize(vectorization)),
|
||||
Operator::Mul(op) => Operator::Mul(op.vectorize(vectorization)),
|
||||
Operator::Div(op) => Operator::Div(op.vectorize(vectorization)),
|
||||
Operator::Abs(op) => Operator::Abs(op.vectorize(vectorization)),
|
||||
Operator::Exp(op) => Operator::Exp(op.vectorize(vectorization)),
|
||||
Operator::Log(op) => Operator::Log(op.vectorize(vectorization)),
|
||||
Operator::Log1p(op) => Operator::Log1p(op.vectorize(vectorization)),
|
||||
Operator::Cos(op) => Operator::Cos(op.vectorize(vectorization)),
|
||||
Operator::Sin(op) => Operator::Sin(op.vectorize(vectorization)),
|
||||
Operator::Tanh(op) => Operator::Tanh(op.vectorize(vectorization)),
|
||||
Operator::Powf(op) => Operator::Powf(op.vectorize(vectorization)),
|
||||
Operator::Sqrt(op) => Operator::Sqrt(op.vectorize(vectorization)),
|
||||
Operator::Erf(op) => Operator::Erf(op.vectorize(vectorization)),
|
||||
Operator::Recip(op) => Operator::Recip(op.vectorize(vectorization)),
|
||||
Operator::Equal(op) => Operator::Equal(op.vectorize(vectorization)),
|
||||
Operator::Lower(op) => Operator::Lower(op.vectorize(vectorization)),
|
||||
Operator::Clamp(op) => Operator::Clamp(op.vectorize(vectorization)),
|
||||
Operator::Greater(op) => Operator::Greater(op.vectorize(vectorization)),
|
||||
Operator::LowerEqual(op) => Operator::LowerEqual(op.vectorize(vectorization)),
|
||||
Operator::GreaterEqual(op) => Operator::GreaterEqual(op.vectorize(vectorization)),
|
||||
Operator::ConditionalAssign(op) => {
|
||||
Operator::ConditionalAssign(op.vectorize(vectorization))
|
||||
}
|
||||
Operator::AssignGlobal(op) => Operator::AssignGlobal(op.vectorize(vectorization)),
|
||||
Operator::AssignLocal(op) => {
|
||||
if let Variable::GlobalScalar(_, _) = op.input {
|
||||
// Assign will not change the type of the output if the input can't be
|
||||
// vectorized.
|
||||
return Operator::AssignLocal(op.clone());
|
||||
}
|
||||
|
||||
Operator::AssignLocal(op.vectorize(vectorization))
|
||||
}
|
||||
Operator::Modulo(op) => Operator::Modulo(op.vectorize(vectorization)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BinaryOperator {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
let lhs = self.lhs.vectorize(vectorization);
|
||||
let rhs = self.rhs.vectorize(vectorization);
|
||||
|
@ -64,7 +116,7 @@ impl BinaryOperation {
|
|||
}
|
||||
}
|
||||
|
||||
impl UnaryOperation {
|
||||
impl UnaryOperator {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
let input = self.input.vectorize(vectorization);
|
||||
let out = self.out.vectorize(vectorization);
|
||||
|
@ -73,7 +125,7 @@ impl UnaryOperation {
|
|||
}
|
||||
}
|
||||
|
||||
impl ClampOperation {
|
||||
impl ClampOperator {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
let input = self.input.vectorize(vectorization);
|
||||
let out = self.out.vectorize(vectorization);
|
||||
|
@ -89,7 +141,7 @@ impl ClampOperation {
|
|||
}
|
||||
}
|
||||
|
||||
impl ConditionalAssignOperation {
|
||||
impl ConditionalAssignOperator {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
let cond = self.cond.vectorize(vectorization);
|
||||
let lhs = self.lhs.vectorize(vectorization);
|
||||
|
@ -105,39 +157,23 @@ impl ConditionalAssignOperation {
|
|||
}
|
||||
}
|
||||
|
||||
impl ReadGlobalOperation {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
let variable = self.variable.vectorize(vectorization);
|
||||
|
||||
Self { variable }
|
||||
}
|
||||
}
|
||||
|
||||
impl ReadGlobalWithLayoutOperation {
|
||||
pub fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
let variable = self.variable.vectorize(vectorization);
|
||||
let tensor_read_pos = self.tensor_read_pos;
|
||||
let tensor_layout_pos = self.tensor_layout_pos;
|
||||
|
||||
Self {
|
||||
variable,
|
||||
tensor_read_pos,
|
||||
tensor_layout_pos,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Variable {
|
||||
pub fn vectorize(&self, vectorize: Vectorization) -> Self {
|
||||
match self {
|
||||
Variable::Input(index, item) => Variable::Input(*index, item.vectorize(vectorize)),
|
||||
Variable::Local(index, item) => Variable::Local(*index, item.vectorize(vectorize)),
|
||||
Variable::Output(index, item) => Variable::Output(*index, item.vectorize(vectorize)),
|
||||
Variable::Constant(index, item) => {
|
||||
Variable::Constant(*index, item.vectorize(vectorize))
|
||||
Variable::GlobalInputArray(index, item) => {
|
||||
Variable::GlobalInputArray(*index, item.vectorize(vectorize))
|
||||
}
|
||||
Variable::Scalar(index, item) => Variable::Scalar(*index, *item), // Don't vectorize
|
||||
// scalar variables.
|
||||
Variable::Local(index, item, name) => {
|
||||
Variable::Local(*index, item.vectorize(vectorize), *name)
|
||||
}
|
||||
Variable::GlobalOutputArray(index, item) => {
|
||||
Variable::GlobalOutputArray(*index, item.vectorize(vectorize))
|
||||
}
|
||||
Variable::ConstantScalar(_, _) => *self,
|
||||
Variable::GlobalScalar(_, _) => *self,
|
||||
Variable::Id => *self,
|
||||
Variable::Rank => *self,
|
||||
Variable::LocalScalar(_, _, _) => *self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,11 +3,22 @@ use std::fmt::Display;
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Variable {
|
||||
Input(u16, Item),
|
||||
Scalar(u16, Item, gpu::Elem),
|
||||
Local(u16, Item),
|
||||
Output(u16, Item),
|
||||
Constant(f64, Item),
|
||||
GlobalInputArray(u16, Item),
|
||||
GlobalOutputArray(u16, Item),
|
||||
GlobalScalar(u16, Elem, gpu::Elem),
|
||||
ConstantScalar(f64, Elem),
|
||||
Local {
|
||||
index: u16,
|
||||
item: Item,
|
||||
scope_depth: u8,
|
||||
},
|
||||
LocalScalar {
|
||||
index: u16,
|
||||
elem: Elem,
|
||||
scope_depth: u8,
|
||||
},
|
||||
Id,
|
||||
Rank,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
|
@ -33,6 +44,26 @@ pub struct IndexedVariable {
|
|||
}
|
||||
|
||||
impl Variable {
|
||||
pub fn is_always_scalar(&self) -> bool {
|
||||
match self {
|
||||
Variable::GlobalScalar(_, _, _) => true,
|
||||
Variable::ConstantScalar(_, _) => true,
|
||||
Variable::LocalScalar {
|
||||
index: _,
|
||||
elem: _,
|
||||
scope_depth: _,
|
||||
} => true,
|
||||
Variable::Id => true,
|
||||
Variable::Rank => true,
|
||||
Variable::GlobalInputArray(_, _) => false,
|
||||
Variable::GlobalOutputArray(_, _) => false,
|
||||
Variable::Local {
|
||||
index: _,
|
||||
item: _,
|
||||
scope_depth: _,
|
||||
} => false,
|
||||
}
|
||||
}
|
||||
pub fn index(&self, index: usize) -> IndexedVariable {
|
||||
IndexedVariable {
|
||||
var: self.clone(),
|
||||
|
@ -40,15 +71,29 @@ impl Variable {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn item(&self) -> &Item {
|
||||
pub fn item(&self) -> Item {
|
||||
match self {
|
||||
Self::Input(_, e) => e,
|
||||
Self::Scalar(_, e, _) => e,
|
||||
Self::Local(_, e) => e,
|
||||
Self::Output(_, e) => e,
|
||||
Self::Constant(_, e) => e,
|
||||
Self::GlobalInputArray(_, e) => *e,
|
||||
Self::GlobalOutputArray(_, e) => *e,
|
||||
Self::Local {
|
||||
index: _,
|
||||
item,
|
||||
scope_depth: _,
|
||||
} => *item,
|
||||
Self::ConstantScalar(_, e) => Item::Scalar(*e),
|
||||
Self::GlobalScalar(_, e, _) => Item::Scalar(*e),
|
||||
Self::Id => Item::Scalar(Elem::U32),
|
||||
Self::Rank => Item::Scalar(Elem::U32),
|
||||
Self::LocalScalar {
|
||||
index: _,
|
||||
elem,
|
||||
scope_depth: _,
|
||||
} => Item::Scalar(*elem),
|
||||
}
|
||||
}
|
||||
pub fn elem(&self) -> Elem {
|
||||
*self.item().elem()
|
||||
}
|
||||
}
|
||||
|
||||
impl Item {
|
||||
|
@ -98,39 +143,24 @@ impl Display for Item {
|
|||
impl Display for Variable {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Variable::Input(number, _) => f.write_fmt(format_args!("input_{number}")),
|
||||
Variable::Local(number, _) => f.write_fmt(format_args!("local_{number}")),
|
||||
Variable::Output(number, _) => f.write_fmt(format_args!("output_{number}")),
|
||||
Variable::Scalar(number, _, elem) => {
|
||||
Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")),
|
||||
Variable::LocalScalar {
|
||||
index,
|
||||
elem: _,
|
||||
scope_depth,
|
||||
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
|
||||
Variable::Local {
|
||||
index,
|
||||
item: _,
|
||||
scope_depth,
|
||||
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
|
||||
Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")),
|
||||
Variable::GlobalScalar(number, _, elem) => {
|
||||
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
|
||||
}
|
||||
Variable::Constant(number, item) => match item {
|
||||
Item::Vec4(elem) => f.write_fmt(format_args!(
|
||||
"
|
||||
vec4(
|
||||
{elem}({number}),
|
||||
{elem}({number}),
|
||||
{elem}({number}),
|
||||
{elem}({number}),
|
||||
)"
|
||||
)),
|
||||
Item::Vec3(elem) => f.write_fmt(format_args!(
|
||||
"
|
||||
vec3(
|
||||
{elem}({number}),
|
||||
{elem}({number}),
|
||||
{elem}({number}),
|
||||
)"
|
||||
)),
|
||||
Item::Vec2(elem) => f.write_fmt(format_args!(
|
||||
"
|
||||
vec2(
|
||||
{elem}({number}),
|
||||
{elem}({number}),
|
||||
)"
|
||||
)),
|
||||
Item::Scalar(elem) => f.write_fmt(format_args!("{elem}({number})")),
|
||||
},
|
||||
Variable::ConstantScalar(number, elem) => f.write_fmt(format_args!("{elem}({number})")),
|
||||
Variable::Id => f.write_str("id"),
|
||||
Variable::Rank => f.write_str("rank"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -149,8 +179,8 @@ impl Display for IndexedVariable {
|
|||
let index = self.index;
|
||||
|
||||
match self.var {
|
||||
Variable::Scalar(_, _, _) => f.write_fmt(format_args!("{var}")),
|
||||
_ => match should_index(item) {
|
||||
Variable::GlobalScalar(_, _, _) => f.write_fmt(format_args!("{var}")),
|
||||
_ => match should_index(&item) {
|
||||
true => f.write_fmt(format_args!("{var}[{index}]")),
|
||||
false => f.write_fmt(format_args!("{var}")),
|
||||
},
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::Operation;
|
||||
use super::Instruction;
|
||||
use std::fmt::Display;
|
||||
|
||||
/// A body is composed of a list of [operations](Operation).
|
||||
|
@ -6,11 +6,11 @@ use std::fmt::Display;
|
|||
/// Note that the body assumes that the kernel will run on a 2D grid defined by the workgroup size
|
||||
/// X and Y, but with Z=1.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Body {
|
||||
pub operators: Vec<Operation>,
|
||||
pub struct Scope {
|
||||
pub operators: Vec<Instruction>,
|
||||
}
|
||||
|
||||
impl Display for Body {
|
||||
impl Display for Scope {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_str(
|
||||
"let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;\n",
|
||||
|
@ -19,7 +19,6 @@ impl Display for Body {
|
|||
|
||||
for ops in self.operators.iter() {
|
||||
f.write_fmt(format_args!("{ops}"))?;
|
||||
f.write_str("\n")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
|
|
@ -11,6 +11,8 @@ use std::marker::PhantomData;
|
|||
/// Wgsl Compiler.
|
||||
#[derive(Clone)]
|
||||
pub struct Compiler<F: FloatElement, I: IntElement> {
|
||||
num_inputs: usize,
|
||||
num_outputs: usize,
|
||||
_float: PhantomData<F>,
|
||||
_int: PhantomData<I>,
|
||||
}
|
||||
|
@ -24,6 +26,8 @@ impl<F: FloatElement, I: IntElement> core::fmt::Debug for Compiler<F, I> {
|
|||
impl<F: FloatElement, I: IntElement> Default for Compiler<F, I> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
num_inputs: 0,
|
||||
num_outputs: 0,
|
||||
_float: PhantomData,
|
||||
_int: PhantomData,
|
||||
}
|
||||
|
@ -37,7 +41,8 @@ impl<F: FloatElement, I: IntElement> compiler::Compiler for Compiler<F, I> {
|
|||
type FullPrecisionCompiler = Compiler<f32, i32>;
|
||||
|
||||
fn compile(shader: gpu::ComputeShader) -> Self::Representation {
|
||||
Self::compile_shader(shader)
|
||||
let mut compiler = Self::default();
|
||||
compiler.compile_shader(shader)
|
||||
}
|
||||
|
||||
fn elem_size(elem: gpu::Elem) -> usize {
|
||||
|
@ -46,6 +51,37 @@ impl<F: FloatElement, I: IntElement> compiler::Compiler for Compiler<F, I> {
|
|||
}
|
||||
|
||||
impl<F: FloatElement, I: IntElement> Compiler<F, I> {
|
||||
fn compile_shader(&mut self, mut value: gpu::ComputeShader) -> wgsl::ComputeShader {
|
||||
self.num_inputs = value.inputs.len();
|
||||
self.num_outputs = value.outputs.len();
|
||||
|
||||
let body = self.compile_scope(&mut value.body);
|
||||
let extensions = register_extensions(&body);
|
||||
|
||||
wgsl::ComputeShader {
|
||||
inputs: value
|
||||
.inputs
|
||||
.into_iter()
|
||||
.map(Self::compile_binding)
|
||||
.collect(),
|
||||
outputs: value
|
||||
.outputs
|
||||
.into_iter()
|
||||
.map(Self::compile_binding)
|
||||
.collect(),
|
||||
named: value
|
||||
.named
|
||||
.into_iter()
|
||||
.map(|(name, binding)| (name, Self::compile_binding(binding)))
|
||||
.collect(),
|
||||
workgroup_size: value.workgroup_size,
|
||||
global_invocation_id: value.global_invocation_id,
|
||||
num_workgroups: value.num_workgroups,
|
||||
body,
|
||||
extensions,
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_item(item: gpu::Item) -> Item {
|
||||
match item {
|
||||
gpu::Item::Vec4(elem) => wgsl::Item::Vec4(Self::compile_elem(elem)),
|
||||
|
@ -66,155 +102,252 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
|
|||
|
||||
fn compile_variable(value: gpu::Variable) -> wgsl::Variable {
|
||||
match value {
|
||||
gpu::Variable::Input(index, item) => {
|
||||
wgsl::Variable::Input(index, Self::compile_item(item))
|
||||
gpu::Variable::GlobalInputArray(index, item) => {
|
||||
wgsl::Variable::GlobalInputArray(index, Self::compile_item(item))
|
||||
}
|
||||
gpu::Variable::Scalar(index, item) => {
|
||||
let elem = item.elem();
|
||||
wgsl::Variable::Scalar(index, Self::compile_item(item), elem)
|
||||
gpu::Variable::GlobalScalar(index, elem) => {
|
||||
wgsl::Variable::GlobalScalar(index, Self::compile_elem(elem), elem)
|
||||
}
|
||||
gpu::Variable::Local(index, item) => {
|
||||
wgsl::Variable::Local(index, Self::compile_item(item))
|
||||
gpu::Variable::Local(index, item, scope_depth) => wgsl::Variable::Local {
|
||||
index,
|
||||
item: Self::compile_item(item),
|
||||
scope_depth,
|
||||
},
|
||||
gpu::Variable::LocalScalar(index, elem, scope_depth) => wgsl::Variable::LocalScalar {
|
||||
index,
|
||||
elem: Self::compile_elem(elem),
|
||||
scope_depth,
|
||||
},
|
||||
gpu::Variable::GlobalOutputArray(index, item) => {
|
||||
wgsl::Variable::GlobalOutputArray(index, Self::compile_item(item))
|
||||
}
|
||||
gpu::Variable::Output(index, item) => {
|
||||
wgsl::Variable::Output(index, Self::compile_item(item))
|
||||
}
|
||||
gpu::Variable::Constant(index, item) => {
|
||||
wgsl::Variable::Constant(index, Self::compile_item(item))
|
||||
gpu::Variable::ConstantScalar(index, elem) => {
|
||||
wgsl::Variable::ConstantScalar(index, Self::compile_elem(elem))
|
||||
}
|
||||
gpu::Variable::Id => wgsl::Variable::Id,
|
||||
gpu::Variable::Rank => wgsl::Variable::Rank,
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_body(value: gpu::Body) -> wgsl::Body {
|
||||
wgsl::Body {
|
||||
operators: value
|
||||
.operators
|
||||
fn compile_scope(&self, value: &mut gpu::Scope) -> wgsl::Scope {
|
||||
let mut operations = Vec::new();
|
||||
let processing = value.process();
|
||||
|
||||
for var in processing.variables {
|
||||
operations.push(wgsl::Instruction::DeclareVariable {
|
||||
var: Self::compile_variable(var),
|
||||
});
|
||||
}
|
||||
|
||||
processing
|
||||
.operations
|
||||
.into_iter()
|
||||
.map(Self::compile_operation)
|
||||
.collect(),
|
||||
.for_each(|op| self.compile_operation(&mut operations, op, value));
|
||||
|
||||
wgsl::Scope {
|
||||
operators: operations,
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_operation(value: gpu::Operation) -> wgsl::Operation {
|
||||
fn compile_operation(
|
||||
&self,
|
||||
instructions: &mut Vec<wgsl::Instruction>,
|
||||
operation: gpu::Operation,
|
||||
scope: &mut gpu::Scope,
|
||||
) {
|
||||
match operation {
|
||||
gpu::Operation::Operator(op) => instructions.push(self.compile_instruction(op)),
|
||||
gpu::Operation::Algorithm(algo) => self.compile_algorithm(instructions, algo, scope),
|
||||
gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)),
|
||||
gpu::Operation::Loop(val) => instructions.push(self.compile_loop(val)),
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_algorithm(
|
||||
&self,
|
||||
instructions: &mut Vec<wgsl::Instruction>,
|
||||
algo: gpu::Algorithm,
|
||||
scope: &mut gpu::Scope,
|
||||
) {
|
||||
let mut compile = |scope: &mut gpu::Scope| {
|
||||
let compiled = self.compile_scope(scope).operators;
|
||||
instructions.extend(compiled);
|
||||
};
|
||||
|
||||
match algo {
|
||||
gpu::Algorithm::ReadGlobalWithLayout(algo) => {
|
||||
algo.expand(scope);
|
||||
compile(scope);
|
||||
}
|
||||
gpu::Algorithm::ReadGlobal(algo) => {
|
||||
algo.expand(scope);
|
||||
compile(scope);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_loop(&self, loop_val: gpu::Loop) -> wgsl::Instruction {
|
||||
match loop_val {
|
||||
gpu::Loop::Range(mut range_loop) => wgsl::Instruction::RangeLoop {
|
||||
i: Self::compile_variable(range_loop.i),
|
||||
start: Self::compile_variable(range_loop.start),
|
||||
end: Self::compile_variable(range_loop.end),
|
||||
instructions: self.compile_scope(&mut range_loop.scope).operators,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_metadata(&self, metadata: gpu::Metadata) -> wgsl::Instruction {
|
||||
match metadata {
|
||||
gpu::Metadata::Stride { dim, var, out } => {
|
||||
let position = match var {
|
||||
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
|
||||
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
||||
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
|
||||
};
|
||||
wgsl::Instruction::Stride {
|
||||
dim: Self::compile_variable(dim),
|
||||
position,
|
||||
out: Self::compile_variable(out),
|
||||
}
|
||||
}
|
||||
gpu::Metadata::Shape { dim, var, out } => {
|
||||
let position = match var {
|
||||
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
|
||||
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
||||
_ => panic!("Only Input and Output have a shape, got {:?}", var),
|
||||
};
|
||||
wgsl::Instruction::Shape {
|
||||
dim: Self::compile_variable(dim),
|
||||
position,
|
||||
out: Self::compile_variable(out),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_instruction(&self, value: gpu::Operator) -> wgsl::Instruction {
|
||||
match value {
|
||||
gpu::Operation::Add(op) => wgsl::Operation::Add {
|
||||
gpu::Operator::Add(op) => wgsl::Instruction::Add {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Sub(op) => wgsl::Operation::Sub {
|
||||
gpu::Operator::Index(op) => wgsl::Instruction::Index {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Mul(op) => wgsl::Operation::Mul {
|
||||
gpu::Operator::Modulo(op) => wgsl::Instruction::Modulo {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Div(op) => wgsl::Operation::Div {
|
||||
gpu::Operator::Sub(op) => wgsl::Instruction::Sub {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Abs(op) => wgsl::Operation::Abs {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Exp(op) => wgsl::Operation::Exp {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Log(op) => wgsl::Operation::Log {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Log1p(op) => wgsl::Operation::Log1p {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Cos(op) => wgsl::Operation::Cos {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Sin(op) => wgsl::Operation::Sin {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Tanh(op) => wgsl::Operation::Tanh {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Powf(op) => wgsl::Operation::Powf {
|
||||
gpu::Operator::Mul(op) => wgsl::Instruction::Mul {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Sqrt(op) => wgsl::Operation::Sqrt {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Erf(op) => wgsl::Operation::Erf {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Recip(op) => wgsl::Operation::Recip {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Equal(op) => wgsl::Operation::Equal {
|
||||
gpu::Operator::Div(op) => wgsl::Instruction::Div {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Lower(op) => wgsl::Operation::Lower {
|
||||
gpu::Operator::Abs(op) => wgsl::Instruction::Abs {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Exp(op) => wgsl::Instruction::Exp {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Log(op) => wgsl::Instruction::Log {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Log1p(op) => wgsl::Instruction::Log1p {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Cos(op) => wgsl::Instruction::Cos {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Sin(op) => wgsl::Instruction::Sin {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Tanh(op) => wgsl::Instruction::Tanh {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Powf(op) => wgsl::Instruction::Powf {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Clamp(op) => wgsl::Operation::Clamp {
|
||||
gpu::Operator::Sqrt(op) => wgsl::Instruction::Sqrt {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Erf(op) => wgsl::Instruction::Erf {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Recip(op) => wgsl::Instruction::Recip {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Equal(op) => wgsl::Instruction::Equal {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Lower(op) => wgsl::Instruction::Lower {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Clamp(op) => wgsl::Instruction::Clamp {
|
||||
input: Self::compile_variable(op.input),
|
||||
min_value: Self::compile_variable(op.min_value),
|
||||
max_value: Self::compile_variable(op.max_value),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::Greater(op) => wgsl::Operation::Greater {
|
||||
gpu::Operator::Greater(op) => wgsl::Instruction::Greater {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::LowerEqual(op) => wgsl::Operation::LowerEqual {
|
||||
gpu::Operator::LowerEqual(op) => wgsl::Instruction::LowerEqual {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::GreaterEqual(op) => wgsl::Operation::GreaterEqual {
|
||||
gpu::Operator::GreaterEqual(op) => wgsl::Instruction::GreaterEqual {
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::ConditionalAssign(op) => wgsl::Operation::ConditionalAssign {
|
||||
gpu::Operator::ConditionalAssign(op) => wgsl::Instruction::ConditionalAssign {
|
||||
cond: Self::compile_variable(op.cond),
|
||||
lhs: Self::compile_variable(op.lhs),
|
||||
rhs: Self::compile_variable(op.rhs),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::AssignGlobal(op) => wgsl::Operation::AssignGlobal {
|
||||
gpu::Operator::AssignGlobal(op) => wgsl::Instruction::AssignGlobal {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::AssignLocal(op) => wgsl::Operation::AssignLocal {
|
||||
gpu::Operator::AssignLocal(op) => wgsl::Instruction::AssignLocal {
|
||||
input: Self::compile_variable(op.input),
|
||||
out: Self::compile_variable(op.out),
|
||||
},
|
||||
gpu::Operation::ReadGlobal(op) => wgsl::Operation::ReadGlobal {
|
||||
variable: Self::compile_variable(op.variable),
|
||||
},
|
||||
gpu::Operation::ReadGlobalWithLayout(op) => wgsl::Operation::ReadGlobalWithLayout {
|
||||
variable: Self::compile_variable(op.variable),
|
||||
tensor_read_pos: op.tensor_read_pos,
|
||||
tensor_layout_pos: op.tensor_layout_pos,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -240,37 +373,9 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
|
|||
size: value.size,
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_shader(value: gpu::ComputeShader) -> wgsl::ComputeShader {
|
||||
let body = Self::compile_body(value.body);
|
||||
let extensions = register_extensions(&body);
|
||||
|
||||
wgsl::ComputeShader {
|
||||
inputs: value
|
||||
.inputs
|
||||
.into_iter()
|
||||
.map(Self::compile_binding)
|
||||
.collect(),
|
||||
outputs: value
|
||||
.outputs
|
||||
.into_iter()
|
||||
.map(Self::compile_binding)
|
||||
.collect(),
|
||||
named: value
|
||||
.named
|
||||
.into_iter()
|
||||
.map(|(name, binding)| (name, Self::compile_binding(binding)))
|
||||
.collect(),
|
||||
workgroup_size: value.workgroup_size,
|
||||
global_invocation_id: value.global_invocation_id,
|
||||
num_workgroups: value.num_workgroups,
|
||||
body,
|
||||
extensions,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn register_extensions(body: &wgsl::Body) -> Vec<wgsl::Extension> {
|
||||
fn register_extensions(body: &wgsl::Scope) -> Vec<wgsl::Extension> {
|
||||
let mut extensions = Vec::new();
|
||||
|
||||
let mut register_extension = |extension: wgsl::Extension| {
|
||||
|
@ -282,20 +387,21 @@ fn register_extensions(body: &wgsl::Body) -> Vec<wgsl::Extension> {
|
|||
// Since not all operators are native to WGSL, we need to add the custom ones.
|
||||
for op in body.operators.iter() {
|
||||
match op {
|
||||
wgsl::Operation::Powf { lhs: _, rhs, out } => match rhs {
|
||||
wgsl::Variable::Scalar(_, _, _) => {
|
||||
register_extension(wgsl::Extension::PowfScalar(*out.item()));
|
||||
wgsl::Instruction::Powf { lhs: _, rhs, out } => {
|
||||
register_extension(wgsl::Extension::PowfPrimitive(out.item()));
|
||||
|
||||
if rhs.is_always_scalar() {
|
||||
register_extension(wgsl::Extension::PowfScalar(out.item()));
|
||||
} else {
|
||||
register_extension(wgsl::Extension::Powf(out.item()));
|
||||
}
|
||||
_ => {
|
||||
register_extension(wgsl::Extension::Powf(*out.item()));
|
||||
}
|
||||
},
|
||||
wgsl::Operation::Erf { input, out: _ } => {
|
||||
register_extension(wgsl::Extension::Erf(*input.item()));
|
||||
wgsl::Instruction::Erf { input, out: _ } => {
|
||||
register_extension(wgsl::Extension::Erf(input.item()));
|
||||
}
|
||||
#[cfg(target_os = "macos")]
|
||||
wgsl::Operation::Tanh { input, out: _ } => {
|
||||
register_extension(wgsl::Extension::SafeTanh(*input.item()))
|
||||
wgsl::Instruction::Tanh { input, out: _ } => {
|
||||
register_extension(wgsl::Extension::SafeTanh(input.item()))
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ use std::fmt::Display;
|
|||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub enum Extension {
|
||||
PowfScalar(Item),
|
||||
PowfPrimitive(Item),
|
||||
Powf(Item),
|
||||
Erf(Item),
|
||||
#[cfg(target_os = "macos")]
|
||||
|
@ -15,6 +16,7 @@ impl Display for Extension {
|
|||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Extension::PowfScalar(elem) => format_powf_scalar(f, elem),
|
||||
Extension::PowfPrimitive(elem) => format_powf_primitive(f, elem),
|
||||
Extension::Powf(elem) => format_powf(f, elem),
|
||||
Extension::Erf(elem) => format_erf(f, elem),
|
||||
#[cfg(target_os = "macos")]
|
||||
|
@ -24,12 +26,10 @@ impl Display for Extension {
|
|||
}
|
||||
|
||||
fn format_powf_scalar(f: &mut core::fmt::Formatter<'_>, item: &Item) -> core::fmt::Result {
|
||||
base_powf_fmt(f, item)?;
|
||||
|
||||
match item {
|
||||
Item::Vec4(elem) => f.write_fmt(format_args!(
|
||||
"
|
||||
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
||||
fn powf_scalar(lhs: {item}, rhs: {elem}) -> {item} {{
|
||||
return vec4(
|
||||
powf_primitive(lhs[0], rhs),
|
||||
powf_primitive(lhs[1], rhs),
|
||||
|
@ -41,7 +41,7 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
|||
)),
|
||||
Item::Vec3(elem) => f.write_fmt(format_args!(
|
||||
"
|
||||
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
||||
fn powf_scalar(lhs: {item}, rhs: {elem}) -> {item} {{
|
||||
return vec3(
|
||||
powf_primitive(lhs[0], rhs),
|
||||
powf_primitive(lhs[1], rhs),
|
||||
|
@ -52,7 +52,7 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
|||
)),
|
||||
Item::Vec2(elem) => f.write_fmt(format_args!(
|
||||
"
|
||||
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
||||
fn powf_scalar(lhs: {item}, rhs: {elem}) -> {item} {{
|
||||
return vec2(
|
||||
powf_primitive(lhs[0], rhs),
|
||||
powf_primitive(lhs[1], rhs),
|
||||
|
@ -62,7 +62,7 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
|||
)),
|
||||
Item::Scalar(elem) => f.write_fmt(format_args!(
|
||||
"
|
||||
fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{
|
||||
fn powf_scalar(lhs: {elem}, rhs: {elem}) -> {elem} {{
|
||||
return powf_primitive(lhs, rhs);
|
||||
}}
|
||||
"
|
||||
|
@ -70,7 +70,10 @@ fn powf(lhs: {elem}, rhs: {elem}) -> {elem} {{
|
|||
}
|
||||
}
|
||||
|
||||
fn base_powf_fmt(f: &mut std::fmt::Formatter<'_>, item: &Item) -> Result<(), std::fmt::Error> {
|
||||
fn format_powf_primitive(
|
||||
f: &mut std::fmt::Formatter<'_>,
|
||||
item: &Item,
|
||||
) -> Result<(), std::fmt::Error> {
|
||||
let elem = item.elem();
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
|
@ -96,12 +99,10 @@ fn powf_primitive(lhs: {elem}, rhs: {elem}) -> {elem} {{
|
|||
}
|
||||
|
||||
fn format_powf(f: &mut core::fmt::Formatter<'_>, item: &Item) -> core::fmt::Result {
|
||||
base_powf_fmt(f, item)?;
|
||||
|
||||
match item {
|
||||
Item::Vec4(elem) => f.write_fmt(format_args!(
|
||||
Item::Vec4(_) => f.write_fmt(format_args!(
|
||||
"
|
||||
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
||||
fn powf(lhs: {item}, rhs: {item}) -> {item} {{
|
||||
return vec4(
|
||||
powf_primitive(lhs[0], rhs[0]),
|
||||
powf_primitive(lhs[1], rhs[1]),
|
||||
|
@ -111,9 +112,9 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
|||
}}
|
||||
"
|
||||
)),
|
||||
Item::Vec3(elem) => f.write_fmt(format_args!(
|
||||
Item::Vec3(_) => f.write_fmt(format_args!(
|
||||
"
|
||||
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
||||
fn powf(lhs: {item}, rhs: {item}) -> {item} {{
|
||||
return vec3(
|
||||
powf_primitive(lhs[0], rhs[0]),
|
||||
powf_primitive(lhs[1], rhs[1]),
|
||||
|
@ -122,9 +123,9 @@ fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
|||
}}
|
||||
"
|
||||
)),
|
||||
Item::Vec2(elem) => f.write_fmt(format_args!(
|
||||
Item::Vec2(_) => f.write_fmt(format_args!(
|
||||
"
|
||||
fn powf(lhs: {item}, rhs: {elem}) -> {item} {{
|
||||
fn powf(lhs: {item}, rhs: {item}) -> {item} {{
|
||||
return vec2(
|
||||
powf_primitive(lhs[0], rhs[0]),
|
||||
powf_primitive(lhs[1], rhs[1]),
|
||||
|
|
|
@ -1,15 +1,28 @@
|
|||
use super::base::{Item, Variable};
|
||||
use std::fmt::Display;
|
||||
|
||||
/// All operations that can be used in a WGSL compute shader.
|
||||
/// All instructions that can be used in a WGSL compute shader.
|
||||
#[derive(Debug, Clone)]
|
||||
#[allow(dead_code)] // Some variants might not be used with different flags
|
||||
pub enum Operation {
|
||||
pub enum Instruction {
|
||||
DeclareVariable {
|
||||
var: Variable,
|
||||
},
|
||||
Add {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Index {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Modulo {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Sub {
|
||||
lhs: Variable,
|
||||
rhs: Variable,
|
||||
|
@ -115,72 +128,99 @@ pub enum Operation {
|
|||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
ReadGlobal {
|
||||
variable: Variable,
|
||||
Stride {
|
||||
dim: Variable,
|
||||
position: usize,
|
||||
out: Variable,
|
||||
},
|
||||
/// Read the tensor in a way to be compatible with another tensor layout.
|
||||
ReadGlobalWithLayout {
|
||||
variable: Variable,
|
||||
tensor_read_pos: usize,
|
||||
tensor_layout_pos: usize,
|
||||
Shape {
|
||||
dim: Variable,
|
||||
position: usize,
|
||||
out: Variable,
|
||||
},
|
||||
RangeLoop {
|
||||
i: Variable,
|
||||
start: Variable,
|
||||
end: Variable,
|
||||
instructions: Vec<Instruction>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for Operation {
|
||||
impl Display for Instruction {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Operation::Add { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} + {rhs};"))
|
||||
Instruction::DeclareVariable { var } => {
|
||||
let item = var.item();
|
||||
f.write_fmt(format_args!("var {var}: {item};\n"))
|
||||
}
|
||||
Operation::Sub { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} - {rhs};"))
|
||||
Instruction::Add { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n"))
|
||||
}
|
||||
Operation::Mul { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} * {rhs};"))
|
||||
Instruction::Index { lhs, rhs, out } => {
|
||||
let item = out.item();
|
||||
let lhs = match lhs {
|
||||
Variable::GlobalInputArray(index, _) => format!("input_{index}_global"),
|
||||
Variable::GlobalOutputArray(index, _) => format!("output_{index}_global"),
|
||||
_ => format!("{lhs}"),
|
||||
};
|
||||
f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n"))
|
||||
}
|
||||
Operation::Div { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = {lhs} / {rhs};"))
|
||||
Instruction::Modulo { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
|
||||
}
|
||||
Operation::Abs { input, out } => f.write_fmt(format_args!("let {out} = abs({input});")),
|
||||
Operation::Exp { input, out } => f.write_fmt(format_args!("let {out} = exp({input});")),
|
||||
Operation::Log { input, out } => f.write_fmt(format_args!("let {out} = log({input});")),
|
||||
Operation::Clamp {
|
||||
Instruction::Sub { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} - {rhs};\n"))
|
||||
}
|
||||
Instruction::Mul { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} * {rhs};\n"))
|
||||
}
|
||||
Instruction::Div { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} / {rhs};\n"))
|
||||
}
|
||||
Instruction::Abs { input, out } => f.write_fmt(format_args!("{out} = abs({input});\n")),
|
||||
Instruction::Exp { input, out } => f.write_fmt(format_args!("{out} = exp({input});\n")),
|
||||
Instruction::Log { input, out } => f.write_fmt(format_args!("{out} = log({input});\n")),
|
||||
Instruction::Clamp {
|
||||
input,
|
||||
min_value,
|
||||
max_value,
|
||||
out,
|
||||
} => f.write_fmt(format_args!(
|
||||
"let {out} = clamp({input}, {min_value}, {max_value});"
|
||||
"{out} = clamp({input}, {min_value}, {max_value});\n"
|
||||
)),
|
||||
Operation::Powf { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("let {out} = powf({lhs}, {rhs});"))
|
||||
Instruction::Powf { lhs, rhs, out } => {
|
||||
if rhs.is_always_scalar() {
|
||||
f.write_fmt(format_args!("{out} = powf_scalar({lhs}, {rhs});\n"))
|
||||
} else {
|
||||
f.write_fmt(format_args!("{out} = powf({lhs}, {rhs});\n"))
|
||||
}
|
||||
Operation::Sqrt { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = sqrt({input});"))
|
||||
}
|
||||
Operation::Log1p { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = log({input} + 1.0);"))
|
||||
Instruction::Sqrt { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = sqrt({input});\n"))
|
||||
}
|
||||
Operation::Cos { input, out } => f.write_fmt(format_args!("let {out} = cos({input});")),
|
||||
Operation::Sin { input, out } => f.write_fmt(format_args!("let {out} = sin({input});")),
|
||||
Operation::Tanh { input, out } => {
|
||||
Instruction::Log1p { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = log({input} + 1.0);\n"))
|
||||
}
|
||||
Instruction::Cos { input, out } => f.write_fmt(format_args!("{out} = cos({input});\n")),
|
||||
Instruction::Sin { input, out } => f.write_fmt(format_args!("{out} = sin({input});\n")),
|
||||
Instruction::Tanh { input, out } => {
|
||||
#[cfg(target_os = "macos")]
|
||||
let result = f.write_fmt(format_args!("let {out} = safe_tanh({input});"));
|
||||
let result = f.write_fmt(format_args!("{out} = safe_tanh({input});\n"));
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let result = f.write_fmt(format_args!("let {out} = tanh({input});"));
|
||||
let result = f.write_fmt(format_args!("{out} = tanh({input});\n"));
|
||||
|
||||
result
|
||||
}
|
||||
Operation::Erf { input, out } => f.write_fmt(format_args!("let {out} = erf({input});")),
|
||||
Operation::Recip { input, out } => {
|
||||
f.write_fmt(format_args!("let {out} = 1.0 / {input};"))
|
||||
Instruction::Erf { input, out } => f.write_fmt(format_args!("{out} = erf({input});\n")),
|
||||
Instruction::Recip { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = 1.0 / {input};"))
|
||||
}
|
||||
Operation::Equal { lhs, rhs, out } => comparison(lhs, rhs, out, "==", f),
|
||||
Operation::Lower { lhs, rhs, out } => comparison(lhs, rhs, out, "<", f),
|
||||
Operation::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f),
|
||||
Operation::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f),
|
||||
Operation::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f),
|
||||
Operation::AssignGlobal { input, out } => {
|
||||
Instruction::Equal { lhs, rhs, out } => comparison(lhs, rhs, out, "==", f),
|
||||
Instruction::Lower { lhs, rhs, out } => comparison(lhs, rhs, out, "<", f),
|
||||
Instruction::Greater { lhs, rhs, out } => comparison(lhs, rhs, out, ">", f),
|
||||
Instruction::LowerEqual { lhs, rhs, out } => comparison(lhs, rhs, out, "<=", f),
|
||||
Instruction::GreaterEqual { lhs, rhs, out } => comparison(lhs, rhs, out, ">=", f),
|
||||
Instruction::AssignGlobal { input, out } => {
|
||||
let elem_out = out.item();
|
||||
let elem_in = input.item();
|
||||
|
||||
|
@ -211,76 +251,18 @@ impl Display for Operation {
|
|||
);"
|
||||
)),
|
||||
Item::Scalar(elem) => {
|
||||
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});"))
|
||||
f.write_fmt(format_args!("{out}_global[id] = {elem}({input});\n"))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
f.write_fmt(format_args!("{out}_global[id] = {elem_out}({input});"))
|
||||
f.write_fmt(format_args!("{out}_global[id] = {elem_out}({input});\n"))
|
||||
}
|
||||
}
|
||||
Operation::AssignLocal { input, out } => {
|
||||
let elem = out.item();
|
||||
f.write_fmt(format_args!("let {out} = {elem}({input});"))
|
||||
Instruction::AssignLocal { input, out } => {
|
||||
let item = out.item();
|
||||
f.write_fmt(format_args!("{out} = {item}({input});\n"))
|
||||
}
|
||||
Operation::ReadGlobal { variable } => match variable {
|
||||
Variable::Input(number, _elem) => f.write_fmt(format_args!(
|
||||
"let input_{number} = input_{number}_global[id];"
|
||||
)),
|
||||
Variable::Local(_, _) => panic!("can't read global local variable."),
|
||||
Variable::Output(number, _elem) => f.write_fmt(format_args!(
|
||||
"let output_{number} = output_{number}_global[id];"
|
||||
)),
|
||||
Variable::Scalar(_, _, _) => panic!("Can't read global scalar variable."),
|
||||
Variable::Constant(_, _) => panic!("Can't read global constant variable."),
|
||||
},
|
||||
Operation::ReadGlobalWithLayout {
|
||||
variable,
|
||||
tensor_read_pos: position,
|
||||
tensor_layout_pos: position_out,
|
||||
} => {
|
||||
let (global, local, elem) = match variable {
|
||||
Variable::Input(number, elem) => (
|
||||
format!("input_{number}_global"),
|
||||
format!("input_{number}"),
|
||||
elem,
|
||||
),
|
||||
Variable::Local(_, _) => panic!("can't read global local variable."),
|
||||
Variable::Output(number, elem) => (
|
||||
format!("output_{number}_global"),
|
||||
format!("output_{number}"),
|
||||
elem,
|
||||
),
|
||||
Variable::Scalar(_, _, _) => panic!("Can't read global scalar variable."),
|
||||
Variable::Constant(_, _) => panic!("Can't read global constant variable."),
|
||||
};
|
||||
|
||||
let offset = match elem {
|
||||
Item::Vec4(_) => 4,
|
||||
Item::Vec3(_) => 3,
|
||||
Item::Vec2(_) => 2,
|
||||
Item::Scalar(_) => 1,
|
||||
};
|
||||
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
var index_{local}: u32 = 0u;
|
||||
|
||||
for (var i: u32 = 1u; i <= rank; i++) {{
|
||||
let position = {position}u * (2u * rank);
|
||||
let position_out = {position_out}u * (2u * rank);
|
||||
|
||||
let stride = info[position + i];
|
||||
let stride_out = info[position_out + i];
|
||||
let shape = info[position + rank + i];
|
||||
|
||||
index_{local} += (id * {offset}u) / stride_out % shape * stride;
|
||||
}}
|
||||
|
||||
let {local} = {elem}({global}[index_{local} / {offset}u]);
|
||||
"
|
||||
))
|
||||
}
|
||||
Operation::ConditionalAssign {
|
||||
Instruction::ConditionalAssign {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
|
@ -301,7 +283,6 @@ let {local} = {elem}({global}[index_{local} / {offset}u]);
|
|||
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
var {out}: {elem};
|
||||
if {cond}[0] {{
|
||||
{out}[0] = {lhs0};
|
||||
}} else {{
|
||||
|
@ -335,7 +316,6 @@ if {cond}[3] {{
|
|||
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
var {out}: {elem};
|
||||
if {cond}[0] {{
|
||||
{out}[0] = {lhs0};
|
||||
}} else {{
|
||||
|
@ -362,7 +342,6 @@ if {cond}[2] {{
|
|||
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
var {out}: {elem};
|
||||
if {cond}[0] {{
|
||||
{out}[0] = {lhs0};
|
||||
}} else {{
|
||||
|
@ -378,7 +357,6 @@ if {cond}[1] {{
|
|||
}
|
||||
Item::Scalar(_) => f.write_fmt(format_args!(
|
||||
"
|
||||
var {out}: {elem};
|
||||
if {cond} {{
|
||||
{out} = {lhs};
|
||||
}} else {{
|
||||
|
@ -388,6 +366,29 @@ if {cond} {{
|
|||
)),
|
||||
}
|
||||
}
|
||||
Instruction::Stride { dim, position, out } => f.write_fmt(format_args!(
|
||||
"{out} = info[({position}u * (2u * rank)) + {dim} + 1u];\n"
|
||||
)),
|
||||
Instruction::Shape { dim, position, out } => f.write_fmt(format_args!(
|
||||
"{out} = info[({position}u * (2u * rank)) + rank + {dim} + 1u];\n"
|
||||
)),
|
||||
Instruction::RangeLoop {
|
||||
i,
|
||||
start,
|
||||
end,
|
||||
instructions,
|
||||
} => {
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
||||
"
|
||||
))?;
|
||||
for instruction in instructions {
|
||||
f.write_fmt(format_args!("{instruction}"))?;
|
||||
}
|
||||
|
||||
f.write_str("}\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -412,7 +413,7 @@ fn comparison(
|
|||
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
let {out} = vec4({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2}, {lhs3} {op} {rhs3});
|
||||
{out} = vec4({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2}, {lhs3} {op} {rhs3});
|
||||
"
|
||||
))
|
||||
}
|
||||
|
@ -426,7 +427,7 @@ let {out} = vec4({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2}, {lh
|
|||
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
let {out} = vec3({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2});
|
||||
{out} = vec3({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2});
|
||||
"
|
||||
))
|
||||
}
|
||||
|
@ -438,12 +439,12 @@ let {out} = vec3({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1}, {lhs2} {op} {rhs2});
|
|||
|
||||
f.write_fmt(format_args!(
|
||||
"
|
||||
let {out} = vec2({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1});
|
||||
{out} = vec2({lhs0} {op} {rhs0}, {lhs1} {op} {rhs1});
|
||||
"
|
||||
))
|
||||
}
|
||||
Item::Scalar(_) => match rhs.item() {
|
||||
Item::Scalar(_) => f.write_fmt(format_args!("let {out} = {lhs} {op} {rhs};")),
|
||||
Item::Scalar(_) => f.write_fmt(format_args!("{out} = {lhs} {op} {rhs};\n")),
|
||||
_ => panic!("Can only compare a scalar when the output is a scalar"),
|
||||
},
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{Body, Extension, Item};
|
||||
use super::{Extension, Item, Scope};
|
||||
use crate::codegen::dialect::gpu::WorkgroupSize;
|
||||
use std::fmt::Display;
|
||||
|
||||
|
@ -31,7 +31,7 @@ pub struct ComputeShader {
|
|||
pub workgroup_size: WorkgroupSize,
|
||||
pub global_invocation_id: bool,
|
||||
pub num_workgroups: bool,
|
||||
pub body: Body,
|
||||
pub body: Scope,
|
||||
pub extensions: Vec<Extension>,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,343 +1,8 @@
|
|||
use burn_compute::client::ComputeClient;
|
||||
|
||||
use crate::codegen::dialect::gpu::{
|
||||
Binding, Body, ComputeShader, Elem, Item, Location, Operation, ReadGlobalOperation,
|
||||
ReadGlobalWithLayoutOperation, UnaryOperation, Variable, Vectorization, Visibility,
|
||||
WorkgroupSize,
|
||||
};
|
||||
use crate::compute::StaticKernel;
|
||||
use crate::element::JitElement;
|
||||
use crate::kernel::{elemwise_workgroup, StaticKernelSource, WORKGROUP_DEFAULT};
|
||||
use crate::Runtime;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
/// Kernel creation input phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||
pub struct InputPhase;
|
||||
/// Kernel creation body phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||
pub struct BodyPhase;
|
||||
/// Kernel creation output phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||
pub struct OutputPhase;
|
||||
/// Kernel compilation phase, see [kernel codegen](ElemWiseKernelCodegen) for more details.
|
||||
pub struct CompilationPhase;
|
||||
|
||||
#[derive(new, Clone, Copy)]
|
||||
pub struct InplaceMapping {
|
||||
pub position_input: usize,
|
||||
pub position_output: usize,
|
||||
}
|
||||
|
||||
/// Allows to create custom wgsl kernels based on configured inputs, body and outputs.
|
||||
///
|
||||
/// This type has 4 phases that must be executed in order, but no worry the type system won't allow
|
||||
/// you to make mistakes.
|
||||
///
|
||||
/// 1. [Input Phase](InputPhase)
|
||||
/// This phase focuses on registering the input arrays and scalars that are going to be used by
|
||||
/// the kernel.
|
||||
/// 2. [Body Phase](BodyPhase)
|
||||
/// After the input phase is done, all the operations that happen in the body must be
|
||||
/// registered.
|
||||
/// 3. [Output Phase](OutputPhase)
|
||||
/// This step focuses on registering all output arrays or inputs that the kernel needs to write to.
|
||||
/// 4. [Compilation Phase](CompilationPhase)
|
||||
/// Now that all other phases are completed, we can actually compile the kernel.
|
||||
pub struct ElemWiseKernelCodegen<Phase = InputPhase> {
|
||||
operations: Vec<Operation>,
|
||||
input_bindings: Vec<Binding>,
|
||||
output_bindings: Vec<Binding>,
|
||||
named_bindings: Vec<(String, Binding)>,
|
||||
vectorization: Vectorization,
|
||||
mappings_inplace: Vec<InplaceMapping>,
|
||||
workgroup_size: WorkgroupSize,
|
||||
_phase: PhantomData<Phase>,
|
||||
}
|
||||
|
||||
pub enum Input {
|
||||
Array {
|
||||
item: Item,
|
||||
visibility: Visibility,
|
||||
strategy: ReadingStrategy,
|
||||
},
|
||||
Scalar {
|
||||
elem: Elem,
|
||||
size: usize,
|
||||
},
|
||||
}
|
||||
|
||||
pub enum ReadingStrategy {
|
||||
/// Each element will be read in a way to be compatible with the output layout.
|
||||
OutputLayout,
|
||||
/// Keep the current layout.
|
||||
Plain,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub enum Output {
|
||||
Array { item: Item, local: u16 },
|
||||
Input { item: Item, input: u16, local: u16 },
|
||||
}
|
||||
|
||||
impl Default for ElemWiseKernelCodegen<InputPhase> {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
operations: Vec::new(),
|
||||
input_bindings: Vec::new(),
|
||||
output_bindings: Vec::new(),
|
||||
named_bindings: Vec::new(),
|
||||
vectorization: Vectorization::Scalar,
|
||||
mappings_inplace: Vec::new(),
|
||||
workgroup_size: WorkgroupSize::default(),
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemWiseKernelCodegen<InputPhase> {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn vectorize(mut self, vectorization: Vectorization) -> Self {
|
||||
self.vectorization = vectorization;
|
||||
self
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn inplace(mut self, mappings: &[InplaceMapping]) -> Self {
|
||||
self.mappings_inplace = mappings.to_vec();
|
||||
self
|
||||
}
|
||||
|
||||
/// Register the inputs used by the kernel.
|
||||
pub fn inputs(mut self, inputs: &[Input]) -> ElemWiseKernelCodegen<BodyPhase> {
|
||||
let mut index: u16 = 0;
|
||||
|
||||
for input in inputs {
|
||||
match input {
|
||||
Input::Array {
|
||||
item,
|
||||
visibility,
|
||||
strategy,
|
||||
} => {
|
||||
let item = item.vectorize(self.vectorization);
|
||||
|
||||
self.input_bindings.push(Binding {
|
||||
item: bool_item(item),
|
||||
visibility: *visibility,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
|
||||
match strategy {
|
||||
ReadingStrategy::OutputLayout => {
|
||||
self.operations.push(Operation::ReadGlobalWithLayout(
|
||||
ReadGlobalWithLayoutOperation {
|
||||
variable: Variable::Input(index, item),
|
||||
tensor_read_pos: index as usize,
|
||||
tensor_layout_pos: 0, // Will set the right value during the output
|
||||
// phase.
|
||||
},
|
||||
));
|
||||
}
|
||||
ReadingStrategy::Plain => {
|
||||
self.operations
|
||||
.push(Operation::ReadGlobal(ReadGlobalOperation {
|
||||
variable: Variable::Input(index, item),
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
index += 1;
|
||||
}
|
||||
Input::Scalar { elem, size } => {
|
||||
let elem = bool_elem(*elem);
|
||||
|
||||
self.named_bindings.push((
|
||||
format!("scalars_{}", elem),
|
||||
Binding {
|
||||
item: Item::Scalar(elem),
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: Some(*size),
|
||||
},
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ElemWiseKernelCodegen {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
vectorization: self.vectorization,
|
||||
mappings_inplace: self.mappings_inplace,
|
||||
workgroup_size: self.workgroup_size,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemWiseKernelCodegen<BodyPhase> {
|
||||
/// Register the [operators](Operator) that the kernel must execute in the order provided.
|
||||
pub fn body(mut self, operators: &[Operation]) -> ElemWiseKernelCodegen<OutputPhase> {
|
||||
for ops in operators.iter() {
|
||||
self.operations.push(ops.vectorize(self.vectorization));
|
||||
}
|
||||
|
||||
ElemWiseKernelCodegen {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
vectorization: self.vectorization,
|
||||
mappings_inplace: self.mappings_inplace,
|
||||
workgroup_size: self.workgroup_size,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemWiseKernelCodegen<OutputPhase> {
|
||||
/// Register the outputs with their local variable index.
|
||||
///
|
||||
/// Note that the index corresponds to the registered [operator](Operator) number at the
|
||||
/// [body phase](BodyPhase).
|
||||
/// So the 4th operator registered creates the local variable 3 (N-1, since the 1th index is 0).
|
||||
pub fn outputs(mut self, outputs: &[Output]) -> ElemWiseKernelCodegen<CompilationPhase> {
|
||||
let mut index = 0;
|
||||
let mut position_out = 0;
|
||||
|
||||
let mut outputs = outputs.to_vec();
|
||||
|
||||
for mapping in self.mappings_inplace.iter() {
|
||||
match outputs.get_mut(mapping.position_output) {
|
||||
Some(output) => match output {
|
||||
Output::Array { item, local } => {
|
||||
*output = Output::Input {
|
||||
item: *item,
|
||||
input: mapping.position_input as u16,
|
||||
local: *local,
|
||||
};
|
||||
}
|
||||
Output::Input {
|
||||
item: _,
|
||||
input: _,
|
||||
local: _,
|
||||
} => continue,
|
||||
},
|
||||
None => continue,
|
||||
}
|
||||
|
||||
if let Some(binding) = self.input_bindings.get_mut(mapping.position_input) {
|
||||
binding.visibility = Visibility::ReadWrite
|
||||
}
|
||||
}
|
||||
|
||||
for array in &outputs {
|
||||
match array {
|
||||
Output::Array { item, local } => {
|
||||
let item = item.vectorize(self.vectorization);
|
||||
let elem_adapted = bool_item(item);
|
||||
|
||||
self.output_bindings.push(Binding {
|
||||
item: elem_adapted,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
self.operations
|
||||
.push(Operation::AssignGlobal(UnaryOperation {
|
||||
input: Variable::Local(*local, item),
|
||||
out: Variable::Output(index, elem_adapted),
|
||||
}));
|
||||
index += 1;
|
||||
|
||||
if index == 1 {
|
||||
position_out = self.input_bindings.len(); // First output when we have a
|
||||
// new array for the output.
|
||||
}
|
||||
}
|
||||
Output::Input { item, input, local } => {
|
||||
let item = item.vectorize(self.vectorization);
|
||||
|
||||
self.operations
|
||||
.push(Operation::AssignGlobal(UnaryOperation {
|
||||
input: Variable::Local(*local, item),
|
||||
out: Variable::Input(*input, bool_item(item)),
|
||||
}));
|
||||
position_out = *input as usize; // Input number when we use inplace operation.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// We set the output number that will be used for the stride definition.
|
||||
for i in 0..self.input_bindings.len() {
|
||||
if let Some(Operation::ReadGlobalWithLayout(ReadGlobalWithLayoutOperation {
|
||||
variable: _,
|
||||
tensor_read_pos: _,
|
||||
tensor_layout_pos,
|
||||
})) = self.operations.get_mut(i)
|
||||
{
|
||||
{
|
||||
*tensor_layout_pos = position_out;
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
ElemWiseKernelCodegen {
|
||||
operations: self.operations,
|
||||
input_bindings: self.input_bindings,
|
||||
output_bindings: self.output_bindings,
|
||||
named_bindings: self.named_bindings,
|
||||
vectorization: self.vectorization,
|
||||
mappings_inplace: self.mappings_inplace,
|
||||
workgroup_size: self.workgroup_size,
|
||||
_phase: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ElemWiseKernelCodegen<CompilationPhase> {
|
||||
#[allow(dead_code)] // Only used for fusion for now.
|
||||
pub fn workgroup_size(mut self, workgroup_size: WorkgroupSize) -> Self {
|
||||
self.workgroup_size = workgroup_size;
|
||||
self
|
||||
}
|
||||
|
||||
/// Compile the kernel into a [compute shader](ComputeShader).
|
||||
pub fn compile(self) -> ComputeShader {
|
||||
let inputs = self.input_bindings;
|
||||
let outputs = self.output_bindings;
|
||||
let mut named = Vec::with_capacity(2);
|
||||
|
||||
named.push((
|
||||
"info".to_string(),
|
||||
Binding {
|
||||
item: Item::Scalar(Elem::UInt),
|
||||
visibility: Visibility::Read,
|
||||
location: Location::Storage,
|
||||
size: None, // We avoid putting the length here since it will force a new kernel
|
||||
// for each tensor rank.
|
||||
},
|
||||
));
|
||||
|
||||
for (name, binding) in self.named_bindings.into_iter() {
|
||||
named.push((name, binding));
|
||||
}
|
||||
|
||||
ComputeShader {
|
||||
inputs,
|
||||
outputs,
|
||||
named,
|
||||
workgroup_size: self.workgroup_size,
|
||||
body: Body::new(self.operations),
|
||||
num_workgroups: true,
|
||||
global_invocation_id: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
use burn_compute::client::ComputeClient;
|
||||
|
||||
#[derive(new)]
|
||||
pub struct StaticHandle<'a, R: Runtime> {
|
||||
|
@ -433,20 +98,3 @@ pub(crate) fn calculate_num_elems_dyn_rank(shape: &[usize]) -> usize {
|
|||
}
|
||||
num_elems
|
||||
}
|
||||
|
||||
fn bool_item(ty: Item) -> Item {
|
||||
match ty {
|
||||
Item::Vec4(elem) => Item::Vec4(bool_elem(elem)),
|
||||
Item::Vec3(elem) => Item::Vec3(bool_elem(elem)),
|
||||
Item::Vec2(elem) => Item::Vec2(bool_elem(elem)),
|
||||
Item::Scalar(elem) => Item::Scalar(bool_elem(elem)),
|
||||
}
|
||||
}
|
||||
|
||||
fn bool_elem(elem: Elem) -> Elem {
|
||||
match elem {
|
||||
// U32 are used for bool tensors
|
||||
Elem::Bool => Elem::UInt,
|
||||
_ => elem,
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
mod compilation;
|
||||
pub(crate) mod compiler;
|
||||
pub(crate) mod dialect;
|
||||
|
||||
mod kernel;
|
||||
|
||||
pub(crate) use compilation::*;
|
||||
pub(crate) use compiler::*;
|
||||
pub(crate) use kernel::*;
|
||||
|
|
|
@ -75,7 +75,7 @@ mod tests {
|
|||
use super::*;
|
||||
use crate::{
|
||||
binary,
|
||||
codegen::dialect::gpu::{BinaryOperation, Elem, Item, Operation, Variable},
|
||||
codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope},
|
||||
kernel::{KernelSettings, WORKGROUP_DEFAULT},
|
||||
tests::{TestCompiler, TestRuntime},
|
||||
Runtime, WgpuDevice,
|
||||
|
@ -84,10 +84,10 @@ mod tests {
|
|||
#[test]
|
||||
fn can_run_kernel() {
|
||||
binary!(
|
||||
operation: |elem: Elem| Operation::Add(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Add(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
compiler: TestCompiler,
|
||||
elem_in: f32,
|
||||
|
|
|
@ -70,7 +70,7 @@ impl<R: Runtime> FusionBackend for JitBackend<R> {
|
|||
type OptimizationState = WgpuOptimizationState;
|
||||
type Optimization = WgpuOptimization<R>;
|
||||
type FusionDevice = R::Device;
|
||||
type Handle = WgpuFusionHandle<R>;
|
||||
type Handle = JitFusionHandle<R>;
|
||||
type FusionClient = MutexFusionClient<Self>;
|
||||
|
||||
fn optimizations(
|
||||
|
@ -126,7 +126,7 @@ pub fn strides_dyn_rank(shape: &[usize]) -> Vec<usize> {
|
|||
}
|
||||
|
||||
/// Handle to be used when fusing operations.
|
||||
pub struct WgpuFusionHandle<R: Runtime> {
|
||||
pub struct JitFusionHandle<R: Runtime> {
|
||||
/// Compute client for wgpu.
|
||||
pub client: ComputeClient<R::Server, R::Channel>,
|
||||
/// The buffer where the data are stored.
|
||||
|
@ -136,13 +136,17 @@ pub struct WgpuFusionHandle<R: Runtime> {
|
|||
pub(crate) strides: Vec<usize>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> core::fmt::Debug for WgpuFusionHandle<R> {
|
||||
fn fmt(&self, _f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
todo!()
|
||||
impl<R: Runtime> core::fmt::Debug for JitFusionHandle<R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.write_fmt(format_args!(
|
||||
"JitFusionHandle {{ device: {:?}, runtime: {}}}",
|
||||
self.device,
|
||||
R::name(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime> Clone for WgpuFusionHandle<R> {
|
||||
impl<R: Runtime> Clone for JitFusionHandle<R> {
|
||||
fn clone(&self) -> Self {
|
||||
Self {
|
||||
client: self.client.clone(),
|
||||
|
@ -153,10 +157,10 @@ impl<R: Runtime> Clone for WgpuFusionHandle<R> {
|
|||
}
|
||||
}
|
||||
|
||||
unsafe impl<R: Runtime> Send for WgpuFusionHandle<R> {}
|
||||
unsafe impl<R: Runtime> Sync for WgpuFusionHandle<R> {}
|
||||
unsafe impl<R: Runtime> Send for JitFusionHandle<R> {}
|
||||
unsafe impl<R: Runtime> Sync for JitFusionHandle<R> {}
|
||||
|
||||
impl<R: Runtime> WgpuFusionHandle<R> {
|
||||
impl<R: Runtime> JitFusionHandle<R> {
|
||||
pub(crate) fn into_tensor<const D: usize, E: JitElement>(
|
||||
self,
|
||||
shape: Shape<D>,
|
||||
|
@ -172,7 +176,7 @@ impl<R: Runtime> WgpuFusionHandle<R> {
|
|||
}
|
||||
}
|
||||
|
||||
impl<R: Runtime, E: JitElement, const D: usize> From<JitTensor<R, E, D>> for WgpuFusionHandle<R> {
|
||||
impl<R: Runtime, E: JitElement, const D: usize> From<JitTensor<R, E, D>> for JitFusionHandle<R> {
|
||||
fn from(value: JitTensor<R, E, D>) -> Self {
|
||||
Self {
|
||||
client: value.client,
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
use super::{optimization::ElementWise, CompilationPhase, Scalars};
|
||||
use super::{optimization::ElementWise, CompilationPhase};
|
||||
use crate::{
|
||||
codegen::dialect::gpu::{
|
||||
BinaryOperation, ConditionalAssignOperation, Elem, Item, Operation, UnaryOperation,
|
||||
Variable,
|
||||
BinaryOperator, ConditionalAssignOperator, Elem, Operator, UnaryOperator, Variable,
|
||||
},
|
||||
element::JitElement,
|
||||
fusion::WgpuOptimization,
|
||||
fusion::{tracing::TraceBuilder, WgpuOptimization},
|
||||
JitBackend, Runtime,
|
||||
};
|
||||
use burn_fusion::{
|
||||
|
@ -14,27 +13,20 @@ use burn_fusion::{
|
|||
NumericOperationDescription, OperationDescription, ScalarOperationDescription,
|
||||
UnaryOperationDescription,
|
||||
},
|
||||
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription, TensorId,
|
||||
OptimizationBuilder, OptimizationProperties, OptimizationStatus, TensorDescription,
|
||||
};
|
||||
use burn_tensor::{
|
||||
ops::{FloatElem, IntElem},
|
||||
Device, Element,
|
||||
};
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// Fused element wise operations that are normally memory bound.
|
||||
pub(crate) struct ElementWiseBuilder<R: Runtime> {
|
||||
pub(crate) inputs: Vec<TensorDescription>,
|
||||
pub(crate) locals: HashMap<TensorId, u16>,
|
||||
pub(crate) tensors: HashMap<TensorId, (TensorDescription, Elem)>,
|
||||
pub(crate) scalars_float: usize,
|
||||
pub(crate) scalars_int: usize,
|
||||
pub(crate) scalars_uint: usize,
|
||||
pub(crate) booleans: usize,
|
||||
pub(crate) operators: Vec<Operation>,
|
||||
pub(crate) current_output_shape: Vec<usize>,
|
||||
pub(crate) status: OptimizationStatus,
|
||||
pub(crate) device: R::Device,
|
||||
builder: TraceBuilder,
|
||||
current_output_shape: Vec<usize>,
|
||||
status: OptimizationStatus,
|
||||
num_added: usize,
|
||||
device: R::Device,
|
||||
}
|
||||
|
||||
impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder<R> {
|
||||
|
@ -81,22 +73,13 @@ impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder
|
|||
};
|
||||
|
||||
self.status = OptimizationStatus::Open;
|
||||
self.num_added += 1;
|
||||
}
|
||||
|
||||
fn build(&self) -> WgpuOptimization<R> {
|
||||
let inputs = self.input_descriptions();
|
||||
let outputs = self.output_descriptions();
|
||||
let locals = outputs
|
||||
.iter()
|
||||
.map(|out| *self.locals.get(&out.0.id).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let op = ElementWise::new(
|
||||
inputs,
|
||||
outputs,
|
||||
locals,
|
||||
Scalars::new(self.scalars_float, self.scalars_uint, self.scalars_int),
|
||||
self.operators.clone(),
|
||||
self.builder.clone().build(),
|
||||
self.num_added,
|
||||
self.device.clone(),
|
||||
CompilationPhase,
|
||||
);
|
||||
|
@ -105,18 +88,12 @@ impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder
|
|||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.operators.len()
|
||||
self.num_added
|
||||
}
|
||||
|
||||
fn reset(&mut self) {
|
||||
self.inputs.clear();
|
||||
self.locals.drain();
|
||||
self.tensors.clear();
|
||||
self.scalars_float = 0;
|
||||
self.scalars_int = 0;
|
||||
self.scalars_uint = 0;
|
||||
self.booleans = 0;
|
||||
self.operators.clear();
|
||||
self.builder = TraceBuilder::new();
|
||||
self.num_added = 0;
|
||||
self.status = OptimizationStatus::Open;
|
||||
self.current_output_shape.clear();
|
||||
}
|
||||
|
@ -126,11 +103,11 @@ impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder
|
|||
}
|
||||
|
||||
fn properties(&self) -> OptimizationProperties {
|
||||
let ready = !self.operators.is_empty();
|
||||
let ready = self.num_added > 0;
|
||||
|
||||
OptimizationProperties {
|
||||
ready,
|
||||
score: self.operators.len() as u64,
|
||||
score: self.num_added as u64,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -138,269 +115,20 @@ impl<R: Runtime> OptimizationBuilder<WgpuOptimization<R>> for ElementWiseBuilder
|
|||
impl<R: Runtime> ElementWiseBuilder<R> {
|
||||
pub fn new(device: Device<JitBackend<R>>) -> Self {
|
||||
Self {
|
||||
inputs: Vec::new(),
|
||||
locals: HashMap::new(),
|
||||
tensors: HashMap::new(),
|
||||
scalars_float: 0,
|
||||
scalars_int: 0,
|
||||
scalars_uint: 0,
|
||||
booleans: 0,
|
||||
operators: Vec::new(),
|
||||
builder: TraceBuilder::new(),
|
||||
num_added: 0,
|
||||
current_output_shape: Vec::new(),
|
||||
status: OptimizationStatus::Open,
|
||||
device,
|
||||
}
|
||||
}
|
||||
|
||||
fn input_descriptions(&self) -> Vec<(TensorDescription, Elem)> {
|
||||
self.inputs
|
||||
.iter()
|
||||
.map(|input| {
|
||||
let updated_tensor = self.tensors.get(&input.id).unwrap();
|
||||
updated_tensor.clone()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn output_descriptions(&self) -> Vec<(TensorDescription, Elem)> {
|
||||
let mut outputs = Vec::new();
|
||||
let mut local_tensor_ids_input = Vec::new();
|
||||
let mut local_tensor_ids_output = Vec::new();
|
||||
|
||||
// Mark a variable to the provided list of tensor ids using the variable list.
|
||||
//
|
||||
// Only local variables can become outputs.
|
||||
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
|
||||
if let Variable::Local(index, _) = var {
|
||||
if let Some((id, _)) = self
|
||||
.locals
|
||||
.iter()
|
||||
.find(|(_id, position)| *position == index)
|
||||
{
|
||||
if !list.contains(id) {
|
||||
list.push(*id);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let mark_binary =
|
||||
|op: &BinaryOperation, inputs: &mut Vec<TensorId>, outputs: &mut Vec<TensorId>| {
|
||||
mark(&op.lhs, inputs);
|
||||
mark(&op.rhs, inputs);
|
||||
mark(&op.out, outputs);
|
||||
};
|
||||
let mark_unary =
|
||||
|op: &UnaryOperation, inputs: &mut Vec<TensorId>, outputs: &mut Vec<TensorId>| {
|
||||
mark(&op.input, inputs);
|
||||
mark(&op.out, outputs);
|
||||
};
|
||||
|
||||
// For all operators, mark their local tensor id in the proper set.
|
||||
for ops in self.operators.iter() {
|
||||
match ops {
|
||||
Operation::AssignGlobal(_) => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operation::AssignLocal(op) => {
|
||||
mark(&op.out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operation::ReadGlobalWithLayout(_) => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operation::ReadGlobal(_) => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
Operation::Add(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Sub(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Mul(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Div(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Exp(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Abs(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Erf(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Log(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Log1p(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Cos(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Sin(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Tanh(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Clamp(op) => {
|
||||
mark(&op.input, &mut local_tensor_ids_input);
|
||||
mark(&op.out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operation::Powf(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Recip(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Lower(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Greater(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::LowerEqual(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::GreaterEqual(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::Equal(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
Operation::ConditionalAssign(op) => {
|
||||
mark(&op.cond, &mut local_tensor_ids_input);
|
||||
mark(&op.lhs, &mut local_tensor_ids_input);
|
||||
mark(&op.rhs, &mut local_tensor_ids_input);
|
||||
mark(&op.out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operation::Sqrt(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
// All output tensors that are never read by a following operation should be written to
|
||||
// since they are essentially the "logical" output of the shader.
|
||||
for out in local_tensor_ids_output {
|
||||
let is_read = local_tensor_ids_input.contains(&out);
|
||||
|
||||
if !is_read {
|
||||
outputs.push(self.tensors.get(&out).unwrap().clone());
|
||||
}
|
||||
}
|
||||
|
||||
// All tensors where their latest description is read only should be written to since they
|
||||
// are going to be used after the fused kernel by other operations.
|
||||
for entry in self.tensors.values() {
|
||||
let (tensor, _) = &entry;
|
||||
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
|
||||
if self.locals.contains_key(&tensor.id) {
|
||||
outputs.push(entry.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outputs
|
||||
}
|
||||
|
||||
fn input_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
|
||||
let already_exists = self.tensors.contains_key(&tensor.id);
|
||||
|
||||
let variable = match already_exists {
|
||||
false => {
|
||||
// New input
|
||||
let var = Variable::Input(self.inputs.len() as u16, Item::Scalar(elem));
|
||||
self.inputs.push(tensor.clone());
|
||||
var
|
||||
}
|
||||
true => match self.locals.get(&tensor.id) {
|
||||
// Is a local variable.
|
||||
Some(local_index) => Variable::Local(*local_index, Item::Scalar(elem)),
|
||||
// Isn't a local variable, so must be an existing input.
|
||||
None => {
|
||||
let input = self
|
||||
.inputs
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, input)| input.id == tensor.id)
|
||||
.unwrap();
|
||||
let input_index = input.0;
|
||||
Variable::Input(input_index as u16, Item::Scalar(elem))
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Update the tensor description with the new version.
|
||||
self.tensors.insert(tensor.id, (tensor.clone(), elem));
|
||||
|
||||
variable
|
||||
}
|
||||
|
||||
fn output_to_var(&mut self, tensor: &TensorDescription, elem: Elem) -> Variable {
|
||||
// Update the tensor description to the new version.
|
||||
self.tensors.insert(tensor.id, (tensor.clone(), elem));
|
||||
|
||||
// Output already registered as a local variable.
|
||||
if let Some(index) = self.locals.get(&tensor.id) {
|
||||
return Variable::Local(*index, Item::Scalar(elem));
|
||||
}
|
||||
|
||||
// New local variable.
|
||||
let local_index = self.locals.len() as u16;
|
||||
self.locals.insert(tensor.id, local_index);
|
||||
Variable::Local(local_index, Item::Scalar(elem))
|
||||
}
|
||||
|
||||
fn register_base<E: JitElement>(&mut self, ops: &BaseOperationDescription) -> bool {
|
||||
match ops {
|
||||
BaseOperationDescription::Equal(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::Equal(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Equal(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
_ => false,
|
||||
}
|
||||
|
@ -410,47 +138,47 @@ impl<R: Runtime> ElementWiseBuilder<R> {
|
|||
match ops {
|
||||
FloatOperationDescription::Exp(desc) => {
|
||||
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
|
||||
Operation::Exp(UnaryOperation { input, out })
|
||||
Operator::Exp(UnaryOperator { input, out })
|
||||
})
|
||||
}
|
||||
FloatOperationDescription::Log(desc) => {
|
||||
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
|
||||
Operation::Log(UnaryOperation { input, out })
|
||||
Operator::Log(UnaryOperator { input, out })
|
||||
})
|
||||
}
|
||||
FloatOperationDescription::Log1p(desc) => {
|
||||
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
|
||||
Operation::Log1p(UnaryOperation { input, out })
|
||||
Operator::Log1p(UnaryOperator { input, out })
|
||||
})
|
||||
}
|
||||
FloatOperationDescription::Cos(desc) => {
|
||||
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
|
||||
Operation::Cos(UnaryOperation { input, out })
|
||||
Operator::Cos(UnaryOperator { input, out })
|
||||
})
|
||||
}
|
||||
FloatOperationDescription::Sin(desc) => {
|
||||
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
|
||||
Operation::Sin(UnaryOperation { input, out })
|
||||
Operator::Sin(UnaryOperator { input, out })
|
||||
})
|
||||
}
|
||||
FloatOperationDescription::PowfScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|
||||
|lhs, rhs, out| Operation::Powf(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Powf(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
FloatOperationDescription::Tanh(desc) => {
|
||||
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
|
||||
Operation::Tanh(UnaryOperation { input, out })
|
||||
Operator::Tanh(UnaryOperator { input, out })
|
||||
})
|
||||
}
|
||||
FloatOperationDescription::Erf(desc) => {
|
||||
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
|
||||
Operation::Erf(UnaryOperation { input, out })
|
||||
Operator::Erf(UnaryOperator { input, out })
|
||||
})
|
||||
}
|
||||
FloatOperationDescription::Recip(desc) => {
|
||||
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
|
||||
Operation::Recip(UnaryOperation { input, out })
|
||||
Operator::Recip(UnaryOperator { input, out })
|
||||
})
|
||||
}
|
||||
_ => false,
|
||||
|
@ -465,110 +193,111 @@ impl<R: Runtime> ElementWiseBuilder<R> {
|
|||
NumericOperationDescription::Add(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|
||||
|lhs, rhs, out| Operation::Add(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Add(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::AddScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|
||||
|lhs, rhs, out| Operation::Add(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Add(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::Sub(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|
||||
|lhs, rhs, out| Operation::Sub(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Sub(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::SubScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|
||||
|lhs, rhs, out| Operation::Sub(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Sub(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::Mul(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|
||||
|lhs, rhs, out| Operation::Mul(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Mul(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::MulScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|
||||
|lhs, rhs, out| Operation::Mul(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Mul(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::Div(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|
||||
|lhs, rhs, out| Operation::Div(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Div(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::DivScalar(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), E::gpu_elem()),
|
||||
|lhs, rhs, out| Operation::Div(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Div(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::Abs(desc) => {
|
||||
self.register_unary_ops(desc, (E::gpu_elem(), E::gpu_elem()), |input, out| {
|
||||
Operation::Abs(UnaryOperation { input, out })
|
||||
Operator::Abs(UnaryOperator { input, out })
|
||||
})
|
||||
}
|
||||
NumericOperationDescription::Lower(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::Lower(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Lower(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::LowerElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::Lower(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Lower(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::Greater(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::Greater(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Greater(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::GreaterElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::Greater(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Greater(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::LowerEqual(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::LowerEqual(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::LowerEqual(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::LowerEqualElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::LowerEqual(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::LowerEqual(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::GreaterEqual(desc) => self.register_binary_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::GreaterEqual(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::GreaterEqualElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::GreaterEqual(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::GreaterEqual(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::EqualElem(desc) => self.register_scalar_ops(
|
||||
desc,
|
||||
(E::gpu_elem(), E::gpu_elem(), Elem::Bool),
|
||||
|lhs, rhs, out| Operation::Equal(BinaryOperation { lhs, rhs, out }),
|
||||
|lhs, rhs, out| Operator::Equal(BinaryOperator { lhs, rhs, out }),
|
||||
),
|
||||
NumericOperationDescription::MaskWhere(desc) => {
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let cond = self.input_to_var(&desc.mask, Elem::Bool);
|
||||
let lhs = self.input_to_var(&desc.value, E::gpu_elem());
|
||||
let rhs = self.input_to_var(&desc.tensor, E::gpu_elem());
|
||||
let out = self.output_to_var(&desc.out, E::gpu_elem());
|
||||
let cond = self.builder.input(&desc.mask, Elem::Bool);
|
||||
let lhs = self.builder.input(&desc.value, E::gpu_elem());
|
||||
let rhs = self.builder.input(&desc.tensor, E::gpu_elem());
|
||||
let out = self.builder.output(&desc.out, E::gpu_elem());
|
||||
|
||||
let ops = Operation::ConditionalAssign(ConditionalAssignOperation {
|
||||
self.builder.register_operation(Operator::ConditionalAssign(
|
||||
ConditionalAssignOperator {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
});
|
||||
self.operators.push(ops);
|
||||
},
|
||||
));
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -577,18 +306,19 @@ impl<R: Runtime> ElementWiseBuilder<R> {
|
|||
return false;
|
||||
}
|
||||
|
||||
let cond = self.input_to_var(&desc.mask, Elem::Bool);
|
||||
let lhs = self.scalar_to_var(&desc.value, E::gpu_elem());
|
||||
let rhs = self.input_to_var(&desc.tensor, E::gpu_elem());
|
||||
let out = self.output_to_var(&desc.out, E::gpu_elem());
|
||||
let cond = self.builder.input(&desc.mask, Elem::Bool);
|
||||
let lhs = self.builder.scalar(&desc.value, E::gpu_elem());
|
||||
let rhs = self.builder.input(&desc.tensor, E::gpu_elem());
|
||||
let out = self.builder.output(&desc.out, E::gpu_elem());
|
||||
|
||||
self.operators
|
||||
.push(Operation::ConditionalAssign(ConditionalAssignOperation {
|
||||
self.builder.register_operation(Operator::ConditionalAssign(
|
||||
ConditionalAssignOperator {
|
||||
cond,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
}));
|
||||
},
|
||||
));
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -597,11 +327,11 @@ impl<R: Runtime> ElementWiseBuilder<R> {
|
|||
return false;
|
||||
}
|
||||
|
||||
let input = Variable::Constant(1.0, Item::Scalar(E::gpu_elem()));
|
||||
let out = self.output_to_var(desc, E::gpu_elem());
|
||||
let input = Variable::ConstantScalar(1.0, E::gpu_elem());
|
||||
let out = self.builder.output(desc, E::gpu_elem());
|
||||
|
||||
self.operators
|
||||
.push(Operation::AssignLocal(UnaryOperation { input, out }));
|
||||
self.builder
|
||||
.register_operation(Operator::AssignLocal(UnaryOperator { input, out }));
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -610,11 +340,11 @@ impl<R: Runtime> ElementWiseBuilder<R> {
|
|||
return false;
|
||||
}
|
||||
|
||||
let input = Variable::Constant(0.0, Item::Scalar(E::gpu_elem()));
|
||||
let out = self.output_to_var(desc, E::gpu_elem());
|
||||
let input = Variable::ConstantScalar(0.0, E::gpu_elem());
|
||||
let out = self.builder.output(desc, E::gpu_elem());
|
||||
|
||||
self.operators
|
||||
.push(Operation::AssignLocal(UnaryOperation { input, out }));
|
||||
self.builder
|
||||
.register_operation(Operator::AssignLocal(UnaryOperator { input, out }));
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -623,11 +353,11 @@ impl<R: Runtime> ElementWiseBuilder<R> {
|
|||
return false;
|
||||
}
|
||||
|
||||
let input = self.scalar_to_var(elem, E::gpu_elem());
|
||||
let out = self.output_to_var(desc, E::gpu_elem());
|
||||
let input = self.builder.scalar(elem, E::gpu_elem());
|
||||
let out = self.builder.output(desc, E::gpu_elem());
|
||||
|
||||
self.operators
|
||||
.push(Operation::AssignLocal(UnaryOperation { input, out }));
|
||||
self.builder
|
||||
.register_operation(Operator::AssignLocal(UnaryOperator { input, out }));
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -642,17 +372,17 @@ impl<R: Runtime> ElementWiseBuilder<R> {
|
|||
func: Func,
|
||||
) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable, Variable) -> Operation,
|
||||
Func: Fn(Variable, Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
|
||||
let rhs = self.input_to_var(&desc.rhs, elem_rhs);
|
||||
let out = self.output_to_var(&desc.out, elem_out);
|
||||
let lhs = self.builder.input(&desc.lhs, elem_lhs);
|
||||
let rhs = self.builder.input(&desc.rhs, elem_rhs);
|
||||
let out = self.builder.output(&desc.out, elem_out);
|
||||
|
||||
self.operators.push(func(lhs, rhs, out));
|
||||
self.builder.register_operation(func(lhs, rhs, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -664,16 +394,16 @@ impl<R: Runtime> ElementWiseBuilder<R> {
|
|||
func: Func,
|
||||
) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable) -> Operation,
|
||||
Func: Fn(Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let input = self.input_to_var(&desc.input, elem_input);
|
||||
let out = self.output_to_var(&desc.out, elem_out);
|
||||
let input = self.builder.input(&desc.input, elem_input);
|
||||
let out = self.builder.output(&desc.out, elem_out);
|
||||
|
||||
self.operators.push(func(input, out));
|
||||
self.builder.register_operation(func(input, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
@ -685,41 +415,21 @@ impl<R: Runtime> ElementWiseBuilder<R> {
|
|||
func: Func,
|
||||
) -> bool
|
||||
where
|
||||
Func: Fn(Variable, Variable, Variable) -> Operation,
|
||||
Func: Fn(Variable, Variable, Variable) -> Operator,
|
||||
{
|
||||
if !self.output_is_compatible(&desc.out) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let lhs = self.input_to_var(&desc.lhs, elem_lhs);
|
||||
let rhs = self.scalar_to_var(&desc.rhs, elem_rhs);
|
||||
let out = self.output_to_var(&desc.out, elem_out);
|
||||
let lhs = self.builder.input(&desc.lhs, elem_lhs);
|
||||
let rhs = self.builder.scalar(&desc.rhs, elem_rhs);
|
||||
let out = self.builder.output(&desc.out, elem_out);
|
||||
|
||||
self.operators.push(func(lhs, rhs, out));
|
||||
self.builder.register_operation(func(lhs, rhs, out));
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
fn scalar_to_var<E: Element>(&mut self, _value: &E, elem_type: Elem) -> Variable {
|
||||
match elem_type {
|
||||
Elem::Float => {
|
||||
self.scalars_float += 1;
|
||||
Variable::Scalar(self.scalars_float as u16 - 1, Item::Scalar(Elem::Float))
|
||||
}
|
||||
Elem::Int => {
|
||||
self.scalars_int += 1;
|
||||
Variable::Scalar(self.scalars_int as u16 - 1, Item::Scalar(Elem::Int))
|
||||
}
|
||||
Elem::UInt => {
|
||||
self.scalars_uint += 1;
|
||||
Variable::Scalar(self.scalars_uint as u16 - 1, Item::Scalar(Elem::UInt))
|
||||
}
|
||||
Elem::Bool => {
|
||||
panic!("Bool scalars not supported")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn output_is_compatible(&mut self, out: &TensorDescription) -> bool {
|
||||
if self.current_output_shape.is_empty() {
|
||||
self.current_output_shape = out.shape.clone();
|
||||
|
|
|
@ -4,7 +4,7 @@ use crate::{
|
|||
fusion::{
|
||||
kernel::{FusionKernel, OutputInfo, Priority, SelectedKernel},
|
||||
source::GpuKernelSource,
|
||||
WgpuFusionHandle,
|
||||
JitFusionHandle,
|
||||
},
|
||||
kernel::elemwise_workgroup,
|
||||
Runtime,
|
||||
|
@ -23,7 +23,7 @@ pub struct VecElementWise<R: Runtime> {
|
|||
impl<R: Runtime> FusionKernel<R> for ScalarElementWise<R> {
|
||||
fn kernel(
|
||||
&self,
|
||||
handles_inputs: &[WgpuFusionHandle<R>],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
) -> SelectedKernel {
|
||||
|
@ -32,7 +32,7 @@ impl<R: Runtime> FusionKernel<R> for ScalarElementWise<R> {
|
|||
|
||||
fn priority(
|
||||
&self,
|
||||
_handles_inputs: &[WgpuFusionHandle<R>],
|
||||
_handles_inputs: &[JitFusionHandle<R>],
|
||||
_inputs: &[&TensorDescription],
|
||||
_outputs: &[&TensorDescription],
|
||||
) -> Priority {
|
||||
|
@ -43,7 +43,7 @@ impl<R: Runtime> FusionKernel<R> for ScalarElementWise<R> {
|
|||
impl<R: Runtime> FusionKernel<R> for VecElementWise<R> {
|
||||
fn kernel(
|
||||
&self,
|
||||
handles_inputs: &[WgpuFusionHandle<R>],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
) -> SelectedKernel {
|
||||
|
@ -52,11 +52,11 @@ impl<R: Runtime> FusionKernel<R> for VecElementWise<R> {
|
|||
|
||||
fn priority(
|
||||
&self,
|
||||
handles_inputs: &[WgpuFusionHandle<R>],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
inputs: &[&TensorDescription],
|
||||
_outputs: &[&TensorDescription],
|
||||
) -> Priority {
|
||||
let is_unavailable_input = |handle: &WgpuFusionHandle<R>, desc: &TensorDescription| {
|
||||
let is_unavailable_input = |handle: &JitFusionHandle<R>, desc: &TensorDescription| {
|
||||
let rank = handle.strides.len();
|
||||
|
||||
// Last dimension strides should be 1, otherwise vecX won't be contiguous.
|
||||
|
@ -96,7 +96,7 @@ impl<R: Runtime> FusionKernel<R> for VecElementWise<R> {
|
|||
impl<R: Runtime> ElementWiseSource<R> {
|
||||
fn kernel(
|
||||
&self,
|
||||
handles_inputs: &[WgpuFusionHandle<R>],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
) -> SelectedKernel {
|
||||
|
@ -110,7 +110,7 @@ impl<R: Runtime> ElementWiseSource<R> {
|
|||
|
||||
match inplace_available(&self.mappings, handles_inputs) {
|
||||
true => {
|
||||
let reference_tensor = inputs[self.mappings[0].position_input];
|
||||
let reference_tensor = inputs[self.mappings[0].pos_input];
|
||||
let num_elems = calculate_num_elems_dyn_rank(&reference_tensor.shape);
|
||||
let workgroup = elemwise_workgroup(num_elems / self.factor, workgroup_size);
|
||||
let kernel = Box::new(DynamicKernel::new(self.source_inplace.clone(), workgroup));
|
||||
|
@ -172,7 +172,7 @@ impl<R: Runtime> ElementWiseSource<R> {
|
|||
let mut inplace_output2input = vec![None; num_output];
|
||||
|
||||
for mapping in mappings.iter() {
|
||||
inplace_output2input[mapping.position_output] = Some(mapping.position_input);
|
||||
inplace_output2input[mapping.pos_output] = Some(mapping.pos_input);
|
||||
}
|
||||
|
||||
Self {
|
||||
|
@ -214,14 +214,14 @@ impl<R: Runtime> VecElementWise<R> {
|
|||
|
||||
fn inplace_available<R: Runtime>(
|
||||
mappings: &[InplaceMapping],
|
||||
handles_inputs: &[WgpuFusionHandle<R>],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
) -> bool {
|
||||
if mappings.is_empty() {
|
||||
return false;
|
||||
}
|
||||
|
||||
for mapping in mappings.iter() {
|
||||
let handle = &handles_inputs[mapping.position_input];
|
||||
let handle = &handles_inputs[mapping.pos_input];
|
||||
|
||||
if !handle.handle.can_mut() {
|
||||
return false;
|
||||
|
|
|
@ -4,163 +4,57 @@ use super::{
|
|||
FusionElemWiseAutotuneKey,
|
||||
};
|
||||
use crate::{
|
||||
codegen::dialect::gpu::{Elem, Item, Operation, Vectorization, Visibility, WorkgroupSize},
|
||||
codegen::{ElemWiseKernelCodegen, InplaceMapping, Input, Output, ReadingStrategy},
|
||||
codegen::{
|
||||
dialect::gpu::{Vectorization, WorkgroupSize},
|
||||
Compilation, CompilationInfo, CompilationSettings,
|
||||
},
|
||||
compute::JitAutotuneKey,
|
||||
fusion::{kernel::FusionKernelSet, source::GpuKernelSource},
|
||||
fusion::{kernel::FusionKernelSet, source::GpuKernelSource, tracing::Trace},
|
||||
JitBackend, Runtime,
|
||||
};
|
||||
use burn_common::id::IdGenerator;
|
||||
use burn_compute::client::ComputeClient;
|
||||
use burn_fusion::{stream::Context, TensorDescription};
|
||||
use burn_fusion::stream::Context;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(new)]
|
||||
pub struct ElementWise<R: Runtime, Phase = ExecutionPhase<R>> {
|
||||
pub(super) inputs: Vec<(TensorDescription, Elem)>,
|
||||
pub(super) outputs: Vec<(TensorDescription, Elem)>,
|
||||
pub(super) locals: Vec<u16>,
|
||||
pub(super) scalars: Scalars,
|
||||
pub(super) operators: Vec<Operation>,
|
||||
pub(super) trace: Trace,
|
||||
pub(super) num_operations: usize,
|
||||
pub(super) device: R::Device,
|
||||
pub(super) phase: Phase,
|
||||
}
|
||||
|
||||
#[derive(new, Clone, Serialize, Deserialize)]
|
||||
pub struct Scalars {
|
||||
pub(super) num_f32: usize,
|
||||
pub(super) num_u32: usize,
|
||||
pub(super) num_i32: usize,
|
||||
}
|
||||
|
||||
/// Phase where the kernel should be compiled.
|
||||
pub struct CompilationPhase;
|
||||
|
||||
/// Phase where the kernel should be executed.
|
||||
#[derive(new)]
|
||||
pub struct ExecutionPhase<R: Runtime> {
|
||||
/// Kernel set with default workgroup size.
|
||||
pub(super) kernel_set_1: FusionKernelSet<R>,
|
||||
/// Kernel set with custom workgroup size.
|
||||
pub(super) kernel_set_2: FusionKernelSet<R>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(new, Serialize, Deserialize)]
|
||||
pub struct ElementWiseState {
|
||||
inputs: Vec<(TensorDescription, Elem)>,
|
||||
outputs: Vec<(TensorDescription, Elem)>,
|
||||
scalars: Scalars,
|
||||
operators: Vec<Operation>,
|
||||
locals: Vec<u16>,
|
||||
trace: Trace,
|
||||
num_operations: usize,
|
||||
}
|
||||
|
||||
impl<R: Runtime> ElementWise<R, CompilationPhase> {
|
||||
pub(crate) fn compile(self) -> ElementWise<R, ExecutionPhase<R>> {
|
||||
let mut inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|(_tensor, elem)| Input::Array {
|
||||
item: Item::Scalar(*elem),
|
||||
visibility: Visibility::Read,
|
||||
strategy: ReadingStrategy::OutputLayout,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let info = self.trace.compiling();
|
||||
|
||||
let outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.zip(self.locals.iter())
|
||||
.map(|((_tensor, elem), local)| Output::Array {
|
||||
item: Item::Scalar(*elem),
|
||||
local: *local,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if self.scalars.num_f32 > 0 {
|
||||
inputs.push(Input::Scalar {
|
||||
elem: Elem::Float,
|
||||
size: self.scalars.num_f32,
|
||||
})
|
||||
}
|
||||
|
||||
if self.scalars.num_u32 > 0 {
|
||||
inputs.push(Input::Scalar {
|
||||
elem: Elem::UInt,
|
||||
size: self.scalars.num_u32,
|
||||
})
|
||||
}
|
||||
|
||||
if self.scalars.num_i32 > 0 {
|
||||
inputs.push(Input::Scalar {
|
||||
elem: Elem::Int,
|
||||
size: self.scalars.num_i32,
|
||||
})
|
||||
}
|
||||
|
||||
let mut potential_inplace = self
|
||||
.inputs
|
||||
.iter()
|
||||
.zip(inputs.iter())
|
||||
.enumerate()
|
||||
.filter(|(_pos, ((desc, _elem), _input))| match desc.status {
|
||||
burn_fusion::TensorStatus::ReadOnly => false,
|
||||
burn_fusion::TensorStatus::ReadWrite => true,
|
||||
burn_fusion::TensorStatus::NotInit => false,
|
||||
})
|
||||
.map(|(pos, ((desc, elem), input))| (pos, desc, elem, input))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mappings = self
|
||||
.outputs
|
||||
.iter()
|
||||
.zip(outputs.iter())
|
||||
.enumerate()
|
||||
.filter_map(|(pos, ((desc, elem), _output))| {
|
||||
if potential_inplace.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut chosen = None;
|
||||
for (index, (_pos_input, desc_input, elem_input, _input)) in
|
||||
potential_inplace.iter().enumerate()
|
||||
{
|
||||
if chosen.is_some() {
|
||||
break;
|
||||
}
|
||||
if desc.shape == desc_input.shape && *elem_input == elem {
|
||||
chosen = Some(index);
|
||||
}
|
||||
}
|
||||
|
||||
match chosen {
|
||||
Some(index) => {
|
||||
let input = potential_inplace.remove(index);
|
||||
Some(InplaceMapping::new(input.0, pos))
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let kernel_set_1 = build_kernel_set::<R>(
|
||||
&inputs,
|
||||
&outputs,
|
||||
&self.operators,
|
||||
&mappings,
|
||||
WorkgroupSize::default(),
|
||||
);
|
||||
let kernel_set_2 = build_kernel_set::<R>(
|
||||
&inputs,
|
||||
&outputs,
|
||||
&self.operators,
|
||||
&mappings,
|
||||
WorkgroupSize::new(16, 16, 1),
|
||||
);
|
||||
let kernel_set_1 = build_kernel_set::<R>(&info, WorkgroupSize::default());
|
||||
let kernel_set_2 = build_kernel_set::<R>(&info, WorkgroupSize::new(16, 16, 1));
|
||||
|
||||
ElementWise {
|
||||
inputs: self.inputs,
|
||||
outputs: self.outputs,
|
||||
scalars: self.scalars,
|
||||
trace: self.trace,
|
||||
device: self.device,
|
||||
operators: self.operators,
|
||||
locals: self.locals,
|
||||
phase: ExecutionPhase::new(kernel_set_1, kernel_set_2),
|
||||
num_operations: self.num_operations,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -170,7 +64,7 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
|||
let client = R::client(&self.device);
|
||||
|
||||
let key = JitAutotuneKey::FusionElemWise(FusionElemWiseAutotuneKey::new(
|
||||
self.operators.len(),
|
||||
self.num_operations,
|
||||
self.autotune_shape(context),
|
||||
));
|
||||
|
||||
|
@ -187,22 +81,14 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
|||
client: ComputeClient<R::Server, R::Channel>,
|
||||
fastest_set_index: usize,
|
||||
) {
|
||||
let info = self.trace.running();
|
||||
let kernel_set = match fastest_set_index {
|
||||
0 => &self.phase.kernel_set_1,
|
||||
1 => &self.phase.kernel_set_2,
|
||||
_ => panic!("Should be 0 or 1, got {fastest_set_index}"),
|
||||
};
|
||||
|
||||
let kernel = kernel_set.select(
|
||||
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
self.scalars.num_f32,
|
||||
self.scalars.num_i32,
|
||||
context,
|
||||
self.device.clone(),
|
||||
client,
|
||||
true,
|
||||
);
|
||||
let kernel = kernel_set.select(&info, context, self.device.clone(), client, true);
|
||||
|
||||
kernel.execute();
|
||||
}
|
||||
|
@ -213,31 +99,24 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
|||
client: ComputeClient<R::Server, R::Channel>,
|
||||
key: JitAutotuneKey,
|
||||
) {
|
||||
let info = self.trace.running();
|
||||
|
||||
let kernel_1 = self.phase.kernel_set_1.select(
|
||||
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
self.scalars.num_f32,
|
||||
self.scalars.num_i32,
|
||||
&info,
|
||||
context,
|
||||
self.device.clone(),
|
||||
client.clone(),
|
||||
false, // Should not mutate the context.
|
||||
);
|
||||
let kernel_2 = self.phase.kernel_set_1.select(
|
||||
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
self.scalars.num_f32,
|
||||
self.scalars.num_i32,
|
||||
&info,
|
||||
context,
|
||||
self.device.clone(),
|
||||
client.clone(),
|
||||
false, // Should not mutate the context.
|
||||
);
|
||||
let kernel_default = self.phase.kernel_set_1.select(
|
||||
&self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
&self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
self.scalars.num_f32,
|
||||
self.scalars.num_i32,
|
||||
&info,
|
||||
context,
|
||||
self.device.clone(),
|
||||
client.clone(),
|
||||
|
@ -253,7 +132,7 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
|||
}
|
||||
|
||||
pub(crate) fn len(&self) -> usize {
|
||||
self.operators.len()
|
||||
self.num_operations
|
||||
}
|
||||
|
||||
/// The first output is chosen when possible, otherwise the first input is chosen.
|
||||
|
@ -261,13 +140,15 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
|||
&self,
|
||||
context: &mut Context<'a, JitBackend<R>>,
|
||||
) -> &'a [usize] {
|
||||
if let Some(tensor) = self.outputs.first() {
|
||||
let tensor = context.tensors.get(&tensor.0.id).unwrap();
|
||||
let info = self.trace.running();
|
||||
|
||||
if let Some(tensor) = info.outputs.first() {
|
||||
let tensor = context.tensors.get(&tensor.id).unwrap();
|
||||
return &tensor.shape;
|
||||
}
|
||||
|
||||
if let Some(tensor) = self.inputs.first() {
|
||||
let tensor = context.tensors.get(&tensor.0.id).unwrap();
|
||||
if let Some(tensor) = info.inputs.first() {
|
||||
let tensor = context.tensors.get(&tensor.id).unwrap();
|
||||
return &tensor.shape;
|
||||
}
|
||||
|
||||
|
@ -281,109 +162,86 @@ impl<R: Runtime> ElementWise<R, ExecutionPhase<R>> {
|
|||
// It is still unclear if the deserialization would be that much faster than
|
||||
// simply recompiling it.
|
||||
ElementWise {
|
||||
inputs: state.inputs,
|
||||
outputs: state.outputs,
|
||||
scalars: state.scalars,
|
||||
trace: state.trace,
|
||||
device: device.clone(),
|
||||
locals: state.locals,
|
||||
operators: state.operators,
|
||||
phase: CompilationPhase,
|
||||
num_operations: state.num_operations,
|
||||
}
|
||||
.compile()
|
||||
}
|
||||
|
||||
pub(crate) fn to_state(&self) -> ElementWiseState {
|
||||
ElementWiseState {
|
||||
inputs: self.inputs.clone(),
|
||||
outputs: self.outputs.clone(),
|
||||
scalars: self.scalars.clone(),
|
||||
operators: self.operators.clone(),
|
||||
locals: self.locals.clone(),
|
||||
trace: self.trace.clone(),
|
||||
num_operations: self.num_operations,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn build_kernel_set<R: Runtime>(
|
||||
inputs: &[Input],
|
||||
outputs: &[Output],
|
||||
operators: &[Operation],
|
||||
mappings: &[InplaceMapping],
|
||||
info: &CompilationInfo,
|
||||
workgroup_size: WorkgroupSize,
|
||||
) -> FusionKernelSet<R> {
|
||||
let scalar = ScalarElementWise::<R>::new(
|
||||
GpuKernelSource::new(
|
||||
IdGenerator::generate(),
|
||||
ElemWiseKernelCodegen::new()
|
||||
.inputs(inputs)
|
||||
.body(operators)
|
||||
.outputs(outputs)
|
||||
.workgroup_size(workgroup_size)
|
||||
.compile(),
|
||||
Compilation::new(info.clone())
|
||||
.compile(CompilationSettings::default().workgroup_size(workgroup_size)),
|
||||
),
|
||||
GpuKernelSource::new(
|
||||
IdGenerator::generate(),
|
||||
ElemWiseKernelCodegen::new()
|
||||
.inplace(mappings)
|
||||
.inputs(inputs)
|
||||
.body(operators)
|
||||
.outputs(outputs)
|
||||
.workgroup_size(workgroup_size)
|
||||
.compile(),
|
||||
Compilation::new(info.clone()).compile(
|
||||
CompilationSettings::default()
|
||||
.inplace(true)
|
||||
.workgroup_size(workgroup_size),
|
||||
),
|
||||
mappings.to_vec(),
|
||||
outputs.len(),
|
||||
),
|
||||
info.mappings.to_vec(),
|
||||
info.outputs.len(),
|
||||
);
|
||||
|
||||
let vec2 = VecElementWise::<R>::new(
|
||||
GpuKernelSource::new(
|
||||
IdGenerator::generate(),
|
||||
ElemWiseKernelCodegen::new()
|
||||
Compilation::new(info.clone()).compile(
|
||||
CompilationSettings::default()
|
||||
.vectorize(Vectorization::Vec2)
|
||||
.inputs(inputs)
|
||||
.body(operators)
|
||||
.outputs(outputs)
|
||||
.workgroup_size(workgroup_size)
|
||||
.compile(),
|
||||
.workgroup_size(workgroup_size),
|
||||
),
|
||||
),
|
||||
GpuKernelSource::new(
|
||||
IdGenerator::generate(),
|
||||
ElemWiseKernelCodegen::new()
|
||||
Compilation::new(info.clone()).compile(
|
||||
CompilationSettings::default()
|
||||
.inplace(true)
|
||||
.vectorize(Vectorization::Vec2)
|
||||
.inplace(mappings)
|
||||
.inputs(inputs)
|
||||
.body(operators)
|
||||
.outputs(outputs)
|
||||
.workgroup_size(workgroup_size)
|
||||
.compile(),
|
||||
.workgroup_size(workgroup_size),
|
||||
),
|
||||
mappings.to_vec(),
|
||||
outputs.len(),
|
||||
),
|
||||
info.mappings.to_vec(),
|
||||
info.outputs.len(),
|
||||
2,
|
||||
);
|
||||
let vec4 = VecElementWise::<R>::new(
|
||||
GpuKernelSource::new(
|
||||
IdGenerator::generate(),
|
||||
ElemWiseKernelCodegen::new()
|
||||
Compilation::new(info.clone()).compile(
|
||||
CompilationSettings::default()
|
||||
.vectorize(Vectorization::Vec4)
|
||||
.inputs(inputs)
|
||||
.body(operators)
|
||||
.outputs(outputs)
|
||||
.workgroup_size(workgroup_size)
|
||||
.compile(),
|
||||
.workgroup_size(workgroup_size),
|
||||
),
|
||||
),
|
||||
GpuKernelSource::new(
|
||||
IdGenerator::generate(),
|
||||
ElemWiseKernelCodegen::new()
|
||||
Compilation::new(info.clone()).compile(
|
||||
CompilationSettings::default()
|
||||
.inplace(true)
|
||||
.vectorize(Vectorization::Vec4)
|
||||
.inplace(mappings)
|
||||
.inputs(inputs)
|
||||
.body(operators)
|
||||
.outputs(outputs)
|
||||
.workgroup_size(workgroup_size)
|
||||
.compile(),
|
||||
.workgroup_size(workgroup_size),
|
||||
),
|
||||
mappings.to_vec(),
|
||||
outputs.len(),
|
||||
),
|
||||
info.mappings.to_vec(),
|
||||
info.outputs.len(),
|
||||
4,
|
||||
);
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::compute::Kernel;
|
||||
use crate::fusion::strides_dyn_rank;
|
||||
use crate::fusion::WgpuFusionHandle;
|
||||
use crate::fusion::JitFusionHandle;
|
||||
use crate::JitBackend;
|
||||
use crate::Runtime;
|
||||
use burn_compute::client::ComputeClient;
|
||||
|
@ -11,6 +11,8 @@ use burn_fusion::{TensorDescription, TensorStatus};
|
|||
use burn_tensor::Device;
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::tracing::ExecutionInfo;
|
||||
|
||||
/// Many kernels can be used for the same set of tensor operations fused into one.
|
||||
///
|
||||
/// This type makes it easy to group those potential kernels and execute the best one depending on the context.
|
||||
|
@ -104,14 +106,14 @@ pub trait FusionKernel<R: Runtime>: Send + Sync {
|
|||
/// Returns the priority of this kernel based on the input and output information.
|
||||
fn priority(
|
||||
&self,
|
||||
handles_inputs: &[WgpuFusionHandle<R>],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
) -> Priority;
|
||||
/// Returns a [selected kernel](SelectedKernel) that can be executed by the compute server.
|
||||
fn kernel(
|
||||
&self,
|
||||
handles_inputs: &[WgpuFusionHandle<R>],
|
||||
handles_inputs: &[JitFusionHandle<R>],
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
) -> SelectedKernel;
|
||||
|
@ -119,20 +121,21 @@ pub trait FusionKernel<R: Runtime>: Send + Sync {
|
|||
|
||||
impl<R: Runtime> FusionKernelSet<R> {
|
||||
/// Select the best kernel based on the given information.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn select(
|
||||
&self,
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
scalars_f32: usize,
|
||||
scalars_i32: usize,
|
||||
running_info: &ExecutionInfo<'_>,
|
||||
context: &mut Context<'_, JitBackend<R>>,
|
||||
device: Device<JitBackend<R>>,
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
stateful: bool,
|
||||
) -> ExecutableKernel<R> {
|
||||
let (handles_input, inputs_description_updated, outputs_description_updated) =
|
||||
process_inputs_outputs(inputs, outputs, context, stateful);
|
||||
process_inputs_outputs(
|
||||
&running_info.inputs,
|
||||
&running_info.outputs,
|
||||
context,
|
||||
stateful,
|
||||
);
|
||||
|
||||
let selected = self.select_kernel(
|
||||
&handles_input,
|
||||
|
@ -140,19 +143,27 @@ impl<R: Runtime> FusionKernelSet<R> {
|
|||
&outputs_description_updated,
|
||||
);
|
||||
|
||||
let rank_input = inputs.first().map(|desc| desc.shape.len()).unwrap_or(1);
|
||||
let rank_output = outputs.first().map(|desc| desc.shape.len()).unwrap_or(1);
|
||||
let rank_input = running_info
|
||||
.inputs
|
||||
.first()
|
||||
.map(|desc| desc.shape.len())
|
||||
.unwrap_or(1);
|
||||
let rank_output = running_info
|
||||
.outputs
|
||||
.first()
|
||||
.map(|desc| desc.shape.len())
|
||||
.unwrap_or(1);
|
||||
let rank = usize::max(rank_input, rank_output);
|
||||
|
||||
let num_tensors = inputs.len() + outputs.len();
|
||||
let num_tensors = running_info.inputs.len() + running_info.outputs.len();
|
||||
// The buffer starts with the rank, then each tensor shape and stride.
|
||||
let info_size = (num_tensors * rank * 2) + 1;
|
||||
|
||||
let mut num_handles = num_tensors + 1;
|
||||
if scalars_f32 > 0 {
|
||||
if running_info.scalars.num_float > 0 {
|
||||
num_handles += 1;
|
||||
}
|
||||
if scalars_i32 > 0 {
|
||||
if running_info.scalars.num_int > 0 {
|
||||
num_handles += 1;
|
||||
}
|
||||
|
||||
|
@ -175,7 +186,7 @@ impl<R: Runtime> FusionKernelSet<R> {
|
|||
// Use the input inplace for this output.
|
||||
OutputInfo::Inplace { input_index } => {
|
||||
let handle = handles.get(*input_index).unwrap().clone();
|
||||
let handle_fusion = WgpuFusionHandle {
|
||||
let handle_fusion = JitFusionHandle {
|
||||
client: client.clone(),
|
||||
device: device.clone(),
|
||||
strides: strides_dyn_rank(&tensor.shape),
|
||||
|
@ -185,7 +196,7 @@ impl<R: Runtime> FusionKernelSet<R> {
|
|||
}
|
||||
// Create a new buffer for this output.
|
||||
OutputInfo::Array { size } => {
|
||||
let handle_fusion = WgpuFusionHandle {
|
||||
let handle_fusion = JitFusionHandle {
|
||||
client: client.clone(),
|
||||
device: device.clone(),
|
||||
strides: strides_dyn_rank(&tensor.shape),
|
||||
|
@ -203,13 +214,16 @@ impl<R: Runtime> FusionKernelSet<R> {
|
|||
handles.push(client.create(bytemuck::cast_slice(&info)));
|
||||
|
||||
// Finally we finish with the named bindings.
|
||||
if scalars_f32 > 0 {
|
||||
handles
|
||||
.push(client.create(bytemuck::cast_slice(&context.scalar_floats[0..scalars_f32])));
|
||||
if running_info.scalars.num_float > 0 {
|
||||
handles.push(client.create(bytemuck::cast_slice(
|
||||
&context.scalar_floats[0..running_info.scalars.num_float],
|
||||
)));
|
||||
}
|
||||
|
||||
if scalars_i32 > 0 {
|
||||
handles.push(client.create(bytemuck::cast_slice(&context.scalar_ints[0..scalars_i32])));
|
||||
if running_info.scalars.num_int > 0 {
|
||||
handles.push(client.create(bytemuck::cast_slice(
|
||||
&context.scalar_ints[0..running_info.scalars.num_int],
|
||||
)));
|
||||
}
|
||||
|
||||
// We have to register the output handles to the context.
|
||||
|
@ -222,7 +236,7 @@ impl<R: Runtime> FusionKernelSet<R> {
|
|||
|
||||
fn select_kernel(
|
||||
&self,
|
||||
handles_input: &[WgpuFusionHandle<R>],
|
||||
handles_input: &[JitFusionHandle<R>],
|
||||
inputs: &[&TensorDescription],
|
||||
outputs: &[&TensorDescription],
|
||||
) -> SelectedKernel {
|
||||
|
@ -249,7 +263,7 @@ impl<R: Runtime> FusionKernelSet<R> {
|
|||
fn register_info_tensor<R: Runtime>(
|
||||
info: &mut Vec<u32>,
|
||||
tensor: &TensorDescription,
|
||||
handle: &WgpuFusionHandle<R>,
|
||||
handle: &JitFusionHandle<R>,
|
||||
) {
|
||||
if info.is_empty() {
|
||||
info.push(handle.strides.len() as u32);
|
||||
|
@ -269,7 +283,7 @@ fn process_inputs_outputs<'a, R: Runtime>(
|
|||
context: &'a mut Context<'_, JitBackend<R>>,
|
||||
stateful: bool,
|
||||
) -> (
|
||||
Vec<WgpuFusionHandle<R>>,
|
||||
Vec<JitFusionHandle<R>>,
|
||||
Vec<&'a TensorDescription>,
|
||||
Vec<&'a TensorDescription>,
|
||||
) {
|
||||
|
|
|
@ -3,6 +3,7 @@ mod elemwise;
|
|||
|
||||
pub(crate) mod kernel;
|
||||
pub(crate) mod source;
|
||||
pub(crate) mod tracing;
|
||||
|
||||
pub use base::*;
|
||||
pub(crate) use elemwise::*;
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
#[derive(Default, Clone, Serialize, Deserialize)]
|
||||
pub struct Scalars {
|
||||
pub(crate) num_float: usize,
|
||||
pub(crate) num_int: usize,
|
||||
pub(crate) num_uint: usize,
|
||||
pub(crate) num_bool: usize,
|
||||
}
|
|
@ -0,0 +1,352 @@
|
|||
use super::{trace::Trace, Scalars};
|
||||
use crate::codegen::dialect::gpu::{self, Operation, Variable};
|
||||
use burn_fusion::{TensorDescription, TensorId};
|
||||
use burn_tensor::Element;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// Type facilitating building a [trace](Trace) by doing most of the conversions between the
|
||||
/// operations provided in [burn_fusion] and the [gpu dialect](gpu).
|
||||
#[derive(Clone)]
|
||||
pub struct TraceBuilder {
|
||||
// Input tensor descriptions with the variables created after reading from global memory.
|
||||
inputs: Vec<(TensorDescription, Variable)>,
|
||||
// Each output tensor id with the output variable index created by the operation.
|
||||
output_to_local: HashMap<TensorId, u16>,
|
||||
tensors: HashMap<TensorId, (TensorDescription, gpu::Elem)>,
|
||||
scalars: Scalars,
|
||||
scope: gpu::Scope,
|
||||
}
|
||||
|
||||
impl TraceBuilder {
|
||||
/// Create a new builder.
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
inputs: Vec::new(),
|
||||
output_to_local: HashMap::new(),
|
||||
tensors: HashMap::new(),
|
||||
scalars: Scalars::default(),
|
||||
scope: gpu::Scope::root(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a [gpu operation](gpu::Operation).
|
||||
pub fn register_operation<T: Into<gpu::Operation>>(&mut self, value: T) {
|
||||
self.scope.register(value)
|
||||
}
|
||||
|
||||
/// Create a variable from an input [tensor description](TensorDescription).
|
||||
pub fn input(&mut self, tensor: &TensorDescription, elem: gpu::Elem) -> gpu::Variable {
|
||||
let already_exists = self.tensors.contains_key(&tensor.id);
|
||||
|
||||
let variable = match already_exists {
|
||||
false => {
|
||||
// New input
|
||||
let index = self.inputs.len() as u16;
|
||||
let item = gpu::Item::Scalar(elem);
|
||||
|
||||
let local = self.scope.read_array(index, item);
|
||||
self.inputs.push((tensor.clone(), local));
|
||||
local
|
||||
}
|
||||
true => match self.output_to_local.get(&tensor.id) {
|
||||
// Is a local variable.
|
||||
Some(local_index) => {
|
||||
gpu::Variable::Local(*local_index, gpu::Item::Scalar(elem), self.scope.depth)
|
||||
}
|
||||
// Isn't an operation output variable, so must be an existing input.
|
||||
None => self
|
||||
.inputs
|
||||
.iter()
|
||||
.find(|(input, _local)| input.id == tensor.id)
|
||||
.map(|(_, local)| *local)
|
||||
.unwrap(),
|
||||
},
|
||||
};
|
||||
|
||||
// Update the tensor description with the new version.
|
||||
self.tensors.insert(tensor.id, (tensor.clone(), elem));
|
||||
|
||||
variable
|
||||
}
|
||||
|
||||
/// Create a variable from an output [tensor description](TensorDescription).
|
||||
pub fn output(&mut self, tensor: &TensorDescription, elem: gpu::Elem) -> gpu::Variable {
|
||||
// Update the tensor description to the new version.
|
||||
self.tensors.insert(tensor.id, (tensor.clone(), elem));
|
||||
|
||||
// Output already registered as a local variable.
|
||||
if let Some(index) = self.output_to_local.get(&tensor.id) {
|
||||
return gpu::Variable::Local(*index, gpu::Item::Scalar(elem), self.scope.depth);
|
||||
}
|
||||
|
||||
let variable = self.scope.create_local(gpu::Item::Scalar(elem));
|
||||
let local_index = variable.index().unwrap();
|
||||
self.output_to_local.insert(tensor.id, local_index);
|
||||
variable
|
||||
}
|
||||
|
||||
/// Create a variable from an input [scalar](Element).
|
||||
pub fn scalar<E: Element>(&mut self, _value: &E, elem_type: gpu::Elem) -> gpu::Variable {
|
||||
match elem_type {
|
||||
gpu::Elem::Float => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_float as u16, elem_type);
|
||||
self.scalars.num_float += 1;
|
||||
var
|
||||
}
|
||||
gpu::Elem::Int => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_int as u16, elem_type);
|
||||
self.scalars.num_int += 1;
|
||||
var
|
||||
}
|
||||
gpu::Elem::UInt => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_uint as u16, elem_type);
|
||||
self.scalars.num_uint += 1;
|
||||
var
|
||||
}
|
||||
gpu::Elem::Bool => {
|
||||
let var = self
|
||||
.scope
|
||||
.read_scalar(self.scalars.num_bool as u16, elem_type);
|
||||
self.scalars.num_bool += 1;
|
||||
var
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Build the [trace](Trace).
|
||||
pub fn build(self) -> Trace {
|
||||
let inputs = self.input_descriptions();
|
||||
let outputs = self.output_descriptions();
|
||||
let locals = outputs
|
||||
.iter()
|
||||
.map(|out| *self.output_to_local.get(&out.0.id).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Trace::new(inputs, outputs, locals, self.scalars, self.scope)
|
||||
}
|
||||
|
||||
fn input_descriptions(&self) -> Vec<(TensorDescription, gpu::Elem)> {
|
||||
self.inputs
|
||||
.iter()
|
||||
.map(|(input, _local)| {
|
||||
let updated_tensor = self.tensors.get(&input.id).unwrap();
|
||||
updated_tensor.clone()
|
||||
})
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
|
||||
fn output_descriptions(&self) -> Vec<(TensorDescription, gpu::Elem)> {
|
||||
let mut outputs = Vec::new();
|
||||
let mut local_tensor_ids_input = Vec::new();
|
||||
let mut local_tensor_ids_output = Vec::new();
|
||||
|
||||
// Mark a variable to the provided list of tensor ids using the variable list.
|
||||
//
|
||||
// Only local variables can become outputs.
|
||||
let mark = |var: &gpu::Variable, list: &mut Vec<TensorId>| {
|
||||
if let gpu::Variable::Local(index, _, _) = var {
|
||||
if let Some((id, _)) = self
|
||||
.output_to_local
|
||||
.iter()
|
||||
.find(|(_id, position)| *position == index)
|
||||
{
|
||||
if !list.contains(id) {
|
||||
list.push(*id);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
let mark_binary =
|
||||
|op: &gpu::BinaryOperator, inputs: &mut Vec<TensorId>, outputs: &mut Vec<TensorId>| {
|
||||
mark(&op.lhs, inputs);
|
||||
mark(&op.rhs, inputs);
|
||||
mark(&op.out, outputs);
|
||||
};
|
||||
let mark_unary =
|
||||
|op: &gpu::UnaryOperator, inputs: &mut Vec<TensorId>, outputs: &mut Vec<TensorId>| {
|
||||
mark(&op.input, inputs);
|
||||
mark(&op.out, outputs);
|
||||
};
|
||||
|
||||
// For all operators, mark their local tensor id in the proper set.
|
||||
for op in self.scope.operations.iter() {
|
||||
match op {
|
||||
Operation::Operator(op) => {
|
||||
match op {
|
||||
gpu::Operator::AssignGlobal(_) => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
gpu::Operator::AssignLocal(op) => {
|
||||
mark(&op.out, &mut local_tensor_ids_output);
|
||||
}
|
||||
gpu::Operator::Add(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Index(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Sub(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Mul(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Div(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Exp(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Abs(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Erf(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Log(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Log1p(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Cos(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Sin(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Tanh(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Clamp(op) => {
|
||||
mark(&op.input, &mut local_tensor_ids_input);
|
||||
mark(&op.out, &mut local_tensor_ids_output);
|
||||
}
|
||||
gpu::Operator::Powf(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Recip(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Lower(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Greater(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::LowerEqual(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::GreaterEqual(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Equal(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::ConditionalAssign(op) => {
|
||||
mark(&op.cond, &mut local_tensor_ids_input);
|
||||
mark(&op.lhs, &mut local_tensor_ids_input);
|
||||
mark(&op.rhs, &mut local_tensor_ids_input);
|
||||
mark(&op.out, &mut local_tensor_ids_output);
|
||||
}
|
||||
gpu::Operator::Sqrt(op) => mark_unary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
gpu::Operator::Modulo(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
&mut local_tensor_ids_output,
|
||||
),
|
||||
}
|
||||
}
|
||||
Operation::Algorithm(algo) => {
|
||||
match algo {
|
||||
gpu::Algorithm::ReadGlobalWithLayout(_) => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
gpu::Algorithm::ReadGlobal(_) => {
|
||||
// Nothing to do here.
|
||||
}
|
||||
}
|
||||
}
|
||||
Operation::Metadata(_) => {
|
||||
// Nothing to do, should never impact read-write access to bindings.
|
||||
}
|
||||
Operation::Loop(_) => {
|
||||
// Nothing to do, should never impact read-write access to bindings.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All output tensors that are never read by a following operation should be written to
|
||||
// since they are essentially the "logical" output of the shader.
|
||||
for out in local_tensor_ids_output {
|
||||
let is_read = local_tensor_ids_input.contains(&out);
|
||||
|
||||
if !is_read {
|
||||
outputs.push(self.tensors.get(&out).unwrap().clone());
|
||||
}
|
||||
}
|
||||
|
||||
// All tensors where their latest description is read only should be written to since they
|
||||
// are going to be used after the fused kernel by other operations.
|
||||
for entry in self.tensors.values() {
|
||||
let (tensor, _) = &entry;
|
||||
if let burn_fusion::TensorStatus::ReadOnly = tensor.status {
|
||||
if self.output_to_local.contains_key(&tensor.id) {
|
||||
outputs.push(entry.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
outputs
|
||||
}
|
||||
}
|
|
@ -0,0 +1,7 @@
|
|||
mod base;
|
||||
mod builder;
|
||||
mod trace;
|
||||
|
||||
pub use base::*;
|
||||
pub use builder::*;
|
||||
pub use trace::*;
|
|
@ -0,0 +1,133 @@
|
|||
use super::Scalars;
|
||||
use crate::codegen::{dialect::gpu, CompilationInfo, InplaceMapping, InputInfo, OutputInfo};
|
||||
use burn_fusion::TensorDescription;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A trace encaptulates all information necessary to perform the compilation and execution of
|
||||
/// captured [tensor operations](burn_fusion::stream::OperationDescription).
|
||||
///
|
||||
/// A trace should be built using a [builder](super::TraceBuilder).
|
||||
#[derive(new, Clone, Serialize, Deserialize)]
|
||||
pub struct Trace {
|
||||
inputs: Vec<(TensorDescription, gpu::Elem)>,
|
||||
outputs: Vec<(TensorDescription, gpu::Elem)>,
|
||||
locals: Vec<u16>,
|
||||
scalars: Scalars,
|
||||
scope: gpu::Scope,
|
||||
}
|
||||
|
||||
/// Information necessary to execute a kernel.
|
||||
pub struct ExecutionInfo<'a> {
|
||||
/// Tensor inputs.
|
||||
pub inputs: Vec<&'a TensorDescription>,
|
||||
/// Tensor outputs.
|
||||
pub outputs: Vec<&'a TensorDescription>,
|
||||
/// Scalar inputs.
|
||||
pub scalars: &'a Scalars,
|
||||
}
|
||||
|
||||
impl Trace {
|
||||
/// Collect information related to running the trace.
|
||||
pub fn running(&self) -> ExecutionInfo<'_> {
|
||||
ExecutionInfo {
|
||||
inputs: self.inputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
outputs: self.outputs.iter().map(|a| &a.0).collect::<Vec<_>>(),
|
||||
scalars: &self.scalars,
|
||||
}
|
||||
}
|
||||
|
||||
/// Collect information related to compiling the trace.
|
||||
pub fn compiling(&self) -> CompilationInfo {
|
||||
let mut inputs = self
|
||||
.inputs
|
||||
.iter()
|
||||
.map(|(_tensor, elem)| InputInfo::Array {
|
||||
item: gpu::Item::Scalar(*elem),
|
||||
visibility: gpu::Visibility::Read,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let outputs = self
|
||||
.outputs
|
||||
.iter()
|
||||
.zip(self.locals.iter())
|
||||
.map(|((_tensor, elem), local)| OutputInfo::Array {
|
||||
item: gpu::Item::Scalar(*elem),
|
||||
local: *local,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if self.scalars.num_float > 0 {
|
||||
inputs.push(InputInfo::Scalar {
|
||||
elem: gpu::Elem::Float,
|
||||
size: self.scalars.num_float,
|
||||
})
|
||||
}
|
||||
|
||||
if self.scalars.num_uint > 0 {
|
||||
inputs.push(InputInfo::Scalar {
|
||||
elem: gpu::Elem::UInt,
|
||||
size: self.scalars.num_uint,
|
||||
})
|
||||
}
|
||||
|
||||
if self.scalars.num_int > 0 {
|
||||
inputs.push(InputInfo::Scalar {
|
||||
elem: gpu::Elem::Int,
|
||||
size: self.scalars.num_int,
|
||||
})
|
||||
}
|
||||
|
||||
let mut potential_inplace = self
|
||||
.inputs
|
||||
.iter()
|
||||
.zip(inputs.iter())
|
||||
.enumerate()
|
||||
.filter(|(_pos, ((desc, _elem), _input))| match desc.status {
|
||||
burn_fusion::TensorStatus::ReadOnly => false,
|
||||
burn_fusion::TensorStatus::ReadWrite => true,
|
||||
burn_fusion::TensorStatus::NotInit => false,
|
||||
})
|
||||
.map(|(pos, ((desc, elem), input))| (pos, desc, elem, input))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let mappings = self
|
||||
.outputs
|
||||
.iter()
|
||||
.zip(outputs.iter())
|
||||
.enumerate()
|
||||
.filter_map(|(pos, ((desc, elem), _output))| {
|
||||
if potential_inplace.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut chosen = None;
|
||||
for (index, (_pos_input, desc_input, elem_input, _input)) in
|
||||
potential_inplace.iter().enumerate()
|
||||
{
|
||||
if chosen.is_some() {
|
||||
break;
|
||||
}
|
||||
if desc.shape == desc_input.shape && *elem_input == elem {
|
||||
chosen = Some(index);
|
||||
}
|
||||
}
|
||||
|
||||
match chosen {
|
||||
Some(index) => {
|
||||
let input = potential_inplace.remove(index);
|
||||
Some(InplaceMapping::new(input.0, pos))
|
||||
}
|
||||
None => None,
|
||||
}
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
CompilationInfo {
|
||||
inputs,
|
||||
outputs,
|
||||
scope: self.scope.clone(),
|
||||
mappings,
|
||||
}
|
||||
}
|
||||
}
|
|
@ -49,6 +49,46 @@ macro_rules! binary {
|
|||
_o: core::marker::PhantomData<O>,
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
fn compile<C, I, O>(
|
||||
settings: $crate::codegen::CompilationSettings,
|
||||
mappings: Vec<$crate::codegen::InplaceMapping>,
|
||||
) -> $crate::kernel::SourceTemplate
|
||||
where
|
||||
C: $crate::codegen::Compiler,
|
||||
I: $crate::element::JitElement,
|
||||
O: $crate::element::JitElement
|
||||
{
|
||||
let mut scope = $crate::codegen::dialect::gpu::Scope::root();
|
||||
let op = $ops(&mut scope, I::gpu_elem());
|
||||
scope.register(op);
|
||||
|
||||
let local = scope.last_local_index().unwrap().index().unwrap();
|
||||
|
||||
let lhs = $crate::codegen::InputInfo::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
};
|
||||
let rhs = $crate::codegen::InputInfo::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
};
|
||||
let out = $crate::codegen::OutputInfo::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(O::gpu_elem()),
|
||||
local,
|
||||
};
|
||||
let info = $crate::codegen::CompilationInfo {
|
||||
inputs: vec![lhs, rhs],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
mappings,
|
||||
};
|
||||
let shader = $crate::codegen::Compilation::new(info).compile(settings);
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<C, I, O> $crate::kernel::StaticKernelSource for Ops<C, I, O>
|
||||
where
|
||||
|
@ -57,28 +97,8 @@ macro_rules! binary {
|
|||
O: $crate::element::JitElement
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(I::gpu_elem())])
|
||||
.outputs(&[$crate::codegen::Output::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(O::gpu_elem()),
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
let settings = $crate::codegen::CompilationSettings::default();
|
||||
compile::<C, I, O>(settings, Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -91,29 +111,13 @@ macro_rules! binary {
|
|||
O: $crate::element::JitElement
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::ReadWrite,
|
||||
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||
},
|
||||
$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(I::gpu_elem())])
|
||||
.outputs(&[$crate::codegen::Output::Input {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
input: 0,
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
let settings = $crate::codegen::CompilationSettings::default()
|
||||
.inplace(true);
|
||||
let mapping = $crate::codegen::InplaceMapping {
|
||||
pos_input: 0,
|
||||
pos_output: 0,
|
||||
};
|
||||
compile::<C, I, O>(settings, vec![mapping])
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -126,29 +130,13 @@ macro_rules! binary {
|
|||
O: $crate::element::JitElement
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::ReadWrite,
|
||||
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(I::gpu_elem())])
|
||||
.outputs(&[$crate::codegen::Output::Input {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(I::gpu_elem()),
|
||||
input: 1,
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
let settings = $crate::codegen::CompilationSettings::default()
|
||||
.inplace(true);
|
||||
let mapping = $crate::codegen::InplaceMapping {
|
||||
pos_input: 1,
|
||||
pos_output: 0,
|
||||
};
|
||||
compile::<C, I, O>(settings, vec![mapping])
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::unary;
|
||||
use crate::{
|
||||
codegen::dialect::gpu::{ClampOperation, Item, Operation, Variable},
|
||||
codegen::dialect::gpu::{ClampOperator, Operator, Scope},
|
||||
element::JitElement,
|
||||
tensor::JitTensor,
|
||||
unary, Runtime,
|
||||
|
@ -12,11 +12,11 @@ pub(crate) fn clamp<R: Runtime, E: JitElement, const D: usize>(
|
|||
max_value: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
unary!(
|
||||
operation: |elem| Operation::Clamp(ClampOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
min_value: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
max_value: Variable::Scalar(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem| Operator::Clamp(ClampOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
min_value: scope.read_scalar(0, elem),
|
||||
max_value: scope.read_scalar(1, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
compiler: R::Compiler,
|
||||
scalar 2
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use crate::{
|
||||
binary,
|
||||
codegen::dialect::gpu::{BinaryOperation, Elem, Item, Operation, Variable},
|
||||
codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope},
|
||||
element::JitElement,
|
||||
kernel::StaticKernelSource,
|
||||
kernel::{binary::binary, unary::unary},
|
||||
|
@ -51,10 +51,10 @@ pub fn equal<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operation::Equal(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
binary: |scope: &mut Scope, elem: Elem| Operator::Equal(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -67,10 +67,10 @@ pub fn greater<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operation::Greater(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
binary: |scope: &mut Scope, elem: Elem| Operator::Greater(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -83,10 +83,10 @@ pub fn greater_equal<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operation::GreaterEqual(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
binary: |scope: &mut Scope, elem: Elem| Operator::GreaterEqual(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -99,10 +99,10 @@ pub fn lower<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operation::Lower(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
binary: |scope: &mut Scope, elem: Elem| Operator::Lower(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -115,10 +115,10 @@ pub fn lower_equal<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
binary: |elem: Elem| Operation::LowerEqual(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
binary: |scope: &mut Scope, elem: Elem| Operator::LowerEqual(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -131,10 +131,10 @@ pub fn equal_elem<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: E,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operation::Equal(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
unary: |scope: &mut Scope, elem: Elem| Operator::Equal(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -147,10 +147,10 @@ pub fn greater_elem<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: E,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operation::Greater(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
unary: |scope: &mut Scope, elem: Elem| Operator::Greater(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -163,10 +163,10 @@ pub fn lower_elem<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: E,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operation::Lower(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
unary: |scope: &mut Scope, elem: Elem| Operator::Lower(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -179,10 +179,10 @@ pub fn greater_equal_elem<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: E,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operation::GreaterEqual(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
unary: |scope: &mut Scope, elem: Elem| Operator::GreaterEqual(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -195,10 +195,10 @@ pub fn lower_equal_elem<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: E,
|
||||
) -> JitTensor<R, u32, D> {
|
||||
comparison!(
|
||||
unary: |elem: Elem| Operation::LowerEqual(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(Elem::Bool)),
|
||||
unary: |scope: &mut Scope, elem: Elem| Operator::LowerEqual(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(Elem::Bool),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
|
|
@ -55,6 +55,42 @@ macro_rules! unary {
|
|||
_e: core::marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
fn compile<C, E>(
|
||||
settings: $crate::codegen::CompilationSettings,
|
||||
mappings: Vec<$crate::codegen::InplaceMapping>,
|
||||
) -> $crate::kernel::SourceTemplate
|
||||
where
|
||||
C: $crate::codegen::Compiler,
|
||||
E: $crate::element::JitElement
|
||||
{
|
||||
|
||||
let mut scope = $crate::codegen::dialect::gpu::Scope::root();
|
||||
let op = $ops(&mut scope, E::gpu_elem());
|
||||
scope.register(op);
|
||||
|
||||
let local = scope.last_local_index().unwrap().index().unwrap();
|
||||
|
||||
let input = $crate::codegen::InputInfo::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
};
|
||||
let out = $crate::codegen::OutputInfo::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
local,
|
||||
};
|
||||
let info = $crate::codegen::CompilationInfo {
|
||||
inputs: vec![input],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
mappings,
|
||||
};
|
||||
let shader = $crate::codegen::Compilation::new(info).compile(settings);
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<C, E> $crate::kernel::StaticKernelSource for Ops<C, E>
|
||||
where
|
||||
|
@ -62,21 +98,8 @@ macro_rules! unary {
|
|||
E: $crate::element::JitElement,
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
}])
|
||||
.body(&[$ops(E::gpu_elem())])
|
||||
.outputs(&[$crate::codegen::Output::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
let settings = $crate::codegen::CompilationSettings::default();
|
||||
compile::<C, E>(settings, Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -87,22 +110,13 @@ macro_rules! unary {
|
|||
E: $crate::element::JitElement,
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::ReadWrite,
|
||||
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||
}])
|
||||
.body(&[$ops(E::gpu_elem())])
|
||||
.outputs(&[$crate::codegen::Output::Input {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
input: 0,
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
let settings = $crate::codegen::CompilationSettings::default()
|
||||
.inplace(true);
|
||||
let mapping = $crate::codegen::InplaceMapping {
|
||||
pos_input: 0,
|
||||
pos_output: 0,
|
||||
};
|
||||
compile::<C, E>(settings, vec![mapping])
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -120,6 +134,46 @@ macro_rules! unary {
|
|||
_e: core::marker::PhantomData<E>,
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
fn compile<C, E>(
|
||||
settings: $crate::codegen::CompilationSettings,
|
||||
mappings: Vec<$crate::codegen::InplaceMapping>,
|
||||
) -> $crate::kernel::SourceTemplate
|
||||
where
|
||||
C: $crate::codegen::Compiler,
|
||||
E: $crate::element::JitElement
|
||||
{
|
||||
|
||||
let mut scope = $crate::codegen::dialect::gpu::Scope::root();
|
||||
let op = $ops(&mut scope, E::gpu_elem());
|
||||
scope.register(op);
|
||||
|
||||
let local = scope.last_local_index().unwrap().index().unwrap();
|
||||
|
||||
let input = $crate::codegen::InputInfo::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
};
|
||||
let scalars = $crate::codegen::InputInfo::Scalar {
|
||||
elem: E::gpu_elem(),
|
||||
size: $num,
|
||||
};
|
||||
let out = $crate::codegen::OutputInfo::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
local,
|
||||
};
|
||||
let info = $crate::codegen::CompilationInfo {
|
||||
inputs: vec![input, scalars],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
mappings,
|
||||
};
|
||||
let shader = $crate::codegen::Compilation::new(info).compile(settings);
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
}
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
impl<C, E> $crate::kernel::StaticKernelSource for Ops<C, E>
|
||||
where
|
||||
|
@ -127,27 +181,8 @@ macro_rules! unary {
|
|||
E: $crate::element::JitElement,
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::Read,
|
||||
strategy: $crate::codegen::ReadingStrategy::OutputLayout,
|
||||
},
|
||||
$crate::codegen::Input::Scalar {
|
||||
elem: E::gpu_elem(),
|
||||
size: $num,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(E::gpu_elem())])
|
||||
.outputs(&[$crate::codegen::Output::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
let settings = $crate::codegen::CompilationSettings::default();
|
||||
compile::<C, E>(settings, Vec::new())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -158,28 +193,13 @@ macro_rules! unary {
|
|||
E: $crate::element::JitElement,
|
||||
{
|
||||
fn source() -> $crate::kernel::SourceTemplate {
|
||||
let shader = $crate::codegen::ElemWiseKernelCodegen::new()
|
||||
.inputs(&[
|
||||
$crate::codegen::Input::Array {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
visibility: $crate::codegen::dialect::gpu::Visibility::ReadWrite,
|
||||
strategy: $crate::codegen::ReadingStrategy::Plain,
|
||||
},
|
||||
$crate::codegen::Input::Scalar {
|
||||
elem: E::gpu_elem(),
|
||||
size: $num,
|
||||
},
|
||||
])
|
||||
.body(&[$ops(E::gpu_elem())])
|
||||
.outputs(&[$crate::codegen::Output::Input {
|
||||
item: $crate::codegen::dialect::gpu::Item::Scalar(E::gpu_elem()),
|
||||
input: 0,
|
||||
local: 0,
|
||||
}])
|
||||
.compile();
|
||||
|
||||
let compiled = C::compile(shader);
|
||||
$crate::kernel::SourceTemplate::new(compiled.to_string())
|
||||
let settings = $crate::codegen::CompilationSettings::default()
|
||||
.inplace(true);
|
||||
let mapping = $crate::codegen::InplaceMapping {
|
||||
pos_input: 0,
|
||||
pos_output: 0,
|
||||
};
|
||||
compile::<C, E>(settings, vec![mapping])
|
||||
}
|
||||
}
|
||||
};
|
||||
|
@ -243,14 +263,14 @@ where
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::codegen::dialect::gpu::{Item, Operation, UnaryOperation, Variable};
|
||||
use crate::codegen::dialect::gpu::{Operator, Scope, UnaryOperator};
|
||||
use crate::tests::{ReferenceBackend, TestBackend, TestCompiler, TestRuntime};
|
||||
use burn_tensor::{Distribution, Tensor};
|
||||
|
||||
unary!(
|
||||
operation: |elem| Operation::Tanh(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem| Operator::Tanh(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
compiler: TestCompiler
|
||||
);
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
use super::numeric;
|
||||
use crate::codegen::dialect::gpu::{
|
||||
BinaryOperation, Elem, Item, Operation, UnaryOperation, Variable,
|
||||
};
|
||||
use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryOperator};
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
use crate::kernel::matmul::init_matmul_output;
|
||||
#[cfg(feature = "autotune")]
|
||||
|
@ -364,9 +362,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_exp<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Exp(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Exp(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
@ -376,9 +374,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_log<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Log(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Log(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
@ -388,9 +386,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_log1p<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Log1p(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Log1p(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
@ -403,10 +401,10 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
rhs: f32,
|
||||
) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Powf(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Powf(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs.elem(),
|
||||
|
@ -416,9 +414,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_sqrt<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Sqrt(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Sqrt(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
@ -428,9 +426,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_abs<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Abs(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Abs(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
@ -440,9 +438,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_cos<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Cos(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Cos(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
@ -452,9 +450,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_sin<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Sin(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Sin(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
@ -464,9 +462,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_tanh<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Tanh(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Tanh(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
@ -476,9 +474,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_erf<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Erf(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Erf(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
@ -521,9 +519,9 @@ impl<R: Runtime> FloatTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn float_recip<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Recip(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Recip(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::numeric;
|
||||
use crate::codegen::dialect::gpu::{Elem, Item, Operation, UnaryOperation, Variable};
|
||||
use crate::codegen::dialect::gpu::{Elem, Item, Operator, Scope, UnaryOperator};
|
||||
use crate::kernel::reduce::{self, init_reduce_output};
|
||||
use crate::{kernel, unary, JitBackend, Runtime};
|
||||
use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
|
||||
|
@ -281,9 +281,9 @@ impl<R: Runtime> IntTensorOps<Self> for JitBackend<R> {
|
|||
|
||||
fn int_abs<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Abs(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Abs(UnaryOperator {
|
||||
input: scope.read_array(0, Item::Scalar(elem)),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: tensor,
|
||||
|
|
|
@ -1,6 +1,4 @@
|
|||
use crate::codegen::dialect::gpu::{
|
||||
BinaryOperation, Elem, Item, Operation, UnaryOperation, Variable,
|
||||
};
|
||||
use crate::codegen::dialect::gpu::{BinaryOperator, Elem, Operator, Scope, UnaryOperator};
|
||||
use crate::{binary, Runtime};
|
||||
use crate::{element::JitElement, tensor::JitTensor, unary};
|
||||
use burn_compute::client::ComputeClient;
|
||||
|
@ -25,9 +23,9 @@ pub fn full_device<R: Runtime, E: JitElement, const D: usize>(
|
|||
let empty = empty_device(client, device, shape);
|
||||
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::AssignLocal(UnaryOperation {
|
||||
input: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::AssignLocal(UnaryOperator {
|
||||
input: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: empty; value,
|
||||
|
@ -84,10 +82,10 @@ pub fn add<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |elem: Elem| Operation::Add(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Add(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -100,10 +98,10 @@ pub fn add_scalar<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Add(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Add(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -116,10 +114,10 @@ pub fn sub<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |elem: Elem| Operation::Sub(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Sub(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -132,10 +130,10 @@ pub fn sub_scalar<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Sub(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Sub(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -148,10 +146,10 @@ pub fn mul<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |elem: Elem| Operation::Mul(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Mul(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -164,10 +162,10 @@ pub fn mul_scalar<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Mul(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Mul(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -180,10 +178,10 @@ pub fn div<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |elem: Elem| Operation::Div(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Div(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -196,10 +194,10 @@ pub fn div_scalar<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::Div(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Scalar(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Div(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_scalar(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
@ -212,10 +210,10 @@ pub fn pow<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
binary!(
|
||||
operation: |elem: Elem| Operation::Powf(BinaryOperation {
|
||||
lhs: Variable::Input(0, Item::Scalar(elem)),
|
||||
rhs: Variable::Input(1, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::Powf(BinaryOperator {
|
||||
lhs: scope.read_array(0, elem),
|
||||
rhs: scope.read_array(1, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: lhs; rhs,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::codegen::dialect::gpu::{Elem, Item, Operation, UnaryOperation, Variable};
|
||||
use crate::codegen::dialect::gpu::{Elem, Operator, Scope, UnaryOperator};
|
||||
use crate::element::JitElement;
|
||||
use crate::{unary, Runtime};
|
||||
use burn_compute::client::ComputeClient;
|
||||
|
@ -145,9 +145,9 @@ where
|
|||
//
|
||||
// The solution is just to use a simple unary compute shader.
|
||||
unary!(
|
||||
operation: |elem: Elem| Operation::AssignLocal(UnaryOperation {
|
||||
input: Variable::Input(0, Item::Scalar(elem)),
|
||||
out: Variable::Local(0, Item::Scalar(elem)),
|
||||
operation: |scope: &mut Scope, elem: Elem| Operator::AssignLocal(UnaryOperator {
|
||||
input: scope.read_array(0, elem),
|
||||
out: scope.create_local(elem),
|
||||
}),
|
||||
runtime: R,
|
||||
input: self.clone(),
|
||||
|
|
Loading…
Reference in New Issue