instruction set for reduce workgroup

This commit is contained in:
louisfd 2024-02-28 14:06:10 -05:00
parent 67de96130c
commit 2773f26433
23 changed files with 326 additions and 43 deletions

View File

@ -3,7 +3,10 @@ use crate::fusion::JitFusionHandle;
#[cfg(feature = "fusion")]
use burn_fusion::TensorDescription;
use super::{dialect::gpu, Compiler};
use super::{
dialect::gpu::{self},
Compiler,
};
use crate::{
codegen::dialect::gpu::{
Binding, ComputeShader, Elem, Item, Location, ReadingStrategy, Variable, Vectorization,

View File

@ -10,6 +10,8 @@ pub enum Branch {
IfElse(IfElse),
// A range loop.
RangeLoop(RangeLoop),
// A while loop.
WhileLoop(WhileLoop),
// A return statement.
Return,
// A break statement.
@ -37,6 +39,12 @@ pub struct RangeLoop {
pub scope: Scope,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct WhileLoop {
pub cond: Variable,
pub scope: Scope,
}
impl If {
/// Registers an if statement to the given scope.
pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, cond: Variable, func: F) {
@ -96,3 +104,15 @@ impl RangeLoop {
}));
}
}
impl WhileLoop {
/// Registers a while loop to the given scope.
pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, cond: Variable, func: F) {
let mut scope = parent_scope.child();
func(&mut scope);
let op = Self { cond, scope };
parent_scope.register(Branch::WhileLoop(op));
}
}

View File

@ -232,6 +232,10 @@ macro_rules! gpu {
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
};
// while (cond).then(|scope| { ... })
($scope:expr, while ($cond:expr).then($arg:expr)) => {
$crate::codegen::dialect::gpu::WhileLoop::register($scope, $cond.into(), $arg);
};
// if (cond).then(|scope| { ... })
($scope:expr, if ($cond:expr).then($arg:expr)) => {
$crate::codegen::dialect::gpu::If::register($scope, $cond.into(), $arg);

View File

@ -5,6 +5,7 @@ mod procedure;
mod processing;
mod scope;
mod shader;
mod synchronization;
mod variable;
mod vectorization;
@ -14,5 +15,6 @@ pub(crate) use operation::*;
pub(crate) use procedure::*;
pub(crate) use scope::*;
pub(crate) use shader::*;
pub(crate) use synchronization::*;
pub(crate) use variable::*;
pub(crate) use vectorization::*;

View File

@ -1,4 +1,4 @@
use super::{Branch, Procedure, Variable};
use super::{Branch, Procedure, Synchronization, Variable};
use serde::{Deserialize, Serialize};
/// All operations that can be used in a GPU compute shader.
@ -16,6 +16,7 @@ pub enum Operation {
Procedure(Procedure),
Metadata(Metadata),
Branch(Branch),
Synchronization(Synchronization),
}
/// All operators that can be used in a GPU compute shader.
@ -115,6 +116,12 @@ impl From<Branch> for Operation {
}
}
impl From<Synchronization> for Operation {
fn from(value: Synchronization) -> Self {
Self::Synchronization(value)
}
}
impl From<Metadata> for Operation {
fn from(val: Metadata) -> Self {
Operation::Metadata(val)

View File

@ -17,6 +17,7 @@ pub struct Scope {
pub depth: u8,
pub operations: Vec<Operation>,
locals: Vec<Variable>,
shared: Vec<Variable>,
reads_global: Vec<(Variable, ReadingStrategy, Variable)>,
index_offset_with_output_layout_position: Vec<usize>,
writes_global: Vec<(Variable, Variable)>,
@ -43,6 +44,7 @@ impl Scope {
depth: 0,
operations: Vec::new(),
locals: Vec::new(),
shared: Vec::new(),
reads_global: Vec::new(),
index_offset_with_output_layout_position: Vec::new(),
writes_global: Vec::new(),
@ -58,6 +60,7 @@ impl Scope {
gpu!(self, local = zero);
local
}
/// Create a local variable of the given [item type](Item).
pub fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
let item = item.into();
@ -191,6 +194,7 @@ impl Scope {
depth: self.depth + 1,
operations: Vec::new(),
locals: Vec::new(),
shared: Vec::new(),
reads_global: Vec::new(),
index_offset_with_output_layout_position: Vec::new(),
writes_global: Vec::new(),
@ -284,6 +288,10 @@ impl Scope {
self.reads_scalar.len() as u16
}
fn new_shared_index(&self) -> u16 {
self.shared.len() as u16
}
fn read_input_strategy(
&mut self,
index: u16,
@ -306,4 +314,13 @@ impl Scope {
self.locals.push(local);
local
}
/// Create a local variable of the given [item type](Item).
pub fn create_shared<I: Into<Item>>(&mut self, item: I, shared_memory_size: u32) -> Variable {
let item = item.into();
let index = self.new_shared_index();
let shared_memory = Variable::SharedMemory(index, item, shared_memory_size);
self.shared.push(shared_memory);
shared_memory
}
}

View File

@ -7,7 +7,6 @@ use super::Scope;
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
pub enum Location {
Storage,
#[allow(dead_code)]
Workgroup,
}
@ -86,6 +85,12 @@ impl Default for WorkgroupSize {
}
}
impl WorkgroupSize {
pub fn num_elements(&self) -> u32 {
self.x * self.y * self.z
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComputeShader {
pub inputs: Vec<Binding>,

View File

@ -0,0 +1,8 @@
use serde::{Deserialize, Serialize};
/// All synchronization types.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum Synchronization {
// A workgroup barrier
WorkgroupBarrier,
}

View File

@ -9,8 +9,12 @@ pub enum Variable {
Local(u16, Item, u8),
LocalScalar(u16, Elem, u8),
ConstantScalar(f64, Elem),
SharedMemory(u16, Item, u32),
Id,
InvocationIndex,
LocalInvocationIndex,
LocalInvocationIdX,
LocalInvocationIdY,
LocalInvocationIdZ,
WorkgroupIdX,
WorkgroupIdY,
WorkgroupIdZ,
@ -18,6 +22,12 @@ pub enum Variable {
GlobalInvocationIdY,
GlobalInvocationIdZ,
Rank,
WorkgroupSizeX,
WorkgroupSizeY,
WorkgroupSizeZ,
NumWorkgroupsX,
NumWorkgroupsY,
NumWorkgroupsZ,
}
impl Variable {
@ -35,8 +45,12 @@ impl Variable {
Variable::LocalScalar(idx, _, _) => Some(*idx),
Variable::GlobalOutputArray(idx, _) => Some(*idx),
Variable::ConstantScalar(_, _) => None,
Variable::SharedMemory(idx, _, _) => Some(*idx),
Variable::Id => None,
Variable::InvocationIndex => None,
Variable::LocalInvocationIndex => None,
Variable::LocalInvocationIdX => None,
Variable::LocalInvocationIdY => None,
Variable::LocalInvocationIdZ => None,
Variable::Rank => None,
Variable::WorkgroupIdX => None,
Variable::WorkgroupIdY => None,
@ -44,6 +58,12 @@ impl Variable {
Variable::GlobalInvocationIdX => None,
Variable::GlobalInvocationIdY => None,
Variable::GlobalInvocationIdZ => None,
Variable::WorkgroupSizeX => None,
Variable::WorkgroupSizeY => None,
Variable::WorkgroupSizeZ => None,
Variable::NumWorkgroupsX => None,
Variable::NumWorkgroupsY => None,
Variable::NumWorkgroupsZ => None,
}
}
pub fn item(&self) -> Item {
@ -54,15 +74,25 @@ impl Variable {
Variable::Local(_, item, _) => *item,
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
Variable::SharedMemory(_, item, _) => *item,
Variable::Id => Item::Scalar(Elem::UInt),
Variable::Rank => Item::Scalar(Elem::UInt),
Variable::InvocationIndex => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdX => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdY => Item::Scalar(Elem::UInt),
Variable::LocalInvocationIdZ => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdX => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdY => Item::Scalar(Elem::UInt),
Variable::WorkgroupIdZ => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdX => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdY => Item::Scalar(Elem::UInt),
Variable::GlobalInvocationIdZ => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeX => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeY => Item::Scalar(Elem::UInt),
Variable::WorkgroupSizeZ => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsX => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsY => Item::Scalar(Elem::UInt),
Variable::NumWorkgroupsZ => Item::Scalar(Elem::UInt),
}
}
}

View File

@ -26,6 +26,9 @@ impl Operation {
Operation::Branch(_) => panic!(
"A branch can't be vectorized, they should only be generated after vectorization."
),
Operation::Synchronization(_) => panic!(
"Synchronization instructions can't be vectorized, they should only be generated after vectorization."
),
}
}
}
@ -112,18 +115,32 @@ impl Variable {
Variable::GlobalOutputArray(index, item) => {
Variable::GlobalOutputArray(*index, item.vectorize(vectorize))
}
Variable::SharedMemory(index, item, size) => Variable::SharedMemory(
*index,
item.vectorize(vectorize),
item.vectorized_size(vectorize, *size),
),
Variable::ConstantScalar(_, _) => *self,
Variable::GlobalScalar(_, _) => *self,
Variable::Id => *self,
Variable::Rank => *self,
Variable::LocalScalar(_, _, _) => *self,
Variable::InvocationIndex => *self,
Variable::LocalInvocationIndex => *self,
Variable::LocalInvocationIdX => *self,
Variable::LocalInvocationIdY => *self,
Variable::LocalInvocationIdZ => *self,
Variable::WorkgroupIdX => *self,
Variable::WorkgroupIdY => *self,
Variable::WorkgroupIdZ => *self,
Variable::GlobalInvocationIdX => *self,
Variable::GlobalInvocationIdY => *self,
Variable::GlobalInvocationIdZ => *self,
Variable::WorkgroupSizeX => *self,
Variable::WorkgroupSizeY => *self,
Variable::WorkgroupSizeZ => *self,
Variable::NumWorkgroupsX => *self,
Variable::NumWorkgroupsY => *self,
Variable::NumWorkgroupsZ => *self,
}
}
}
@ -137,4 +154,13 @@ impl Item {
Vectorization::Scalar => Item::Scalar(self.elem()),
}
}
pub fn vectorized_size(&self, vectorize: Vectorization, size: u32) -> u32 {
match vectorize {
Vectorization::Vec4 => size / 4,
Vectorization::Vec3 => size / 3,
Vectorization::Vec2 => size / 2,
Vectorization::Scalar => size,
}
}
}

View File

@ -17,8 +17,12 @@ pub enum Variable {
elem: Elem,
scope_depth: u8,
},
SharedMemory(u16, Item, u32),
Id,
LocalInvocationIndex,
LocalInvocationIdX,
LocalInvocationIdY,
LocalInvocationIdZ,
Rank,
WorkgroupIdX,
WorkgroupIdY,
@ -26,6 +30,12 @@ pub enum Variable {
GlobalInvocationIdX,
GlobalInvocationIdY,
GlobalInvocationIdZ,
WorkgroupSizeX,
WorkgroupSizeY,
WorkgroupSizeZ,
NumWorkgroupsX,
NumWorkgroupsY,
NumWorkgroupsZ,
}
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
@ -62,9 +72,13 @@ impl Variable {
} => true,
Variable::Id => true,
Variable::LocalInvocationIndex => true,
Variable::LocalInvocationIdX => true,
Variable::LocalInvocationIdY => true,
Variable::LocalInvocationIdZ => true,
Variable::Rank => true,
Variable::GlobalInputArray(_, _) => false,
Variable::GlobalOutputArray(_, _) => false,
Variable::SharedMemory(_, _, _) => false,
Variable::Local {
index: _,
item: _,
@ -76,6 +90,12 @@ impl Variable {
Variable::GlobalInvocationIdX => true,
Variable::GlobalInvocationIdY => true,
Variable::GlobalInvocationIdZ => true,
Variable::WorkgroupSizeX => true,
Variable::WorkgroupSizeY => true,
Variable::WorkgroupSizeZ => true,
Variable::NumWorkgroupsX => true,
Variable::NumWorkgroupsY => true,
Variable::NumWorkgroupsZ => true,
}
}
pub fn index(&self, index: usize) -> IndexedVariable {
@ -89,6 +109,7 @@ impl Variable {
match self {
Self::GlobalInputArray(_, e) => *e,
Self::GlobalOutputArray(_, e) => *e,
Self::SharedMemory(_, e, _) => *e,
Self::Local {
index: _,
item,
@ -98,6 +119,9 @@ impl Variable {
Self::GlobalScalar(_, e, _) => Item::Scalar(*e),
Self::Id => Item::Scalar(Elem::U32),
Self::LocalInvocationIndex => Item::Scalar(Elem::U32),
Self::LocalInvocationIdX => Item::Scalar(Elem::U32),
Self::LocalInvocationIdY => Item::Scalar(Elem::U32),
Self::LocalInvocationIdZ => Item::Scalar(Elem::U32),
Self::Rank => Item::Scalar(Elem::U32),
Self::LocalScalar {
index: _,
@ -110,6 +134,12 @@ impl Variable {
Self::GlobalInvocationIdX => Item::Scalar(Elem::U32),
Self::GlobalInvocationIdY => Item::Scalar(Elem::U32),
Self::GlobalInvocationIdZ => Item::Scalar(Elem::U32),
Self::WorkgroupSizeX => Item::Scalar(Elem::U32),
Self::WorkgroupSizeY => Item::Scalar(Elem::U32),
Self::WorkgroupSizeZ => Item::Scalar(Elem::U32),
Self::NumWorkgroupsX => Item::Scalar(Elem::U32),
Self::NumWorkgroupsY => Item::Scalar(Elem::U32),
Self::NumWorkgroupsZ => Item::Scalar(Elem::U32),
}
}
pub fn elem(&self) -> Elem {
@ -184,8 +214,14 @@ impl Display for Variable {
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
}
Variable::ConstantScalar(number, elem) => f.write_fmt(format_args!("{elem}({number})")),
Variable::SharedMemory(number, _, _) => {
f.write_fmt(format_args!("shared_memory_{number}"))
}
Variable::Id => f.write_str("id"),
Variable::LocalInvocationIndex => f.write_str("local_idx"),
Variable::LocalInvocationIdX => f.write_str("local_invocation_id.x"),
Variable::LocalInvocationIdY => f.write_str("local_invocation_id.y"),
Variable::LocalInvocationIdZ => f.write_str("local_invocation_id.z"),
Variable::Rank => f.write_str("rank"),
Variable::WorkgroupIdX => f.write_str("workgroup_id.x"),
Variable::WorkgroupIdY => f.write_str("workgroup_id.y"),
@ -193,6 +229,12 @@ impl Display for Variable {
Variable::GlobalInvocationIdX => f.write_str("global_id.x"),
Variable::GlobalInvocationIdY => f.write_str("global_id.y"),
Variable::GlobalInvocationIdZ => f.write_str("global_id.z"),
Variable::WorkgroupSizeX => f.write_str("WORKGROUP_SIZE_X"),
Variable::WorkgroupSizeY => f.write_str("WORKGROUP_SIZE_Y"),
Variable::WorkgroupSizeZ => f.write_str("WORKGROUP_SIZE_Z"),
Variable::NumWorkgroupsX => f.write_str("num_workgroups.x"),
Variable::NumWorkgroupsY => f.write_str("num_workgroups.y"),
Variable::NumWorkgroupsZ => f.write_str("num_workgroups.z"),
}
}
}

View File

@ -1,4 +1,4 @@
use super::{shader::ComputeShader, Item};
use super::{shader::ComputeShader, Item, SharedMemory};
use crate::{
codegen::{
compiler,
@ -13,13 +13,16 @@ use std::marker::PhantomData;
pub struct Compiler<F: FloatElement, I: IntElement> {
num_inputs: usize,
num_outputs: usize,
invocation_index: bool,
local_invocation_index: bool,
local_invocation_id: bool,
global_invocation_id: bool,
workgroup_id: bool,
rank: bool,
id: bool,
stride: bool,
shape: bool,
num_workgroups: bool,
shared_memories: Vec<SharedMemory>,
_float: PhantomData<F>,
_int: PhantomData<I>,
}
@ -35,13 +38,16 @@ impl<F: FloatElement, I: IntElement> Default for Compiler<F, I> {
Self {
num_inputs: 0,
num_outputs: 0,
invocation_index: false,
local_invocation_index: false,
local_invocation_id: false,
global_invocation_id: false,
workgroup_id: false,
rank: false,
id: false,
stride: false,
shape: false,
num_workgroups: false,
shared_memories: Vec::default(),
_float: PhantomData,
_int: PhantomData,
}
@ -95,10 +101,12 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
.into_iter()
.map(|(name, binding)| (name, Self::compile_binding(binding)))
.collect(),
shared_memories: self.shared_memories.clone(),
workgroup_size: value.workgroup_size,
global_invocation_id: self.global_invocation_id || self.id,
local_invocation_index: self.invocation_index,
num_workgroups: self.id,
local_invocation_index: self.local_invocation_index,
local_invocation_id: self.local_invocation_id,
num_workgroups: self.id || self.num_workgroups,
workgroup_id: self.workgroup_id,
body,
extensions,
@ -147,6 +155,12 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
gpu::Variable::ConstantScalar(index, elem) => {
wgsl::Variable::ConstantScalar(index, Self::compile_elem(elem))
}
gpu::Variable::SharedMemory(index, item, size) => {
let item = Self::compile_item(item);
self.shared_memories
.push(SharedMemory::new(index, item, size));
wgsl::Variable::SharedMemory(index, item, size)
}
gpu::Variable::Id => {
self.id = true;
wgsl::Variable::Id
@ -155,10 +169,22 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
self.rank = true;
wgsl::Variable::Rank
}
gpu::Variable::InvocationIndex => {
self.invocation_index = true;
gpu::Variable::LocalInvocationIndex => {
self.local_invocation_index = true;
wgsl::Variable::LocalInvocationIndex
}
gpu::Variable::LocalInvocationIdX => {
self.local_invocation_id = true;
wgsl::Variable::LocalInvocationIdX
}
gpu::Variable::LocalInvocationIdY => {
self.local_invocation_id = true;
wgsl::Variable::LocalInvocationIdY
}
gpu::Variable::LocalInvocationIdZ => {
self.local_invocation_id = true;
wgsl::Variable::LocalInvocationIdZ
}
gpu::Variable::WorkgroupIdX => {
self.workgroup_id = true;
wgsl::Variable::WorkgroupIdX
@ -183,6 +209,21 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
self.global_invocation_id = true;
wgsl::Variable::GlobalInvocationIdZ
}
gpu::Variable::WorkgroupSizeX => wgsl::Variable::WorkgroupSizeX,
gpu::Variable::WorkgroupSizeY => wgsl::Variable::WorkgroupSizeY,
gpu::Variable::WorkgroupSizeZ => wgsl::Variable::WorkgroupSizeZ,
gpu::Variable::NumWorkgroupsX => {
self.num_workgroups = true;
wgsl::Variable::NumWorkgroupsX
}
gpu::Variable::NumWorkgroupsY => {
self.num_workgroups = true;
wgsl::Variable::NumWorkgroupsY
}
gpu::Variable::NumWorkgroupsZ => {
self.num_workgroups = true;
wgsl::Variable::NumWorkgroupsZ
}
}
}
@ -215,6 +256,7 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
gpu::Operation::Procedure(proc) => self.compile_procedure(instructions, proc, scope),
gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)),
gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
gpu::Operation::Synchronization(val) => self.compile_synchronization(instructions, val),
}
}
@ -239,6 +281,22 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
instructions: self.compile_scope(&mut range_loop.scope),
})
}
gpu::Branch::WhileLoop(mut op) => instructions.push(wgsl::Instruction::WhileLoop {
cond: self.compile_variable(op.cond),
instructions: self.compile_scope(&mut op.scope),
}),
};
}
fn compile_synchronization(
&mut self,
instructions: &mut Vec<wgsl::Instruction>,
synchronization: gpu::Synchronization,
) {
match synchronization {
gpu::Synchronization::WorkgroupBarrier => {
instructions.push(wgsl::Instruction::WorkgroupBarrier)
}
};
}

View File

@ -24,6 +24,7 @@ pub enum Instruction {
},
Return,
Break,
WorkgroupBarrier,
// Index handles casting to correct local variable.
Index {
lhs: Variable,
@ -157,6 +158,10 @@ pub enum Instruction {
end: Variable,
instructions: Vec<Instruction>,
},
WhileLoop {
cond: Variable,
instructions: Vec<Instruction>,
},
}
impl Display for Instruction {
@ -371,9 +376,17 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
}
Instruction::Return => f.write_str("return;\n"),
Instruction::Break => f.write_str("break;\n"),
Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"),
Instruction::ArrayLength { var, out } => {
f.write_fmt(format_args!("{out} = arrayLength(&{var});\n"))
}
Instruction::WhileLoop { cond, instructions } => {
f.write_fmt(format_args!("while {cond} {{\n"))?;
for i in instructions {
f.write_fmt(format_args!("{i}"))?;
}
f.write_str("}\n")
}
}
}
}

View File

@ -23,14 +23,35 @@ pub struct Binding {
pub size: Option<usize>,
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct SharedMemory {
location: Location,
index: u16,
item: Item,
size: u32,
}
impl SharedMemory {
pub fn new(index: u16, item: Item, size: u32) -> Self {
Self {
location: Location::Workgroup,
index,
item,
size,
}
}
}
#[derive(Debug, Clone)]
pub struct ComputeShader {
pub inputs: Vec<Binding>,
pub outputs: Vec<Binding>,
pub named: Vec<(String, Binding)>,
pub shared_memories: Vec<SharedMemory>,
pub workgroup_size: WorkgroupSize,
pub global_invocation_id: bool,
pub local_invocation_index: bool,
pub local_invocation_id: bool,
pub num_workgroups: bool,
pub workgroup_id: bool,
pub body: Body,
@ -51,6 +72,13 @@ impl Display for ComputeShader {
)?;
}
for shared_memory in self.shared_memories.iter() {
f.write_fmt(format_args!(
"var<{}> shared_memory_{}: array<{}, {}>;\n\n",
shared_memory.location, shared_memory.index, shared_memory.item, shared_memory.size
))?;
}
f.write_fmt(format_args!(
"const WORKGROUP_SIZE_X = {}u;
const WORKGROUP_SIZE_Y = {}u;
@ -75,6 +103,10 @@ fn main(
f.write_str(" @builtin(local_invocation_index) local_idx: u32,\n")?;
}
if self.local_invocation_id {
f.write_str(" @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n")?;
}
if self.num_workgroups {
f.write_str(" @builtin(num_workgroups) num_workgroups: vec3<u32>,\n")?;
}

View File

@ -186,6 +186,7 @@ where
}
let source = kernel.source().complete();
println!("{}", source);
let pipeline = self.compile_source(&source);
self.pipelines.insert(kernel_id.clone(), pipeline.clone());

View File

@ -329,6 +329,9 @@ impl TraceBuilder {
Operation::Branch(_) => {
// Nothing to do, should never impact read-write access to bindings.
}
Operation::Synchronization(_) => {
// Nothing to do, should never impact read-write access to bindings.
}
}
}

View File

@ -30,7 +30,7 @@ struct MatmulComputeShader {
impl MatmulComputeShader {
fn expand(self, scope: &mut Scope) {
// Define out global variables.
let local_idx = Variable::InvocationIndex;
let local_idx = Variable::LocalInvocationIndex;
let batch = Variable::GlobalInvocationIdZ;
let rank = Variable::Rank;
let block_size: Variable = self.block_size.into();

View File

@ -1,11 +1,12 @@
mod base;
mod naive_reduce_shader;
mod reduction;
mod reduction_shared_memory;
mod shader;
mod tune;
mod workgroup_reduce_shader;
pub use base::*;
pub(crate) use naive_reduce_shader::*;
pub use reduction::*;
pub use reduction_shared_memory::*;
pub(crate) use shader::*;
pub use tune::*;

View File

@ -10,16 +10,18 @@ use crate::{
Runtime,
};
pub(crate) trait ReduceDim: Send + Sync + 'static {
pub(crate) trait NaiveReduceDim: Send + Sync + 'static {
type Accumulator: Copy;
fn initialize(scope: &mut Scope, input_item: Item, output_item: Item) -> Self::Accumulator;
fn inner_loop(
scope: &mut Scope,
accumulator: Self::Accumulator,
current_value: Variable,
i: Variable,
);
fn assign(
scope: &mut Scope,
output: Variable,
@ -30,14 +32,14 @@ pub(crate) trait ReduceDim: Send + Sync + 'static {
pub(crate) struct SumDim;
impl ReduceDim for SumDim {
impl NaiveReduceDim for SumDim {
type Accumulator = Variable;
fn initialize(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
scope.zero(output_item)
}
fn inner_loop(scope: &mut Scope, accumulator: Variable, value: Variable, i: Variable) {
fn inner_loop(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
gpu!(scope, accumulator += value);
}
@ -54,7 +56,7 @@ impl ReduceDim for SumDim {
pub(crate) struct MeanDim;
impl ReduceDim for MeanDim {
impl NaiveReduceDim for MeanDim {
type Accumulator = Variable;
fn initialize(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
@ -81,10 +83,10 @@ impl ReduceDim for MeanDim {
pub(crate) struct ArgMax;
impl ReduceDim for ArgMax {
impl NaiveReduceDim for ArgMax {
type Accumulator = (Variable, Variable);
fn initialize(scope: &mut Scope, input_item: Item, output_item: Item) -> Self::Accumulator {
fn initialize(scope: &mut Scope, input_item: Item, _output_item: Item) -> Self::Accumulator {
let max = scope.create_local(input_item);
let index = scope.create_local(Elem::UInt);
gpu!(scope, max = cast(-32767.0));
@ -118,10 +120,10 @@ impl ReduceDim for ArgMax {
pub(crate) struct ArgMin;
impl ReduceDim for ArgMin {
impl NaiveReduceDim for ArgMin {
type Accumulator = (Variable, Variable);
fn initialize(scope: &mut Scope, input_item: Item, output_item: Item) -> Self::Accumulator {
fn initialize(scope: &mut Scope, input_item: Item, _output_item: Item) -> Self::Accumulator {
let min = scope.create_local(input_item);
let index = scope.create_local(Elem::UInt);
gpu!(scope, min = cast(32767.0));
@ -153,15 +155,20 @@ impl ReduceDim for ArgMin {
}
}
pub(crate) struct ReduceDimComputeShader<RD: ReduceDim> {
pub(crate) struct NaiveReduceDimComputeShader<RD: NaiveReduceDim> {
tensor: Variable,
dim: usize,
output: Variable,
reduce_dim: PhantomData<RD>,
_reduce_dim: PhantomData<RD>,
}
#[derive(new)]
pub(crate) struct ReduceDimEagerKernel<RD: ReduceDim, R: Runtime, EI: JitElement, EO: JitElement> {
pub(crate) struct NaiveReduceDimEagerKernel<
RD: NaiveReduceDim,
R: Runtime,
EI: JitElement,
EO: JitElement,
> {
dim: usize,
reduce_dim: PhantomData<RD>,
_runtime: PhantomData<R>,
@ -169,8 +176,8 @@ pub(crate) struct ReduceDimEagerKernel<RD: ReduceDim, R: Runtime, EI: JitElement
_elem_out: PhantomData<EO>,
}
impl<RD: ReduceDim, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSource
for ReduceDimEagerKernel<RD, R, EI, EO>
impl<RD: NaiveReduceDim, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSource
for NaiveReduceDimEagerKernel<RD, R, EI, EO>
{
fn source(&self) -> crate::kernel::SourceTemplate {
let mut scope = Scope::root();
@ -180,11 +187,11 @@ impl<RD: ReduceDim, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSou
let tensor = Variable::GlobalInputArray(0, item_input);
let output = Variable::GlobalOutputArray(0, item_output);
ReduceDimComputeShader {
NaiveReduceDimComputeShader {
tensor,
dim: self.dim,
output,
reduce_dim: PhantomData::<RD>,
_reduce_dim: PhantomData::<RD>,
}
.expand(&mut scope);
@ -214,7 +221,7 @@ impl<RD: ReduceDim, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSou
}
}
impl<RD: ReduceDim> ReduceDimComputeShader<RD> {
impl<RD: NaiveReduceDim> NaiveReduceDimComputeShader<RD> {
pub(crate) fn expand(self, scope: &mut Scope) {
let tensor = self.tensor;
let dim: Variable = self.dim.into();

View File

@ -1,13 +1,16 @@
use crate::{
codegen::{execute_dynamic, EagerHandle, WorkgroupLaunch},
element::JitElement,
kernel::reduce,
tensor::JitTensor,
Runtime,
};
use burn_tensor::Shape;
use super::{init_reduce_output, ArgMax, ArgMin, MeanDim, ReduceDim, ReduceDimEagerKernel, SumDim};
#[cfg(not(feature = "autotune"))]
use super::init_reduce_output;
#[cfg(feature = "autotune")]
use super::tune::{mean_dim_autotune, sum_dim_autotune};
use super::{ArgMax, ArgMin, MeanDim, NaiveReduceDim, NaiveReduceDimEagerKernel, SumDim};
/// Sum all elements in the input buffer.
pub fn sum<R: Runtime, E: JitElement, const D: usize>(
@ -18,13 +21,14 @@ pub fn sum<R: Runtime, E: JitElement, const D: usize>(
sum_dim(input, 0)
}
/// Sum all elements on one dimension. Autotunable
pub fn sum_dim<R: Runtime, E: JitElement, const D: usize>(
tensor: JitTensor<R, E, D>,
dim: usize,
) -> JitTensor<R, E, D> {
#[cfg(feature = "autotune")]
{
reduce::sum_dim_autotune(tensor, dim)
sum_dim_autotune(tensor, dim)
}
#[cfg(not(feature = "autotune"))]
@ -34,13 +38,14 @@ pub fn sum_dim<R: Runtime, E: JitElement, const D: usize>(
}
}
/// Mean of all elements on one dimension. Autotunable
pub fn mean_dim<R: Runtime, E: JitElement, const D: usize>(
tensor: JitTensor<R, E, D>,
dim: usize,
) -> JitTensor<R, E, D> {
#[cfg(feature = "autotune")]
{
reduce::mean_dim_autotune(tensor, dim)
mean_dim_autotune(tensor, dim)
}
#[cfg(not(feature = "autotune"))]
@ -60,7 +65,7 @@ pub fn sum_dim_naive<R: Runtime, E: JitElement, const D: usize>(
}
pub(crate) fn reduce_dim_naive<
RD: ReduceDim,
RD: NaiveReduceDim,
R: Runtime,
EI: JitElement,
EO: JitElement,
@ -70,9 +75,9 @@ pub(crate) fn reduce_dim_naive<
output: JitTensor<R, EO, D>,
dim: usize,
) -> JitTensor<R, EO, D> {
let kernel = ReduceDimEagerKernel::new(dim);
let kernel = NaiveReduceDimEagerKernel::new(dim);
execute_dynamic::<R, ReduceDimEagerKernel<RD, R, EI, EO>, EI>(
execute_dynamic::<R, NaiveReduceDimEagerKernel<RD, R, EI, EO>, EI>(
&[EagerHandle::new(
&input.handle,
&input.strides,

View File

@ -116,6 +116,7 @@ fn reduction_dim_shared_memory<K: StaticKernelSource, R: Runtime, E: JitElement,
let n_invocation_per_workgroup = WORKGROUP_DEFAULT * WORKGROUP_DEFAULT;
let n_reduce_elements_per_thread =
f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32;
// n_reduce_elements_per_thread: u32(ceil(f32(shape_dim)/f32(WGD*WGD)))
// Add dimension of reduction and how many reduce elements are treated per thread
info.push(reduce_dim as u32);

View File

@ -7,7 +7,7 @@ use crate::{
kernel::{
prng::{random_like_uniform, random_like_uniform_int},
reduce::{
init_reduce_output, int_mean_dim_naive, int_mean_dim_shared_memory, mean_dim, mean_dim_naive,
init_reduce_output, int_mean_dim_naive, int_mean_dim_shared_memory, mean_dim_naive,
mean_dim_shared_memory,
},
},

View File

@ -7,8 +7,6 @@ use crate::kernel::matmul::matmul_autotune;
#[cfg(not(feature = "autotune"))]
use crate::kernel::matmul::vec4::matmul_tiling_2d_vec4;
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
// #[cfg(not(feature = "autotune"))]
use crate::kernel::reduce::init_reduce_output;
use crate::kernel::{self, reduce};
use crate::tensor::JitTensor;
use crate::Runtime;