diff --git a/crates/burn-cube/src/codegen/integrator.rs b/crates/burn-cube/src/codegen/integrator.rs index d65df226b..c19b810a8 100644 --- a/crates/burn-cube/src/codegen/integrator.rs +++ b/crates/burn-cube/src/codegen/integrator.rs @@ -423,17 +423,24 @@ impl KernelIntegrator { } else { item }; - let elem_adapted = bool_item(item); + let item_adapted = bool_item(item); self.output_bindings.push(Binding { - item: elem_adapted, + item: item_adapted, visibility: Visibility::ReadWrite, location: Location::Storage, size: None, }); self.expansion.scope.write_global( - Variable::Local(local, item, self.expansion.scope.depth), - Variable::GlobalOutputArray(index, elem_adapted), + Variable::Local { + id: local, + item, + depth: self.expansion.scope.depth, + }, + Variable::GlobalOutputArray { + id: index, + item: item_adapted, + }, position, ); index += 1; @@ -451,8 +458,15 @@ impl KernelIntegrator { }; self.expansion.scope.write_global( - Variable::Local(local, item, self.expansion.scope.depth), - Variable::GlobalInputArray(input, bool_item(item)), + Variable::Local { + id: local, + item, + depth: self.expansion.scope.depth, + }, + Variable::GlobalInputArray { + id: input, + item: bool_item(item), + }, position, ); } diff --git a/crates/burn-cube/src/frontend/branch.rs b/crates/burn-cube/src/frontend/branch.rs index b2355ffc5..6caf4c34e 100644 --- a/crates/burn-cube/src/frontend/branch.rs +++ b/crates/burn-cube/src/frontend/branch.rs @@ -27,11 +27,11 @@ where if unroll { let start = match start.deref() { - Variable::ConstantScalar(val, _) => *val as usize, + Variable::ConstantScalar { value, .. } => *value as usize, _ => panic!("Only constant start can be unrolled."), }; let end = match end.deref() { - Variable::ConstantScalar(val, _) => *val as usize, + Variable::ConstantScalar { value, .. } => *value as usize, _ => panic!("Only constant end can be unrolled."), }; diff --git a/crates/burn-cube/src/frontend/cmma.rs b/crates/burn-cube/src/frontend/cmma.rs index ce636e9dd..f00e1dce3 100644 --- a/crates/burn-cube/src/frontend/cmma.rs +++ b/crates/burn-cube/src/frontend/cmma.rs @@ -15,14 +15,14 @@ //! 16, //! 16, //! 16, -//! cmma::MatrixLayout::ColMajor, +//! cmma::MatrixLayout::RowMajor, //! ); //! let b = cmma::Matrix::::new( //! cmma::MatrixIdent::B, //! 16, //! 16, //! 16, -//! cmma::MatrixLayout::RowMajor, +//! cmma::MatrixLayout::ColMajor, //! ); //! let c = cmma::Matrix::::new( //! cmma::MatrixIdent::Accumulator, @@ -32,12 +32,17 @@ //! cmma::MatrixLayout::Undefined, //! ); //! cmma::fill::(&c, F32::new(0.0)); -//! cmma::load::(&a, lhs, UInt::new(16)); -//! cmma::load::(&b, rhs, UInt::new(16)); +//! cmma::load::(&a, lhs.as_slice(), UInt::new(16)); +//! cmma::load::(&b, rhs.as_slice(), UInt::new(16)); //! //! cmma::execute::(&a, &b, &c, &c); //! -//! cmma::store::(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor); +//! cmma::store::( +//! out.as_slice_mut(), +//! &c, +//! UInt::new(16), +//! cmma::MatrixLayout::RowMajor, +//! ); //! } //! ``` @@ -49,7 +54,8 @@ use crate::{ }; use super::{ - Array, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, UInt, + CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut, + UInt, }; pub use ir::{MatrixIdent, MatrixLayout}; @@ -142,7 +148,7 @@ pub mod fill { /// Load the matrix with the provided array using the stride. #[allow(unused_variables)] -pub fn load(mat: &Matrix, value: &Array, stride: UInt) { +pub fn load(mat: &Matrix, value: &Slice<'_, C>, stride: UInt) { unexpanded!() } @@ -155,7 +161,7 @@ pub mod load { pub fn __expand( context: &mut CubeContext, mat: MatrixExpand, - value: ExpandElementTyped>, + value: ExpandElementTyped>, stride: ExpandElement, ) { context.register(Operation::CoopMma(ir::CoopMma::Load { @@ -169,7 +175,7 @@ pub mod load { /// Store the matrix in the given array following the given stride and layout. #[allow(unused_variables)] pub fn store( - output: &Array, + output: &mut SliceMut<'_, C>, mat: &Matrix, stride: UInt, layout: MatrixLayout, @@ -185,7 +191,7 @@ pub mod store { #[allow(unused_variables)] pub fn __expand( context: &mut CubeContext, - output: ExpandElementTyped>, + output: ExpandElementTyped>, mat: MatrixExpand, stride: ExpandElement, layout: MatrixLayout, diff --git a/crates/burn-cube/src/frontend/context.rs b/crates/burn-cube/src/frontend/context.rs index ef0378f21..1d8c52290 100644 --- a/crates/burn-cube/src/frontend/context.rs +++ b/crates/burn-cube/src/frontend/context.rs @@ -4,8 +4,6 @@ use alloc::rc::Rc; use core::cell::RefCell; use std::collections::HashMap; -use super::{CubePrimitive, SharedMemoryExpand}; - #[derive(Default, Clone)] pub struct VariablePool { map: Rc>>>, @@ -114,14 +112,14 @@ impl CubeContext { ExpandElement::Plain(variable) } - pub fn create_shared( - &mut self, - item: Item, - size: u32, - ) -> SharedMemoryExpand { - SharedMemoryExpand { - val: ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size)), - } + /// Create a new slice element. + pub fn create_slice(&mut self, item: Item) -> ExpandElement { + let variable = self.scope.borrow_mut().create_slice(item); + ExpandElement::Plain(variable) + } + + pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement { + ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size)) } pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement { @@ -129,19 +127,19 @@ impl CubeContext { } /// Obtain the index-th input - pub fn input(&mut self, index: u16, item: Item) -> ExpandElement { - ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item)) + pub fn input(&mut self, id: u16, item: Item) -> ExpandElement { + ExpandElement::Plain(crate::ir::Variable::GlobalInputArray { id, item }) } /// Obtain the index-th output - pub fn output(&mut self, index: u16, item: Item) -> ExpandElement { - let var = crate::ir::Variable::GlobalOutputArray(index, item); + pub fn output(&mut self, id: u16, item: Item) -> ExpandElement { + let var = crate::ir::Variable::GlobalOutputArray { id, item }; self.scope.borrow_mut().write_global_custom(var); ExpandElement::Plain(var) } /// Obtain the index-th scalar - pub fn scalar(&self, index: u16, elem: Elem) -> ExpandElement { - ExpandElement::Plain(crate::ir::Variable::GlobalScalar(index, elem)) + pub fn scalar(&self, id: u16, elem: Elem) -> ExpandElement { + ExpandElement::Plain(crate::ir::Variable::GlobalScalar { id, elem }) } } diff --git a/crates/burn-cube/src/frontend/element/array.rs b/crates/burn-cube/src/frontend/element/array.rs index 16b6a5507..5831a8d47 100644 --- a/crates/burn-cube/src/frontend/element/array.rs +++ b/crates/burn-cube/src/frontend/element/array.rs @@ -40,7 +40,7 @@ impl Array { ) -> ::ExpandType { let size = size.value(); let size = match size { - crate::ir::Variable::ConstantScalar(val, _) => val as u32, + crate::ir::Variable::ConstantScalar { value, .. } => value as u32, _ => panic!("Array need constant initialization value"), }; context @@ -55,7 +55,7 @@ impl Array { ) -> ::ExpandType { let size = size.value(); let size = match size { - crate::ir::Variable::ConstantScalar(val, _) => val as u32, + crate::ir::Variable::ConstantScalar { value, .. } => value as u32, _ => panic!("Shared memory need constant initialization value"), }; context diff --git a/crates/burn-cube/src/frontend/element/base.rs b/crates/burn-cube/src/frontend/element/base.rs index 8c7389135..f9de7ad4f 100644 --- a/crates/burn-cube/src/frontend/element/base.rs +++ b/crates/burn-cube/src/frontend/element/base.rs @@ -160,7 +160,7 @@ impl ExpandElement { pub fn can_mut(&self) -> bool { match self { ExpandElement::Managed(var) => { - if let Variable::Local(_, _, _) = var.as_ref() { + if let Variable::Local { .. } = var.as_ref() { Rc::strong_count(var) <= 2 } else { false @@ -201,10 +201,10 @@ impl Init for ExpandElement { let mut init = |elem: Self| init_expand(context, elem, Operator::Assign); match *self { - Variable::GlobalScalar(_, _) => init(self), - Variable::LocalScalar(_, _, _) => init(self), - Variable::ConstantScalar(_, _) => init(self), - Variable::Local(_, _, _) => init(self), + Variable::GlobalScalar { .. } => init(self), + Variable::LocalScalar { .. } => init(self), + Variable::ConstantScalar { .. } => init(self), + Variable::Local { .. } => init(self), // Constant should be initialized since the new variable can be mutated afterward. // And it is assumed those values are cloned. Variable::Rank @@ -230,11 +230,12 @@ impl Init for ExpandElement { | Variable::AbsolutePosY | Variable::AbsolutePosZ => init(self), // Array types can't be copied, so we should simply return the same variable. - Variable::SharedMemory(_, _, _) - | Variable::GlobalInputArray(_, _) - | Variable::GlobalOutputArray(_, _) - | Variable::LocalArray(_, _, _, _) - | Variable::Matrix(_, _) => self, + Variable::SharedMemory { .. } + | Variable::GlobalInputArray { .. } + | Variable::GlobalOutputArray { .. } + | Variable::LocalArray { .. } + | Variable::Slice { .. } + | Variable::Matrix { .. } => self, } } } diff --git a/crates/burn-cube/src/frontend/element/cube_elem.rs b/crates/burn-cube/src/frontend/element/cube_elem.rs index 4efe920ae..e949171f5 100644 --- a/crates/burn-cube/src/frontend/element/cube_elem.rs +++ b/crates/burn-cube/src/frontend/element/cube_elem.rs @@ -41,9 +41,9 @@ impl_into_expand_element!(i64); /// Useful for Comptime impl From for ExpandElement { fn from(value: UInt) -> Self { - ExpandElement::Plain(crate::ir::Variable::ConstantScalar( - value.val as f64, - UInt::as_elem(), - )) + ExpandElement::Plain(crate::ir::Variable::ConstantScalar { + value: value.val as f64, + elem: UInt::as_elem(), + }) } } diff --git a/crates/burn-cube/src/frontend/element/float.rs b/crates/burn-cube/src/frontend/element/float.rs index 584fc873e..c8dec62ab 100644 --- a/crates/burn-cube/src/frontend/element/float.rs +++ b/crates/burn-cube/src/frontend/element/float.rs @@ -93,7 +93,10 @@ macro_rules! impl_float { _context: &mut CubeContext, val: f32, ) -> ::ExpandType { - let new_var = Variable::ConstantScalar(val as f64, Self::as_elem()); + let new_var = Variable::ConstantScalar { + value: val as f64, + elem: Self::as_elem(), + }; ExpandElement::Plain(new_var) } diff --git a/crates/burn-cube/src/frontend/element/int.rs b/crates/burn-cube/src/frontend/element/int.rs index c0a4b9253..d5a92f73c 100644 --- a/crates/burn-cube/src/frontend/element/int.rs +++ b/crates/burn-cube/src/frontend/element/int.rs @@ -63,7 +63,10 @@ macro_rules! impl_int { _context: &mut CubeContext, val: i64, ) -> ::ExpandType { - let new_var = Variable::ConstantScalar(val as f64, Self::as_elem()); + let new_var = Variable::ConstantScalar { + value: val as f64, + elem: Self::as_elem(), + }; ExpandElement::Plain(new_var) } diff --git a/crates/burn-cube/src/frontend/element/mod.rs b/crates/burn-cube/src/frontend/element/mod.rs index 98e69ddd6..d67fb8647 100644 --- a/crates/burn-cube/src/frontend/element/mod.rs +++ b/crates/burn-cube/src/frontend/element/mod.rs @@ -7,6 +7,7 @@ mod float; mod int; mod numeric; mod shared_memory; +mod slice; mod tensor; mod uint; mod vectorized; @@ -20,6 +21,7 @@ pub use float::*; pub use int::*; pub use numeric::*; pub use shared_memory::*; +pub use slice::*; pub use tensor::*; pub use uint::*; pub use vectorized::*; diff --git a/crates/burn-cube/src/frontend/element/numeric.rs b/crates/burn-cube/src/frontend/element/numeric.rs index 12abf708c..3c92a2eeb 100644 --- a/crates/burn-cube/src/frontend/element/numeric.rs +++ b/crates/burn-cube/src/frontend/element/numeric.rs @@ -50,7 +50,10 @@ pub trait Numeric: } fn __expand_from_int(_context: &mut CubeContext, val: i64) -> ::ExpandType { - let new_var = Variable::ConstantScalar(val as f64, Self::as_elem()); + let new_var = Variable::ConstantScalar { + value: val as f64, + elem: Self::as_elem(), + }; ExpandElement::Plain(new_var) } diff --git a/crates/burn-cube/src/frontend/element/shared_memory.rs b/crates/burn-cube/src/frontend/element/shared_memory.rs index ff7597e02..3ad49c330 100644 --- a/crates/burn-cube/src/frontend/element/shared_memory.rs +++ b/crates/burn-cube/src/frontend/element/shared_memory.rs @@ -5,32 +5,21 @@ use crate::{ ir::Item, }; -use super::{ExpandElement, Init, UInt}; +use super::{ExpandElementTyped, Init, UInt}; #[derive(Clone, Copy)] pub struct SharedMemory { _val: PhantomData, } -#[derive(Clone)] -pub struct SharedMemoryExpand { - pub val: ::ExpandType, -} - -impl From> for ExpandElement { - fn from(shared_memory_expand: SharedMemoryExpand) -> Self { - shared_memory_expand.val - } -} - -impl Init for SharedMemoryExpand { +impl Init for ExpandElementTyped> { fn init(self, _context: &mut CubeContext) -> Self { self } } impl CubeType for SharedMemory { - type ExpandType = SharedMemoryExpand; + type ExpandType = ExpandElementTyped>; } impl SharedMemory { @@ -49,13 +38,14 @@ impl SharedMemory { ) -> ::ExpandType { let size = size.value(); let size = match size { - crate::ir::Variable::ConstantScalar(val, _) => val as u32, + crate::ir::Variable::ConstantScalar { value, .. } => value as u32, _ => panic!("Shared memory need constant initialization value"), }; - context.create_shared( + let var = context.create_shared( Item::vectorized(T::as_elem(), vectorization_factor.val as u8), size, - ) + ); + ExpandElementTyped::new(var) } pub fn __expand_new( @@ -64,9 +54,10 @@ impl SharedMemory { ) -> ::ExpandType { let size = size.value(); let size = match size { - crate::ir::Variable::ConstantScalar(val, _) => val as u32, + crate::ir::Variable::ConstantScalar { value, .. } => value as u32, _ => panic!("Shared memory need constant initialization value"), }; - context.create_shared(Item::new(T::as_elem()), size) + let var = context.create_shared(Item::new(T::as_elem()), size); + ExpandElementTyped::new(var) } } diff --git a/crates/burn-cube/src/frontend/element/slice.rs b/crates/burn-cube/src/frontend/element/slice.rs new file mode 100644 index 000000000..fea7e801d --- /dev/null +++ b/crates/burn-cube/src/frontend/element/slice.rs @@ -0,0 +1,285 @@ +use std::marker::PhantomData; + +use super::{ + Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory, Tensor, + UInt, +}; +use crate::{ + frontend::indexation::Index, + ir::{self, Operator}, + prelude::CubeContext, + unexpanded, +}; + +/// A read-only contiguous list of elements +pub struct Slice<'a, E> { + _e: PhantomData, + _l: &'a (), +} + +/// A read-write contiguous list of elements. +pub struct SliceMut<'a, E> { + _e: PhantomData, + _l: &'a mut (), +} + +impl<'a, E> Slice<'a, E> { + /// Get the length of the slice. + pub fn len(&self) -> UInt { + unexpanded!() + } +} + +impl<'a, E> SliceMut<'a, E> { + /// Get the length of the slice. + pub fn len(&self) -> UInt { + unexpanded!() + } +} + +impl<'a, E: CubeType> CubeType for Slice<'a, E> { + type ExpandType = ExpandElementTyped>; +} + +impl<'a, C: CubeType> Init for ExpandElementTyped> { + fn init(self, _context: &mut crate::prelude::CubeContext) -> Self { + // The type can't be deeply cloned/copied. + self + } +} + +impl<'a, E: CubeType> CubeType for SliceMut<'a, E> { + type ExpandType = ExpandElementTyped>; +} + +impl<'a, C: CubeType> Init for ExpandElementTyped> { + fn init(self, _context: &mut crate::prelude::CubeContext) -> Self { + // The type can't be deeply cloned/copied. + self + } +} + +pub trait SliceOperator: CubeType { + type Expand: SliceOperatorExpand; + + /// Return a read-only view of all elements comprise between the start and end index. + #[allow(unused_variables)] + fn slice(&self, start: Start, end: End) -> &'_ Slice<'_, E> { + unexpanded!() + } + /// Expand function of [SliceOperator::slice]. + fn slice_expand( + context: &mut CubeContext, + expand: Self::Expand, + start: Start, + end: End, + ) -> ExpandElementTyped> { + expand.slice_expand(context, start, end) + } + + /// Return a read-write view of all elements comprise between the start and end index. + #[allow(unused_variables)] + fn slice_mut( + &mut self, + start: Start, + end: End, + ) -> &'_ mut SliceMut<'_, E> { + unexpanded!() + } + + /// Expand function of [SliceOperator::slice_mut]. + fn slice_mut_expand( + context: &mut CubeContext, + expand: Self::Expand, + start: Start, + end: End, + ) -> ExpandElementTyped> { + expand.slice_mut_expand(context, start, end) + } + + /// Return a read-write view of all elements comprise between the start and end index. + /// + /// # Warning + /// + /// Ignore the multiple borrow rule. + #[allow(unused_variables)] + fn slice_mut_unsafe( + &self, + start: Start, + end: End, + ) -> SliceMut<'static, E> { + unexpanded!() + } + + /// Expand function of [SliceOperator::slice_mut_unsafe]. + fn slice_mut_unsafe_expand( + context: &mut CubeContext, + expand: Self::Expand, + start: Start, + end: End, + ) -> ExpandElementTyped> { + expand.slice_mut_unsafe_expand(context, start, end) + } + + /// Reinterprete the current type as a read-only slice. + #[allow(unused_variables)] + fn as_slice(&self) -> &'_ Slice<'_, E> { + unexpanded!() + } + + /// Expand function of [SliceOperator::as_slice]. + fn as_slice_expand( + context: &mut CubeContext, + expand: Self::Expand, + ) -> ExpandElementTyped> { + expand.as_slice_expand(context) + } + + /// Reinterprete the current type as a read-write slice. + #[allow(unused_variables)] + fn as_slice_mut(&mut self) -> &'_ mut SliceMut<'_, E> { + unexpanded!() + } + + /// Expand function of [SliceOperator::as_slice_mut]. + fn as_slice_mut_expand( + context: &mut CubeContext, + expand: Self::Expand, + ) -> ExpandElementTyped> { + expand.as_slice_mut_expand(context) + } + + /// Reinterprete the current type as a read-write slice. + /// + /// # Warning + /// + /// Ignore the multiple borrow rule. + #[allow(unused_variables)] + fn as_slice_mut_unsafe(&self) -> SliceMut<'static, E> { + unexpanded!() + } + + /// Expand function of [SliceOperator::as_slice_mut_unsafe]. + fn as_slice_mut_unsafe_expand( + context: &mut CubeContext, + expand: Self::Expand, + ) -> ExpandElementTyped> { + expand.as_slice_mut_unsafe_expand(context) + } +} + +pub trait SliceOperatorExpand: Into + Clone { + fn slice_base( + &self, + context: &mut CubeContext, + start: Start, + end: End, + ) -> ExpandElement; + + fn slice_expand( + &self, + context: &mut CubeContext, + start: Start, + end: End, + ) -> ExpandElementTyped> { + ExpandElementTyped::new(self.slice_base(context, start, end)) + } + + fn slice_mut_expand( + &self, + context: &mut CubeContext, + start: Start, + end: End, + ) -> ExpandElementTyped> { + ExpandElementTyped::new(self.slice_base(context, start, end)) + } + + fn slice_mut_unsafe_expand( + &self, + context: &mut CubeContext, + start: Start, + end: End, + ) -> ExpandElementTyped> { + ExpandElementTyped::new(self.slice_base(context, start, end)) + } + + fn as_slice_expand(&self, _context: &mut CubeContext) -> ExpandElementTyped> { + let expand = self.clone().into(); + ExpandElementTyped::new(expand) + } + + fn as_slice_mut_unsafe_expand( + &self, + context: &mut CubeContext, + ) -> ExpandElementTyped> { + self.as_slice_mut_expand(context) + } + + fn as_slice_mut_expand( + &self, + _context: &mut CubeContext, + ) -> ExpandElementTyped> { + let expand = self.clone().into(); + ExpandElementTyped::new(expand) + } +} + +macro_rules! slice_op { + ($type:ident) => { + impl SliceOperator for $type { + type Expand = ExpandElementTyped<$type>; + } + + impl SliceOperatorExpand for ExpandElementTyped<$type> { + fn slice_base( + &self, + context: &mut CubeContext, + start: Start, + end: End, + ) -> ExpandElement { + slice_expand(context, self.clone(), start, end) + } + } + }; + (slice $type:ident) => { + impl<'a, E: CubePrimitive> SliceOperator for $type<'a, E> { + type Expand = ExpandElementTyped<$type<'static, E>>; + } + + impl<'a, E: CubePrimitive> SliceOperatorExpand for ExpandElementTyped<$type<'a, E>> { + fn slice_base( + &self, + context: &mut CubeContext, + start: Start, + end: End, + ) -> ExpandElement { + slice_expand(context, self.clone(), start, end) + } + } + }; +} + +slice_op!(Array); +slice_op!(Tensor); +slice_op!(SharedMemory); +slice_op!(slice Slice); +slice_op!(slice SliceMut); + +pub fn slice_expand, S1: Index, S2: Index>( + context: &mut CubeContext, + input: I, + start: S1, + end: S2, // Todo use it to get the length. +) -> ExpandElement { + let input = input.into(); + let out = context.create_slice(input.item()); + + context.register(Operator::Slice(ir::SliceOperator { + input: *input, + start: start.value(), + end: end.value(), + out: *out, + })); + + out +} diff --git a/crates/burn-cube/src/frontend/element/tensor.rs b/crates/burn-cube/src/frontend/element/tensor.rs index 958b8fc41..7a61422c4 100644 --- a/crates/burn-cube/src/frontend/element/tensor.rs +++ b/crates/burn-cube/src/frontend/element/tensor.rs @@ -202,7 +202,7 @@ impl ExpandElementTyped { // Expanded version of len pub fn len_expand(self, context: &mut CubeContext) -> ExpandElement { let out = context.create_local(Item::new(Elem::UInt)); - context.register(Metadata::ArrayLength { + context.register(Metadata::Length { var: self.expand.into(), out: out.clone().into(), }); diff --git a/crates/burn-cube/src/frontend/element/uint.rs b/crates/burn-cube/src/frontend/element/uint.rs index 6a4fa61fe..4df2516b6 100644 --- a/crates/burn-cube/src/frontend/element/uint.rs +++ b/crates/burn-cube/src/frontend/element/uint.rs @@ -49,7 +49,10 @@ impl UInt { } pub fn __expand_new(_context: &mut CubeContext, val: u32) -> ::ExpandType { - let new_var = Variable::ConstantScalar(val as f64, Self::as_elem()); + let new_var = Variable::ConstantScalar { + value: val as f64, + elem: Self::as_elem(), + }; ExpandElement::Plain(new_var) } diff --git a/crates/burn-cube/src/frontend/indexation.rs b/crates/burn-cube/src/frontend/indexation.rs index 4b9a2f1eb..9960d1f4b 100644 --- a/crates/burn-cube/src/frontend/indexation.rs +++ b/crates/burn-cube/src/frontend/indexation.rs @@ -7,31 +7,46 @@ pub trait Index { impl Index for Comptime { fn value(self) -> Variable { - Variable::ConstantScalar(self.inner as f64, Elem::UInt) + Variable::ConstantScalar { + value: self.inner as f64, + elem: Elem::UInt, + } } } impl Index for Comptime { fn value(self) -> Variable { - Variable::ConstantScalar(self.inner as f64, Elem::UInt) + Variable::ConstantScalar { + value: self.inner as f64, + elem: Elem::UInt, + } } } impl Index for i32 { fn value(self) -> Variable { - Variable::ConstantScalar(self as f64, Elem::UInt) + Variable::ConstantScalar { + value: self as f64, + elem: Elem::UInt, + } } } impl Index for u32 { fn value(self) -> Variable { - Variable::ConstantScalar(self as f64, Elem::UInt) + Variable::ConstantScalar { + value: self as f64, + elem: Elem::UInt, + } } } impl Index for UInt { fn value(self) -> Variable { - Variable::ConstantScalar(self.val as f64, Elem::UInt) + Variable::ConstantScalar { + value: self.val as f64, + elem: Elem::UInt, + } } } diff --git a/crates/burn-cube/src/frontend/operation/assignation.rs b/crates/burn-cube/src/frontend/operation/assignation.rs index 420d485fe..0136b8025 100644 --- a/crates/burn-cube/src/frontend/operation/assignation.rs +++ b/crates/burn-cube/src/frontend/operation/assignation.rs @@ -19,7 +19,7 @@ pub mod assign { } pub mod index_assign { - use crate::{frontend::CubeType, unexpanded}; + use crate::{frontend::CubeType, prelude::SliceMut, unexpanded}; use self::ir::{BinaryOperator, Operator, Variable}; @@ -34,7 +34,10 @@ pub mod index_assign { let array = array.into(); let index: Variable = *index.into(); let index = match index { - Variable::ConstantScalar(val, _) => Variable::ConstantScalar(val, ir::Elem::UInt), + Variable::ConstantScalar { value, .. } => Variable::ConstantScalar { + value, + elem: ir::Elem::UInt, + }, _ => index, }; context.register(Operator::IndexAssign(BinaryOperator { @@ -58,6 +61,12 @@ pub mod index_assign { impl_index!(Array); impl_index!(Tensor); impl_index!(SharedMemory); + + impl<'a, E: CubeType, I: Into> core::ops::IndexMut for SliceMut<'a, E> { + fn index_mut(&mut self, _index: I) -> &mut Self::Output { + unexpanded!() + } + } } pub mod index { @@ -66,6 +75,7 @@ pub mod index { operation::base::{binary_expand, binary_expand_no_vec}, CubeType, }, + prelude::{Slice, SliceMut}, unexpanded, }; @@ -78,20 +88,21 @@ pub mod index { array: L, index: R, ) -> ExpandElement { - let index = index.into(); + let index: ExpandElement = index.into(); let index_var: Variable = *index; let index = match index_var { - Variable::ConstantScalar(val, _) => { - ExpandElement::Plain(Variable::ConstantScalar(val, ir::Elem::UInt)) + Variable::ConstantScalar { value, .. } => { + ExpandElement::Plain(Variable::ConstantScalar { + value, + elem: ir::Elem::UInt, + }) } _ => index, }; - let array = array.into(); + let array: ExpandElement = array.into(); let var: Variable = *array; match var { - Variable::Local(_, _, _) => { - binary_expand_no_vec(context, array, index, Operator::Index) - } + Variable::Local { .. } => binary_expand_no_vec(context, array, index, Operator::Index), _ => binary_expand(context, array, index, Operator::Index), } } @@ -111,6 +122,20 @@ pub mod index { impl_index!(Array); impl_index!(Tensor); impl_index!(SharedMemory); + + impl<'a, E: CubeType, I: Into> core::ops::Index for SliceMut<'a, E> { + type Output = E; + fn index(&self, _index: I) -> &Self::Output { + unexpanded!() + } + } + + impl<'a, E: CubeType, I: Into> core::ops::Index for Slice<'a, E> { + type Output = E; + fn index(&self, _index: I) -> &Self::Output { + unexpanded!() + } + } } pub mod add_assign_array_op { diff --git a/crates/burn-cube/src/ir/macros.rs b/crates/burn-cube/src/ir/macros.rs index e7ccd628f..d50035b51 100644 --- a/crates/burn-cube/src/ir/macros.rs +++ b/crates/burn-cube/src/ir/macros.rs @@ -352,7 +352,7 @@ macro_rules! cpa { }; // out = len(array) ($scope:expr, $out:ident = len($input:expr)) => { - $scope.register($crate::ir::Metadata::ArrayLength { + $scope.register($crate::ir::Metadata::Length { var: $input.into(), out: $out.into(), }); @@ -398,43 +398,64 @@ macro_rules! cpa { impl From for Variable { fn from(value: bool) -> Self { - Self::ConstantScalar(if value { 1.0 } else { 0.0 }, super::Elem::Bool) + Self::ConstantScalar { + value: if value { 1.0 } else { 0.0 }, + elem: super::Elem::Bool, + } } } impl From for Variable { fn from(value: i32) -> Self { - Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I32)) + Self::ConstantScalar { + value: value as f64, + elem: super::Elem::Int(super::IntKind::I32), + } } } impl From for Variable { fn from(value: i64) -> Self { - Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I64)) + Self::ConstantScalar { + value: value as f64, + elem: super::Elem::Int(super::IntKind::I64), + } } } impl From for Variable { fn from(value: f32) -> Self { - Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F32)) + Self::ConstantScalar { + value: value as f64, + elem: super::Elem::Float(super::FloatKind::F32), + } } } impl From for Variable { fn from(value: f64) -> Self { - Self::ConstantScalar(value, super::Elem::Float(super::FloatKind::F64)) + Self::ConstantScalar { + value, + elem: super::Elem::Float(super::FloatKind::F64), + } } } impl From for Variable { fn from(value: u32) -> Self { - Self::ConstantScalar(value as f64, super::Elem::UInt) + Self::ConstantScalar { + value: value as f64, + elem: super::Elem::UInt, + } } } impl From for Variable { fn from(value: usize) -> Self { - Self::ConstantScalar(value as f64, super::Elem::UInt) + Self::ConstantScalar { + value: value as f64, + elem: super::Elem::UInt, + } } } diff --git a/crates/burn-cube/src/ir/operation.rs b/crates/burn-cube/src/ir/operation.rs index 1075e54bb..f5ed445bf 100644 --- a/crates/burn-cube/src/ir/operation.rs +++ b/crates/burn-cube/src/ir/operation.rs @@ -53,6 +53,7 @@ pub enum Operator { Assign(UnaryOperator), Modulo(BinaryOperator), Index(BinaryOperator), + Slice(SliceOperator), UncheckedIndex(BinaryOperator), IndexAssign(BinaryOperator), UncheckedIndexAssign(BinaryOperator), @@ -84,7 +85,7 @@ pub enum Metadata { var: Variable, out: Variable, }, - ArrayLength { + Length { var: Variable, out: Variable, }, @@ -120,6 +121,15 @@ pub struct ClampOperator { pub out: Variable, } +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[allow(missing_docs)] +pub struct SliceOperator { + pub input: Variable, + pub start: Variable, + pub end: Variable, + pub out: Variable, +} + #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] #[allow(missing_docs)] pub struct ReadGlobalOperator { diff --git a/crates/burn-cube/src/ir/scope.rs b/crates/burn-cube/src/ir/scope.rs index 2c119faa7..2e6675981 100644 --- a/crates/burn-cube/src/ir/scope.rs +++ b/crates/burn-cube/src/ir/scope.rs @@ -19,6 +19,7 @@ pub struct Scope { pub operations: Vec, locals: Vec, matrices: Vec, + slices: Vec, shared_memories: Vec, local_arrays: Vec, reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>, @@ -49,6 +50,7 @@ impl Scope { operations: Vec::new(), locals: Vec::new(), matrices: Vec::new(), + slices: Vec::new(), local_arrays: Vec::new(), shared_memories: Vec::new(), reads_global: Vec::new(), @@ -75,7 +77,10 @@ impl Scope { I: Into + Copy, { let local = self.create_local(item); - let value = Variable::ConstantScalar(value.into(), item.into().elem()); + let value = Variable::ConstantScalar { + value: value.into(), + elem: item.into().elem(), + }; cpa!(self, local = value); local } @@ -83,16 +88,35 @@ impl Scope { /// Create a matrix variable pub fn create_matrix(&mut self, matrix: Matrix) -> Variable { let index = self.matrices.len() as u16; - let variable = Variable::Matrix(index, matrix); + let variable = Variable::Matrix { + id: index, + mat: matrix, + }; self.matrices.push(variable); variable } + /// Create a slice variable + pub fn create_slice(&mut self, item: Item) -> Variable { + let id = self.slices.len() as u16; + let variable = Variable::Slice { + id, + item, + depth: self.depth, + }; + self.slices.push(variable); + variable + } + /// Create a local variable of the given [item type](Item). pub fn create_local>(&mut self, item: I) -> Variable { let item = item.into(); let index = self.new_local_index(); - let local = Variable::Local(index, item, self.depth); + let local = Variable::Local { + id: index, + item, + depth: self.depth, + }; self.locals.push(local); local } @@ -102,7 +126,11 @@ impl Scope { /// Useful for _for loops_ and other algorithms that require the control over initialization. pub fn create_local_undeclared(&mut self, item: Item) -> Variable { let index = self.new_local_index(); - let local = Variable::Local(index, item, self.depth); + let local = Variable::Local { + id: index, + item, + depth: self.depth, + }; self.undeclared += 1; local } @@ -131,8 +159,12 @@ impl Scope { /// /// The index refers to the scalar position for the same [element](Elem) type. pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable { - let local = Variable::LocalScalar(self.new_local_scalar_index(), elem, self.depth); - let scalar = Variable::GlobalScalar(index, elem); + let local = Variable::LocalScalar { + id: self.new_local_scalar_index(), + elem, + depth: self.depth, + }; + let scalar = Variable::GlobalScalar { id: index, elem }; self.reads_scalar.push((local, scalar)); @@ -215,7 +247,7 @@ impl Scope { self.reads_global .iter() .map(|(var, strategy, _, _)| match var { - Variable::GlobalInputArray(id, _) => (*id, *strategy), + Variable::GlobalInputArray { id, .. } => (*id, *strategy), _ => panic!("Can only read global input arrays."), }) .collect() @@ -233,6 +265,7 @@ impl Scope { operations: Vec::new(), locals: Vec::new(), matrices: Vec::new(), + slices: Vec::new(), shared_memories: Vec::new(), local_arrays: Vec::new(), reads_global: Vec::new(), @@ -259,6 +292,9 @@ impl Scope { for var in self.matrices.drain(..) { variables.push(var); } + for var in self.slices.drain(..) { + variables.push(var); + } for index in self.index_offset_with_output_layout_position.drain(..) { if let Some(Operation::Procedure(Procedure::IndexOffsetGlobalWithLayout(proc))) = @@ -357,9 +393,16 @@ impl Scope { }, _ => item, }; - let input = Variable::GlobalInputArray(index, item_global); + let input = Variable::GlobalInputArray { + id: index, + item: item_global, + }; let index = self.new_local_index(); - let local = Variable::Local(index, item, self.depth); + let local = Variable::Local { + id: index, + item, + depth: self.depth, + }; self.reads_global.push((input, strategy, local, position)); self.locals.push(local); local @@ -369,7 +412,11 @@ impl Scope { pub fn create_shared>(&mut self, item: I, shared_memory_size: u32) -> Variable { let item = item.into(); let index = self.new_shared_index(); - let shared_memory = Variable::SharedMemory(index, item, shared_memory_size); + let shared_memory = Variable::SharedMemory { + id: index, + item, + length: shared_memory_size, + }; self.shared_memories.push(shared_memory); shared_memory } @@ -378,7 +425,12 @@ impl Scope { pub fn create_local_array>(&mut self, item: I, array_size: u32) -> Variable { let item = item.into(); let index = self.new_local_array_index(); - let local_array = Variable::LocalArray(index, item, self.depth, array_size); + let local_array = Variable::LocalArray { + id: index, + item, + depth: self.depth, + length: array_size, + }; self.local_arrays.push(local_array); local_array } diff --git a/crates/burn-cube/src/ir/variable.rs b/crates/burn-cube/src/ir/variable.rs index aa41bdda2..ea6335b7a 100644 --- a/crates/burn-cube/src/ir/variable.rs +++ b/crates/burn-cube/src/ir/variable.rs @@ -5,14 +5,52 @@ use serde::{Deserialize, Serialize}; #[allow(missing_docs)] pub enum Variable { Rank, - GlobalInputArray(u16, Item), - GlobalScalar(u16, Elem), - GlobalOutputArray(u16, Item), - Local(u16, Item, u8), - LocalScalar(u16, Elem, u8), - ConstantScalar(f64, Elem), - SharedMemory(u16, Item, u32), - LocalArray(u16, Item, u8, u32), + GlobalInputArray { + id: u16, + item: Item, + }, + GlobalScalar { + id: u16, + elem: Elem, + }, + GlobalOutputArray { + id: u16, + item: Item, + }, + Local { + id: u16, + item: Item, + depth: u8, + }, + LocalScalar { + id: u16, + elem: Elem, + depth: u8, + }, + ConstantScalar { + value: f64, + elem: Elem, + }, + SharedMemory { + id: u16, + item: Item, + length: u32, + }, + LocalArray { + id: u16, + item: Item, + depth: u8, + length: u32, + }, + Matrix { + id: u16, + mat: Matrix, + }, + Slice { + id: u16, + item: Item, + depth: u8, + }, UnitPos, UnitPosX, UnitPosY, @@ -34,20 +72,21 @@ pub enum Variable { AbsolutePosX, AbsolutePosY, AbsolutePosZ, - Matrix(u16, Matrix), } impl Variable { pub fn index(&self) -> Option { match self { - Variable::GlobalInputArray(idx, _) => Some(*idx), - Variable::GlobalScalar(idx, _) => Some(*idx), - Variable::Local(idx, _, _) => Some(*idx), - Variable::LocalScalar(idx, _, _) => Some(*idx), - Variable::GlobalOutputArray(idx, _) => Some(*idx), - Variable::ConstantScalar(_, _) => None, - Variable::SharedMemory(idx, _, _) => Some(*idx), - Variable::LocalArray(idx, _, _, _) => Some(*idx), + Variable::GlobalInputArray { id, .. } => Some(*id), + Variable::GlobalScalar { id, .. } => Some(*id), + Variable::Local { id, .. } => Some(*id), + Variable::Slice { id, .. } => Some(*id), + Variable::LocalScalar { id, .. } => Some(*id), + Variable::GlobalOutputArray { id, .. } => Some(*id), + Variable::ConstantScalar { .. } => None, + Variable::SharedMemory { id, .. } => Some(*id), + Variable::LocalArray { id, .. } => Some(*id), + Variable::Matrix { id, .. } => Some(*id), Variable::AbsolutePos => None, Variable::UnitPos => None, Variable::UnitPosX => None, @@ -70,21 +109,22 @@ impl Variable { Variable::CubeCount => None, Variable::CubeDim => None, Variable::SubcubeDim => None, - Variable::Matrix(idx, _) => Some(*idx), } } /// Fetch the item of the variable. pub fn item(&self) -> Item { match self { - Variable::GlobalInputArray(_, item) => *item, - Variable::GlobalOutputArray(_, item) => *item, - Variable::GlobalScalar(_, elem) => Item::new(*elem), - Variable::Local(_, item, _) => *item, - Variable::LocalScalar(_, elem, _) => Item::new(*elem), - Variable::ConstantScalar(_, elem) => Item::new(*elem), - Variable::SharedMemory(_, item, _) => *item, - Variable::LocalArray(_, item, _, _) => *item, + Variable::GlobalInputArray { item, .. } => *item, + Variable::GlobalOutputArray { item, .. } => *item, + Variable::GlobalScalar { elem, .. } => Item::new(*elem), + Variable::Local { item, .. } => *item, + Variable::LocalScalar { elem, .. } => Item::new(*elem), + Variable::ConstantScalar { elem, .. } => Item::new(*elem), + Variable::SharedMemory { item, .. } => *item, + Variable::LocalArray { item, .. } => *item, + Variable::Slice { item, .. } => *item, + Variable::Matrix { mat, .. } => Item::new(mat.elem), Variable::AbsolutePos => Item::new(Elem::UInt), Variable::Rank => Item::new(Elem::UInt), Variable::UnitPos => Item::new(Elem::UInt), @@ -107,7 +147,6 @@ impl Variable { Variable::CubeCount => Item::new(Elem::UInt), Variable::CubeDim => Item::new(Elem::UInt), Variable::SubcubeDim => Item::new(Elem::UInt), - Variable::Matrix(_, matrix) => Item::new(matrix.elem), } } } diff --git a/crates/burn-cube/src/ir/vectorization.rs b/crates/burn-cube/src/ir/vectorization.rs index b0dd59a87..15c81b3f0 100644 --- a/crates/burn-cube/src/ir/vectorization.rs +++ b/crates/burn-cube/src/ir/vectorization.rs @@ -1,6 +1,6 @@ use super::{ - BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator, Subcube, - UnaryOperator, Variable, + BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator, + SliceOperator, Subcube, UnaryOperator, Variable, }; pub type Vectorization = u8; @@ -60,7 +60,7 @@ impl Operator { Operator::LowerEqual(op) => Operator::LowerEqual(op.vectorize(vectorization)), Operator::GreaterEqual(op) => Operator::GreaterEqual(op.vectorize(vectorization)), Operator::Assign(op) => { - if let Variable::GlobalScalar(_, _) = op.input { + if let Variable::GlobalScalar { .. } = op.input { // Assign will not change the type of the output if the input can't be // vectorized. return Operator::Assign(op.clone()); @@ -81,6 +81,7 @@ impl Operator { Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)), Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)), Operator::Remainder(op) => Operator::Remainder(op.vectorize(vectorization)), + Operator::Slice(op) => Operator::Slice(op.vectorize(vectorization)), } } } @@ -104,6 +105,22 @@ impl UnaryOperator { } } +impl SliceOperator { + pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { + let input = self.input.vectorize(vectorization); + let start = self.start.vectorize(vectorization); + let end = self.end.vectorize(vectorization); + let out = self.out.vectorize(vectorization); + + Self { + input, + start, + end, + out, + } + } +} + impl InitOperator { pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self { let out = self.out.vectorize(vectorization); @@ -155,31 +172,46 @@ impl FmaOperator { impl Variable { pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self { match self { - Variable::GlobalInputArray(index, item) => { - Variable::GlobalInputArray(*index, item.vectorize(vectorize)) - } - Variable::Local(index, item, name) => { - Variable::Local(*index, item.vectorize(vectorize), *name) - } - Variable::GlobalOutputArray(index, item) => { - Variable::GlobalOutputArray(*index, item.vectorize(vectorize)) - } - Variable::SharedMemory(index, item, size) => Variable::SharedMemory( - *index, - item.vectorize(vectorize), - item.vectorized_size(vectorize, *size), - ), - Variable::LocalArray(index, item, name, size) => Variable::LocalArray( - *index, - item.vectorize(vectorize), - *name, - item.vectorized_size(vectorize, *size), - ), - Variable::ConstantScalar(_, _) => *self, - Variable::GlobalScalar(_, _) => *self, + Variable::GlobalInputArray { id, item } => Variable::GlobalInputArray { + id: *id, + item: item.vectorize(vectorize), + }, + Variable::Local { id, item, depth } => Variable::Local { + id: *id, + item: item.vectorize(vectorize), + depth: *depth, + }, + Variable::Slice { id, item, depth } => Variable::Slice { + id: *id, + item: item.vectorize(vectorize), + depth: *depth, + }, + Variable::GlobalOutputArray { id, item } => Variable::GlobalOutputArray { + id: *id, + item: item.vectorize(vectorize), + }, + Variable::SharedMemory { id, item, length } => Variable::SharedMemory { + id: *id, + item: item.vectorize(vectorize), + length: item.vectorized_size(vectorize, *length), + }, + Variable::LocalArray { + id, + item, + depth, + length, + } => Variable::LocalArray { + id: *id, + item: item.vectorize(vectorize), + depth: *depth, + length: item.vectorized_size(vectorize, *length), + }, + Variable::ConstantScalar { .. } => *self, + Variable::GlobalScalar { .. } => *self, Variable::AbsolutePos => *self, Variable::Rank => *self, - Variable::LocalScalar(_, _, _) => *self, + Variable::LocalScalar { .. } => *self, + Variable::Matrix { .. } => *self, Variable::UnitPos => *self, Variable::UnitPosX => *self, Variable::UnitPosY => *self, @@ -200,7 +232,6 @@ impl Variable { Variable::CubeCount => *self, Variable::CubeDim => *self, Variable::SubcubeDim => *self, - Variable::Matrix(_, _) => *self, } } } diff --git a/crates/burn-cube/src/prelude.rs b/crates/burn-cube/src/prelude.rs index b724e8dfc..027a7f866 100644 --- a/crates/burn-cube/src/prelude.rs +++ b/crates/burn-cube/src/prelude.rs @@ -11,7 +11,8 @@ pub use crate::runtime::Runtime; /// Elements pub use crate::frontend::{ - Array, ArrayHandle, Bool, Float, LaunchArg, Tensor, TensorArg, UInt, F16, F32, F64, I32, I64, + Array, ArrayHandle, Bool, Float, LaunchArg, Slice, SliceMut, Tensor, TensorArg, UInt, F16, F32, + F64, I32, I64, }; pub use crate::pod::CubeElement; diff --git a/crates/burn-cube/src/runtime_tests/cmma.rs b/crates/burn-cube/src/runtime_tests/cmma.rs index 1ef211818..6c0668cbb 100644 --- a/crates/burn-cube/src/runtime_tests/cmma.rs +++ b/crates/burn-cube/src/runtime_tests/cmma.rs @@ -31,12 +31,17 @@ pub fn kernel_simple_1(lhs: &Array, rhs: &Array, out: &mut Array) cmma::MatrixLayout::Undefined, ); cmma::fill::(&c, F32::new(0.0)); - cmma::load::(&a, lhs, UInt::new(16)); - cmma::load::(&b, rhs, UInt::new(16)); + cmma::load::(&a, lhs.as_slice(), UInt::new(16)); + cmma::load::(&b, rhs.as_slice(), UInt::new(16)); cmma::execute::(&a, &b, &c, &c); - cmma::store::(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor); + cmma::store::( + out.as_slice_mut(), + &c, + UInt::new(16), + cmma::MatrixLayout::RowMajor, + ); } pub fn test_simple_1(client: ComputeClient) { diff --git a/crates/burn-cube/src/runtime_tests/mod.rs b/crates/burn-cube/src/runtime_tests/mod.rs index 0654bccfa..1eda96028 100644 --- a/crates/burn-cube/src/runtime_tests/mod.rs +++ b/crates/burn-cube/src/runtime_tests/mod.rs @@ -1,5 +1,6 @@ pub mod cmma; pub mod launch; +pub mod slice; pub mod subcube; #[allow(missing_docs)] @@ -11,5 +12,6 @@ macro_rules! testgen_all { burn_cube::testgen_subcube!(); burn_cube::testgen_launch!(); burn_cube::testgen_cmma!(); + burn_cube::testgen_slice!(); }; } diff --git a/crates/burn-cube/src/runtime_tests/slice.rs b/crates/burn-cube/src/runtime_tests/slice.rs new file mode 100644 index 000000000..a8f4ab96c --- /dev/null +++ b/crates/burn-cube/src/runtime_tests/slice.rs @@ -0,0 +1,107 @@ +use crate as burn_cube; +use burn_cube::prelude::*; + +#[cube(launch)] +pub fn slice_select(input: &Array, output: &mut Array) { + if UNIT_POS == UInt::new(0) { + let slice = input.slice(2, 3); + output[0] = slice[0u32]; + } +} + +#[cube(launch)] +pub fn slice_assign(input: &Array, output: &mut Array) { + if UNIT_POS == UInt::new(0) { + let slice_1 = output.slice_mut(2, 3); + slice_1[0] = input[0u32]; + } +} + +#[cube(launch)] +pub fn slice_len(input: &Array, output: &mut Array) { + if UNIT_POS == UInt::new(0) { + let slice = input.slice(2, 4); + let _tmp = slice[0]; // It must be used at least once, otherwise wgpu isn't happy. + output[0] = slice.len(); + } +} + +pub fn test_slice_select(client: ComputeClient) { + let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); + let output = client.empty(core::mem::size_of::()); + + slice_select_launch::( + client.clone(), + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::new(&input, 5), + ArrayArg::new(&output, 1), + ); + + let actual = client.read(output.binding()); + let actual = f32::from_bytes(&actual); + + assert_eq!(actual[0], 2.0); +} + +pub fn test_slice_len(client: ComputeClient) { + let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); + let output = client.empty(core::mem::size_of::()); + + slice_len_launch::( + client.clone(), + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::new(&input, 5), + ArrayArg::new(&output, 1), + ); + + let actual = client.read(output.binding()); + let actual = u32::from_bytes(&actual); + + assert_eq!(actual, &[2]); +} + +pub fn test_slice_assign(client: ComputeClient) { + let input = client.create(f32::as_bytes(&[15.0])); + let output = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0])); + + slice_assign_launch::( + client.clone(), + CubeCount::Static(1, 1, 1), + CubeDim::new(1, 1, 1), + ArrayArg::new(&input, 5), + ArrayArg::new(&output, 1), + ); + + let actual = client.read(output.binding()); + let actual = f32::from_bytes(&actual); + + assert_eq!(actual, &[0.0, 1.0, 15.0, 3.0, 4.0]); +} + +#[allow(missing_docs)] +#[macro_export] +macro_rules! testgen_slice { + () => { + use super::*; + + #[test] + fn test_slice_select() { + let client = TestRuntime::client(&Default::default()); + burn_cube::runtime_tests::slice::test_slice_select::(client); + } + + #[test] + fn test_slice_assign() { + let client = TestRuntime::client(&Default::default()); + burn_cube::runtime_tests::slice::test_slice_assign::(client); + } + + #[test] + fn test_slice_len() { + let client = TestRuntime::client(&Default::default()); + burn_cube::runtime_tests::slice::test_slice_len::(client); + } + }; +} diff --git a/crates/burn-cube/tests/frontend/array.rs b/crates/burn-cube/tests/frontend/array.rs index b0f2638a2..502ee2b24 100644 --- a/crates/burn-cube/tests/frontend/array.rs +++ b/crates/burn-cube/tests/frontend/array.rs @@ -36,7 +36,7 @@ mod tests { use super::*; use burn_cube::{ cpa, - ir::{Elem, Item, Variable}, + ir::{self, Elem, Item, Variable}, }; type ElemType = F32; @@ -47,7 +47,7 @@ mod tests { array_read_write::__expand::(&mut context, 512); assert_eq!( - format!("{:?}", context.into_scope().operations), + context.into_scope().operations, inline_macro_ref_read_write() ) } @@ -60,10 +60,7 @@ mod tests { array_add_assign_simple::__expand(&mut context, array.into()); let scope = context.into_scope(); - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_array_add_assign_simple() - ); + assert_eq!(scope.operations, inline_macro_array_add_assign_simple()); } #[test] @@ -72,7 +69,7 @@ mod tests { array_to_vectorized_variable::__expand::(&mut context); assert_eq!( - format!("{:?}", context.into_scope().operations), + context.into_scope().operations, inline_macro_ref_to_vectorized() ); } @@ -83,12 +80,12 @@ mod tests { array_of_one_to_vectorized_variable::__expand::(&mut context); assert_eq!( - format!("{:?}", context.into_scope().operations), + context.into_scope().operations, inline_macro_ref_one_to_vectorized() ); } - fn inline_macro_ref_read_write() -> String { + fn inline_macro_ref_read_write() -> Vec { let context = CubeContext::root(); let item = Item::new(ElemType::as_elem()); @@ -105,7 +102,7 @@ mod tests { // Read cpa!(scope, var = array[pos]); - format!("{:?}", scope.operations) + scope.operations } #[test] @@ -116,30 +113,36 @@ mod tests { array_add_assign_expr::__expand(&mut context, array.into()); let scope = context.into_scope(); - assert_eq!( - format!("{:?}", scope.operations), - inline_macro_array_add_assign_expr() - ); + assert_eq!(scope.operations, inline_macro_array_add_assign_expr()); } - fn inline_macro_array_add_assign_simple() -> String { + fn inline_macro_array_add_assign_simple() -> Vec { let context = CubeContext::root(); let mut scope = context.into_scope(); let local = scope.create_local(Item::new(Elem::UInt)); - let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt)); - let index = Variable::ConstantScalar(1., Elem::UInt); - let value = Variable::ConstantScalar(1., Elem::UInt); + let array = Variable::GlobalInputArray { + id: 0, + item: Item::new(Elem::UInt), + }; + let index = Variable::ConstantScalar { + value: 1., + elem: Elem::UInt, + }; + let value = Variable::ConstantScalar { + value: 1., + elem: Elem::UInt, + }; cpa!(scope, local = array[index]); cpa!(scope, local += value); cpa!(scope, array[index] = local); - format!("{:?}", scope.operations) + scope.operations } - fn inline_macro_ref_to_vectorized() -> String { + fn inline_macro_ref_to_vectorized() -> Vec { let context = CubeContext::root(); let scalar_item = Item::new(ElemType::as_elem()); let vectorized_item = Item::vectorized(ElemType::as_elem(), 2); @@ -158,10 +161,10 @@ mod tests { cpa!(scope, tmp = array[pos1]); cpa!(scope, vectorized_var[pos1] = tmp); - format!("{:?}", scope.operations) + scope.operations } - fn inline_macro_ref_one_to_vectorized() -> String { + fn inline_macro_ref_one_to_vectorized() -> Vec { let context = CubeContext::root(); let scalar_item = Item::new(ElemType::as_elem()); let unvectorized_item = Item::new(ElemType::as_elem()); @@ -176,26 +179,38 @@ mod tests { cpa!(scope, tmp = array[pos0]); cpa!(scope, unvectorized_var = tmp); - format!("{:?}", scope.operations) + scope.operations } - fn inline_macro_array_add_assign_expr() -> String { + fn inline_macro_array_add_assign_expr() -> Vec { let context = CubeContext::root(); let mut scope = context.into_scope(); let index = scope.create_local(Item::new(Elem::UInt)); let local = scope.create_local(Item::new(Elem::UInt)); - let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt)); - let const1 = Variable::ConstantScalar(1., Elem::UInt); - let const2 = Variable::ConstantScalar(5., Elem::UInt); - let value = Variable::ConstantScalar(1., Elem::UInt); + let array = Variable::GlobalInputArray { + id: 0, + item: Item::new(Elem::UInt), + }; + let const1 = Variable::ConstantScalar { + value: 1., + elem: Elem::UInt, + }; + let const2 = Variable::ConstantScalar { + value: 5., + elem: Elem::UInt, + }; + let value = Variable::ConstantScalar { + value: 1., + elem: Elem::UInt, + }; cpa!(scope, index = const1 + const2); cpa!(scope, local = array[index]); cpa!(scope, local += value); cpa!(scope, array[index] = local); - format!("{:?}", scope.operations) + scope.operations } } diff --git a/crates/burn-cube/tests/frontend/assign.rs b/crates/burn-cube/tests/frontend/assign.rs index dde8e4d9b..b720a3b87 100644 --- a/crates/burn-cube/tests/frontend/assign.rs +++ b/crates/burn-cube/tests/frontend/assign.rs @@ -98,8 +98,14 @@ mod tests { let mut scope = context.into_scope(); let x = scope.create_local(Item::new(Elem::UInt)); - let zero = Variable::ConstantScalar(0., Elem::UInt); - let one = Variable::ConstantScalar(1., Elem::UInt); + let zero = Variable::ConstantScalar { + value: 0., + elem: Elem::UInt, + }; + let one = Variable::ConstantScalar { + value: 1., + elem: Elem::UInt, + }; cpa!(scope, x = zero); cpa!(scope, x = x + one); @@ -115,8 +121,14 @@ mod tests { let y: Variable = y.into(); let x = scope.create_local(item); - let one = Variable::ConstantScalar(1., Elem::UInt); - let two = Variable::ConstantScalar(2., Elem::UInt); + let one = Variable::ConstantScalar { + value: 1., + elem: Elem::UInt, + }; + let two = Variable::ConstantScalar { + value: 2., + elem: Elem::UInt, + }; cpa!(scope, x = y); cpa!(scope, x = x + one); cpa!(scope, y = y + two); @@ -133,8 +145,14 @@ mod tests { let y: Variable = y.into(); let x = scope.create_local(item); - let one = Variable::ConstantScalar(1., Elem::UInt); - let two = Variable::ConstantScalar(2., Elem::UInt); + let one = Variable::ConstantScalar { + value: 1., + elem: Elem::UInt, + }; + let two = Variable::ConstantScalar { + value: 2., + elem: Elem::UInt, + }; cpa!(scope, x = y); cpa!(scope, y = y + one); cpa!(scope, x = x + two); @@ -151,10 +169,22 @@ mod tests { let y: Variable = y.into(); let x = scope.create_local(item); - let zero = Variable::ConstantScalar(0., Elem::UInt); - let one = Variable::ConstantScalar(1., Elem::UInt); - let two = Variable::ConstantScalar(2., Elem::UInt); - let three = Variable::ConstantScalar(3., Elem::UInt); + let zero = Variable::ConstantScalar { + value: 0., + elem: Elem::UInt, + }; + let one = Variable::ConstantScalar { + value: 1., + elem: Elem::UInt, + }; + let two = Variable::ConstantScalar { + value: 2., + elem: Elem::UInt, + }; + let three = Variable::ConstantScalar { + value: 3., + elem: Elem::UInt, + }; cpa!(scope, x[zero] = one); cpa!(scope, x[one] = one); cpa!(scope, x[two] = one); diff --git a/crates/burn-cube/tests/frontend/ops.rs b/crates/burn-cube/tests/frontend/ops.rs index 3d32d164d..e53f664c1 100644 --- a/crates/burn-cube/tests/frontend/ops.rs +++ b/crates/burn-cube/tests/frontend/ops.rs @@ -395,31 +395,31 @@ mod tests { let out_number = if in_type == out_type { 0 } else { 2 }; format!( "[Operator({ops_name}(BinaryOperator {{ \ - lhs: Local(0, Item {{ \ + lhs: Local {{ id: 0, item: Item {{ \ elem: {in_type}, \ vectorization: 1 \ - }}, 0), \ - rhs: Local(1, Item {{ \ + }}, depth: 0 }}, \ + rhs: Local {{ id: 1, item: Item {{ \ elem: {in_type}, \ vectorization: 1 \ - }}, 0), \ - out: Local({out_number}, Item {{ \ + }}, depth: 0 }}, \ + out: Local {{ id: {out_number}, item: Item {{ \ elem: {out_type}, \ vectorization: 1 \ - }}, 0) \ + }}, depth: 0 }} \ }}))]" ) } else { format!( "[Operator({ops_name}(UnaryOperator {{ \ - input: Local(0, Item {{ \ + input: Local {{ id: 0, item: Item {{ \ elem: {in_type}, \ vectorization: 1 \ - }}, 0), \ - out: Local(0, Item {{ \ + }}, depth: 0 }}, \ + out: Local {{ id: 0, item: Item {{ \ elem: {out_type}, \ vectorization: 1 \ - }}, 0) \ + }}, depth: 0 }} \ }}))]" ) } diff --git a/crates/burn-cube/tests/frontend/redeclare.rs b/crates/burn-cube/tests/frontend/redeclare.rs index d73ccd84d..4fb786721 100644 --- a/crates/burn-cube/tests/frontend/redeclare.rs +++ b/crates/burn-cube/tests/frontend/redeclare.rs @@ -117,7 +117,10 @@ mod tests { let i = scope.create_with_value(1, item); cpa!(scope, x += i); - let value = Variable::ConstantScalar(2., item.elem()); + let value = Variable::ConstantScalar { + value: 2., + elem: item.elem(), + }; cpa!(scope, i = value); cpa!(scope, x += i); @@ -156,7 +159,10 @@ mod tests { cpa!( &mut scope, range(0u32, end, false).for_each(|_, scope| { - let value = Variable::ConstantScalar(2.into(), item.elem()); + let value = Variable::ConstantScalar { + value: 2.into(), + elem: item.elem(), + }; cpa!(scope, y = value); cpa!(scope, x += y); }) diff --git a/crates/burn-cuda/src/compiler/base.rs b/crates/burn-cuda/src/compiler/base.rs index f3596124d..d0f2a0553 100644 --- a/crates/burn-cuda/src/compiler/base.rs +++ b/crates/burn-cuda/src/compiler/base.rs @@ -83,6 +83,9 @@ impl CudaCompiler { let processing = value.process(); for var in processing.variables { + if let gpu::Variable::Slice { .. } = var { + continue; + } instructions.push(Instruction::DeclareVariable { var: self.compile_variable(var), }); @@ -192,8 +195,8 @@ impl CudaCompiler { gpu::Metadata::Stride { dim, var, out } => { self.stride = true; let position = match var { - gpu::Variable::GlobalInputArray(idx, _) => idx as usize, - gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize, + gpu::Variable::GlobalInputArray { id, .. } => id as usize, + gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize, _ => panic!("Only Input and Output have a stride, got: {:?}", var), }; Instruction::Stride { @@ -205,8 +208,8 @@ impl CudaCompiler { gpu::Metadata::Shape { dim, var, out } => { self.shape = true; let position = match var { - gpu::Variable::GlobalInputArray(idx, _) => idx as usize, - gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize, + gpu::Variable::GlobalInputArray { id, .. } => id as usize, + gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize, _ => panic!("Only Input and Output have a shape, got {:?}", var), }; Instruction::Shape { @@ -215,12 +218,20 @@ impl CudaCompiler { out: self.compile_variable(out), } } - gpu::Metadata::ArrayLength { var, out } => super::Instruction::ArrayLength { - input: self.compile_variable(var), - out: self.compile_variable(out), - num_inputs: self.num_inputs, - num_outputs: self.num_outputs, - }, + gpu::Metadata::Length { var, out } => { + let input = self.compile_variable(var); + let out = self.compile_variable(out); + + match input { + super::Variable::Slice { .. } => super::Instruction::SliceLength { input, out }, + _ => super::Instruction::Length { + input, + out, + num_inputs: self.num_inputs, + num_outputs: self.num_outputs, + }, + } + } } } @@ -297,6 +308,12 @@ impl CudaCompiler { gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)), gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)), gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)), + gpu::Operator::Slice(op) => Instruction::Slice { + input: self.compile_variable(op.input), + start: self.compile_variable(op.start), + end: self.compile_variable(op.end), + out: self.compile_variable(op.out), + }, gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)), gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)), gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)), @@ -372,35 +389,40 @@ impl CudaCompiler { fn compile_variable(&mut self, value: gpu::Variable) -> super::Variable { match value { - gpu::Variable::GlobalInputArray(index, item) => { - super::Variable::GlobalInputArray(index, Self::compile_item(item)) + gpu::Variable::GlobalInputArray { id, item } => { + super::Variable::GlobalInputArray(id, Self::compile_item(item)) } - gpu::Variable::GlobalScalar(index, elem) => { - super::Variable::GlobalScalar(index, Self::compile_elem(elem), elem) + gpu::Variable::GlobalScalar { id, elem } => { + super::Variable::GlobalScalar(id, Self::compile_elem(elem), elem) } - gpu::Variable::Local(index, item, scope_depth) => super::Variable::Local { - index, + gpu::Variable::Local { id, item, depth } => super::Variable::Local { + id, item: Self::compile_item(item), - scope_depth, + depth, }, - gpu::Variable::LocalScalar(index, elem, scope_depth) => super::Variable::LocalScalar { - index, + gpu::Variable::Slice { id, item, depth } => super::Variable::Slice { + id, + item: Self::compile_item(item), + depth, + }, + gpu::Variable::LocalScalar { id, elem, depth } => super::Variable::LocalScalar { + id, elem: Self::compile_elem(elem), - scope_depth, + depth, }, - gpu::Variable::GlobalOutputArray(index, item) => { - super::Variable::GlobalOutputArray(index, Self::compile_item(item)) + gpu::Variable::GlobalOutputArray { id, item } => { + super::Variable::GlobalOutputArray(id, Self::compile_item(item)) } - gpu::Variable::ConstantScalar(index, elem) => { - super::Variable::ConstantScalar(index, Self::compile_elem(elem)) + gpu::Variable::ConstantScalar { value, elem } => { + super::Variable::ConstantScalar(value, Self::compile_elem(elem)) } - gpu::Variable::SharedMemory(index, item, size) => { + gpu::Variable::SharedMemory { id, item, length } => { let item = Self::compile_item(item); - if !self.shared_memories.iter().any(|s| s.index == index) { + if !self.shared_memories.iter().any(|s| s.index == id) { self.shared_memories - .push(super::SharedMemory::new(index, item, size)); + .push(super::SharedMemory::new(id, item, length)); } - super::Variable::SharedMemory(index, item, size) + super::Variable::SharedMemory(id, item, length) } gpu::Variable::AbsolutePos => { self.id = true; @@ -438,7 +460,12 @@ impl CudaCompiler { gpu::Variable::CubeCountX => super::Variable::NumWorkgroupsX, gpu::Variable::CubeCountY => super::Variable::NumWorkgroupsY, gpu::Variable::CubeCountZ => super::Variable::NumWorkgroupsZ, - gpu::Variable::LocalArray(id, item, depth, size) => { + gpu::Variable::LocalArray { + id, + item, + depth, + length, + } => { let item = Self::compile_item(item); if !self .local_arrays @@ -446,19 +473,19 @@ impl CudaCompiler { .any(|s| s.index == id && s.depth == depth) { self.local_arrays - .push(super::LocalArray::new(id, item, depth, size)); + .push(super::LocalArray::new(id, item, depth, length)); } - super::Variable::LocalArray(id, item, depth, size) + super::Variable::LocalArray(id, item, depth, length) } gpu::Variable::CubePos => todo!(), gpu::Variable::CubeDim => todo!(), gpu::Variable::CubeCount => todo!(), gpu::Variable::SubcubeDim => todo!(), - gpu::Variable::Matrix(index, matrix) => { + gpu::Variable::Matrix { id, mat } => { self.wmma = true; super::Variable::WmmaFragment { - index, - frag: Self::compile_matrix(matrix), + id, + frag: Self::compile_matrix(mat), } } } diff --git a/crates/burn-cuda/src/compiler/binary.rs b/crates/burn-cuda/src/compiler/binary.rs index f2fae1c71..ed722a8ac 100644 --- a/crates/burn-cuda/src/compiler/binary.rs +++ b/crates/burn-cuda/src/compiler/binary.rs @@ -361,9 +361,9 @@ impl Binary for IndexAssign { out: &Variable, ) -> std::fmt::Result { if let Variable::Local { - index: _, + id: _, item: _, - scope_depth: _, + depth: _, } = out { return IndexAssignVector::format(f, lhs, rhs, out); @@ -388,9 +388,9 @@ impl Binary for Index { out: &Variable, ) -> std::fmt::Result { if let Variable::Local { - index: _, + id: _, item: _, - scope_depth: _, + depth: _, } = lhs { return IndexVector::format(f, lhs, rhs, out); diff --git a/crates/burn-cuda/src/compiler/element.rs b/crates/burn-cuda/src/compiler/element.rs index 42950cf4c..583f2bb24 100644 --- a/crates/burn-cuda/src/compiler/element.rs +++ b/crates/burn-cuda/src/compiler/element.rs @@ -86,9 +86,14 @@ impl Component for Variable { Variable::GlobalOutputArray(_, e) => *e, Variable::SharedMemory(_, e, _) => *e, Variable::Local { - index: _, + id: _, item, - scope_depth: _, + depth: _, + } => *item, + Variable::Slice { + id: _, + item, + depth: _, } => *item, Variable::ConstantScalar(_, e) => Item::Scalar(*e), Variable::GlobalScalar(_, e, _) => Item::Scalar(*e), @@ -99,9 +104,9 @@ impl Component for Variable { Variable::LocalInvocationIdZ => Item::Scalar(Elem::U32), Variable::Rank => Item::Scalar(Elem::U32), Variable::LocalScalar { - index: _, + id: _, elem, - scope_depth: _, + depth: _, } => Item::Scalar(*elem), Variable::WorkgroupIdX => Item::Scalar(Elem::U32), Variable::WorkgroupIdY => Item::Scalar(Elem::U32), @@ -117,7 +122,7 @@ impl Component for Variable { Variable::NumWorkgroupsZ => Item::Scalar(Elem::U32), Variable::LocalArray(_, e, _, _) => *e, Variable::WarpSize => Item::Scalar(Elem::U32), - Variable::WmmaFragment { index: _, frag } => Item::Scalar(frag.elem), + Variable::WmmaFragment { id: _, frag } => Item::Scalar(frag.elem), } } } @@ -129,16 +134,9 @@ pub enum Variable { GlobalOutputArray(u16, Item), GlobalScalar(u16, Elem, gpu::Elem), ConstantScalar(f64, Elem), - Local { - index: u16, - item: Item, - scope_depth: u8, - }, - LocalScalar { - index: u16, - elem: Elem, - scope_depth: u8, - }, + Local { id: u16, item: Item, depth: u8 }, + Slice { id: u16, item: Item, depth: u8 }, + LocalScalar { id: u16, elem: Elem, depth: u8 }, SharedMemory(u16, Item, u32), LocalArray(u16, Item, u8, u32), Id, @@ -159,10 +157,7 @@ pub enum Variable { NumWorkgroupsX, NumWorkgroupsY, NumWorkgroupsZ, - WmmaFragment { - index: u16, - frag: Fragment, - }, + WmmaFragment { id: u16, frag: Fragment }, } impl Display for Variable { @@ -170,15 +165,18 @@ impl Display for Variable { match self { Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")), Variable::LocalScalar { - index, + id: index, elem: _, - scope_depth, + depth: scope_depth, } => f.write_fmt(format_args!("s_{scope_depth}_{index}")), Variable::Local { - index, + id: index, item: _, - scope_depth, + depth: scope_depth, } => f.write_fmt(format_args!("l_{scope_depth}_{index}")), + Variable::Slice { id, item: _, depth } => { + f.write_fmt(format_args!("slice_{depth}_{id}")) + } Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")), Variable::GlobalScalar(number, _, elem) => { f.write_fmt(format_args!("scalars_{elem}[{number}]")) @@ -209,7 +207,9 @@ impl Display for Variable { f.write_fmt(format_args!("l_arr_{}_{}", id, depth)) } Variable::WarpSize => f.write_str("warpSize"), - Variable::WmmaFragment { index, frag: _ } => f.write_fmt(format_args!("frag_{index}")), + Variable::WmmaFragment { id: index, frag: _ } => { + f.write_fmt(format_args!("frag_{index}")) + } } } } @@ -220,9 +220,9 @@ impl Variable { Variable::GlobalScalar(_, _, _) => true, Variable::ConstantScalar(_, _) => true, Variable::LocalScalar { - index: _, + id: _, elem: _, - scope_depth: _, + depth: _, } => true, Variable::Id => true, Variable::LocalInvocationIndex => true, @@ -234,9 +234,14 @@ impl Variable { Variable::GlobalOutputArray(_, _) => false, Variable::SharedMemory(_, _, _) => false, Variable::Local { - index: _, + id: _, item: _, - scope_depth: _, + depth: _, + } => false, + Variable::Slice { + id: _, + item: _, + depth: _, } => false, Variable::WorkgroupIdX => true, Variable::WorkgroupIdY => true, @@ -252,7 +257,7 @@ impl Variable { Variable::NumWorkgroupsZ => true, Variable::LocalArray(_, _, _, _) => false, Variable::WarpSize => true, - Variable::WmmaFragment { index: _, frag: _ } => false, + Variable::WmmaFragment { id: _, frag: _ } => false, } } diff --git a/crates/burn-cuda/src/compiler/instruction.rs b/crates/burn-cuda/src/compiler/instruction.rs index 2f28422c7..622416b23 100644 --- a/crates/burn-cuda/src/compiler/instruction.rs +++ b/crates/burn-cuda/src/compiler/instruction.rs @@ -16,12 +16,16 @@ pub struct UnaryInstruction { #[derive(Debug, Clone)] pub enum Instruction { - ArrayLength { + Length { input: Variable, out: Variable, num_inputs: usize, num_outputs: usize, }, + SliceLength { + input: Variable, + out: Variable, + }, DeclareVariable { var: Variable, }, @@ -58,6 +62,12 @@ pub enum Instruction { instructions_if: Vec, instructions_else: Vec, }, + Slice { + input: Variable, + start: Variable, + end: Variable, + out: Variable, + }, Return, Break, Stride { @@ -114,7 +124,7 @@ impl Display for Instruction { Instruction::Return => f.write_str("return;"), Instruction::Break => f.write_str("break;"), Instruction::DeclareVariable { var } => match var { - Variable::WmmaFragment { index: _, frag } => { + Variable::WmmaFragment { id: _, frag } => { f.write_fmt(format_args!("{frag} {var};\n")) } _ => { @@ -123,6 +133,16 @@ impl Display for Instruction { } }, Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out), + Instruction::Slice { + input, + start, + end, + out, + } => { + let item = out.item(); + f.write_fmt(format_args!("uint {out}_length = {end} - {start};\n"))?; + f.write_fmt(format_args!("{item} *{out} = {input} + {start};\n")) + } Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out), Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out), @@ -225,7 +245,10 @@ for (uint {i} = {start}; {i} < {end}; {i}++) {{ Instruction::SyncThreads => f.write_str("__syncthreads();\n"), Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out), Instruction::Floor(it) => Floor::format(f, &it.input, &it.out), - Instruction::ArrayLength { + Instruction::SliceLength { input, out } => { + f.write_fmt(format_args!("{out} = {input}_length;\n")) + } + Instruction::Length { input, out, num_inputs, diff --git a/crates/burn-jit/src/fusion/elemwise/builder.rs b/crates/burn-jit/src/fusion/elemwise/builder.rs index 784f3db66..169331062 100644 --- a/crates/burn-jit/src/fusion/elemwise/builder.rs +++ b/crates/burn-jit/src/fusion/elemwise/builder.rs @@ -288,7 +288,10 @@ impl ElementWiseBuilder { return false; } - let input = Variable::ConstantScalar(1.0, desc.dtype.into()); + let input = Variable::ConstantScalar { + value: 1.0, + elem: desc.dtype.into(), + }; let out = self.builder.output(desc, Variable::AbsolutePos); self.builder @@ -301,7 +304,10 @@ impl ElementWiseBuilder { return false; } - let input = Variable::ConstantScalar(0.0, desc.dtype.into()); + let input = Variable::ConstantScalar { + value: 0.0, + elem: desc.dtype.into(), + }; let out = self.builder.output(desc, Variable::AbsolutePos); self.builder diff --git a/crates/burn-jit/src/fusion/tracing/builder.rs b/crates/burn-jit/src/fusion/tracing/builder.rs index 32187e7b6..efb86ad11 100644 --- a/crates/burn-jit/src/fusion/tracing/builder.rs +++ b/crates/burn-jit/src/fusion/tracing/builder.rs @@ -56,9 +56,11 @@ impl TraceBuilder { } true => match self.output_to_local.get(&tensor.id) { // Is a local variable. - Some(local_index) => { - Variable::Local(*local_index, Item::new(elem), self.scope.depth) - } + Some(local_index) => Variable::Local { + id: *local_index, + item: Item::new(elem), + depth: self.scope.depth, + }, // Isn't an operation output variable, so must be an existing input. None => self .inputs @@ -85,7 +87,11 @@ impl TraceBuilder { // Output already registered as a local variable. if let Some(index) = self.output_to_local.get(&tensor.id) { - return Variable::Local(*index, Item::new(elem), self.scope.depth); + return Variable::Local { + id: *index, + item: Item::new(elem), + depth: self.scope.depth, + }; } let variable = self.scope.create_local(Item::new(elem)); @@ -159,11 +165,11 @@ impl TraceBuilder { // // Only local variables can become outputs. let mark = |var: &Variable, list: &mut Vec| { - if let Variable::Local(index, _, _) = var { + if let Variable::Local { id: id_local, .. } = var { if let Some((id, _)) = self .output_to_local .iter() - .find(|(_id, position)| *position == index) + .find(|(_tensor_id, position)| *position == id_local) { if !list.contains(id) { list.push(*id); @@ -201,6 +207,11 @@ impl TraceBuilder { mark(&op.c, &mut local_tensor_ids_input); mark(&op.out, &mut local_tensor_ids_output); } + Operator::Slice(op) => { + mark(&op.input, &mut local_tensor_ids_input); + mark(&op.start, &mut local_tensor_ids_input); + mark(&op.out, &mut local_tensor_ids_output); + } Operator::Max(op) => mark_binary( op, &mut local_tensor_ids_input, diff --git a/crates/burn-jit/src/kernel/cast/base.rs b/crates/burn-jit/src/kernel/cast/base.rs index 5ccc7cb3d..944a1abad 100644 --- a/crates/burn-jit/src/kernel/cast/base.rs +++ b/crates/burn-jit/src/kernel/cast/base.rs @@ -68,8 +68,14 @@ impl Kernel for CastEagerKernel Kernel for BoolCastEagerKernel { let item_input = Item::new(Elem::Bool); let item_output = EO::cube_elem().into(); - let tensor = Variable::GlobalInputArray(0, item_input); - let output = Variable::GlobalOutputArray(0, item_output); + let tensor = Variable::GlobalInputArray { + id: 0, + item: item_input, + }; + let output = Variable::GlobalOutputArray { + id: 0, + item: item_output, + }; BoolCastShader { tensor, output }.expand(&mut scope); diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs b/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs index 5b28e2aea..657b47739 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs +++ b/crates/burn-jit/src/kernel/conv/conv_transpose2d.rs @@ -93,13 +93,34 @@ impl Conv2dTransposeComputeShader { cpa!(scope, kernel_size_0 = shape(weight, 2u32)); cpa!(scope, kernel_size_1 = shape(weight, 3u32)); - let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt); - let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt); - let dilation_0 = Variable::GlobalScalar(2, Elem::UInt); - let dilation_1 = Variable::GlobalScalar(3, Elem::UInt); - let padding_0 = Variable::GlobalScalar(4, Elem::UInt); - let padding_1 = Variable::GlobalScalar(5, Elem::UInt); - let groups = Variable::GlobalScalar(6, Elem::UInt); + let conv_stride_0 = Variable::GlobalScalar { + id: 0, + elem: Elem::UInt, + }; + let conv_stride_1 = Variable::GlobalScalar { + id: 1, + elem: Elem::UInt, + }; + let dilation_0 = Variable::GlobalScalar { + id: 2, + elem: Elem::UInt, + }; + let dilation_1 = Variable::GlobalScalar { + id: 3, + elem: Elem::UInt, + }; + let padding_0 = Variable::GlobalScalar { + id: 4, + elem: Elem::UInt, + }; + let padding_1 = Variable::GlobalScalar { + id: 5, + elem: Elem::UInt, + }; + let groups = Variable::GlobalScalar { + id: 6, + elem: Elem::UInt, + }; let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); @@ -294,10 +315,10 @@ impl Kernel for Conv2dTransposeEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray(0, item); - let weight = Variable::GlobalInputArray(1, item); - let bias = Variable::GlobalInputArray(2, item); - let output = Variable::GlobalOutputArray(0, item); + let input = Variable::GlobalInputArray { id: 0, item }; + let weight = Variable::GlobalInputArray { id: 1, item }; + let bias = Variable::GlobalInputArray { id: 2, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs index d7c09524e..cc4ca53af 100644 --- a/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs +++ b/crates/burn-jit/src/kernel/conv/conv_transpose3d.rs @@ -105,16 +105,46 @@ impl Conv3dTransposeComputeShader { cpa!(scope, kernel_size_1 = shape(weight, 3u32)); cpa!(scope, kernel_size_2 = shape(weight, 4u32)); - let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt); - let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt); - let conv_stride_2 = Variable::GlobalScalar(2, Elem::UInt); - let dilation_0 = Variable::GlobalScalar(3, Elem::UInt); - let dilation_1 = Variable::GlobalScalar(4, Elem::UInt); - let dilation_2 = Variable::GlobalScalar(5, Elem::UInt); - let padding_0 = Variable::GlobalScalar(6, Elem::UInt); - let padding_1 = Variable::GlobalScalar(7, Elem::UInt); - let padding_2 = Variable::GlobalScalar(8, Elem::UInt); - let groups = Variable::GlobalScalar(9, Elem::UInt); + let conv_stride_0 = Variable::GlobalScalar { + id: 0, + elem: Elem::UInt, + }; + let conv_stride_1 = Variable::GlobalScalar { + id: 1, + elem: Elem::UInt, + }; + let conv_stride_2 = Variable::GlobalScalar { + id: 2, + elem: Elem::UInt, + }; + let dilation_0 = Variable::GlobalScalar { + id: 3, + elem: Elem::UInt, + }; + let dilation_1 = Variable::GlobalScalar { + id: 4, + elem: Elem::UInt, + }; + let dilation_2 = Variable::GlobalScalar { + id: 5, + elem: Elem::UInt, + }; + let padding_0 = Variable::GlobalScalar { + id: 6, + elem: Elem::UInt, + }; + let padding_1 = Variable::GlobalScalar { + id: 7, + elem: Elem::UInt, + }; + let padding_2 = Variable::GlobalScalar { + id: 8, + elem: Elem::UInt, + }; + let groups = Variable::GlobalScalar { + id: 9, + elem: Elem::UInt, + }; let stride_0_i = scope.create_local(Elem::Int(IntKind::I32)); let stride_1_i = scope.create_local(Elem::Int(IntKind::I32)); @@ -362,10 +392,10 @@ impl Kernel for Conv3dTransposeEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray(0, item); - let weight = Variable::GlobalInputArray(1, item); - let bias = Variable::GlobalInputArray(2, item); - let output = Variable::GlobalOutputArray(0, item); + let input = Variable::GlobalInputArray { id: 0, item }; + let weight = Variable::GlobalInputArray { id: 1, item }; + let bias = Variable::GlobalInputArray { id: 2, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/index/flip.rs b/crates/burn-jit/src/kernel/index/flip.rs index ed5f2fb22..d2a8ffae2 100644 --- a/crates/burn-jit/src/kernel/index/flip.rs +++ b/crates/burn-jit/src/kernel/index/flip.rs @@ -43,7 +43,10 @@ impl FlipComputeShader { cpa!(scope, shape = shape(output, i)); cpa!( scope, - flip = cast(Variable::GlobalScalar(i as u16, Elem::UInt)) + flip = cast(Variable::GlobalScalar { + id: i as u16, + elem: Elem::UInt + }) ); cpa!(scope, flip_bool = flip == 1u32); @@ -70,8 +73,8 @@ impl Kernel for FlipEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray(0, item); - let output = Variable::GlobalOutputArray(0, item); + let input = Variable::GlobalInputArray { id: 0, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/index/gather.rs b/crates/burn-jit/src/kernel/index/gather.rs index 8f4ce5d5b..a58d32491 100644 --- a/crates/burn-jit/src/kernel/index/gather.rs +++ b/crates/burn-jit/src/kernel/index/gather.rs @@ -27,8 +27,8 @@ struct GatherComputeShader { impl GatherComputeShader { pub fn expand(self, scope: &mut Scope) { match self.tensor { - Variable::GlobalInputArray(_, _) => (), - Variable::GlobalOutputArray(_, _) => (), + Variable::GlobalInputArray { .. } => (), + Variable::GlobalOutputArray { .. } => (), _ => panic!("Tensor variable must be an global array."), }; @@ -78,10 +78,16 @@ impl Kernel for GatherEagerKernel { let item_tensor = E::cube_elem().into(); let item_indices: Item = Elem::Int(IntKind::I32).into(); - let tensor = Variable::GlobalInputArray(0, item_tensor); + let tensor = Variable::GlobalInputArray { + id: 0, + item: item_tensor, + }; let indices = scope.read_array(1, item_indices, Variable::AbsolutePos); - let output_array = Variable::GlobalOutputArray(0, item_tensor); + let output_array = Variable::GlobalOutputArray { + id: 0, + item: item_tensor, + }; let output_local = scope.create_local(item_tensor); GatherComputeShader { diff --git a/crates/burn-jit/src/kernel/index/repeat.rs b/crates/burn-jit/src/kernel/index/repeat.rs index c4cf71362..111b985c4 100644 --- a/crates/burn-jit/src/kernel/index/repeat.rs +++ b/crates/burn-jit/src/kernel/index/repeat.rs @@ -61,8 +61,8 @@ impl Kernel for RepeatEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray(0, item); - let output = Variable::GlobalOutputArray(0, item); + let input = Variable::GlobalInputArray { id: 0, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/index/scatter.rs b/crates/burn-jit/src/kernel/index/scatter.rs index 7a795e18d..3438c4980 100644 --- a/crates/burn-jit/src/kernel/index/scatter.rs +++ b/crates/burn-jit/src/kernel/index/scatter.rs @@ -32,13 +32,13 @@ struct ScatterComputeShader { impl ScatterComputeShader { pub fn expand(self, scope: &mut Scope) { match self.input { - Variable::GlobalInputArray(_, _) => (), - Variable::GlobalOutputArray(_, _) => (), + Variable::GlobalInputArray { .. } => (), + Variable::GlobalOutputArray { .. } => (), _ => panic!("Input variable must be an global array."), }; match self.value { - Variable::GlobalInputArray(_, _) => (), - Variable::GlobalOutputArray(_, _) => (), + Variable::GlobalInputArray { .. } => (), + Variable::GlobalOutputArray { .. } => (), _ => panic!("Value variable must be an global array."), }; @@ -136,9 +136,18 @@ impl Kernel for ScatterEagerKernel { let item_value = E::cube_elem().into(); let item_indices: Item = Elem::Int(IntKind::I32).into(); - let input_output = Variable::GlobalInputArray(0, item_value); - let indices = Variable::GlobalInputArray(1, Elem::Int(IntKind::I32).into()); - let value = Variable::GlobalInputArray(2, item_value); + let input_output = Variable::GlobalInputArray { + id: 0, + item: item_value, + }; + let indices = Variable::GlobalInputArray { + id: 1, + item: Elem::Int(IntKind::I32).into(), + }; + let value = Variable::GlobalInputArray { + id: 2, + item: item_value, + }; scope.write_global_custom(input_output); diff --git a/crates/burn-jit/src/kernel/index/select.rs b/crates/burn-jit/src/kernel/index/select.rs index 23c453826..6e1a5782d 100644 --- a/crates/burn-jit/src/kernel/index/select.rs +++ b/crates/burn-jit/src/kernel/index/select.rs @@ -73,9 +73,12 @@ impl Kernel for SelectEagerKernel { let item = E::cube_elem().into(); let item_indices: Item = Elem::Int(IntKind::I32).into(); - let input = Variable::GlobalInputArray(0, item); - let indices = Variable::GlobalInputArray(1, item_indices); - let output = Variable::GlobalOutputArray(0, item); + let input = Variable::GlobalInputArray { id: 0, item }; + let indices = Variable::GlobalInputArray { + id: 1, + item: item_indices, + }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/index/select_assign.rs b/crates/burn-jit/src/kernel/index/select_assign.rs index b2e348cd0..78c48706e 100644 --- a/crates/burn-jit/src/kernel/index/select_assign.rs +++ b/crates/burn-jit/src/kernel/index/select_assign.rs @@ -134,9 +134,12 @@ impl Kernel for SelectAssignEagerKernel { let item = E::cube_elem().into(); let item_indices: Item = Elem::Int(IntKind::I32).into(); - let tensor = Variable::GlobalInputArray(0, item); - let value = Variable::GlobalInputArray(1, item); - let indices = Variable::GlobalInputArray(2, item_indices); + let tensor = Variable::GlobalInputArray { id: 0, item }; + let value = Variable::GlobalInputArray { id: 1, item }; + let indices = Variable::GlobalInputArray { + id: 2, + item: item_indices, + }; scope.write_global_custom(tensor); diff --git a/crates/burn-jit/src/kernel/index/slice.rs b/crates/burn-jit/src/kernel/index/slice.rs index ae1abb336..b0f15862c 100644 --- a/crates/burn-jit/src/kernel/index/slice.rs +++ b/crates/burn-jit/src/kernel/index/slice.rs @@ -44,7 +44,10 @@ impl SliceComputeShader { cpa!(scope, shape_output = shape(output, i)); cpa!( scope, - range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt)) + range_start = cast(Variable::GlobalScalar { + id: i as u16, + elem: Elem::UInt + }) ); cpa!(scope, offset_local = id / stride_output); @@ -66,8 +69,8 @@ impl Kernel for SliceEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray(0, item); - let output = Variable::GlobalOutputArray(0, item); + let input = Variable::GlobalInputArray { id: 0, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/index/slice_assign.rs b/crates/burn-jit/src/kernel/index/slice_assign.rs index 036e0707d..8ccdb18e8 100644 --- a/crates/burn-jit/src/kernel/index/slice_assign.rs +++ b/crates/burn-jit/src/kernel/index/slice_assign.rs @@ -47,7 +47,10 @@ impl SliceAssignComputeShader { cpa!(scope, shape_input = shape(input, i)); cpa!( scope, - range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt)) + range_start = cast(Variable::GlobalScalar { + id: i as u16, + elem: Elem::UInt + }) ); cpa!(scope, offset_local = id / stride_value); @@ -75,8 +78,8 @@ impl Kernel for SliceAssignEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray(0, item); - let value = Variable::GlobalInputArray(1, item); + let input = Variable::GlobalInputArray { id: 0, item }; + let value = Variable::GlobalInputArray { id: 1, item }; scope.write_global_custom(input); diff --git a/crates/burn-jit/src/kernel/interpolate/bicubic.rs b/crates/burn-jit/src/kernel/interpolate/bicubic.rs index 374632e76..60527f298 100644 --- a/crates/burn-jit/src/kernel/interpolate/bicubic.rs +++ b/crates/burn-jit/src/kernel/interpolate/bicubic.rs @@ -371,8 +371,8 @@ impl Kernel for InterpolateBicubicEagerKernel Kernel for InterpolateBilinearEagerKernel Kernel for InterpolateNearestEagerKernel Kernel for InterpolateNearestBackwardEagerKer let mut scope = Scope::root(); let item = E::cube_elem().into(); - let out_grad = Variable::GlobalInputArray(0, item); - let output = Variable::GlobalOutputArray(0, item); + let out_grad = Variable::GlobalInputArray { id: 0, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; InterpolateNearestBackwardShader { out_grad, diff --git a/crates/burn-jit/src/kernel/mask/shader.rs b/crates/burn-jit/src/kernel/mask/shader.rs index 192007643..b24bc0d90 100644 --- a/crates/burn-jit/src/kernel/mask/shader.rs +++ b/crates/burn-jit/src/kernel/mask/shader.rs @@ -40,7 +40,10 @@ impl MaskStrategy for MaskFill { } fn value_variable(value_item: Item) -> Variable { - Variable::GlobalScalar(0, value_item.elem()) + Variable::GlobalScalar { + id: 0, + elem: value_item.elem(), + } } } @@ -65,7 +68,10 @@ impl MaskStrategy for MaskWhere { } fn value_variable(value_item: Item) -> Variable { - Variable::GlobalInputArray(2, value_item) + Variable::GlobalInputArray { + id: 2, + item: value_item, + } } } @@ -102,10 +108,19 @@ impl Kernel let tensor_item = EI::cube_elem().into(); let mask_item = EM::cube_elem().into(); - let input = Variable::GlobalInputArray(0, tensor_item); - let mask = Variable::GlobalInputArray(1, mask_item); + let input = Variable::GlobalInputArray { + id: 0, + item: tensor_item, + }; + let mask = Variable::GlobalInputArray { + id: 1, + item: mask_item, + }; let value = M::value_variable(tensor_item); - let output = Variable::GlobalOutputArray(0, tensor_item); + let output = Variable::GlobalOutputArray { + id: 0, + item: tensor_item, + }; MaskShader:: { input, @@ -176,8 +191,14 @@ impl Kernel let tensor_item = EI::cube_elem().into(); let mask_item = EM::cube_elem().into(); - let input = Variable::GlobalInputArray(0, tensor_item); - let mask = Variable::GlobalInputArray(1, mask_item); + let input = Variable::GlobalInputArray { + id: 0, + item: tensor_item, + }; + let mask = Variable::GlobalInputArray { + id: 1, + item: mask_item, + }; let value = M::value_variable(tensor_item); MaskShader:: { diff --git a/crates/burn-jit/src/kernel/matmul/tiling2d.rs b/crates/burn-jit/src/kernel/matmul/tiling2d.rs index cb284b8fe..74807f44f 100644 --- a/crates/burn-jit/src/kernel/matmul/tiling2d.rs +++ b/crates/burn-jit/src/kernel/matmul/tiling2d.rs @@ -40,9 +40,9 @@ impl Kernel for MatmulTiling2dEagerKernel { ); let item = elem.into(); - let lhs = Variable::GlobalInputArray(0, item); - let rhs = Variable::GlobalInputArray(1, item); - let out = Variable::GlobalOutputArray(0, item); + let lhs = Variable::GlobalInputArray { id: 0, item }; + let rhs = Variable::GlobalInputArray { id: 1, item }; + let out = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(out); diff --git a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs index fe2346fe5..e3754b086 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_avg_pool2d_backward.rs @@ -207,8 +207,8 @@ impl Kernel for AdaptiveAvgPool2dBackwardEagerKern let mut scope = Scope::root(); let item = E::cube_elem().into(); - let grad = Variable::GlobalInputArray(0, item); - let output = Variable::GlobalOutputArray(0, item); + let grad = Variable::GlobalInputArray { id: 0, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs index eb87aca4a..28c0eb5df 100644 --- a/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs +++ b/crates/burn-jit/src/kernel/pool/adaptive_pool2d_shader.rs @@ -193,8 +193,8 @@ impl Kernel for AdaptivePool2dEagerKernel { let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray(0, item); - let output = Variable::GlobalOutputArray(0, item); + let input = Variable::GlobalInputArray { id: 0, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs index 2cac774e2..0eb6bce76 100644 --- a/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/avg_pool2d_backward.rs @@ -71,10 +71,22 @@ impl AvgPool2dBackwardComputeShader { cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt); - let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt); - let padding_0 = Variable::GlobalScalar(4, Elem::UInt); - let padding_1 = Variable::GlobalScalar(5, Elem::UInt); + let pool_stride_0 = Variable::GlobalScalar { + id: 0, + elem: Elem::UInt, + }; + let pool_stride_1 = Variable::GlobalScalar { + id: 1, + elem: Elem::UInt, + }; + let padding_0 = Variable::GlobalScalar { + id: 4, + elem: Elem::UInt, + }; + let padding_1 = Variable::GlobalScalar { + id: 5, + elem: Elem::UInt, + }; let [kernel_size_0, kernel_size_1] = self.kernel_size; let b = scope.create_local(Elem::UInt); @@ -215,12 +227,30 @@ impl AvgPool2dBackwardComputeShader { output_stride_2: Variable, output_stride_3: Variable, ) -> (Variable, Variable, Variable, Variable) { - let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt); - let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt); - let dilation_0 = Variable::GlobalScalar(2, Elem::UInt); - let dilation_1 = Variable::GlobalScalar(3, Elem::UInt); - let padding_0 = Variable::GlobalScalar(4, Elem::UInt); - let padding_1 = Variable::GlobalScalar(5, Elem::UInt); + let pool_stride_0 = Variable::GlobalScalar { + id: 0, + elem: Elem::UInt, + }; + let pool_stride_1 = Variable::GlobalScalar { + id: 1, + elem: Elem::UInt, + }; + let dilation_0 = Variable::GlobalScalar { + id: 2, + elem: Elem::UInt, + }; + let dilation_1 = Variable::GlobalScalar { + id: 3, + elem: Elem::UInt, + }; + let padding_0 = Variable::GlobalScalar { + id: 4, + elem: Elem::UInt, + }; + let padding_1 = Variable::GlobalScalar { + id: 5, + elem: Elem::UInt, + }; let [kernel_size_0, kernel_size_1] = self.kernel_size; @@ -321,8 +351,8 @@ impl Kernel for AvgPool2dBackwardEagerKernel let mut scope = Scope::root(); let item = E::cube_elem().into(); - let grad = Variable::GlobalInputArray(0, item); - let output = Variable::GlobalOutputArray(0, item); + let grad = Variable::GlobalInputArray { id: 0, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d.rs b/crates/burn-jit/src/kernel/pool/max_pool2d.rs index 7e7d5f0bc..6e2e0ae34 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d.rs @@ -21,7 +21,10 @@ impl PoolStrategy for MaxPool { fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator { let max_val = scope.create_local(item); - let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem()); + let max_initial = Variable::ConstantScalar { + value: E::minimum_value().to_f64(), + elem: item.elem(), + }; cpa!(scope, max_val = max_initial); max_val } @@ -67,7 +70,10 @@ impl PoolStrategy for MaxPoolWithIndices { fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator { let max_val = scope.create_local(item); - let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem()); + let max_initial = Variable::ConstantScalar { + value: E::minimum_value().to_f64(), + elem: item.elem(), + }; cpa!(scope, max_val = max_initial); let max_index = scope.create_local(Elem::UInt); (max_val, max_index) diff --git a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs index c0b570218..764700c7b 100644 --- a/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs +++ b/crates/burn-jit/src/kernel/pool/max_pool2d_backward.rs @@ -161,12 +161,30 @@ impl MaxPool2dBackwardComputeShader { output_stride_2: Variable, output_stride_3: Variable, ) -> (Variable, Variable, Variable, Variable) { - let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt); - let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt); - let dilation_0 = Variable::GlobalScalar(2, Elem::UInt); - let dilation_1 = Variable::GlobalScalar(3, Elem::UInt); - let padding_0 = Variable::GlobalScalar(4, Elem::UInt); - let padding_1 = Variable::GlobalScalar(5, Elem::UInt); + let pool_stride_0 = Variable::GlobalScalar { + id: 0, + elem: Elem::UInt, + }; + let pool_stride_1 = Variable::GlobalScalar { + id: 1, + elem: Elem::UInt, + }; + let dilation_0 = Variable::GlobalScalar { + id: 2, + elem: Elem::UInt, + }; + let dilation_1 = Variable::GlobalScalar { + id: 3, + elem: Elem::UInt, + }; + let padding_0 = Variable::GlobalScalar { + id: 4, + elem: Elem::UInt, + }; + let padding_1 = Variable::GlobalScalar { + id: 5, + elem: Elem::UInt, + }; let [kernel_size_0, kernel_size_1] = self.kernel_size; @@ -267,9 +285,12 @@ impl Kernel for MaxPool2dWithIndicesBackwardEagerK let mut scope = Scope::root(); let item = E::cube_elem().into(); - let indices = Variable::GlobalInputArray(0, Item::new(Elem::Int(IntKind::I32))); - let grad = Variable::GlobalInputArray(1, item); - let output = Variable::GlobalOutputArray(0, item); + let indices = Variable::GlobalInputArray { + id: 0, + item: Item::new(Elem::Int(IntKind::I32)), + }; + let grad = Variable::GlobalInputArray { id: 1, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; scope.write_global_custom(output); diff --git a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs index 58c77540a..6a7dc62c9 100644 --- a/crates/burn-jit/src/kernel/pool/pool2d_shader.rs +++ b/crates/burn-jit/src/kernel/pool/pool2d_shader.rs @@ -65,12 +65,30 @@ impl Pool2dComputeShader cpa!(scope, output_shape_2 = shape(output, 2u32)); cpa!(scope, output_shape_3 = shape(output, 3u32)); - let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt); - let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt); - let dilation_0 = Variable::GlobalScalar(2, Elem::UInt); - let dilation_1 = Variable::GlobalScalar(3, Elem::UInt); - let padding_0 = Variable::GlobalScalar(4, Elem::UInt); - let padding_1 = Variable::GlobalScalar(5, Elem::UInt); + let pool_stride_0 = Variable::GlobalScalar { + id: 0, + elem: Elem::UInt, + }; + let pool_stride_1 = Variable::GlobalScalar { + id: 1, + elem: Elem::UInt, + }; + let dilation_0 = Variable::GlobalScalar { + id: 2, + elem: Elem::UInt, + }; + let dilation_1 = Variable::GlobalScalar { + id: 3, + elem: Elem::UInt, + }; + let padding_0 = Variable::GlobalScalar { + id: 4, + elem: Elem::UInt, + }; + let padding_1 = Variable::GlobalScalar { + id: 5, + elem: Elem::UInt, + }; let b = scope.create_local(Elem::UInt); let c = scope.create_local(Elem::UInt); @@ -177,13 +195,13 @@ impl Kernel for Pool2dEagerKernel let mut scope = Scope::root(); let item = E::cube_elem().into(); - let input = Variable::GlobalInputArray(0, item); - let output = Variable::GlobalOutputArray(0, item); + let input = Variable::GlobalInputArray { id: 0, item }; + let output = Variable::GlobalOutputArray { id: 0, item }; let indices = if P::with_indices() { - Some(Variable::GlobalOutputArray( - 1, - Item::new(Elem::Int(IntKind::I32)), - )) + Some(Variable::GlobalOutputArray { + id: 1, + item: Item::new(Elem::Int(IntKind::I32)), + }) } else { None }; diff --git a/crates/burn-jit/src/kernel/prng/base.rs b/crates/burn-jit/src/kernel/prng/base.rs index bf04e33e2..6d15d1bde 100644 --- a/crates/burn-jit/src/kernel/prng/base.rs +++ b/crates/burn-jit/src/kernel/prng/base.rs @@ -66,17 +66,32 @@ impl, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel::new(); for i in 0..P::args_length() { - args.push(Variable::GlobalScalar(i as u16, item.elem())); + args.push(Variable::GlobalScalar { + id: i as u16, + elem: item.elem(), + }); } PrngShader::::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope); diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs b/crates/burn-jit/src/kernel/reduce/naive/argmax.rs index e3d3a3c68..ec5af005f 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/argmax.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/argmax.rs @@ -16,7 +16,10 @@ impl ReduceDimNaive for Argmax { ) -> Self::Accumulator { let index = scope.create_local(Elem::UInt); let max = scope.create_local(input_item); - let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem()); + let max_initial = Variable::ConstantScalar { + value: E::minimum_value().to_f64(), + elem: input_item.elem(), + }; cpa!(scope, max = max_initial); (max, index) diff --git a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs b/crates/burn-jit/src/kernel/reduce/naive/argmin.rs index fe5b2d13f..44c1f24d6 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/argmin.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/argmin.rs @@ -17,7 +17,10 @@ impl ReduceDimNaive for Argmin { ) -> Self::Accumulator { let index = scope.create_local(Elem::UInt); let min = scope.create_local(input_item); - let min_initial = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem()); + let min_initial = Variable::ConstantScalar { + value: E::maximum_value().to_f64(), + elem: input_item.elem(), + }; cpa!(scope, min = min_initial); (min, index) diff --git a/crates/burn-jit/src/kernel/reduce/naive/shader.rs b/crates/burn-jit/src/kernel/reduce/naive/shader.rs index 85639e110..ed70b0cbd 100644 --- a/crates/burn-jit/src/kernel/reduce/naive/shader.rs +++ b/crates/burn-jit/src/kernel/reduce/naive/shader.rs @@ -41,8 +41,14 @@ impl, R: JitRuntime, EI: JitElement, EO: JitElement> Kern let item_input = EI::cube_elem().into(); let item_output = EO::cube_elem().into(); - let tensor = Variable::GlobalInputArray(0, item_input); - let output = Variable::GlobalOutputArray(0, item_output); + let tensor = Variable::GlobalInputArray { + id: 0, + item: item_input, + }; + let output = Variable::GlobalOutputArray { + id: 0, + item: item_output, + }; NaiveReduceDimComputeShader { tensor, diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs index 977be5aeb..910f98697 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmax.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmax.rs @@ -18,7 +18,10 @@ impl ReduceDimShared for Argmax { let value_shared_memory = scope.create_shared(input_item, shared_memory_size); let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); - let max = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem()); + let max = Variable::ConstantScalar { + value: E::minimum_value().to_f64(), + elem: input_item.elem(), + }; cpa!(scope, value_shared_memory[write_position] = max); (value_shared_memory, index_shared_memory) } diff --git a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs index aa11511bb..3d1772840 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/argmin.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/argmin.rs @@ -19,7 +19,10 @@ impl ReduceDimShared for Argmin { let value_shared_memory = scope.create_shared(input_item, shared_memory_size); let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size); - let min = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem()); + let min = Variable::ConstantScalar { + value: E::maximum_value().to_f64(), + elem: input_item.elem(), + }; cpa!(scope, value_shared_memory[write_position] = min); (value_shared_memory, index_shared_memory) } diff --git a/crates/burn-jit/src/kernel/reduce/shared/shader.rs b/crates/burn-jit/src/kernel/reduce/shared/shader.rs index c7704a591..db80791e4 100644 --- a/crates/burn-jit/src/kernel/reduce/shared/shader.rs +++ b/crates/burn-jit/src/kernel/reduce/shared/shader.rs @@ -51,8 +51,14 @@ impl, R: JitRuntime, EI: JitElement, EO: JitElement> Ker let item_input = EI::cube_elem().into(); let item_output = EO::cube_elem().into(); - let tensor = Variable::GlobalInputArray(0, item_input); - let output = Variable::GlobalOutputArray(0, item_output); + let tensor = Variable::GlobalInputArray { + id: 0, + item: item_input, + }; + let output = Variable::GlobalOutputArray { + id: 0, + item: item_output, + }; // Reduce groups are elements that are aligned along the reduce dim SharedReduceDimComputeShader { diff --git a/crates/burn-wgpu/src/compiler/wgsl/base.rs b/crates/burn-wgpu/src/compiler/wgsl/base.rs index 1ee7db1f2..3eda1c3d2 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/base.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/base.rs @@ -9,14 +9,24 @@ pub enum Variable { GlobalScalar(u16, Elem, cube::Elem), ConstantScalar(f64, Elem), Local { - index: u16, + id: u16, item: Item, - scope_depth: u8, + depth: u8, + }, + Named { + name: String, + item: Item, + is_array: bool, + }, + Slice { + id: u16, + item: Item, + depth: u8, }, LocalScalar { - index: u16, + id: u16, elem: Elem, - scope_depth: u8, + depth: u8, }, SharedMemory(u16, Item, u32), LocalArray(u16, Item, u8, u32), @@ -71,9 +81,9 @@ impl Variable { Variable::GlobalScalar(_, _, _) => true, Variable::ConstantScalar(_, _) => true, Variable::LocalScalar { - index: _, + id: _, elem: _, - scope_depth: _, + depth: _, } => true, Variable::Id => true, Variable::LocalInvocationIndex => true, @@ -85,11 +95,9 @@ impl Variable { Variable::GlobalOutputArray(_, _) => false, Variable::SharedMemory(_, _, _) => false, Variable::LocalArray(_, _, _, _) => false, - Variable::Local { - index: _, - item: _, - scope_depth: _, - } => false, + Variable::Local { .. } => false, + Variable::Named { .. } => false, + Variable::Slice { .. } => false, Variable::WorkgroupIdX => true, Variable::WorkgroupIdY => true, Variable::WorkgroupIdZ => true, @@ -121,11 +129,9 @@ impl Variable { Self::GlobalOutputArray(_, e) => *e, Self::SharedMemory(_, e, _) => *e, Self::LocalArray(_, e, _, _) => *e, - Self::Local { - index: _, - item, - scope_depth: _, - } => *item, + Self::Local { item, .. } => *item, + Self::Slice { item, .. } => *item, + Self::Named { item, .. } => *item, Self::ConstantScalar(_, e) => Item::Scalar(*e), Self::GlobalScalar(_, e, _) => Item::Scalar(*e), Self::Id => Item::Scalar(Elem::U32), @@ -134,11 +140,7 @@ impl Variable { Self::LocalInvocationIdY => Item::Scalar(Elem::U32), Self::LocalInvocationIdZ => Item::Scalar(Elem::U32), Self::Rank => Item::Scalar(Elem::U32), - Self::LocalScalar { - index: _, - elem, - scope_depth: _, - } => Item::Scalar(*elem), + Self::LocalScalar { elem, .. } => Item::Scalar(*elem), Self::WorkgroupId => Item::Scalar(Elem::U32), Self::WorkgroupIdX => Item::Scalar(Elem::U32), Self::WorkgroupIdY => Item::Scalar(Elem::U32), @@ -222,15 +224,21 @@ impl Display for Variable { f.write_fmt(format_args!("input_{number}_global")) } Variable::LocalScalar { - index, + id: index, elem: _, - scope_depth, + depth: scope_depth, } => f.write_fmt(format_args!("s_{scope_depth}_{index}")), Variable::Local { - index, + id: index, item: _, - scope_depth, + depth: scope_depth, } => f.write_fmt(format_args!("l_{scope_depth}_{index}")), + Variable::Named { name, .. } => f.write_str(name), + Variable::Slice { + id: index, + item: _, + depth: scope_depth, + } => f.write_fmt(format_args!("slice_{scope_depth}_{index}")), Variable::GlobalOutputArray(number, _) => { f.write_fmt(format_args!("output_{number}_global")) } diff --git a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs index 5054b5bad..f5aeac4fc 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/compiler.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/compiler.rs @@ -127,43 +127,53 @@ impl WgslCompiler { fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable { match value { - cube::Variable::GlobalInputArray(index, item) => { - wgsl::Variable::GlobalInputArray(index, Self::compile_item(item)) + cube::Variable::GlobalInputArray { id, item } => { + wgsl::Variable::GlobalInputArray(id, Self::compile_item(item)) } - cube::Variable::GlobalScalar(index, elem) => { - wgsl::Variable::GlobalScalar(index, Self::compile_elem(elem), elem) + cube::Variable::GlobalScalar { id, elem } => { + wgsl::Variable::GlobalScalar(id, Self::compile_elem(elem), elem) } - cube::Variable::Local(index, item, scope_depth) => wgsl::Variable::Local { - index, + cube::Variable::Local { id, item, depth } => wgsl::Variable::Local { + id, item: Self::compile_item(item), - scope_depth, + depth, }, - cube::Variable::LocalScalar(index, elem, scope_depth) => wgsl::Variable::LocalScalar { - index, + cube::Variable::Slice { id, item, depth } => wgsl::Variable::Slice { + id, + item: Self::compile_item(item), + depth, + }, + cube::Variable::LocalScalar { id, elem, depth } => wgsl::Variable::LocalScalar { + id, elem: Self::compile_elem(elem), - scope_depth, + depth, }, - cube::Variable::GlobalOutputArray(index, item) => { - wgsl::Variable::GlobalOutputArray(index, Self::compile_item(item)) + cube::Variable::GlobalOutputArray { id, item } => { + wgsl::Variable::GlobalOutputArray(id, Self::compile_item(item)) } - cube::Variable::ConstantScalar(index, elem) => { - wgsl::Variable::ConstantScalar(index, Self::compile_elem(elem)) + cube::Variable::ConstantScalar { value, elem } => { + wgsl::Variable::ConstantScalar(value, Self::compile_elem(elem)) } - cube::Variable::SharedMemory(index, item, size) => { + cube::Variable::SharedMemory { id, item, length } => { let item = Self::compile_item(item); - if !self.shared_memories.iter().any(|s| s.index == index) { + if !self.shared_memories.iter().any(|s| s.index == id) { self.shared_memories - .push(SharedMemory::new(index, item, size)); + .push(SharedMemory::new(id, item, length)); } - wgsl::Variable::SharedMemory(index, item, size) + wgsl::Variable::SharedMemory(id, item, length) } - cube::Variable::LocalArray(index, item, scope_depth, size) => { + cube::Variable::LocalArray { + id, + item, + depth, + length, + } => { let item = Self::compile_item(item); - if !self.local_arrays.iter().any(|s| s.index == index) { + if !self.local_arrays.iter().any(|s| s.index == id) { self.local_arrays - .push(LocalArray::new(index, item, scope_depth, size)); + .push(LocalArray::new(id, item, depth, length)); } - wgsl::Variable::LocalArray(index, item, scope_depth, size) + wgsl::Variable::LocalArray(id, item, depth, length) } cube::Variable::AbsolutePos => { self.id = true; @@ -241,7 +251,7 @@ impl WgslCompiler { wgsl::Variable::NumWorkgroups } cube::Variable::SubcubeDim => wgsl::Variable::SubgroupSize, - cube::Variable::Matrix(_, _) => { + cube::Variable::Matrix { .. } => { panic!("Cooperative matrix-multiply and accumulate not supported.") } } @@ -252,6 +262,11 @@ impl WgslCompiler { let processing = value.process(); for var in processing.variables { + // We don't declare slices. + if let cube::Variable::Slice { .. } = var { + continue; + } + instructions.push(wgsl::Instruction::DeclareVariable { var: self.compile_variable(var), }); @@ -427,8 +442,8 @@ impl WgslCompiler { cube::Metadata::Stride { dim, var, out } => { self.stride = true; let position = match var { - cube::Variable::GlobalInputArray(idx, _) => idx as usize, - cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize, + cube::Variable::GlobalInputArray { id, .. } => id as usize, + cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize, _ => panic!("Only Input and Output have a stride, got: {:?}", var), }; wgsl::Instruction::Stride { @@ -440,8 +455,8 @@ impl WgslCompiler { cube::Metadata::Shape { dim, var, out } => { self.shape = true; let position = match var { - cube::Variable::GlobalInputArray(idx, _) => idx as usize, - cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize, + cube::Variable::GlobalInputArray { id, .. } => id as usize, + cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize, _ => panic!("Only Input and Output have a shape, got {:?}", var), }; wgsl::Instruction::Shape { @@ -450,7 +465,7 @@ impl WgslCompiler { out: self.compile_variable(out), } } - cube::Metadata::ArrayLength { var, out } => wgsl::Instruction::ArrayLength { + cube::Metadata::Length { var, out } => wgsl::Instruction::Length { out: self.compile_variable(out), var: self.compile_variable(var), }, @@ -652,6 +667,12 @@ impl WgslCompiler { rhs: self.compile_variable(op.rhs), out: self.compile_variable(op.out), }, + cube::Operator::Slice(op) => wgsl::Instruction::Slice { + input: self.compile_variable(op.input), + start: self.compile_variable(op.start), + end: self.compile_variable(op.end), + out: self.compile_variable(op.out), + }, } } diff --git a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs index 5c65d669c..fedd8ce3c 100644 --- a/crates/burn-wgpu/src/compiler/wgsl/instructions.rs +++ b/crates/burn-wgpu/src/compiler/wgsl/instructions.rs @@ -1,6 +1,6 @@ use super::{ base::{Item, Variable}, - IndexedVariable, Subgroup, + Elem, IndexedVariable, Subgroup, }; use std::fmt::Display; @@ -167,7 +167,7 @@ pub enum Instruction { position: usize, out: Variable, }, - ArrayLength { + Length { var: Variable, out: Variable, }, @@ -232,6 +232,12 @@ pub enum Instruction { rhs: Variable, out: Variable, }, + Slice { + input: Variable, + start: Variable, + end: Variable, + out: Variable, + }, Subgroup(Subgroup), } @@ -245,6 +251,16 @@ impl Display for Instruction { Instruction::Add { lhs, rhs, out } => { f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n")) } + Instruction::Slice { + input, + start, + end, + out, + } => { + f.write_fmt(format_args!("let {out}_offset = {start};\n"))?; + f.write_fmt(format_args!("let {out}_length = {end} - {start};\n"))?; + f.write_fmt(format_args!("let {out}_ptr = &{input};\n")) + } Instruction::Fma { a, b, c, out } => { f.write_fmt(format_args!("{out} = fma({a}, {b}, {c});\n")) } @@ -261,10 +277,22 @@ impl Display for Instruction { f.write_fmt(format_args!("{out} = {lhs} || {rhs};\n")) } Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")), - Instruction::Index { lhs, rhs, out } => { - let item = out.item(); - f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n")) - } + Instruction::Index { lhs, rhs, out } => match lhs { + Variable::Slice { item, .. } => { + let offset = Variable::Named { + name: format!("{lhs}_offset"), + item: Item::Scalar(Elem::U32), + is_array: false, + }; + let lhs = Variable::Named { + name: format!("(*{lhs}_ptr)"), + item: *item, + is_array: true, + }; + index(f, &lhs, rhs, out, Some(offset)) + } + _ => index(f, lhs, rhs, out, None), + }, Instruction::Modulo { lhs, rhs, out } => { f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n")) } @@ -406,88 +434,24 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{ f.write_str("}\n") } - Instruction::IndexAssign { lhs, rhs, out } => match lhs.item() { - Item::Vec4(elem) => { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - let lhs3 = lhs.index(3); + Instruction::IndexAssign { lhs, rhs, out } => { + if let Variable::Slice { item, .. } = out { + let offset = Variable::Named { + name: format!("{out}_offset"), + item: Item::Scalar(Elem::U32), + is_array: false, + }; + let out = Variable::Named { + name: format!("(*{out}_ptr)"), + item: *item, + is_array: true, + }; - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - let rhs3 = rhs.index(3); - - f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; - f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; - f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?; - f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n")) + index_assign(f, lhs, rhs, &out, Some(offset)) + } else { + index_assign(f, lhs, rhs, out, None) } - Item::Vec3(elem) => { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - let lhs2 = lhs.index(2); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - let rhs2 = rhs.index(2); - - f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; - f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; - f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n")) - } - Item::Vec2(elem) => { - let lhs0 = lhs.index(0); - let lhs1 = lhs.index(1); - - let rhs0 = rhs.index(0); - let rhs1 = rhs.index(1); - - f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; - f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n")) - } - Item::Scalar(_elem) => { - let is_array = matches!( - out, - Variable::GlobalInputArray(_, _) - | Variable::GlobalOutputArray(_, _) - | Variable::SharedMemory(_, _, _) - | Variable::LocalArray(_, _, _, _) - ); - - if !is_array { - let elem_out = out.elem(); - let casting_type = match rhs.item() { - Item::Vec4(_) => Item::Vec4(elem_out), - Item::Vec3(_) => Item::Vec3(elem_out), - Item::Vec2(_) => Item::Vec2(elem_out), - Item::Scalar(_) => Item::Scalar(elem_out), - }; - f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n")) - } else { - let item_rhs = rhs.item(); - let item_out = out.item(); - - let vectorization_factor = item_out.vectorization_factor(); - if vectorization_factor > item_rhs.vectorization_factor() { - let casting_type = item_out.elem(); - f.write_fmt(format_args!("{out}[{lhs}] = vec{vectorization_factor}("))?; - for i in 0..vectorization_factor { - let value = rhs.index(i); - f.write_fmt(format_args!("{casting_type}({value})"))?; - - if i < vectorization_factor - 1 { - f.write_str(",")?; - } - } - f.write_str(");\n") - } else { - let casting_type = item_out; - f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n")) - } - } - } - }, + } Instruction::If { cond, instructions } => { f.write_fmt(format_args!("if {cond} {{\n"))?; for i in instructions { @@ -513,9 +477,10 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{ Instruction::Return => f.write_str("return;\n"), Instruction::Break => f.write_str("break;\n"), Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"), - Instruction::ArrayLength { var, out } => { - f.write_fmt(format_args!("{out} = arrayLength(&{var});\n")) - } + Instruction::Length { var, out } => match var { + Variable::Slice { .. } => f.write_fmt(format_args!("{out} = {var}_length;\n")), + _ => f.write_fmt(format_args!("{out} = arrayLength(&{var});\n")), + }, Instruction::Loop { instructions } => { f.write_fmt(format_args!("loop {{\n"))?; for i in instructions { @@ -623,3 +588,140 @@ fn unroll< } Ok(()) } + +struct IndexOffset { + var: Variable, + offset: Option, + index: usize, +} +impl IndexOffset { + fn new(var: &Variable, offset: &Option, index: usize) -> Self { + Self { + var: var.clone(), + offset: offset.clone(), + index, + } + } +} + +impl Display for IndexOffset { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let var = self.var.index(self.index); + + match &self.offset { + Some(offset) => { + let offset = offset.index(self.index); + f.write_fmt(format_args!("{var} + {offset}")) + } + None => f.write_fmt(format_args!("{var}")), + } + } +} + +fn index( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + offset: Option, +) -> core::fmt::Result { + let item = out.item(); + match offset { + Some(offset) => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs} + {offset}]);\n")), + None => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n")), + } +} + +fn index_assign( + f: &mut std::fmt::Formatter<'_>, + lhs: &Variable, + rhs: &Variable, + out: &Variable, + offset: Option, +) -> core::fmt::Result { + match lhs.item() { + Item::Vec4(elem) => { + let lhs0 = IndexOffset::new(lhs, &offset, 0); + let lhs1 = IndexOffset::new(lhs, &offset, 1); + let lhs2 = IndexOffset::new(lhs, &offset, 2); + let lhs3 = IndexOffset::new(lhs, &offset, 3); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + let rhs2 = rhs.index(2); + let rhs3 = rhs.index(3); + + f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; + f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; + f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?; + f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n")) + } + Item::Vec3(elem) => { + let lhs0 = IndexOffset::new(lhs, &offset, 0); + let lhs1 = IndexOffset::new(lhs, &offset, 1); + let lhs2 = IndexOffset::new(lhs, &offset, 2); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + let rhs2 = rhs.index(2); + + f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; + f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?; + f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n")) + } + Item::Vec2(elem) => { + let lhs0 = IndexOffset::new(lhs, &offset, 0); + let lhs1 = IndexOffset::new(lhs, &offset, 1); + + let rhs0 = rhs.index(0); + let rhs1 = rhs.index(1); + + f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?; + f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n")) + } + Item::Scalar(_elem) => { + let is_array = match out { + Variable::GlobalInputArray(_, _) + | Variable::GlobalOutputArray(_, _) + | Variable::SharedMemory(_, _, _) + | Variable::Slice { .. } + | Variable::LocalArray(_, _, _, _) => true, + Variable::Named { is_array, .. } => *is_array, + _ => false, + }; + + if !is_array { + let elem_out = out.elem(); + let casting_type = match rhs.item() { + Item::Vec4(_) => Item::Vec4(elem_out), + Item::Vec3(_) => Item::Vec3(elem_out), + Item::Vec2(_) => Item::Vec2(elem_out), + Item::Scalar(_) => Item::Scalar(elem_out), + }; + f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n")) + } else { + let item_rhs = rhs.item(); + let item_out = out.item(); + let lhs = IndexOffset::new(lhs, &offset, 0); + + let vectorization_factor = item_out.vectorization_factor(); + if vectorization_factor > item_rhs.vectorization_factor() { + let casting_type = item_out.elem(); + f.write_fmt(format_args!("{out}[{lhs}] = vec{vectorization_factor}("))?; + for i in 0..vectorization_factor { + let value = rhs.index(i); + f.write_fmt(format_args!("{casting_type}({value})"))?; + + if i < vectorization_factor - 1 { + f.write_str(",")?; + } + } + f.write_str(");\n") + } else { + let casting_type = item_out; + f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n")) + } + } + } + } +}