mirror of https://github.com/tracel-ai/burn.git
instruction set for reduce workgroup
This commit is contained in:
parent
67de96130c
commit
2773f26433
|
@ -3,7 +3,10 @@ use crate::fusion::JitFusionHandle;
|
|||
#[cfg(feature = "fusion")]
|
||||
use burn_fusion::TensorDescription;
|
||||
|
||||
use super::{dialect::gpu, Compiler};
|
||||
use super::{
|
||||
dialect::gpu::{self},
|
||||
Compiler,
|
||||
};
|
||||
use crate::{
|
||||
codegen::dialect::gpu::{
|
||||
Binding, ComputeShader, Elem, Item, Location, ReadingStrategy, Variable, Vectorization,
|
||||
|
|
|
@ -10,6 +10,8 @@ pub enum Branch {
|
|||
IfElse(IfElse),
|
||||
// A range loop.
|
||||
RangeLoop(RangeLoop),
|
||||
// A while loop.
|
||||
WhileLoop(WhileLoop),
|
||||
// A return statement.
|
||||
Return,
|
||||
// A break statement.
|
||||
|
@ -37,6 +39,12 @@ pub struct RangeLoop {
|
|||
pub scope: Scope,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct WhileLoop {
|
||||
pub cond: Variable,
|
||||
pub scope: Scope,
|
||||
}
|
||||
|
||||
impl If {
|
||||
/// Registers an if statement to the given scope.
|
||||
pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, cond: Variable, func: F) {
|
||||
|
@ -96,3 +104,15 @@ impl RangeLoop {
|
|||
}));
|
||||
}
|
||||
}
|
||||
|
||||
impl WhileLoop {
|
||||
/// Registers a while loop to the given scope.
|
||||
pub fn register<F: Fn(&mut Scope)>(parent_scope: &mut Scope, cond: Variable, func: F) {
|
||||
let mut scope = parent_scope.child();
|
||||
|
||||
func(&mut scope);
|
||||
|
||||
let op = Self { cond, scope };
|
||||
parent_scope.register(Branch::WhileLoop(op));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -232,6 +232,10 @@ macro_rules! gpu {
|
|||
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
|
||||
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
|
||||
};
|
||||
// while (cond).then(|scope| { ... })
|
||||
($scope:expr, while ($cond:expr).then($arg:expr)) => {
|
||||
$crate::codegen::dialect::gpu::WhileLoop::register($scope, $cond.into(), $arg);
|
||||
};
|
||||
// if (cond).then(|scope| { ... })
|
||||
($scope:expr, if ($cond:expr).then($arg:expr)) => {
|
||||
$crate::codegen::dialect::gpu::If::register($scope, $cond.into(), $arg);
|
||||
|
|
|
@ -5,6 +5,7 @@ mod procedure;
|
|||
mod processing;
|
||||
mod scope;
|
||||
mod shader;
|
||||
mod synchronization;
|
||||
mod variable;
|
||||
mod vectorization;
|
||||
|
||||
|
@ -14,5 +15,6 @@ pub(crate) use operation::*;
|
|||
pub(crate) use procedure::*;
|
||||
pub(crate) use scope::*;
|
||||
pub(crate) use shader::*;
|
||||
pub(crate) use synchronization::*;
|
||||
pub(crate) use variable::*;
|
||||
pub(crate) use vectorization::*;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{Branch, Procedure, Variable};
|
||||
use super::{Branch, Procedure, Synchronization, Variable};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// All operations that can be used in a GPU compute shader.
|
||||
|
@ -16,6 +16,7 @@ pub enum Operation {
|
|||
Procedure(Procedure),
|
||||
Metadata(Metadata),
|
||||
Branch(Branch),
|
||||
Synchronization(Synchronization),
|
||||
}
|
||||
|
||||
/// All operators that can be used in a GPU compute shader.
|
||||
|
@ -115,6 +116,12 @@ impl From<Branch> for Operation {
|
|||
}
|
||||
}
|
||||
|
||||
impl From<Synchronization> for Operation {
|
||||
fn from(value: Synchronization) -> Self {
|
||||
Self::Synchronization(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Metadata> for Operation {
|
||||
fn from(val: Metadata) -> Self {
|
||||
Operation::Metadata(val)
|
||||
|
|
|
@ -17,6 +17,7 @@ pub struct Scope {
|
|||
pub depth: u8,
|
||||
pub operations: Vec<Operation>,
|
||||
locals: Vec<Variable>,
|
||||
shared: Vec<Variable>,
|
||||
reads_global: Vec<(Variable, ReadingStrategy, Variable)>,
|
||||
index_offset_with_output_layout_position: Vec<usize>,
|
||||
writes_global: Vec<(Variable, Variable)>,
|
||||
|
@ -43,6 +44,7 @@ impl Scope {
|
|||
depth: 0,
|
||||
operations: Vec::new(),
|
||||
locals: Vec::new(),
|
||||
shared: Vec::new(),
|
||||
reads_global: Vec::new(),
|
||||
index_offset_with_output_layout_position: Vec::new(),
|
||||
writes_global: Vec::new(),
|
||||
|
@ -58,6 +60,7 @@ impl Scope {
|
|||
gpu!(self, local = zero);
|
||||
local
|
||||
}
|
||||
|
||||
/// Create a local variable of the given [item type](Item).
|
||||
pub fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
|
||||
let item = item.into();
|
||||
|
@ -191,6 +194,7 @@ impl Scope {
|
|||
depth: self.depth + 1,
|
||||
operations: Vec::new(),
|
||||
locals: Vec::new(),
|
||||
shared: Vec::new(),
|
||||
reads_global: Vec::new(),
|
||||
index_offset_with_output_layout_position: Vec::new(),
|
||||
writes_global: Vec::new(),
|
||||
|
@ -284,6 +288,10 @@ impl Scope {
|
|||
self.reads_scalar.len() as u16
|
||||
}
|
||||
|
||||
fn new_shared_index(&self) -> u16 {
|
||||
self.shared.len() as u16
|
||||
}
|
||||
|
||||
fn read_input_strategy(
|
||||
&mut self,
|
||||
index: u16,
|
||||
|
@ -306,4 +314,13 @@ impl Scope {
|
|||
self.locals.push(local);
|
||||
local
|
||||
}
|
||||
|
||||
/// Create a local variable of the given [item type](Item).
|
||||
pub fn create_shared<I: Into<Item>>(&mut self, item: I, shared_memory_size: u32) -> Variable {
|
||||
let item = item.into();
|
||||
let index = self.new_shared_index();
|
||||
let shared_memory = Variable::SharedMemory(index, item, shared_memory_size);
|
||||
self.shared.push(shared_memory);
|
||||
shared_memory
|
||||
}
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ use super::Scope;
|
|||
#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
|
||||
pub enum Location {
|
||||
Storage,
|
||||
#[allow(dead_code)]
|
||||
Workgroup,
|
||||
}
|
||||
|
||||
|
@ -86,6 +85,12 @@ impl Default for WorkgroupSize {
|
|||
}
|
||||
}
|
||||
|
||||
impl WorkgroupSize {
|
||||
pub fn num_elements(&self) -> u32 {
|
||||
self.x * self.y * self.z
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ComputeShader {
|
||||
pub inputs: Vec<Binding>,
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// All synchronization types.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub enum Synchronization {
|
||||
// A workgroup barrier
|
||||
WorkgroupBarrier,
|
||||
}
|
|
@ -9,8 +9,12 @@ pub enum Variable {
|
|||
Local(u16, Item, u8),
|
||||
LocalScalar(u16, Elem, u8),
|
||||
ConstantScalar(f64, Elem),
|
||||
SharedMemory(u16, Item, u32),
|
||||
Id,
|
||||
InvocationIndex,
|
||||
LocalInvocationIndex,
|
||||
LocalInvocationIdX,
|
||||
LocalInvocationIdY,
|
||||
LocalInvocationIdZ,
|
||||
WorkgroupIdX,
|
||||
WorkgroupIdY,
|
||||
WorkgroupIdZ,
|
||||
|
@ -18,6 +22,12 @@ pub enum Variable {
|
|||
GlobalInvocationIdY,
|
||||
GlobalInvocationIdZ,
|
||||
Rank,
|
||||
WorkgroupSizeX,
|
||||
WorkgroupSizeY,
|
||||
WorkgroupSizeZ,
|
||||
NumWorkgroupsX,
|
||||
NumWorkgroupsY,
|
||||
NumWorkgroupsZ,
|
||||
}
|
||||
|
||||
impl Variable {
|
||||
|
@ -35,8 +45,12 @@ impl Variable {
|
|||
Variable::LocalScalar(idx, _, _) => Some(*idx),
|
||||
Variable::GlobalOutputArray(idx, _) => Some(*idx),
|
||||
Variable::ConstantScalar(_, _) => None,
|
||||
Variable::SharedMemory(idx, _, _) => Some(*idx),
|
||||
Variable::Id => None,
|
||||
Variable::InvocationIndex => None,
|
||||
Variable::LocalInvocationIndex => None,
|
||||
Variable::LocalInvocationIdX => None,
|
||||
Variable::LocalInvocationIdY => None,
|
||||
Variable::LocalInvocationIdZ => None,
|
||||
Variable::Rank => None,
|
||||
Variable::WorkgroupIdX => None,
|
||||
Variable::WorkgroupIdY => None,
|
||||
|
@ -44,6 +58,12 @@ impl Variable {
|
|||
Variable::GlobalInvocationIdX => None,
|
||||
Variable::GlobalInvocationIdY => None,
|
||||
Variable::GlobalInvocationIdZ => None,
|
||||
Variable::WorkgroupSizeX => None,
|
||||
Variable::WorkgroupSizeY => None,
|
||||
Variable::WorkgroupSizeZ => None,
|
||||
Variable::NumWorkgroupsX => None,
|
||||
Variable::NumWorkgroupsY => None,
|
||||
Variable::NumWorkgroupsZ => None,
|
||||
}
|
||||
}
|
||||
pub fn item(&self) -> Item {
|
||||
|
@ -54,15 +74,25 @@ impl Variable {
|
|||
Variable::Local(_, item, _) => *item,
|
||||
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
|
||||
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
|
||||
Variable::SharedMemory(_, item, _) => *item,
|
||||
Variable::Id => Item::Scalar(Elem::UInt),
|
||||
Variable::Rank => Item::Scalar(Elem::UInt),
|
||||
Variable::InvocationIndex => Item::Scalar(Elem::UInt),
|
||||
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),
|
||||
Variable::LocalInvocationIdX => Item::Scalar(Elem::UInt),
|
||||
Variable::LocalInvocationIdY => Item::Scalar(Elem::UInt),
|
||||
Variable::LocalInvocationIdZ => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupIdX => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupIdY => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupIdZ => Item::Scalar(Elem::UInt),
|
||||
Variable::GlobalInvocationIdX => Item::Scalar(Elem::UInt),
|
||||
Variable::GlobalInvocationIdY => Item::Scalar(Elem::UInt),
|
||||
Variable::GlobalInvocationIdZ => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupSizeX => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupSizeY => Item::Scalar(Elem::UInt),
|
||||
Variable::WorkgroupSizeZ => Item::Scalar(Elem::UInt),
|
||||
Variable::NumWorkgroupsX => Item::Scalar(Elem::UInt),
|
||||
Variable::NumWorkgroupsY => Item::Scalar(Elem::UInt),
|
||||
Variable::NumWorkgroupsZ => Item::Scalar(Elem::UInt),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -26,6 +26,9 @@ impl Operation {
|
|||
Operation::Branch(_) => panic!(
|
||||
"A branch can't be vectorized, they should only be generated after vectorization."
|
||||
),
|
||||
Operation::Synchronization(_) => panic!(
|
||||
"Synchronization instructions can't be vectorized, they should only be generated after vectorization."
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -112,18 +115,32 @@ impl Variable {
|
|||
Variable::GlobalOutputArray(index, item) => {
|
||||
Variable::GlobalOutputArray(*index, item.vectorize(vectorize))
|
||||
}
|
||||
Variable::SharedMemory(index, item, size) => Variable::SharedMemory(
|
||||
*index,
|
||||
item.vectorize(vectorize),
|
||||
item.vectorized_size(vectorize, *size),
|
||||
),
|
||||
Variable::ConstantScalar(_, _) => *self,
|
||||
Variable::GlobalScalar(_, _) => *self,
|
||||
Variable::Id => *self,
|
||||
Variable::Rank => *self,
|
||||
Variable::LocalScalar(_, _, _) => *self,
|
||||
Variable::InvocationIndex => *self,
|
||||
Variable::LocalInvocationIndex => *self,
|
||||
Variable::LocalInvocationIdX => *self,
|
||||
Variable::LocalInvocationIdY => *self,
|
||||
Variable::LocalInvocationIdZ => *self,
|
||||
Variable::WorkgroupIdX => *self,
|
||||
Variable::WorkgroupIdY => *self,
|
||||
Variable::WorkgroupIdZ => *self,
|
||||
Variable::GlobalInvocationIdX => *self,
|
||||
Variable::GlobalInvocationIdY => *self,
|
||||
Variable::GlobalInvocationIdZ => *self,
|
||||
Variable::WorkgroupSizeX => *self,
|
||||
Variable::WorkgroupSizeY => *self,
|
||||
Variable::WorkgroupSizeZ => *self,
|
||||
Variable::NumWorkgroupsX => *self,
|
||||
Variable::NumWorkgroupsY => *self,
|
||||
Variable::NumWorkgroupsZ => *self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -137,4 +154,13 @@ impl Item {
|
|||
Vectorization::Scalar => Item::Scalar(self.elem()),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn vectorized_size(&self, vectorize: Vectorization, size: u32) -> u32 {
|
||||
match vectorize {
|
||||
Vectorization::Vec4 => size / 4,
|
||||
Vectorization::Vec3 => size / 3,
|
||||
Vectorization::Vec2 => size / 2,
|
||||
Vectorization::Scalar => size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -17,8 +17,12 @@ pub enum Variable {
|
|||
elem: Elem,
|
||||
scope_depth: u8,
|
||||
},
|
||||
SharedMemory(u16, Item, u32),
|
||||
Id,
|
||||
LocalInvocationIndex,
|
||||
LocalInvocationIdX,
|
||||
LocalInvocationIdY,
|
||||
LocalInvocationIdZ,
|
||||
Rank,
|
||||
WorkgroupIdX,
|
||||
WorkgroupIdY,
|
||||
|
@ -26,6 +30,12 @@ pub enum Variable {
|
|||
GlobalInvocationIdX,
|
||||
GlobalInvocationIdY,
|
||||
GlobalInvocationIdZ,
|
||||
WorkgroupSizeX,
|
||||
WorkgroupSizeY,
|
||||
WorkgroupSizeZ,
|
||||
NumWorkgroupsX,
|
||||
NumWorkgroupsY,
|
||||
NumWorkgroupsZ,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Copy)]
|
||||
|
@ -62,9 +72,13 @@ impl Variable {
|
|||
} => true,
|
||||
Variable::Id => true,
|
||||
Variable::LocalInvocationIndex => true,
|
||||
Variable::LocalInvocationIdX => true,
|
||||
Variable::LocalInvocationIdY => true,
|
||||
Variable::LocalInvocationIdZ => true,
|
||||
Variable::Rank => true,
|
||||
Variable::GlobalInputArray(_, _) => false,
|
||||
Variable::GlobalOutputArray(_, _) => false,
|
||||
Variable::SharedMemory(_, _, _) => false,
|
||||
Variable::Local {
|
||||
index: _,
|
||||
item: _,
|
||||
|
@ -76,6 +90,12 @@ impl Variable {
|
|||
Variable::GlobalInvocationIdX => true,
|
||||
Variable::GlobalInvocationIdY => true,
|
||||
Variable::GlobalInvocationIdZ => true,
|
||||
Variable::WorkgroupSizeX => true,
|
||||
Variable::WorkgroupSizeY => true,
|
||||
Variable::WorkgroupSizeZ => true,
|
||||
Variable::NumWorkgroupsX => true,
|
||||
Variable::NumWorkgroupsY => true,
|
||||
Variable::NumWorkgroupsZ => true,
|
||||
}
|
||||
}
|
||||
pub fn index(&self, index: usize) -> IndexedVariable {
|
||||
|
@ -89,6 +109,7 @@ impl Variable {
|
|||
match self {
|
||||
Self::GlobalInputArray(_, e) => *e,
|
||||
Self::GlobalOutputArray(_, e) => *e,
|
||||
Self::SharedMemory(_, e, _) => *e,
|
||||
Self::Local {
|
||||
index: _,
|
||||
item,
|
||||
|
@ -98,6 +119,9 @@ impl Variable {
|
|||
Self::GlobalScalar(_, e, _) => Item::Scalar(*e),
|
||||
Self::Id => Item::Scalar(Elem::U32),
|
||||
Self::LocalInvocationIndex => Item::Scalar(Elem::U32),
|
||||
Self::LocalInvocationIdX => Item::Scalar(Elem::U32),
|
||||
Self::LocalInvocationIdY => Item::Scalar(Elem::U32),
|
||||
Self::LocalInvocationIdZ => Item::Scalar(Elem::U32),
|
||||
Self::Rank => Item::Scalar(Elem::U32),
|
||||
Self::LocalScalar {
|
||||
index: _,
|
||||
|
@ -110,6 +134,12 @@ impl Variable {
|
|||
Self::GlobalInvocationIdX => Item::Scalar(Elem::U32),
|
||||
Self::GlobalInvocationIdY => Item::Scalar(Elem::U32),
|
||||
Self::GlobalInvocationIdZ => Item::Scalar(Elem::U32),
|
||||
Self::WorkgroupSizeX => Item::Scalar(Elem::U32),
|
||||
Self::WorkgroupSizeY => Item::Scalar(Elem::U32),
|
||||
Self::WorkgroupSizeZ => Item::Scalar(Elem::U32),
|
||||
Self::NumWorkgroupsX => Item::Scalar(Elem::U32),
|
||||
Self::NumWorkgroupsY => Item::Scalar(Elem::U32),
|
||||
Self::NumWorkgroupsZ => Item::Scalar(Elem::U32),
|
||||
}
|
||||
}
|
||||
pub fn elem(&self) -> Elem {
|
||||
|
@ -184,8 +214,14 @@ impl Display for Variable {
|
|||
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
|
||||
}
|
||||
Variable::ConstantScalar(number, elem) => f.write_fmt(format_args!("{elem}({number})")),
|
||||
Variable::SharedMemory(number, _, _) => {
|
||||
f.write_fmt(format_args!("shared_memory_{number}"))
|
||||
}
|
||||
Variable::Id => f.write_str("id"),
|
||||
Variable::LocalInvocationIndex => f.write_str("local_idx"),
|
||||
Variable::LocalInvocationIdX => f.write_str("local_invocation_id.x"),
|
||||
Variable::LocalInvocationIdY => f.write_str("local_invocation_id.y"),
|
||||
Variable::LocalInvocationIdZ => f.write_str("local_invocation_id.z"),
|
||||
Variable::Rank => f.write_str("rank"),
|
||||
Variable::WorkgroupIdX => f.write_str("workgroup_id.x"),
|
||||
Variable::WorkgroupIdY => f.write_str("workgroup_id.y"),
|
||||
|
@ -193,6 +229,12 @@ impl Display for Variable {
|
|||
Variable::GlobalInvocationIdX => f.write_str("global_id.x"),
|
||||
Variable::GlobalInvocationIdY => f.write_str("global_id.y"),
|
||||
Variable::GlobalInvocationIdZ => f.write_str("global_id.z"),
|
||||
Variable::WorkgroupSizeX => f.write_str("WORKGROUP_SIZE_X"),
|
||||
Variable::WorkgroupSizeY => f.write_str("WORKGROUP_SIZE_Y"),
|
||||
Variable::WorkgroupSizeZ => f.write_str("WORKGROUP_SIZE_Z"),
|
||||
Variable::NumWorkgroupsX => f.write_str("num_workgroups.x"),
|
||||
Variable::NumWorkgroupsY => f.write_str("num_workgroups.y"),
|
||||
Variable::NumWorkgroupsZ => f.write_str("num_workgroups.z"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{shader::ComputeShader, Item};
|
||||
use super::{shader::ComputeShader, Item, SharedMemory};
|
||||
use crate::{
|
||||
codegen::{
|
||||
compiler,
|
||||
|
@ -13,13 +13,16 @@ use std::marker::PhantomData;
|
|||
pub struct Compiler<F: FloatElement, I: IntElement> {
|
||||
num_inputs: usize,
|
||||
num_outputs: usize,
|
||||
invocation_index: bool,
|
||||
local_invocation_index: bool,
|
||||
local_invocation_id: bool,
|
||||
global_invocation_id: bool,
|
||||
workgroup_id: bool,
|
||||
rank: bool,
|
||||
id: bool,
|
||||
stride: bool,
|
||||
shape: bool,
|
||||
num_workgroups: bool,
|
||||
shared_memories: Vec<SharedMemory>,
|
||||
_float: PhantomData<F>,
|
||||
_int: PhantomData<I>,
|
||||
}
|
||||
|
@ -35,13 +38,16 @@ impl<F: FloatElement, I: IntElement> Default for Compiler<F, I> {
|
|||
Self {
|
||||
num_inputs: 0,
|
||||
num_outputs: 0,
|
||||
invocation_index: false,
|
||||
local_invocation_index: false,
|
||||
local_invocation_id: false,
|
||||
global_invocation_id: false,
|
||||
workgroup_id: false,
|
||||
rank: false,
|
||||
id: false,
|
||||
stride: false,
|
||||
shape: false,
|
||||
num_workgroups: false,
|
||||
shared_memories: Vec::default(),
|
||||
_float: PhantomData,
|
||||
_int: PhantomData,
|
||||
}
|
||||
|
@ -95,10 +101,12 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
|
|||
.into_iter()
|
||||
.map(|(name, binding)| (name, Self::compile_binding(binding)))
|
||||
.collect(),
|
||||
shared_memories: self.shared_memories.clone(),
|
||||
workgroup_size: value.workgroup_size,
|
||||
global_invocation_id: self.global_invocation_id || self.id,
|
||||
local_invocation_index: self.invocation_index,
|
||||
num_workgroups: self.id,
|
||||
local_invocation_index: self.local_invocation_index,
|
||||
local_invocation_id: self.local_invocation_id,
|
||||
num_workgroups: self.id || self.num_workgroups,
|
||||
workgroup_id: self.workgroup_id,
|
||||
body,
|
||||
extensions,
|
||||
|
@ -147,6 +155,12 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
|
|||
gpu::Variable::ConstantScalar(index, elem) => {
|
||||
wgsl::Variable::ConstantScalar(index, Self::compile_elem(elem))
|
||||
}
|
||||
gpu::Variable::SharedMemory(index, item, size) => {
|
||||
let item = Self::compile_item(item);
|
||||
self.shared_memories
|
||||
.push(SharedMemory::new(index, item, size));
|
||||
wgsl::Variable::SharedMemory(index, item, size)
|
||||
}
|
||||
gpu::Variable::Id => {
|
||||
self.id = true;
|
||||
wgsl::Variable::Id
|
||||
|
@ -155,10 +169,22 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
|
|||
self.rank = true;
|
||||
wgsl::Variable::Rank
|
||||
}
|
||||
gpu::Variable::InvocationIndex => {
|
||||
self.invocation_index = true;
|
||||
gpu::Variable::LocalInvocationIndex => {
|
||||
self.local_invocation_index = true;
|
||||
wgsl::Variable::LocalInvocationIndex
|
||||
}
|
||||
gpu::Variable::LocalInvocationIdX => {
|
||||
self.local_invocation_id = true;
|
||||
wgsl::Variable::LocalInvocationIdX
|
||||
}
|
||||
gpu::Variable::LocalInvocationIdY => {
|
||||
self.local_invocation_id = true;
|
||||
wgsl::Variable::LocalInvocationIdY
|
||||
}
|
||||
gpu::Variable::LocalInvocationIdZ => {
|
||||
self.local_invocation_id = true;
|
||||
wgsl::Variable::LocalInvocationIdZ
|
||||
}
|
||||
gpu::Variable::WorkgroupIdX => {
|
||||
self.workgroup_id = true;
|
||||
wgsl::Variable::WorkgroupIdX
|
||||
|
@ -183,6 +209,21 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
|
|||
self.global_invocation_id = true;
|
||||
wgsl::Variable::GlobalInvocationIdZ
|
||||
}
|
||||
gpu::Variable::WorkgroupSizeX => wgsl::Variable::WorkgroupSizeX,
|
||||
gpu::Variable::WorkgroupSizeY => wgsl::Variable::WorkgroupSizeY,
|
||||
gpu::Variable::WorkgroupSizeZ => wgsl::Variable::WorkgroupSizeZ,
|
||||
gpu::Variable::NumWorkgroupsX => {
|
||||
self.num_workgroups = true;
|
||||
wgsl::Variable::NumWorkgroupsX
|
||||
}
|
||||
gpu::Variable::NumWorkgroupsY => {
|
||||
self.num_workgroups = true;
|
||||
wgsl::Variable::NumWorkgroupsY
|
||||
}
|
||||
gpu::Variable::NumWorkgroupsZ => {
|
||||
self.num_workgroups = true;
|
||||
wgsl::Variable::NumWorkgroupsZ
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -215,6 +256,7 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
|
|||
gpu::Operation::Procedure(proc) => self.compile_procedure(instructions, proc, scope),
|
||||
gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op)),
|
||||
gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
|
||||
gpu::Operation::Synchronization(val) => self.compile_synchronization(instructions, val),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -239,6 +281,22 @@ impl<F: FloatElement, I: IntElement> Compiler<F, I> {
|
|||
instructions: self.compile_scope(&mut range_loop.scope),
|
||||
})
|
||||
}
|
||||
gpu::Branch::WhileLoop(mut op) => instructions.push(wgsl::Instruction::WhileLoop {
|
||||
cond: self.compile_variable(op.cond),
|
||||
instructions: self.compile_scope(&mut op.scope),
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
fn compile_synchronization(
|
||||
&mut self,
|
||||
instructions: &mut Vec<wgsl::Instruction>,
|
||||
synchronization: gpu::Synchronization,
|
||||
) {
|
||||
match synchronization {
|
||||
gpu::Synchronization::WorkgroupBarrier => {
|
||||
instructions.push(wgsl::Instruction::WorkgroupBarrier)
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -24,6 +24,7 @@ pub enum Instruction {
|
|||
},
|
||||
Return,
|
||||
Break,
|
||||
WorkgroupBarrier,
|
||||
// Index handles casting to correct local variable.
|
||||
Index {
|
||||
lhs: Variable,
|
||||
|
@ -157,6 +158,10 @@ pub enum Instruction {
|
|||
end: Variable,
|
||||
instructions: Vec<Instruction>,
|
||||
},
|
||||
WhileLoop {
|
||||
cond: Variable,
|
||||
instructions: Vec<Instruction>,
|
||||
},
|
||||
}
|
||||
|
||||
impl Display for Instruction {
|
||||
|
@ -371,9 +376,17 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
|||
}
|
||||
Instruction::Return => f.write_str("return;\n"),
|
||||
Instruction::Break => f.write_str("break;\n"),
|
||||
Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"),
|
||||
Instruction::ArrayLength { var, out } => {
|
||||
f.write_fmt(format_args!("{out} = arrayLength(&{var});\n"))
|
||||
}
|
||||
Instruction::WhileLoop { cond, instructions } => {
|
||||
f.write_fmt(format_args!("while {cond} {{\n"))?;
|
||||
for i in instructions {
|
||||
f.write_fmt(format_args!("{i}"))?;
|
||||
}
|
||||
f.write_str("}\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,14 +23,35 @@ pub struct Binding {
|
|||
pub size: Option<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct SharedMemory {
|
||||
location: Location,
|
||||
index: u16,
|
||||
item: Item,
|
||||
size: u32,
|
||||
}
|
||||
|
||||
impl SharedMemory {
|
||||
pub fn new(index: u16, item: Item, size: u32) -> Self {
|
||||
Self {
|
||||
location: Location::Workgroup,
|
||||
index,
|
||||
item,
|
||||
size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeShader {
|
||||
pub inputs: Vec<Binding>,
|
||||
pub outputs: Vec<Binding>,
|
||||
pub named: Vec<(String, Binding)>,
|
||||
pub shared_memories: Vec<SharedMemory>,
|
||||
pub workgroup_size: WorkgroupSize,
|
||||
pub global_invocation_id: bool,
|
||||
pub local_invocation_index: bool,
|
||||
pub local_invocation_id: bool,
|
||||
pub num_workgroups: bool,
|
||||
pub workgroup_id: bool,
|
||||
pub body: Body,
|
||||
|
@ -51,6 +72,13 @@ impl Display for ComputeShader {
|
|||
)?;
|
||||
}
|
||||
|
||||
for shared_memory in self.shared_memories.iter() {
|
||||
f.write_fmt(format_args!(
|
||||
"var<{}> shared_memory_{}: array<{}, {}>;\n\n",
|
||||
shared_memory.location, shared_memory.index, shared_memory.item, shared_memory.size
|
||||
))?;
|
||||
}
|
||||
|
||||
f.write_fmt(format_args!(
|
||||
"const WORKGROUP_SIZE_X = {}u;
|
||||
const WORKGROUP_SIZE_Y = {}u;
|
||||
|
@ -75,6 +103,10 @@ fn main(
|
|||
f.write_str(" @builtin(local_invocation_index) local_idx: u32,\n")?;
|
||||
}
|
||||
|
||||
if self.local_invocation_id {
|
||||
f.write_str(" @builtin(local_invocation_id) local_invocation_id: vec3<u32>,\n")?;
|
||||
}
|
||||
|
||||
if self.num_workgroups {
|
||||
f.write_str(" @builtin(num_workgroups) num_workgroups: vec3<u32>,\n")?;
|
||||
}
|
||||
|
|
|
@ -186,6 +186,7 @@ where
|
|||
}
|
||||
|
||||
let source = kernel.source().complete();
|
||||
println!("{}", source);
|
||||
let pipeline = self.compile_source(&source);
|
||||
self.pipelines.insert(kernel_id.clone(), pipeline.clone());
|
||||
|
||||
|
|
|
@ -329,6 +329,9 @@ impl TraceBuilder {
|
|||
Operation::Branch(_) => {
|
||||
// Nothing to do, should never impact read-write access to bindings.
|
||||
}
|
||||
Operation::Synchronization(_) => {
|
||||
// Nothing to do, should never impact read-write access to bindings.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -30,7 +30,7 @@ struct MatmulComputeShader {
|
|||
impl MatmulComputeShader {
|
||||
fn expand(self, scope: &mut Scope) {
|
||||
// Define out global variables.
|
||||
let local_idx = Variable::InvocationIndex;
|
||||
let local_idx = Variable::LocalInvocationIndex;
|
||||
let batch = Variable::GlobalInvocationIdZ;
|
||||
let rank = Variable::Rank;
|
||||
let block_size: Variable = self.block_size.into();
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
mod base;
|
||||
mod naive_reduce_shader;
|
||||
mod reduction;
|
||||
mod reduction_shared_memory;
|
||||
mod shader;
|
||||
mod tune;
|
||||
mod workgroup_reduce_shader;
|
||||
|
||||
pub use base::*;
|
||||
pub(crate) use naive_reduce_shader::*;
|
||||
pub use reduction::*;
|
||||
pub use reduction_shared_memory::*;
|
||||
pub(crate) use shader::*;
|
||||
pub use tune::*;
|
||||
|
|
|
@ -10,16 +10,18 @@ use crate::{
|
|||
Runtime,
|
||||
};
|
||||
|
||||
pub(crate) trait ReduceDim: Send + Sync + 'static {
|
||||
pub(crate) trait NaiveReduceDim: Send + Sync + 'static {
|
||||
type Accumulator: Copy;
|
||||
|
||||
fn initialize(scope: &mut Scope, input_item: Item, output_item: Item) -> Self::Accumulator;
|
||||
|
||||
fn inner_loop(
|
||||
scope: &mut Scope,
|
||||
accumulator: Self::Accumulator,
|
||||
current_value: Variable,
|
||||
i: Variable,
|
||||
);
|
||||
|
||||
fn assign(
|
||||
scope: &mut Scope,
|
||||
output: Variable,
|
||||
|
@ -30,14 +32,14 @@ pub(crate) trait ReduceDim: Send + Sync + 'static {
|
|||
|
||||
pub(crate) struct SumDim;
|
||||
|
||||
impl ReduceDim for SumDim {
|
||||
impl NaiveReduceDim for SumDim {
|
||||
type Accumulator = Variable;
|
||||
|
||||
fn initialize(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
|
||||
scope.zero(output_item)
|
||||
}
|
||||
|
||||
fn inner_loop(scope: &mut Scope, accumulator: Variable, value: Variable, i: Variable) {
|
||||
fn inner_loop(scope: &mut Scope, accumulator: Variable, value: Variable, _i: Variable) {
|
||||
gpu!(scope, accumulator += value);
|
||||
}
|
||||
|
||||
|
@ -54,7 +56,7 @@ impl ReduceDim for SumDim {
|
|||
|
||||
pub(crate) struct MeanDim;
|
||||
|
||||
impl ReduceDim for MeanDim {
|
||||
impl NaiveReduceDim for MeanDim {
|
||||
type Accumulator = Variable;
|
||||
|
||||
fn initialize(scope: &mut Scope, _input_item: Item, output_item: Item) -> Variable {
|
||||
|
@ -81,10 +83,10 @@ impl ReduceDim for MeanDim {
|
|||
|
||||
pub(crate) struct ArgMax;
|
||||
|
||||
impl ReduceDim for ArgMax {
|
||||
impl NaiveReduceDim for ArgMax {
|
||||
type Accumulator = (Variable, Variable);
|
||||
|
||||
fn initialize(scope: &mut Scope, input_item: Item, output_item: Item) -> Self::Accumulator {
|
||||
fn initialize(scope: &mut Scope, input_item: Item, _output_item: Item) -> Self::Accumulator {
|
||||
let max = scope.create_local(input_item);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, max = cast(-32767.0));
|
||||
|
@ -118,10 +120,10 @@ impl ReduceDim for ArgMax {
|
|||
|
||||
pub(crate) struct ArgMin;
|
||||
|
||||
impl ReduceDim for ArgMin {
|
||||
impl NaiveReduceDim for ArgMin {
|
||||
type Accumulator = (Variable, Variable);
|
||||
|
||||
fn initialize(scope: &mut Scope, input_item: Item, output_item: Item) -> Self::Accumulator {
|
||||
fn initialize(scope: &mut Scope, input_item: Item, _output_item: Item) -> Self::Accumulator {
|
||||
let min = scope.create_local(input_item);
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, min = cast(32767.0));
|
||||
|
@ -153,15 +155,20 @@ impl ReduceDim for ArgMin {
|
|||
}
|
||||
}
|
||||
|
||||
pub(crate) struct ReduceDimComputeShader<RD: ReduceDim> {
|
||||
pub(crate) struct NaiveReduceDimComputeShader<RD: NaiveReduceDim> {
|
||||
tensor: Variable,
|
||||
dim: usize,
|
||||
output: Variable,
|
||||
reduce_dim: PhantomData<RD>,
|
||||
_reduce_dim: PhantomData<RD>,
|
||||
}
|
||||
|
||||
#[derive(new)]
|
||||
pub(crate) struct ReduceDimEagerKernel<RD: ReduceDim, R: Runtime, EI: JitElement, EO: JitElement> {
|
||||
pub(crate) struct NaiveReduceDimEagerKernel<
|
||||
RD: NaiveReduceDim,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
> {
|
||||
dim: usize,
|
||||
reduce_dim: PhantomData<RD>,
|
||||
_runtime: PhantomData<R>,
|
||||
|
@ -169,8 +176,8 @@ pub(crate) struct ReduceDimEagerKernel<RD: ReduceDim, R: Runtime, EI: JitElement
|
|||
_elem_out: PhantomData<EO>,
|
||||
}
|
||||
|
||||
impl<RD: ReduceDim, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSource
|
||||
for ReduceDimEagerKernel<RD, R, EI, EO>
|
||||
impl<RD: NaiveReduceDim, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSource
|
||||
for NaiveReduceDimEagerKernel<RD, R, EI, EO>
|
||||
{
|
||||
fn source(&self) -> crate::kernel::SourceTemplate {
|
||||
let mut scope = Scope::root();
|
||||
|
@ -180,11 +187,11 @@ impl<RD: ReduceDim, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSou
|
|||
let tensor = Variable::GlobalInputArray(0, item_input);
|
||||
let output = Variable::GlobalOutputArray(0, item_output);
|
||||
|
||||
ReduceDimComputeShader {
|
||||
NaiveReduceDimComputeShader {
|
||||
tensor,
|
||||
dim: self.dim,
|
||||
output,
|
||||
reduce_dim: PhantomData::<RD>,
|
||||
_reduce_dim: PhantomData::<RD>,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
|
@ -214,7 +221,7 @@ impl<RD: ReduceDim, R: Runtime, EI: JitElement, EO: JitElement> DynamicKernelSou
|
|||
}
|
||||
}
|
||||
|
||||
impl<RD: ReduceDim> ReduceDimComputeShader<RD> {
|
||||
impl<RD: NaiveReduceDim> NaiveReduceDimComputeShader<RD> {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let tensor = self.tensor;
|
||||
let dim: Variable = self.dim.into();
|
|
@ -1,13 +1,16 @@
|
|||
use crate::{
|
||||
codegen::{execute_dynamic, EagerHandle, WorkgroupLaunch},
|
||||
element::JitElement,
|
||||
kernel::reduce,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
use super::{init_reduce_output, ArgMax, ArgMin, MeanDim, ReduceDim, ReduceDimEagerKernel, SumDim};
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
use super::init_reduce_output;
|
||||
#[cfg(feature = "autotune")]
|
||||
use super::tune::{mean_dim_autotune, sum_dim_autotune};
|
||||
use super::{ArgMax, ArgMin, MeanDim, NaiveReduceDim, NaiveReduceDimEagerKernel, SumDim};
|
||||
|
||||
/// Sum all elements in the input buffer.
|
||||
pub fn sum<R: Runtime, E: JitElement, const D: usize>(
|
||||
|
@ -18,13 +21,14 @@ pub fn sum<R: Runtime, E: JitElement, const D: usize>(
|
|||
sum_dim(input, 0)
|
||||
}
|
||||
|
||||
/// Sum all elements on one dimension. Autotunable
|
||||
pub fn sum_dim<R: Runtime, E: JitElement, const D: usize>(
|
||||
tensor: JitTensor<R, E, D>,
|
||||
dim: usize,
|
||||
) -> JitTensor<R, E, D> {
|
||||
#[cfg(feature = "autotune")]
|
||||
{
|
||||
reduce::sum_dim_autotune(tensor, dim)
|
||||
sum_dim_autotune(tensor, dim)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
|
@ -34,13 +38,14 @@ pub fn sum_dim<R: Runtime, E: JitElement, const D: usize>(
|
|||
}
|
||||
}
|
||||
|
||||
/// Mean of all elements on one dimension. Autotunable
|
||||
pub fn mean_dim<R: Runtime, E: JitElement, const D: usize>(
|
||||
tensor: JitTensor<R, E, D>,
|
||||
dim: usize,
|
||||
) -> JitTensor<R, E, D> {
|
||||
#[cfg(feature = "autotune")]
|
||||
{
|
||||
reduce::mean_dim_autotune(tensor, dim)
|
||||
mean_dim_autotune(tensor, dim)
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
|
@ -60,7 +65,7 @@ pub fn sum_dim_naive<R: Runtime, E: JitElement, const D: usize>(
|
|||
}
|
||||
|
||||
pub(crate) fn reduce_dim_naive<
|
||||
RD: ReduceDim,
|
||||
RD: NaiveReduceDim,
|
||||
R: Runtime,
|
||||
EI: JitElement,
|
||||
EO: JitElement,
|
||||
|
@ -70,9 +75,9 @@ pub(crate) fn reduce_dim_naive<
|
|||
output: JitTensor<R, EO, D>,
|
||||
dim: usize,
|
||||
) -> JitTensor<R, EO, D> {
|
||||
let kernel = ReduceDimEagerKernel::new(dim);
|
||||
let kernel = NaiveReduceDimEagerKernel::new(dim);
|
||||
|
||||
execute_dynamic::<R, ReduceDimEagerKernel<RD, R, EI, EO>, EI>(
|
||||
execute_dynamic::<R, NaiveReduceDimEagerKernel<RD, R, EI, EO>, EI>(
|
||||
&[EagerHandle::new(
|
||||
&input.handle,
|
||||
&input.strides,
|
||||
|
|
|
@ -116,6 +116,7 @@ fn reduction_dim_shared_memory<K: StaticKernelSource, R: Runtime, E: JitElement,
|
|||
let n_invocation_per_workgroup = WORKGROUP_DEFAULT * WORKGROUP_DEFAULT;
|
||||
let n_reduce_elements_per_thread =
|
||||
f32::ceil(reduce_group_size as f32 / n_invocation_per_workgroup as f32) as u32;
|
||||
// n_reduce_elements_per_thread: u32(ceil(f32(shape_dim)/f32(WGD*WGD)))
|
||||
|
||||
// Add dimension of reduction and how many reduce elements are treated per thread
|
||||
info.push(reduce_dim as u32);
|
||||
|
|
|
@ -7,7 +7,7 @@ use crate::{
|
|||
kernel::{
|
||||
prng::{random_like_uniform, random_like_uniform_int},
|
||||
reduce::{
|
||||
init_reduce_output, int_mean_dim_naive, int_mean_dim_shared_memory, mean_dim, mean_dim_naive,
|
||||
init_reduce_output, int_mean_dim_naive, int_mean_dim_shared_memory, mean_dim_naive,
|
||||
mean_dim_shared_memory,
|
||||
},
|
||||
},
|
||||
|
|
|
@ -7,8 +7,6 @@ use crate::kernel::matmul::matmul_autotune;
|
|||
#[cfg(not(feature = "autotune"))]
|
||||
use crate::kernel::matmul::vec4::matmul_tiling_2d_vec4;
|
||||
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
|
||||
// #[cfg(not(feature = "autotune"))]
|
||||
use crate::kernel::reduce::init_reduce_output;
|
||||
use crate::kernel::{self, reduce};
|
||||
use crate::tensor::JitTensor;
|
||||
use crate::Runtime;
|
||||
|
|
Loading…
Reference in New Issue