Merge branch 'main' into refactor/cube/expand

This commit is contained in:
nathaniel 2024-07-11 17:35:07 -04:00
commit 01e5c39ae4
70 changed files with 1664 additions and 565 deletions

View File

@ -423,17 +423,24 @@ impl KernelIntegrator {
} else {
item
};
let elem_adapted = bool_item(item);
let item_adapted = bool_item(item);
self.output_bindings.push(Binding {
item: elem_adapted,
item: item_adapted,
visibility: Visibility::ReadWrite,
location: Location::Storage,
size: None,
});
self.expansion.scope.write_global(
Variable::Local(local, item, self.expansion.scope.depth),
Variable::GlobalOutputArray(index, elem_adapted),
Variable::Local {
id: local,
item,
depth: self.expansion.scope.depth,
},
Variable::GlobalOutputArray {
id: index,
item: item_adapted,
},
position,
);
index += 1;
@ -451,8 +458,15 @@ impl KernelIntegrator {
};
self.expansion.scope.write_global(
Variable::Local(local, item, self.expansion.scope.depth),
Variable::GlobalInputArray(input, bool_item(item)),
Variable::Local {
id: local,
item,
depth: self.expansion.scope.depth,
},
Variable::GlobalInputArray {
id: input,
item: bool_item(item),
},
position,
);
}

View File

@ -27,11 +27,11 @@ where
if unroll {
let start = match start.deref() {
Variable::ConstantScalar(val, _) => *val as usize,
Variable::ConstantScalar { value, .. } => *value as usize,
_ => panic!("Only constant start can be unrolled."),
};
let end = match end.deref() {
Variable::ConstantScalar(val, _) => *val as usize,
Variable::ConstantScalar { value, .. } => *value as usize,
_ => panic!("Only constant end can be unrolled."),
};

View File

@ -15,14 +15,14 @@
//! 16,
//! 16,
//! 16,
//! cmma::MatrixLayout::ColMajor,
//! cmma::MatrixLayout::RowMajor,
//! );
//! let b = cmma::Matrix::<F16>::new(
//! cmma::MatrixIdent::B,
//! 16,
//! 16,
//! 16,
//! cmma::MatrixLayout::RowMajor,
//! cmma::MatrixLayout::ColMajor,
//! );
//! let c = cmma::Matrix::<F32>::new(
//! cmma::MatrixIdent::Accumulator,
@ -32,12 +32,17 @@
//! cmma::MatrixLayout::Undefined,
//! );
//! cmma::fill::<F32>(&c, F32::new(0.0));
//! cmma::load::<F16>(&a, lhs, UInt::new(16));
//! cmma::load::<F16>(&b, rhs, UInt::new(16));
//! cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
//! cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
//!
//! cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
//!
//! cmma::store::<F32>(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor);
//! cmma::store::<F32>(
//! out.as_slice_mut(),
//! &c,
//! UInt::new(16),
//! cmma::MatrixLayout::RowMajor,
//! );
//! }
//! ```
@ -49,7 +54,8 @@ use crate::{
};
use super::{
Array, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, UInt,
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut,
UInt,
};
pub use ir::{MatrixIdent, MatrixLayout};
@ -142,7 +148,7 @@ pub mod fill {
/// Load the matrix with the provided array using the stride.
#[allow(unused_variables)]
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Array<C>, stride: UInt) {
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
unexpanded!()
}
@ -155,7 +161,7 @@ pub mod load {
pub fn __expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElementTyped<Array<C>>,
value: ExpandElementTyped<Slice<'static, C>>,
stride: ExpandElement,
) {
context.register(Operation::CoopMma(ir::CoopMma::Load {
@ -169,7 +175,7 @@ pub mod load {
/// Store the matrix in the given array following the given stride and layout.
#[allow(unused_variables)]
pub fn store<C: CubePrimitive>(
output: &Array<C>,
output: &mut SliceMut<'_, C>,
mat: &Matrix<C>,
stride: UInt,
layout: MatrixLayout,
@ -185,7 +191,7 @@ pub mod store {
#[allow(unused_variables)]
pub fn __expand<C: CubePrimitive>(
context: &mut CubeContext,
output: ExpandElementTyped<Array<C>>,
output: ExpandElementTyped<SliceMut<'static, C>>,
mat: MatrixExpand,
stride: ExpandElement,
layout: MatrixLayout,

View File

@ -4,8 +4,6 @@ use alloc::rc::Rc;
use core::cell::RefCell;
use std::collections::HashMap;
use super::{CubePrimitive, SharedMemoryExpand};
#[derive(Default, Clone)]
pub struct VariablePool {
map: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
@ -114,14 +112,14 @@ impl CubeContext {
ExpandElement::Plain(variable)
}
pub fn create_shared<T: CubePrimitive>(
&mut self,
item: Item,
size: u32,
) -> SharedMemoryExpand<T> {
SharedMemoryExpand {
val: ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size)),
}
/// Create a new slice element.
pub fn create_slice(&mut self, item: Item) -> ExpandElement {
let variable = self.scope.borrow_mut().create_slice(item);
ExpandElement::Plain(variable)
}
pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement {
ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size))
}
pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
@ -129,19 +127,19 @@ impl CubeContext {
}
/// Obtain the index-th input
pub fn input(&mut self, index: u16, item: Item) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item))
pub fn input(&mut self, id: u16, item: Item) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray { id, item })
}
/// Obtain the index-th output
pub fn output(&mut self, index: u16, item: Item) -> ExpandElement {
let var = crate::ir::Variable::GlobalOutputArray(index, item);
pub fn output(&mut self, id: u16, item: Item) -> ExpandElement {
let var = crate::ir::Variable::GlobalOutputArray { id, item };
self.scope.borrow_mut().write_global_custom(var);
ExpandElement::Plain(var)
}
/// Obtain the index-th scalar
pub fn scalar(&self, index: u16, elem: Elem) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalScalar(index, elem))
pub fn scalar(&self, id: u16, elem: Elem) -> ExpandElement {
ExpandElement::Plain(crate::ir::Variable::GlobalScalar { id, elem })
}
}

View File

@ -40,7 +40,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Array need constant initialization value"),
};
context
@ -55,7 +55,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
context

View File

@ -160,7 +160,7 @@ impl ExpandElement {
pub fn can_mut(&self) -> bool {
match self {
ExpandElement::Managed(var) => {
if let Variable::Local(_, _, _) = var.as_ref() {
if let Variable::Local { .. } = var.as_ref() {
Rc::strong_count(var) <= 2
} else {
false
@ -201,10 +201,10 @@ impl Init for ExpandElement {
let mut init = |elem: Self| init_expand(context, elem, Operator::Assign);
match *self {
Variable::GlobalScalar(_, _) => init(self),
Variable::LocalScalar(_, _, _) => init(self),
Variable::ConstantScalar(_, _) => init(self),
Variable::Local(_, _, _) => init(self),
Variable::GlobalScalar { .. } => init(self),
Variable::LocalScalar { .. } => init(self),
Variable::ConstantScalar { .. } => init(self),
Variable::Local { .. } => init(self),
// Constant should be initialized since the new variable can be mutated afterward.
// And it is assumed those values are cloned.
Variable::Rank
@ -230,11 +230,12 @@ impl Init for ExpandElement {
| Variable::AbsolutePosY
| Variable::AbsolutePosZ => init(self),
// Array types can't be copied, so we should simply return the same variable.
Variable::SharedMemory(_, _, _)
| Variable::GlobalInputArray(_, _)
| Variable::GlobalOutputArray(_, _)
| Variable::LocalArray(_, _, _, _)
| Variable::Matrix(_, _) => self,
Variable::SharedMemory { .. }
| Variable::GlobalInputArray { .. }
| Variable::GlobalOutputArray { .. }
| Variable::LocalArray { .. }
| Variable::Slice { .. }
| Variable::Matrix { .. } => self,
}
}
}

View File

@ -41,9 +41,9 @@ impl_into_expand_element!(i64);
/// Useful for Comptime
impl From<UInt> for ExpandElement {
fn from(value: UInt) -> Self {
ExpandElement::Plain(crate::ir::Variable::ConstantScalar(
value.val as f64,
UInt::as_elem(),
))
ExpandElement::Plain(crate::ir::Variable::ConstantScalar {
value: value.val as f64,
elem: UInt::as_elem(),
})
}
}

View File

@ -93,7 +93,10 @@ macro_rules! impl_float {
_context: &mut CubeContext,
val: f32,
) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}

View File

@ -63,7 +63,10 @@ macro_rules! impl_int {
_context: &mut CubeContext,
val: i64,
) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}

View File

@ -7,6 +7,7 @@ mod float;
mod int;
mod numeric;
mod shared_memory;
mod slice;
mod tensor;
mod uint;
mod vectorized;
@ -20,6 +21,7 @@ pub use float::*;
pub use int::*;
pub use numeric::*;
pub use shared_memory::*;
pub use slice::*;
pub use tensor::*;
pub use uint::*;
pub use vectorized::*;

View File

@ -50,7 +50,10 @@ pub trait Numeric:
}
fn __expand_from_int(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}

View File

@ -5,32 +5,21 @@ use crate::{
ir::Item,
};
use super::{ExpandElement, Init, UInt};
use super::{ExpandElementTyped, Init, UInt};
#[derive(Clone, Copy)]
pub struct SharedMemory<T: CubeType> {
_val: PhantomData<T>,
}
#[derive(Clone)]
pub struct SharedMemoryExpand<T: CubePrimitive> {
pub val: <T as CubeType>::ExpandType,
}
impl<T: CubePrimitive> From<SharedMemoryExpand<T>> for ExpandElement {
fn from(shared_memory_expand: SharedMemoryExpand<T>) -> Self {
shared_memory_expand.val
}
}
impl<T: CubePrimitive> Init for SharedMemoryExpand<T> {
impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
fn init(self, _context: &mut CubeContext) -> Self {
self
}
}
impl<T: CubePrimitive> CubeType for SharedMemory<T> {
type ExpandType = SharedMemoryExpand<T>;
type ExpandType = ExpandElementTyped<SharedMemory<T>>;
}
impl<T: CubePrimitive + Clone> SharedMemory<T> {
@ -49,13 +38,14 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
context.create_shared(
let var = context.create_shared(
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
size,
)
);
ExpandElementTyped::new(var)
}
pub fn __expand_new<S: Index>(
@ -64,9 +54,10 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
context.create_shared(Item::new(T::as_elem()), size)
let var = context.create_shared(Item::new(T::as_elem()), size);
ExpandElementTyped::new(var)
}
}

View File

@ -0,0 +1,285 @@
use std::marker::PhantomData;
use super::{
Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory, Tensor,
UInt,
};
use crate::{
frontend::indexation::Index,
ir::{self, Operator},
prelude::CubeContext,
unexpanded,
};
/// A read-only contiguous list of elements
pub struct Slice<'a, E> {
_e: PhantomData<E>,
_l: &'a (),
}
/// A read-write contiguous list of elements.
pub struct SliceMut<'a, E> {
_e: PhantomData<E>,
_l: &'a mut (),
}
impl<'a, E> Slice<'a, E> {
/// Get the length of the slice.
pub fn len(&self) -> UInt {
unexpanded!()
}
}
impl<'a, E> SliceMut<'a, E> {
/// Get the length of the slice.
pub fn len(&self) -> UInt {
unexpanded!()
}
}
impl<'a, E: CubeType> CubeType for Slice<'a, E> {
type ExpandType = ExpandElementTyped<Slice<'static, E>>;
}
impl<'a, C: CubeType> Init for ExpandElementTyped<Slice<'a, C>> {
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
// The type can't be deeply cloned/copied.
self
}
}
impl<'a, E: CubeType> CubeType for SliceMut<'a, E> {
type ExpandType = ExpandElementTyped<SliceMut<'static, E>>;
}
impl<'a, C: CubeType> Init for ExpandElementTyped<SliceMut<'a, C>> {
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
// The type can't be deeply cloned/copied.
self
}
}
pub trait SliceOperator<E>: CubeType<ExpandType = Self::Expand> {
type Expand: SliceOperatorExpand<E>;
/// Return a read-only view of all elements comprise between the start and end index.
#[allow(unused_variables)]
fn slice<Start: Index, End: Index>(&self, start: Start, end: End) -> &'_ Slice<'_, E> {
unexpanded!()
}
/// Expand function of [SliceOperator::slice].
fn slice_expand<Start: Index, End: Index>(
context: &mut CubeContext,
expand: Self::Expand,
start: Start,
end: End,
) -> ExpandElementTyped<Slice<'static, E>> {
expand.slice_expand(context, start, end)
}
/// Return a read-write view of all elements comprise between the start and end index.
#[allow(unused_variables)]
fn slice_mut<Start: Index, End: Index>(
&mut self,
start: Start,
end: End,
) -> &'_ mut SliceMut<'_, E> {
unexpanded!()
}
/// Expand function of [SliceOperator::slice_mut].
fn slice_mut_expand<Start: Index, End: Index>(
context: &mut CubeContext,
expand: Self::Expand,
start: Start,
end: End,
) -> ExpandElementTyped<SliceMut<'static, E>> {
expand.slice_mut_expand(context, start, end)
}
/// Return a read-write view of all elements comprise between the start and end index.
///
/// # Warning
///
/// Ignore the multiple borrow rule.
#[allow(unused_variables)]
fn slice_mut_unsafe<Start: Index, End: Index>(
&self,
start: Start,
end: End,
) -> SliceMut<'static, E> {
unexpanded!()
}
/// Expand function of [SliceOperator::slice_mut_unsafe].
fn slice_mut_unsafe_expand<Start: Index, End: Index>(
context: &mut CubeContext,
expand: Self::Expand,
start: Start,
end: End,
) -> ExpandElementTyped<SliceMut<'static, E>> {
expand.slice_mut_unsafe_expand(context, start, end)
}
/// Reinterprete the current type as a read-only slice.
#[allow(unused_variables)]
fn as_slice(&self) -> &'_ Slice<'_, E> {
unexpanded!()
}
/// Expand function of [SliceOperator::as_slice].
fn as_slice_expand(
context: &mut CubeContext,
expand: Self::Expand,
) -> ExpandElementTyped<Slice<'static, E>> {
expand.as_slice_expand(context)
}
/// Reinterprete the current type as a read-write slice.
#[allow(unused_variables)]
fn as_slice_mut(&mut self) -> &'_ mut SliceMut<'_, E> {
unexpanded!()
}
/// Expand function of [SliceOperator::as_slice_mut].
fn as_slice_mut_expand(
context: &mut CubeContext,
expand: Self::Expand,
) -> ExpandElementTyped<SliceMut<'static, E>> {
expand.as_slice_mut_expand(context)
}
/// Reinterprete the current type as a read-write slice.
///
/// # Warning
///
/// Ignore the multiple borrow rule.
#[allow(unused_variables)]
fn as_slice_mut_unsafe(&self) -> SliceMut<'static, E> {
unexpanded!()
}
/// Expand function of [SliceOperator::as_slice_mut_unsafe].
fn as_slice_mut_unsafe_expand(
context: &mut CubeContext,
expand: Self::Expand,
) -> ExpandElementTyped<SliceMut<'static, E>> {
expand.as_slice_mut_unsafe_expand(context)
}
}
pub trait SliceOperatorExpand<E>: Into<ExpandElement> + Clone {
fn slice_base<Start: Index, End: Index>(
&self,
context: &mut CubeContext,
start: Start,
end: End,
) -> ExpandElement;
fn slice_expand<Start: Index, End: Index>(
&self,
context: &mut CubeContext,
start: Start,
end: End,
) -> ExpandElementTyped<Slice<'static, E>> {
ExpandElementTyped::new(self.slice_base(context, start, end))
}
fn slice_mut_expand<Start: Index, End: Index>(
&self,
context: &mut CubeContext,
start: Start,
end: End,
) -> ExpandElementTyped<SliceMut<'static, E>> {
ExpandElementTyped::new(self.slice_base(context, start, end))
}
fn slice_mut_unsafe_expand<Start: Index, End: Index>(
&self,
context: &mut CubeContext,
start: Start,
end: End,
) -> ExpandElementTyped<SliceMut<'static, E>> {
ExpandElementTyped::new(self.slice_base(context, start, end))
}
fn as_slice_expand(&self, _context: &mut CubeContext) -> ExpandElementTyped<Slice<'static, E>> {
let expand = self.clone().into();
ExpandElementTyped::new(expand)
}
fn as_slice_mut_unsafe_expand(
&self,
context: &mut CubeContext,
) -> ExpandElementTyped<SliceMut<'static, E>> {
self.as_slice_mut_expand(context)
}
fn as_slice_mut_expand(
&self,
_context: &mut CubeContext,
) -> ExpandElementTyped<SliceMut<'static, E>> {
let expand = self.clone().into();
ExpandElementTyped::new(expand)
}
}
macro_rules! slice_op {
($type:ident) => {
impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
type Expand = ExpandElementTyped<$type<E>>;
}
impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
fn slice_base<Start: Index, End: Index>(
&self,
context: &mut CubeContext,
start: Start,
end: End,
) -> ExpandElement {
slice_expand(context, self.clone(), start, end)
}
}
};
(slice $type:ident) => {
impl<'a, E: CubePrimitive> SliceOperator<E> for $type<'a, E> {
type Expand = ExpandElementTyped<$type<'static, E>>;
}
impl<'a, E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<'a, E>> {
fn slice_base<Start: Index, End: Index>(
&self,
context: &mut CubeContext,
start: Start,
end: End,
) -> ExpandElement {
slice_expand(context, self.clone(), start, end)
}
}
};
}
slice_op!(Array);
slice_op!(Tensor);
slice_op!(SharedMemory);
slice_op!(slice Slice);
slice_op!(slice SliceMut);
pub fn slice_expand<I: Into<ExpandElement>, S1: Index, S2: Index>(
context: &mut CubeContext,
input: I,
start: S1,
end: S2, // Todo use it to get the length.
) -> ExpandElement {
let input = input.into();
let out = context.create_slice(input.item());
context.register(Operator::Slice(ir::SliceOperator {
input: *input,
start: start.value(),
end: end.value(),
out: *out,
}));
out
}

View File

@ -202,7 +202,7 @@ impl<T> ExpandElementTyped<T> {
// Expanded version of len
pub fn len_expand(self, context: &mut CubeContext) -> ExpandElement {
let out = context.create_local(Item::new(Elem::UInt));
context.register(Metadata::ArrayLength {
context.register(Metadata::Length {
var: self.expand.into(),
out: out.clone().into(),
});

View File

@ -49,7 +49,10 @@ impl UInt {
}
pub fn __expand_new(_context: &mut CubeContext, val: u32) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}

View File

@ -7,31 +7,46 @@ pub trait Index {
impl Index for Comptime<u32> {
fn value(self) -> Variable {
Variable::ConstantScalar(self.inner as f64, Elem::UInt)
Variable::ConstantScalar {
value: self.inner as f64,
elem: Elem::UInt,
}
}
}
impl Index for Comptime<i32> {
fn value(self) -> Variable {
Variable::ConstantScalar(self.inner as f64, Elem::UInt)
Variable::ConstantScalar {
value: self.inner as f64,
elem: Elem::UInt,
}
}
}
impl Index for i32 {
fn value(self) -> Variable {
Variable::ConstantScalar(self as f64, Elem::UInt)
Variable::ConstantScalar {
value: self as f64,
elem: Elem::UInt,
}
}
}
impl Index for u32 {
fn value(self) -> Variable {
Variable::ConstantScalar(self as f64, Elem::UInt)
Variable::ConstantScalar {
value: self as f64,
elem: Elem::UInt,
}
}
}
impl Index for UInt {
fn value(self) -> Variable {
Variable::ConstantScalar(self.val as f64, Elem::UInt)
Variable::ConstantScalar {
value: self.val as f64,
elem: Elem::UInt,
}
}
}

View File

@ -19,7 +19,7 @@ pub mod assign {
}
pub mod index_assign {
use crate::{frontend::CubeType, unexpanded};
use crate::{frontend::CubeType, prelude::SliceMut, unexpanded};
use self::ir::{BinaryOperator, Operator, Variable};
@ -34,7 +34,10 @@ pub mod index_assign {
let array = array.into();
let index: Variable = *index.into();
let index = match index {
Variable::ConstantScalar(val, _) => Variable::ConstantScalar(val, ir::Elem::UInt),
Variable::ConstantScalar { value, .. } => Variable::ConstantScalar {
value,
elem: ir::Elem::UInt,
},
_ => index,
};
context.register(Operator::IndexAssign(BinaryOperator {
@ -58,6 +61,12 @@ pub mod index_assign {
impl_index!(Array);
impl_index!(Tensor);
impl_index!(SharedMemory);
impl<'a, E: CubeType, I: Into<UInt>> core::ops::IndexMut<I> for SliceMut<'a, E> {
fn index_mut(&mut self, _index: I) -> &mut Self::Output {
unexpanded!()
}
}
}
pub mod index {
@ -66,6 +75,7 @@ pub mod index {
operation::base::{binary_expand, binary_expand_no_vec},
CubeType,
},
prelude::{Slice, SliceMut},
unexpanded,
};
@ -78,20 +88,21 @@ pub mod index {
array: L,
index: R,
) -> ExpandElement {
let index = index.into();
let index: ExpandElement = index.into();
let index_var: Variable = *index;
let index = match index_var {
Variable::ConstantScalar(val, _) => {
ExpandElement::Plain(Variable::ConstantScalar(val, ir::Elem::UInt))
Variable::ConstantScalar { value, .. } => {
ExpandElement::Plain(Variable::ConstantScalar {
value,
elem: ir::Elem::UInt,
})
}
_ => index,
};
let array = array.into();
let array: ExpandElement = array.into();
let var: Variable = *array;
match var {
Variable::Local(_, _, _) => {
binary_expand_no_vec(context, array, index, Operator::Index)
}
Variable::Local { .. } => binary_expand_no_vec(context, array, index, Operator::Index),
_ => binary_expand(context, array, index, Operator::Index),
}
}
@ -111,6 +122,20 @@ pub mod index {
impl_index!(Array);
impl_index!(Tensor);
impl_index!(SharedMemory);
impl<'a, E: CubeType, I: Into<UInt>> core::ops::Index<I> for SliceMut<'a, E> {
type Output = E;
fn index(&self, _index: I) -> &Self::Output {
unexpanded!()
}
}
impl<'a, E: CubeType, I: Into<UInt>> core::ops::Index<I> for Slice<'a, E> {
type Output = E;
fn index(&self, _index: I) -> &Self::Output {
unexpanded!()
}
}
}
pub mod add_assign_array_op {

View File

@ -352,7 +352,7 @@ macro_rules! cpa {
};
// out = len(array)
($scope:expr, $out:ident = len($input:expr)) => {
$scope.register($crate::ir::Metadata::ArrayLength {
$scope.register($crate::ir::Metadata::Length {
var: $input.into(),
out: $out.into(),
});
@ -398,43 +398,64 @@ macro_rules! cpa {
impl From<bool> for Variable {
fn from(value: bool) -> Self {
Self::ConstantScalar(if value { 1.0 } else { 0.0 }, super::Elem::Bool)
Self::ConstantScalar {
value: if value { 1.0 } else { 0.0 },
elem: super::Elem::Bool,
}
}
}
impl From<i32> for Variable {
fn from(value: i32) -> Self {
Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I32))
Self::ConstantScalar {
value: value as f64,
elem: super::Elem::Int(super::IntKind::I32),
}
}
}
impl From<i64> for Variable {
fn from(value: i64) -> Self {
Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I64))
Self::ConstantScalar {
value: value as f64,
elem: super::Elem::Int(super::IntKind::I64),
}
}
}
impl From<f32> for Variable {
fn from(value: f32) -> Self {
Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F32))
Self::ConstantScalar {
value: value as f64,
elem: super::Elem::Float(super::FloatKind::F32),
}
}
}
impl From<f64> for Variable {
fn from(value: f64) -> Self {
Self::ConstantScalar(value, super::Elem::Float(super::FloatKind::F64))
Self::ConstantScalar {
value,
elem: super::Elem::Float(super::FloatKind::F64),
}
}
}
impl From<u32> for Variable {
fn from(value: u32) -> Self {
Self::ConstantScalar(value as f64, super::Elem::UInt)
Self::ConstantScalar {
value: value as f64,
elem: super::Elem::UInt,
}
}
}
impl From<usize> for Variable {
fn from(value: usize) -> Self {
Self::ConstantScalar(value as f64, super::Elem::UInt)
Self::ConstantScalar {
value: value as f64,
elem: super::Elem::UInt,
}
}
}

View File

@ -53,6 +53,7 @@ pub enum Operator {
Assign(UnaryOperator),
Modulo(BinaryOperator),
Index(BinaryOperator),
Slice(SliceOperator),
UncheckedIndex(BinaryOperator),
IndexAssign(BinaryOperator),
UncheckedIndexAssign(BinaryOperator),
@ -84,7 +85,7 @@ pub enum Metadata {
var: Variable,
out: Variable,
},
ArrayLength {
Length {
var: Variable,
out: Variable,
},
@ -120,6 +121,15 @@ pub struct ClampOperator {
pub out: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct SliceOperator {
pub input: Variable,
pub start: Variable,
pub end: Variable,
pub out: Variable,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[allow(missing_docs)]
pub struct ReadGlobalOperator {

View File

@ -19,6 +19,7 @@ pub struct Scope {
pub operations: Vec<Operation>,
locals: Vec<Variable>,
matrices: Vec<Variable>,
slices: Vec<Variable>,
shared_memories: Vec<Variable>,
local_arrays: Vec<Variable>,
reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>,
@ -49,6 +50,7 @@ impl Scope {
operations: Vec::new(),
locals: Vec::new(),
matrices: Vec::new(),
slices: Vec::new(),
local_arrays: Vec::new(),
shared_memories: Vec::new(),
reads_global: Vec::new(),
@ -75,7 +77,10 @@ impl Scope {
I: Into<Item> + Copy,
{
let local = self.create_local(item);
let value = Variable::ConstantScalar(value.into(), item.into().elem());
let value = Variable::ConstantScalar {
value: value.into(),
elem: item.into().elem(),
};
cpa!(self, local = value);
local
}
@ -83,16 +88,35 @@ impl Scope {
/// Create a matrix variable
pub fn create_matrix(&mut self, matrix: Matrix) -> Variable {
let index = self.matrices.len() as u16;
let variable = Variable::Matrix(index, matrix);
let variable = Variable::Matrix {
id: index,
mat: matrix,
};
self.matrices.push(variable);
variable
}
/// Create a slice variable
pub fn create_slice(&mut self, item: Item) -> Variable {
let id = self.slices.len() as u16;
let variable = Variable::Slice {
id,
item,
depth: self.depth,
};
self.slices.push(variable);
variable
}
/// Create a local variable of the given [item type](Item).
pub fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
let item = item.into();
let index = self.new_local_index();
let local = Variable::Local(index, item, self.depth);
let local = Variable::Local {
id: index,
item,
depth: self.depth,
};
self.locals.push(local);
local
}
@ -102,7 +126,11 @@ impl Scope {
/// Useful for _for loops_ and other algorithms that require the control over initialization.
pub fn create_local_undeclared(&mut self, item: Item) -> Variable {
let index = self.new_local_index();
let local = Variable::Local(index, item, self.depth);
let local = Variable::Local {
id: index,
item,
depth: self.depth,
};
self.undeclared += 1;
local
}
@ -131,8 +159,12 @@ impl Scope {
///
/// The index refers to the scalar position for the same [element](Elem) type.
pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable {
let local = Variable::LocalScalar(self.new_local_scalar_index(), elem, self.depth);
let scalar = Variable::GlobalScalar(index, elem);
let local = Variable::LocalScalar {
id: self.new_local_scalar_index(),
elem,
depth: self.depth,
};
let scalar = Variable::GlobalScalar { id: index, elem };
self.reads_scalar.push((local, scalar));
@ -215,7 +247,7 @@ impl Scope {
self.reads_global
.iter()
.map(|(var, strategy, _, _)| match var {
Variable::GlobalInputArray(id, _) => (*id, *strategy),
Variable::GlobalInputArray { id, .. } => (*id, *strategy),
_ => panic!("Can only read global input arrays."),
})
.collect()
@ -233,6 +265,7 @@ impl Scope {
operations: Vec::new(),
locals: Vec::new(),
matrices: Vec::new(),
slices: Vec::new(),
shared_memories: Vec::new(),
local_arrays: Vec::new(),
reads_global: Vec::new(),
@ -259,6 +292,9 @@ impl Scope {
for var in self.matrices.drain(..) {
variables.push(var);
}
for var in self.slices.drain(..) {
variables.push(var);
}
for index in self.index_offset_with_output_layout_position.drain(..) {
if let Some(Operation::Procedure(Procedure::IndexOffsetGlobalWithLayout(proc))) =
@ -357,9 +393,16 @@ impl Scope {
},
_ => item,
};
let input = Variable::GlobalInputArray(index, item_global);
let input = Variable::GlobalInputArray {
id: index,
item: item_global,
};
let index = self.new_local_index();
let local = Variable::Local(index, item, self.depth);
let local = Variable::Local {
id: index,
item,
depth: self.depth,
};
self.reads_global.push((input, strategy, local, position));
self.locals.push(local);
local
@ -369,7 +412,11 @@ impl Scope {
pub fn create_shared<I: Into<Item>>(&mut self, item: I, shared_memory_size: u32) -> Variable {
let item = item.into();
let index = self.new_shared_index();
let shared_memory = Variable::SharedMemory(index, item, shared_memory_size);
let shared_memory = Variable::SharedMemory {
id: index,
item,
length: shared_memory_size,
};
self.shared_memories.push(shared_memory);
shared_memory
}
@ -378,7 +425,12 @@ impl Scope {
pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> Variable {
let item = item.into();
let index = self.new_local_array_index();
let local_array = Variable::LocalArray(index, item, self.depth, array_size);
let local_array = Variable::LocalArray {
id: index,
item,
depth: self.depth,
length: array_size,
};
self.local_arrays.push(local_array);
local_array
}

View File

@ -5,14 +5,52 @@ use serde::{Deserialize, Serialize};
#[allow(missing_docs)]
pub enum Variable {
Rank,
GlobalInputArray(u16, Item),
GlobalScalar(u16, Elem),
GlobalOutputArray(u16, Item),
Local(u16, Item, u8),
LocalScalar(u16, Elem, u8),
ConstantScalar(f64, Elem),
SharedMemory(u16, Item, u32),
LocalArray(u16, Item, u8, u32),
GlobalInputArray {
id: u16,
item: Item,
},
GlobalScalar {
id: u16,
elem: Elem,
},
GlobalOutputArray {
id: u16,
item: Item,
},
Local {
id: u16,
item: Item,
depth: u8,
},
LocalScalar {
id: u16,
elem: Elem,
depth: u8,
},
ConstantScalar {
value: f64,
elem: Elem,
},
SharedMemory {
id: u16,
item: Item,
length: u32,
},
LocalArray {
id: u16,
item: Item,
depth: u8,
length: u32,
},
Matrix {
id: u16,
mat: Matrix,
},
Slice {
id: u16,
item: Item,
depth: u8,
},
UnitPos,
UnitPosX,
UnitPosY,
@ -34,20 +72,21 @@ pub enum Variable {
AbsolutePosX,
AbsolutePosY,
AbsolutePosZ,
Matrix(u16, Matrix),
}
impl Variable {
pub fn index(&self) -> Option<u16> {
match self {
Variable::GlobalInputArray(idx, _) => Some(*idx),
Variable::GlobalScalar(idx, _) => Some(*idx),
Variable::Local(idx, _, _) => Some(*idx),
Variable::LocalScalar(idx, _, _) => Some(*idx),
Variable::GlobalOutputArray(idx, _) => Some(*idx),
Variable::ConstantScalar(_, _) => None,
Variable::SharedMemory(idx, _, _) => Some(*idx),
Variable::LocalArray(idx, _, _, _) => Some(*idx),
Variable::GlobalInputArray { id, .. } => Some(*id),
Variable::GlobalScalar { id, .. } => Some(*id),
Variable::Local { id, .. } => Some(*id),
Variable::Slice { id, .. } => Some(*id),
Variable::LocalScalar { id, .. } => Some(*id),
Variable::GlobalOutputArray { id, .. } => Some(*id),
Variable::ConstantScalar { .. } => None,
Variable::SharedMemory { id, .. } => Some(*id),
Variable::LocalArray { id, .. } => Some(*id),
Variable::Matrix { id, .. } => Some(*id),
Variable::AbsolutePos => None,
Variable::UnitPos => None,
Variable::UnitPosX => None,
@ -70,21 +109,22 @@ impl Variable {
Variable::CubeCount => None,
Variable::CubeDim => None,
Variable::SubcubeDim => None,
Variable::Matrix(idx, _) => Some(*idx),
}
}
/// Fetch the item of the variable.
pub fn item(&self) -> Item {
match self {
Variable::GlobalInputArray(_, item) => *item,
Variable::GlobalOutputArray(_, item) => *item,
Variable::GlobalScalar(_, elem) => Item::new(*elem),
Variable::Local(_, item, _) => *item,
Variable::LocalScalar(_, elem, _) => Item::new(*elem),
Variable::ConstantScalar(_, elem) => Item::new(*elem),
Variable::SharedMemory(_, item, _) => *item,
Variable::LocalArray(_, item, _, _) => *item,
Variable::GlobalInputArray { item, .. } => *item,
Variable::GlobalOutputArray { item, .. } => *item,
Variable::GlobalScalar { elem, .. } => Item::new(*elem),
Variable::Local { item, .. } => *item,
Variable::LocalScalar { elem, .. } => Item::new(*elem),
Variable::ConstantScalar { elem, .. } => Item::new(*elem),
Variable::SharedMemory { item, .. } => *item,
Variable::LocalArray { item, .. } => *item,
Variable::Slice { item, .. } => *item,
Variable::Matrix { mat, .. } => Item::new(mat.elem),
Variable::AbsolutePos => Item::new(Elem::UInt),
Variable::Rank => Item::new(Elem::UInt),
Variable::UnitPos => Item::new(Elem::UInt),
@ -107,7 +147,6 @@ impl Variable {
Variable::CubeCount => Item::new(Elem::UInt),
Variable::CubeDim => Item::new(Elem::UInt),
Variable::SubcubeDim => Item::new(Elem::UInt),
Variable::Matrix(_, matrix) => Item::new(matrix.elem),
}
}
}

View File

@ -1,6 +1,6 @@
use super::{
BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator, Subcube,
UnaryOperator, Variable,
BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator,
SliceOperator, Subcube, UnaryOperator, Variable,
};
pub type Vectorization = u8;
@ -60,7 +60,7 @@ impl Operator {
Operator::LowerEqual(op) => Operator::LowerEqual(op.vectorize(vectorization)),
Operator::GreaterEqual(op) => Operator::GreaterEqual(op.vectorize(vectorization)),
Operator::Assign(op) => {
if let Variable::GlobalScalar(_, _) = op.input {
if let Variable::GlobalScalar { .. } = op.input {
// Assign will not change the type of the output if the input can't be
// vectorized.
return Operator::Assign(op.clone());
@ -81,6 +81,7 @@ impl Operator {
Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)),
Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)),
Operator::Remainder(op) => Operator::Remainder(op.vectorize(vectorization)),
Operator::Slice(op) => Operator::Slice(op.vectorize(vectorization)),
}
}
}
@ -104,6 +105,22 @@ impl UnaryOperator {
}
}
impl SliceOperator {
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
let input = self.input.vectorize(vectorization);
let start = self.start.vectorize(vectorization);
let end = self.end.vectorize(vectorization);
let out = self.out.vectorize(vectorization);
Self {
input,
start,
end,
out,
}
}
}
impl InitOperator {
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
let out = self.out.vectorize(vectorization);
@ -155,31 +172,46 @@ impl FmaOperator {
impl Variable {
pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self {
match self {
Variable::GlobalInputArray(index, item) => {
Variable::GlobalInputArray(*index, item.vectorize(vectorize))
}
Variable::Local(index, item, name) => {
Variable::Local(*index, item.vectorize(vectorize), *name)
}
Variable::GlobalOutputArray(index, item) => {
Variable::GlobalOutputArray(*index, item.vectorize(vectorize))
}
Variable::SharedMemory(index, item, size) => Variable::SharedMemory(
*index,
item.vectorize(vectorize),
item.vectorized_size(vectorize, *size),
),
Variable::LocalArray(index, item, name, size) => Variable::LocalArray(
*index,
item.vectorize(vectorize),
*name,
item.vectorized_size(vectorize, *size),
),
Variable::ConstantScalar(_, _) => *self,
Variable::GlobalScalar(_, _) => *self,
Variable::GlobalInputArray { id, item } => Variable::GlobalInputArray {
id: *id,
item: item.vectorize(vectorize),
},
Variable::Local { id, item, depth } => Variable::Local {
id: *id,
item: item.vectorize(vectorize),
depth: *depth,
},
Variable::Slice { id, item, depth } => Variable::Slice {
id: *id,
item: item.vectorize(vectorize),
depth: *depth,
},
Variable::GlobalOutputArray { id, item } => Variable::GlobalOutputArray {
id: *id,
item: item.vectorize(vectorize),
},
Variable::SharedMemory { id, item, length } => Variable::SharedMemory {
id: *id,
item: item.vectorize(vectorize),
length: item.vectorized_size(vectorize, *length),
},
Variable::LocalArray {
id,
item,
depth,
length,
} => Variable::LocalArray {
id: *id,
item: item.vectorize(vectorize),
depth: *depth,
length: item.vectorized_size(vectorize, *length),
},
Variable::ConstantScalar { .. } => *self,
Variable::GlobalScalar { .. } => *self,
Variable::AbsolutePos => *self,
Variable::Rank => *self,
Variable::LocalScalar(_, _, _) => *self,
Variable::LocalScalar { .. } => *self,
Variable::Matrix { .. } => *self,
Variable::UnitPos => *self,
Variable::UnitPosX => *self,
Variable::UnitPosY => *self,
@ -200,7 +232,6 @@ impl Variable {
Variable::CubeCount => *self,
Variable::CubeDim => *self,
Variable::SubcubeDim => *self,
Variable::Matrix(_, _) => *self,
}
}
}

View File

@ -11,7 +11,8 @@ pub use crate::runtime::Runtime;
/// Elements
pub use crate::frontend::{
Array, ArrayHandle, Bool, Float, LaunchArg, Tensor, TensorArg, UInt, F16, F32, F64, I32, I64,
Array, ArrayHandle, Bool, Float, LaunchArg, Slice, SliceMut, Tensor, TensorArg, UInt, F16, F32,
F64, I32, I64,
};
pub use crate::pod::CubeElement;

View File

@ -31,12 +31,17 @@ pub fn kernel_simple_1(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>)
cmma::MatrixLayout::Undefined,
);
cmma::fill::<F32>(&c, F32::new(0.0));
cmma::load::<F16>(&a, lhs, UInt::new(16));
cmma::load::<F16>(&b, rhs, UInt::new(16));
cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
cmma::store::<F32>(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor);
cmma::store::<F32>(
out.as_slice_mut(),
&c,
UInt::new(16),
cmma::MatrixLayout::RowMajor,
);
}
pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {

View File

@ -1,5 +1,6 @@
pub mod cmma;
pub mod launch;
pub mod slice;
pub mod subcube;
#[allow(missing_docs)]
@ -11,5 +12,6 @@ macro_rules! testgen_all {
burn_cube::testgen_subcube!();
burn_cube::testgen_launch!();
burn_cube::testgen_cmma!();
burn_cube::testgen_slice!();
};
}

View File

@ -0,0 +1,107 @@
use crate as burn_cube;
use burn_cube::prelude::*;
#[cube(launch)]
pub fn slice_select<F: Float>(input: &Array<F>, output: &mut Array<F>) {
if UNIT_POS == UInt::new(0) {
let slice = input.slice(2, 3);
output[0] = slice[0u32];
}
}
#[cube(launch)]
pub fn slice_assign<F: Float>(input: &Array<F>, output: &mut Array<F>) {
if UNIT_POS == UInt::new(0) {
let slice_1 = output.slice_mut(2, 3);
slice_1[0] = input[0u32];
}
}
#[cube(launch)]
pub fn slice_len<F: Float>(input: &Array<F>, output: &mut Array<UInt>) {
if UNIT_POS == UInt::new(0) {
let slice = input.slice(2, 4);
let _tmp = slice[0]; // It must be used at least once, otherwise wgpu isn't happy.
output[0] = slice.len();
}
}
pub fn test_slice_select<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
let output = client.empty(core::mem::size_of::<f32>());
slice_select_launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(1, 1, 1),
ArrayArg::new(&input, 5),
ArrayArg::new(&output, 1),
);
let actual = client.read(output.binding());
let actual = f32::from_bytes(&actual);
assert_eq!(actual[0], 2.0);
}
pub fn test_slice_len<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
let output = client.empty(core::mem::size_of::<u32>());
slice_len_launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(1, 1, 1),
ArrayArg::new(&input, 5),
ArrayArg::new(&output, 1),
);
let actual = client.read(output.binding());
let actual = u32::from_bytes(&actual);
assert_eq!(actual, &[2]);
}
pub fn test_slice_assign<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let input = client.create(f32::as_bytes(&[15.0]));
let output = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
slice_assign_launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(1, 1, 1),
ArrayArg::new(&input, 5),
ArrayArg::new(&output, 1),
);
let actual = client.read(output.binding());
let actual = f32::from_bytes(&actual);
assert_eq!(actual, &[0.0, 1.0, 15.0, 3.0, 4.0]);
}
#[allow(missing_docs)]
#[macro_export]
macro_rules! testgen_slice {
() => {
use super::*;
#[test]
fn test_slice_select() {
let client = TestRuntime::client(&Default::default());
burn_cube::runtime_tests::slice::test_slice_select::<TestRuntime>(client);
}
#[test]
fn test_slice_assign() {
let client = TestRuntime::client(&Default::default());
burn_cube::runtime_tests::slice::test_slice_assign::<TestRuntime>(client);
}
#[test]
fn test_slice_len() {
let client = TestRuntime::client(&Default::default());
burn_cube::runtime_tests::slice::test_slice_len::<TestRuntime>(client);
}
};
}

View File

@ -36,7 +36,7 @@ mod tests {
use super::*;
use burn_cube::{
cpa,
ir::{Elem, Item, Variable},
ir::{self, Elem, Item, Variable},
};
type ElemType = F32;
@ -47,7 +47,7 @@ mod tests {
array_read_write::__expand::<ElemType>(&mut context, 512);
assert_eq!(
format!("{:?}", context.into_scope().operations),
context.into_scope().operations,
inline_macro_ref_read_write()
)
}
@ -60,10 +60,7 @@ mod tests {
array_add_assign_simple::__expand(&mut context, array.into());
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_array_add_assign_simple()
);
assert_eq!(scope.operations, inline_macro_array_add_assign_simple());
}
#[test]
@ -72,7 +69,7 @@ mod tests {
array_to_vectorized_variable::__expand::<ElemType>(&mut context);
assert_eq!(
format!("{:?}", context.into_scope().operations),
context.into_scope().operations,
inline_macro_ref_to_vectorized()
);
}
@ -83,12 +80,12 @@ mod tests {
array_of_one_to_vectorized_variable::__expand::<ElemType>(&mut context);
assert_eq!(
format!("{:?}", context.into_scope().operations),
context.into_scope().operations,
inline_macro_ref_one_to_vectorized()
);
}
fn inline_macro_ref_read_write() -> String {
fn inline_macro_ref_read_write() -> Vec<ir::Operation> {
let context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
@ -105,7 +102,7 @@ mod tests {
// Read
cpa!(scope, var = array[pos]);
format!("{:?}", scope.operations)
scope.operations
}
#[test]
@ -116,30 +113,36 @@ mod tests {
array_add_assign_expr::__expand(&mut context, array.into());
let scope = context.into_scope();
assert_eq!(
format!("{:?}", scope.operations),
inline_macro_array_add_assign_expr()
);
assert_eq!(scope.operations, inline_macro_array_add_assign_expr());
}
fn inline_macro_array_add_assign_simple() -> String {
fn inline_macro_array_add_assign_simple() -> Vec<ir::Operation> {
let context = CubeContext::root();
let mut scope = context.into_scope();
let local = scope.create_local(Item::new(Elem::UInt));
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
let index = Variable::ConstantScalar(1., Elem::UInt);
let value = Variable::ConstantScalar(1., Elem::UInt);
let array = Variable::GlobalInputArray {
id: 0,
item: Item::new(Elem::UInt),
};
let index = Variable::ConstantScalar {
value: 1.,
elem: Elem::UInt,
};
let value = Variable::ConstantScalar {
value: 1.,
elem: Elem::UInt,
};
cpa!(scope, local = array[index]);
cpa!(scope, local += value);
cpa!(scope, array[index] = local);
format!("{:?}", scope.operations)
scope.operations
}
fn inline_macro_ref_to_vectorized() -> String {
fn inline_macro_ref_to_vectorized() -> Vec<ir::Operation> {
let context = CubeContext::root();
let scalar_item = Item::new(ElemType::as_elem());
let vectorized_item = Item::vectorized(ElemType::as_elem(), 2);
@ -158,10 +161,10 @@ mod tests {
cpa!(scope, tmp = array[pos1]);
cpa!(scope, vectorized_var[pos1] = tmp);
format!("{:?}", scope.operations)
scope.operations
}
fn inline_macro_ref_one_to_vectorized() -> String {
fn inline_macro_ref_one_to_vectorized() -> Vec<ir::Operation> {
let context = CubeContext::root();
let scalar_item = Item::new(ElemType::as_elem());
let unvectorized_item = Item::new(ElemType::as_elem());
@ -176,26 +179,38 @@ mod tests {
cpa!(scope, tmp = array[pos0]);
cpa!(scope, unvectorized_var = tmp);
format!("{:?}", scope.operations)
scope.operations
}
fn inline_macro_array_add_assign_expr() -> String {
fn inline_macro_array_add_assign_expr() -> Vec<ir::Operation> {
let context = CubeContext::root();
let mut scope = context.into_scope();
let index = scope.create_local(Item::new(Elem::UInt));
let local = scope.create_local(Item::new(Elem::UInt));
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
let const1 = Variable::ConstantScalar(1., Elem::UInt);
let const2 = Variable::ConstantScalar(5., Elem::UInt);
let value = Variable::ConstantScalar(1., Elem::UInt);
let array = Variable::GlobalInputArray {
id: 0,
item: Item::new(Elem::UInt),
};
let const1 = Variable::ConstantScalar {
value: 1.,
elem: Elem::UInt,
};
let const2 = Variable::ConstantScalar {
value: 5.,
elem: Elem::UInt,
};
let value = Variable::ConstantScalar {
value: 1.,
elem: Elem::UInt,
};
cpa!(scope, index = const1 + const2);
cpa!(scope, local = array[index]);
cpa!(scope, local += value);
cpa!(scope, array[index] = local);
format!("{:?}", scope.operations)
scope.operations
}
}

View File

@ -98,8 +98,14 @@ mod tests {
let mut scope = context.into_scope();
let x = scope.create_local(Item::new(Elem::UInt));
let zero = Variable::ConstantScalar(0., Elem::UInt);
let one = Variable::ConstantScalar(1., Elem::UInt);
let zero = Variable::ConstantScalar {
value: 0.,
elem: Elem::UInt,
};
let one = Variable::ConstantScalar {
value: 1.,
elem: Elem::UInt,
};
cpa!(scope, x = zero);
cpa!(scope, x = x + one);
@ -115,8 +121,14 @@ mod tests {
let y: Variable = y.into();
let x = scope.create_local(item);
let one = Variable::ConstantScalar(1., Elem::UInt);
let two = Variable::ConstantScalar(2., Elem::UInt);
let one = Variable::ConstantScalar {
value: 1.,
elem: Elem::UInt,
};
let two = Variable::ConstantScalar {
value: 2.,
elem: Elem::UInt,
};
cpa!(scope, x = y);
cpa!(scope, x = x + one);
cpa!(scope, y = y + two);
@ -133,8 +145,14 @@ mod tests {
let y: Variable = y.into();
let x = scope.create_local(item);
let one = Variable::ConstantScalar(1., Elem::UInt);
let two = Variable::ConstantScalar(2., Elem::UInt);
let one = Variable::ConstantScalar {
value: 1.,
elem: Elem::UInt,
};
let two = Variable::ConstantScalar {
value: 2.,
elem: Elem::UInt,
};
cpa!(scope, x = y);
cpa!(scope, y = y + one);
cpa!(scope, x = x + two);
@ -151,10 +169,22 @@ mod tests {
let y: Variable = y.into();
let x = scope.create_local(item);
let zero = Variable::ConstantScalar(0., Elem::UInt);
let one = Variable::ConstantScalar(1., Elem::UInt);
let two = Variable::ConstantScalar(2., Elem::UInt);
let three = Variable::ConstantScalar(3., Elem::UInt);
let zero = Variable::ConstantScalar {
value: 0.,
elem: Elem::UInt,
};
let one = Variable::ConstantScalar {
value: 1.,
elem: Elem::UInt,
};
let two = Variable::ConstantScalar {
value: 2.,
elem: Elem::UInt,
};
let three = Variable::ConstantScalar {
value: 3.,
elem: Elem::UInt,
};
cpa!(scope, x[zero] = one);
cpa!(scope, x[one] = one);
cpa!(scope, x[two] = one);

View File

@ -395,31 +395,31 @@ mod tests {
let out_number = if in_type == out_type { 0 } else { 2 };
format!(
"[Operator({ops_name}(BinaryOperator {{ \
lhs: Local(0, Item {{ \
lhs: Local {{ id: 0, item: Item {{ \
elem: {in_type}, \
vectorization: 1 \
}}, 0), \
rhs: Local(1, Item {{ \
}}, depth: 0 }}, \
rhs: Local {{ id: 1, item: Item {{ \
elem: {in_type}, \
vectorization: 1 \
}}, 0), \
out: Local({out_number}, Item {{ \
}}, depth: 0 }}, \
out: Local {{ id: {out_number}, item: Item {{ \
elem: {out_type}, \
vectorization: 1 \
}}, 0) \
}}, depth: 0 }} \
}}))]"
)
} else {
format!(
"[Operator({ops_name}(UnaryOperator {{ \
input: Local(0, Item {{ \
input: Local {{ id: 0, item: Item {{ \
elem: {in_type}, \
vectorization: 1 \
}}, 0), \
out: Local(0, Item {{ \
}}, depth: 0 }}, \
out: Local {{ id: 0, item: Item {{ \
elem: {out_type}, \
vectorization: 1 \
}}, 0) \
}}, depth: 0 }} \
}}))]"
)
}

View File

@ -117,7 +117,10 @@ mod tests {
let i = scope.create_with_value(1, item);
cpa!(scope, x += i);
let value = Variable::ConstantScalar(2., item.elem());
let value = Variable::ConstantScalar {
value: 2.,
elem: item.elem(),
};
cpa!(scope, i = value);
cpa!(scope, x += i);
@ -156,7 +159,10 @@ mod tests {
cpa!(
&mut scope,
range(0u32, end, false).for_each(|_, scope| {
let value = Variable::ConstantScalar(2.into(), item.elem());
let value = Variable::ConstantScalar {
value: 2.into(),
elem: item.elem(),
};
cpa!(scope, y = value);
cpa!(scope, x += y);
})

View File

@ -83,6 +83,9 @@ impl CudaCompiler {
let processing = value.process();
for var in processing.variables {
if let gpu::Variable::Slice { .. } = var {
continue;
}
instructions.push(Instruction::DeclareVariable {
var: self.compile_variable(var),
});
@ -192,8 +195,8 @@ impl CudaCompiler {
gpu::Metadata::Stride { dim, var, out } => {
self.stride = true;
let position = match var {
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
gpu::Variable::GlobalInputArray { id, .. } => id as usize,
gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
};
Instruction::Stride {
@ -205,8 +208,8 @@ impl CudaCompiler {
gpu::Metadata::Shape { dim, var, out } => {
self.shape = true;
let position = match var {
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
gpu::Variable::GlobalInputArray { id, .. } => id as usize,
gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
_ => panic!("Only Input and Output have a shape, got {:?}", var),
};
Instruction::Shape {
@ -215,12 +218,20 @@ impl CudaCompiler {
out: self.compile_variable(out),
}
}
gpu::Metadata::ArrayLength { var, out } => super::Instruction::ArrayLength {
input: self.compile_variable(var),
out: self.compile_variable(out),
num_inputs: self.num_inputs,
num_outputs: self.num_outputs,
},
gpu::Metadata::Length { var, out } => {
let input = self.compile_variable(var);
let out = self.compile_variable(out);
match input {
super::Variable::Slice { .. } => super::Instruction::SliceLength { input, out },
_ => super::Instruction::Length {
input,
out,
num_inputs: self.num_inputs,
num_outputs: self.num_outputs,
},
}
}
}
}
@ -297,6 +308,12 @@ impl CudaCompiler {
gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)),
gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)),
gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)),
gpu::Operator::Slice(op) => Instruction::Slice {
input: self.compile_variable(op.input),
start: self.compile_variable(op.start),
end: self.compile_variable(op.end),
out: self.compile_variable(op.out),
},
gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)),
gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)),
gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)),
@ -372,35 +389,40 @@ impl CudaCompiler {
fn compile_variable(&mut self, value: gpu::Variable) -> super::Variable {
match value {
gpu::Variable::GlobalInputArray(index, item) => {
super::Variable::GlobalInputArray(index, Self::compile_item(item))
gpu::Variable::GlobalInputArray { id, item } => {
super::Variable::GlobalInputArray(id, Self::compile_item(item))
}
gpu::Variable::GlobalScalar(index, elem) => {
super::Variable::GlobalScalar(index, Self::compile_elem(elem), elem)
gpu::Variable::GlobalScalar { id, elem } => {
super::Variable::GlobalScalar(id, Self::compile_elem(elem), elem)
}
gpu::Variable::Local(index, item, scope_depth) => super::Variable::Local {
index,
gpu::Variable::Local { id, item, depth } => super::Variable::Local {
id,
item: Self::compile_item(item),
scope_depth,
depth,
},
gpu::Variable::LocalScalar(index, elem, scope_depth) => super::Variable::LocalScalar {
index,
gpu::Variable::Slice { id, item, depth } => super::Variable::Slice {
id,
item: Self::compile_item(item),
depth,
},
gpu::Variable::LocalScalar { id, elem, depth } => super::Variable::LocalScalar {
id,
elem: Self::compile_elem(elem),
scope_depth,
depth,
},
gpu::Variable::GlobalOutputArray(index, item) => {
super::Variable::GlobalOutputArray(index, Self::compile_item(item))
gpu::Variable::GlobalOutputArray { id, item } => {
super::Variable::GlobalOutputArray(id, Self::compile_item(item))
}
gpu::Variable::ConstantScalar(index, elem) => {
super::Variable::ConstantScalar(index, Self::compile_elem(elem))
gpu::Variable::ConstantScalar { value, elem } => {
super::Variable::ConstantScalar(value, Self::compile_elem(elem))
}
gpu::Variable::SharedMemory(index, item, size) => {
gpu::Variable::SharedMemory { id, item, length } => {
let item = Self::compile_item(item);
if !self.shared_memories.iter().any(|s| s.index == index) {
if !self.shared_memories.iter().any(|s| s.index == id) {
self.shared_memories
.push(super::SharedMemory::new(index, item, size));
.push(super::SharedMemory::new(id, item, length));
}
super::Variable::SharedMemory(index, item, size)
super::Variable::SharedMemory(id, item, length)
}
gpu::Variable::AbsolutePos => {
self.id = true;
@ -438,7 +460,12 @@ impl CudaCompiler {
gpu::Variable::CubeCountX => super::Variable::NumWorkgroupsX,
gpu::Variable::CubeCountY => super::Variable::NumWorkgroupsY,
gpu::Variable::CubeCountZ => super::Variable::NumWorkgroupsZ,
gpu::Variable::LocalArray(id, item, depth, size) => {
gpu::Variable::LocalArray {
id,
item,
depth,
length,
} => {
let item = Self::compile_item(item);
if !self
.local_arrays
@ -446,19 +473,19 @@ impl CudaCompiler {
.any(|s| s.index == id && s.depth == depth)
{
self.local_arrays
.push(super::LocalArray::new(id, item, depth, size));
.push(super::LocalArray::new(id, item, depth, length));
}
super::Variable::LocalArray(id, item, depth, size)
super::Variable::LocalArray(id, item, depth, length)
}
gpu::Variable::CubePos => todo!(),
gpu::Variable::CubeDim => todo!(),
gpu::Variable::CubeCount => todo!(),
gpu::Variable::SubcubeDim => todo!(),
gpu::Variable::Matrix(index, matrix) => {
gpu::Variable::Matrix { id, mat } => {
self.wmma = true;
super::Variable::WmmaFragment {
index,
frag: Self::compile_matrix(matrix),
id,
frag: Self::compile_matrix(mat),
}
}
}

View File

@ -361,9 +361,9 @@ impl Binary for IndexAssign {
out: &Variable,
) -> std::fmt::Result {
if let Variable::Local {
index: _,
id: _,
item: _,
scope_depth: _,
depth: _,
} = out
{
return IndexAssignVector::format(f, lhs, rhs, out);
@ -388,9 +388,9 @@ impl Binary for Index {
out: &Variable,
) -> std::fmt::Result {
if let Variable::Local {
index: _,
id: _,
item: _,
scope_depth: _,
depth: _,
} = lhs
{
return IndexVector::format(f, lhs, rhs, out);

View File

@ -86,9 +86,14 @@ impl Component for Variable {
Variable::GlobalOutputArray(_, e) => *e,
Variable::SharedMemory(_, e, _) => *e,
Variable::Local {
index: _,
id: _,
item,
scope_depth: _,
depth: _,
} => *item,
Variable::Slice {
id: _,
item,
depth: _,
} => *item,
Variable::ConstantScalar(_, e) => Item::Scalar(*e),
Variable::GlobalScalar(_, e, _) => Item::Scalar(*e),
@ -99,9 +104,9 @@ impl Component for Variable {
Variable::LocalInvocationIdZ => Item::Scalar(Elem::U32),
Variable::Rank => Item::Scalar(Elem::U32),
Variable::LocalScalar {
index: _,
id: _,
elem,
scope_depth: _,
depth: _,
} => Item::Scalar(*elem),
Variable::WorkgroupIdX => Item::Scalar(Elem::U32),
Variable::WorkgroupIdY => Item::Scalar(Elem::U32),
@ -117,7 +122,7 @@ impl Component for Variable {
Variable::NumWorkgroupsZ => Item::Scalar(Elem::U32),
Variable::LocalArray(_, e, _, _) => *e,
Variable::WarpSize => Item::Scalar(Elem::U32),
Variable::WmmaFragment { index: _, frag } => Item::Scalar(frag.elem),
Variable::WmmaFragment { id: _, frag } => Item::Scalar(frag.elem),
}
}
}
@ -129,16 +134,9 @@ pub enum Variable {
GlobalOutputArray(u16, Item),
GlobalScalar(u16, Elem, gpu::Elem),
ConstantScalar(f64, Elem),
Local {
index: u16,
item: Item,
scope_depth: u8,
},
LocalScalar {
index: u16,
elem: Elem,
scope_depth: u8,
},
Local { id: u16, item: Item, depth: u8 },
Slice { id: u16, item: Item, depth: u8 },
LocalScalar { id: u16, elem: Elem, depth: u8 },
SharedMemory(u16, Item, u32),
LocalArray(u16, Item, u8, u32),
Id,
@ -159,10 +157,7 @@ pub enum Variable {
NumWorkgroupsX,
NumWorkgroupsY,
NumWorkgroupsZ,
WmmaFragment {
index: u16,
frag: Fragment,
},
WmmaFragment { id: u16, frag: Fragment },
}
impl Display for Variable {
@ -170,15 +165,18 @@ impl Display for Variable {
match self {
Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")),
Variable::LocalScalar {
index,
id: index,
elem: _,
scope_depth,
depth: scope_depth,
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
Variable::Local {
index,
id: index,
item: _,
scope_depth,
depth: scope_depth,
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
Variable::Slice { id, item: _, depth } => {
f.write_fmt(format_args!("slice_{depth}_{id}"))
}
Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")),
Variable::GlobalScalar(number, _, elem) => {
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
@ -209,7 +207,9 @@ impl Display for Variable {
f.write_fmt(format_args!("l_arr_{}_{}", id, depth))
}
Variable::WarpSize => f.write_str("warpSize"),
Variable::WmmaFragment { index, frag: _ } => f.write_fmt(format_args!("frag_{index}")),
Variable::WmmaFragment { id: index, frag: _ } => {
f.write_fmt(format_args!("frag_{index}"))
}
}
}
}
@ -220,9 +220,9 @@ impl Variable {
Variable::GlobalScalar(_, _, _) => true,
Variable::ConstantScalar(_, _) => true,
Variable::LocalScalar {
index: _,
id: _,
elem: _,
scope_depth: _,
depth: _,
} => true,
Variable::Id => true,
Variable::LocalInvocationIndex => true,
@ -234,9 +234,14 @@ impl Variable {
Variable::GlobalOutputArray(_, _) => false,
Variable::SharedMemory(_, _, _) => false,
Variable::Local {
index: _,
id: _,
item: _,
scope_depth: _,
depth: _,
} => false,
Variable::Slice {
id: _,
item: _,
depth: _,
} => false,
Variable::WorkgroupIdX => true,
Variable::WorkgroupIdY => true,
@ -252,7 +257,7 @@ impl Variable {
Variable::NumWorkgroupsZ => true,
Variable::LocalArray(_, _, _, _) => false,
Variable::WarpSize => true,
Variable::WmmaFragment { index: _, frag: _ } => false,
Variable::WmmaFragment { id: _, frag: _ } => false,
}
}

View File

@ -16,12 +16,16 @@ pub struct UnaryInstruction {
#[derive(Debug, Clone)]
pub enum Instruction {
ArrayLength {
Length {
input: Variable,
out: Variable,
num_inputs: usize,
num_outputs: usize,
},
SliceLength {
input: Variable,
out: Variable,
},
DeclareVariable {
var: Variable,
},
@ -58,6 +62,12 @@ pub enum Instruction {
instructions_if: Vec<Self>,
instructions_else: Vec<Self>,
},
Slice {
input: Variable,
start: Variable,
end: Variable,
out: Variable,
},
Return,
Break,
Stride {
@ -114,7 +124,7 @@ impl Display for Instruction {
Instruction::Return => f.write_str("return;"),
Instruction::Break => f.write_str("break;"),
Instruction::DeclareVariable { var } => match var {
Variable::WmmaFragment { index: _, frag } => {
Variable::WmmaFragment { id: _, frag } => {
f.write_fmt(format_args!("{frag} {var};\n"))
}
_ => {
@ -123,6 +133,16 @@ impl Display for Instruction {
}
},
Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Slice {
input,
start,
end,
out,
} => {
let item = out.item();
f.write_fmt(format_args!("uint {out}_length = {end} - {start};\n"))?;
f.write_fmt(format_args!("{item} *{out} = {input} + {start};\n"))
}
Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
@ -225,7 +245,10 @@ for (uint {i} = {start}; {i} < {end}; {i}++) {{
Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
Instruction::ArrayLength {
Instruction::SliceLength { input, out } => {
f.write_fmt(format_args!("{out} = {input}_length;\n"))
}
Instruction::Length {
input,
out,
num_inputs,

View File

@ -288,7 +288,10 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
return false;
}
let input = Variable::ConstantScalar(1.0, desc.dtype.into());
let input = Variable::ConstantScalar {
value: 1.0,
elem: desc.dtype.into(),
};
let out = self.builder.output(desc, Variable::AbsolutePos);
self.builder
@ -301,7 +304,10 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
return false;
}
let input = Variable::ConstantScalar(0.0, desc.dtype.into());
let input = Variable::ConstantScalar {
value: 0.0,
elem: desc.dtype.into(),
};
let out = self.builder.output(desc, Variable::AbsolutePos);
self.builder

View File

@ -56,9 +56,11 @@ impl TraceBuilder {
}
true => match self.output_to_local.get(&tensor.id) {
// Is a local variable.
Some(local_index) => {
Variable::Local(*local_index, Item::new(elem), self.scope.depth)
}
Some(local_index) => Variable::Local {
id: *local_index,
item: Item::new(elem),
depth: self.scope.depth,
},
// Isn't an operation output variable, so must be an existing input.
None => self
.inputs
@ -85,7 +87,11 @@ impl TraceBuilder {
// Output already registered as a local variable.
if let Some(index) = self.output_to_local.get(&tensor.id) {
return Variable::Local(*index, Item::new(elem), self.scope.depth);
return Variable::Local {
id: *index,
item: Item::new(elem),
depth: self.scope.depth,
};
}
let variable = self.scope.create_local(Item::new(elem));
@ -159,11 +165,11 @@ impl TraceBuilder {
//
// Only local variables can become outputs.
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
if let Variable::Local(index, _, _) = var {
if let Variable::Local { id: id_local, .. } = var {
if let Some((id, _)) = self
.output_to_local
.iter()
.find(|(_id, position)| *position == index)
.find(|(_tensor_id, position)| *position == id_local)
{
if !list.contains(id) {
list.push(*id);
@ -201,6 +207,11 @@ impl TraceBuilder {
mark(&op.c, &mut local_tensor_ids_input);
mark(&op.out, &mut local_tensor_ids_output);
}
Operator::Slice(op) => {
mark(&op.input, &mut local_tensor_ids_input);
mark(&op.start, &mut local_tensor_ids_input);
mark(&op.out, &mut local_tensor_ids_output);
}
Operator::Max(op) => mark_binary(
op,
&mut local_tensor_ids_input,

View File

@ -68,8 +68,14 @@ impl<R: JitRuntime, EI: JitElement, EO: JitElement> Kernel for CastEagerKernel<R
let item_input = EI::cube_elem().into();
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray(0, item_input);
let output = Variable::GlobalOutputArray(0, item_output);
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
CastShader { tensor, output }.expand(&mut scope);

View File

@ -60,8 +60,14 @@ impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
let item_input = Item::new(Elem::Bool);
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray(0, item_input);
let output = Variable::GlobalOutputArray(0, item_output);
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
BoolCastShader { tensor, output }.expand(&mut scope);

View File

@ -93,13 +93,34 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
cpa!(scope, kernel_size_0 = shape(weight, 2u32));
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let groups = Variable::GlobalScalar(6, Elem::UInt);
let conv_stride_0 = Variable::GlobalScalar {
id: 0,
elem: Elem::UInt,
};
let conv_stride_1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let dilation_0 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let groups = Variable::GlobalScalar {
id: 6,
elem: Elem::UInt,
};
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
@ -294,10 +315,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let weight = Variable::GlobalInputArray(1, item);
let bias = Variable::GlobalInputArray(2, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let weight = Variable::GlobalInputArray { id: 1, item };
let bias = Variable::GlobalInputArray { id: 2, item };
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -105,16 +105,46 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
cpa!(scope, kernel_size_2 = shape(weight, 4u32));
let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let conv_stride_2 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(3, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(4, Elem::UInt);
let dilation_2 = Variable::GlobalScalar(5, Elem::UInt);
let padding_0 = Variable::GlobalScalar(6, Elem::UInt);
let padding_1 = Variable::GlobalScalar(7, Elem::UInt);
let padding_2 = Variable::GlobalScalar(8, Elem::UInt);
let groups = Variable::GlobalScalar(9, Elem::UInt);
let conv_stride_0 = Variable::GlobalScalar {
id: 0,
elem: Elem::UInt,
};
let conv_stride_1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let conv_stride_2 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let dilation_0 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let dilation_2 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 6,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 7,
elem: Elem::UInt,
};
let padding_2 = Variable::GlobalScalar {
id: 8,
elem: Elem::UInt,
};
let groups = Variable::GlobalScalar {
id: 9,
elem: Elem::UInt,
};
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
@ -362,10 +392,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv3dTransposeEagerKernel<R, E> {
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let weight = Variable::GlobalInputArray(1, item);
let bias = Variable::GlobalInputArray(2, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let weight = Variable::GlobalInputArray { id: 1, item };
let bias = Variable::GlobalInputArray { id: 2, item };
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -43,7 +43,10 @@ impl FlipComputeShader {
cpa!(scope, shape = shape(output, i));
cpa!(
scope,
flip = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
flip = cast(Variable::GlobalScalar {
id: i as u16,
elem: Elem::UInt
})
);
cpa!(scope, flip_bool = flip == 1u32);
@ -70,8 +73,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for FlipEagerKernel<R, E> {
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -27,8 +27,8 @@ struct GatherComputeShader {
impl GatherComputeShader {
pub fn expand(self, scope: &mut Scope) {
match self.tensor {
Variable::GlobalInputArray(_, _) => (),
Variable::GlobalOutputArray(_, _) => (),
Variable::GlobalInputArray { .. } => (),
Variable::GlobalOutputArray { .. } => (),
_ => panic!("Tensor variable must be an global array."),
};
@ -78,10 +78,16 @@ impl<R: JitRuntime, E: JitElement> Kernel for GatherEagerKernel<R, E> {
let item_tensor = E::cube_elem().into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let tensor = Variable::GlobalInputArray(0, item_tensor);
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_tensor,
};
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
let output_array = Variable::GlobalOutputArray(0, item_tensor);
let output_array = Variable::GlobalOutputArray {
id: 0,
item: item_tensor,
};
let output_local = scope.create_local(item_tensor);
GatherComputeShader {

View File

@ -61,8 +61,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for RepeatEagerKernel<R, E> {
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -32,13 +32,13 @@ struct ScatterComputeShader {
impl ScatterComputeShader {
pub fn expand(self, scope: &mut Scope) {
match self.input {
Variable::GlobalInputArray(_, _) => (),
Variable::GlobalOutputArray(_, _) => (),
Variable::GlobalInputArray { .. } => (),
Variable::GlobalOutputArray { .. } => (),
_ => panic!("Input variable must be an global array."),
};
match self.value {
Variable::GlobalInputArray(_, _) => (),
Variable::GlobalOutputArray(_, _) => (),
Variable::GlobalInputArray { .. } => (),
Variable::GlobalOutputArray { .. } => (),
_ => panic!("Value variable must be an global array."),
};
@ -136,9 +136,18 @@ impl<R: JitRuntime, E: JitElement> Kernel for ScatterEagerKernel<R, E> {
let item_value = E::cube_elem().into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let input_output = Variable::GlobalInputArray(0, item_value);
let indices = Variable::GlobalInputArray(1, Elem::Int(IntKind::I32).into());
let value = Variable::GlobalInputArray(2, item_value);
let input_output = Variable::GlobalInputArray {
id: 0,
item: item_value,
};
let indices = Variable::GlobalInputArray {
id: 1,
item: Elem::Int(IntKind::I32).into(),
};
let value = Variable::GlobalInputArray {
id: 2,
item: item_value,
};
scope.write_global_custom(input_output);

View File

@ -73,9 +73,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for SelectEagerKernel<R, E> {
let item = E::cube_elem().into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let input = Variable::GlobalInputArray(0, item);
let indices = Variable::GlobalInputArray(1, item_indices);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let indices = Variable::GlobalInputArray {
id: 1,
item: item_indices,
};
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -134,9 +134,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for SelectAssignEagerKernel<R, E> {
let item = E::cube_elem().into();
let item_indices: Item = Elem::Int(IntKind::I32).into();
let tensor = Variable::GlobalInputArray(0, item);
let value = Variable::GlobalInputArray(1, item);
let indices = Variable::GlobalInputArray(2, item_indices);
let tensor = Variable::GlobalInputArray { id: 0, item };
let value = Variable::GlobalInputArray { id: 1, item };
let indices = Variable::GlobalInputArray {
id: 2,
item: item_indices,
};
scope.write_global_custom(tensor);

View File

@ -44,7 +44,10 @@ impl SliceComputeShader {
cpa!(scope, shape_output = shape(output, i));
cpa!(
scope,
range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
range_start = cast(Variable::GlobalScalar {
id: i as u16,
elem: Elem::UInt
})
);
cpa!(scope, offset_local = id / stride_output);
@ -66,8 +69,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceEagerKernel<R, E> {
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -47,7 +47,10 @@ impl SliceAssignComputeShader {
cpa!(scope, shape_input = shape(input, i));
cpa!(
scope,
range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
range_start = cast(Variable::GlobalScalar {
id: i as u16,
elem: Elem::UInt
})
);
cpa!(scope, offset_local = id / stride_value);
@ -75,8 +78,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceAssignEagerKernel<R, E> {
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let value = Variable::GlobalInputArray(1, item);
let input = Variable::GlobalInputArray { id: 0, item };
let value = Variable::GlobalInputArray { id: 1, item };
scope.write_global_custom(input);

View File

@ -371,8 +371,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBicubicEagerKernel<R, E
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
InterpolateBicubicShader {
input,

View File

@ -189,8 +189,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBilinearEagerKernel<R,
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
InterpolateBilinearShader { input, output }.expand(&mut scope);

View File

@ -127,8 +127,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestEagerKernel<R, E
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
InterpolateNearestShader {
input,

View File

@ -185,8 +185,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestBackwardEagerKer
let mut scope = Scope::root();
let item = E::cube_elem().into();
let out_grad = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let out_grad = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
InterpolateNearestBackwardShader {
out_grad,

View File

@ -40,7 +40,10 @@ impl MaskStrategy for MaskFill {
}
fn value_variable(value_item: Item) -> Variable {
Variable::GlobalScalar(0, value_item.elem())
Variable::GlobalScalar {
id: 0,
elem: value_item.elem(),
}
}
}
@ -65,7 +68,10 @@ impl MaskStrategy for MaskWhere {
}
fn value_variable(value_item: Item) -> Variable {
Variable::GlobalInputArray(2, value_item)
Variable::GlobalInputArray {
id: 2,
item: value_item,
}
}
}
@ -102,10 +108,19 @@ impl<M: MaskStrategy, R: JitRuntime, EI: JitElement, EM: JitElement> Kernel
let tensor_item = EI::cube_elem().into();
let mask_item = EM::cube_elem().into();
let input = Variable::GlobalInputArray(0, tensor_item);
let mask = Variable::GlobalInputArray(1, mask_item);
let input = Variable::GlobalInputArray {
id: 0,
item: tensor_item,
};
let mask = Variable::GlobalInputArray {
id: 1,
item: mask_item,
};
let value = M::value_variable(tensor_item);
let output = Variable::GlobalOutputArray(0, tensor_item);
let output = Variable::GlobalOutputArray {
id: 0,
item: tensor_item,
};
MaskShader::<EI, EM, M> {
input,
@ -176,8 +191,14 @@ impl<M: MaskStrategy, R: JitRuntime, EI: JitElement, EM: JitElement> Kernel
let tensor_item = EI::cube_elem().into();
let mask_item = EM::cube_elem().into();
let input = Variable::GlobalInputArray(0, tensor_item);
let mask = Variable::GlobalInputArray(1, mask_item);
let input = Variable::GlobalInputArray {
id: 0,
item: tensor_item,
};
let mask = Variable::GlobalInputArray {
id: 1,
item: mask_item,
};
let value = M::value_variable(tensor_item);
MaskShader::<EI, EM, M> {

View File

@ -40,9 +40,9 @@ impl<R: JitRuntime, E: JitElement> Kernel for MatmulTiling2dEagerKernel<R, E> {
);
let item = elem.into();
let lhs = Variable::GlobalInputArray(0, item);
let rhs = Variable::GlobalInputArray(1, item);
let out = Variable::GlobalOutputArray(0, item);
let lhs = Variable::GlobalInputArray { id: 0, item };
let rhs = Variable::GlobalInputArray { id: 1, item };
let out = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(out);

View File

@ -207,8 +207,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AdaptiveAvgPool2dBackwardEagerKern
let mut scope = Scope::root();
let item = E::cube_elem().into();
let grad = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let grad = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -193,8 +193,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AdaptivePool2dEagerKernel<R, E> {
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -71,10 +71,22 @@ impl AvgPool2dBackwardComputeShader {
cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32));
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let pool_stride_0 = Variable::GlobalScalar {
id: 0,
elem: Elem::UInt,
};
let pool_stride_1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let [kernel_size_0, kernel_size_1] = self.kernel_size;
let b = scope.create_local(Elem::UInt);
@ -215,12 +227,30 @@ impl AvgPool2dBackwardComputeShader {
output_stride_2: Variable,
output_stride_3: Variable,
) -> (Variable, Variable, Variable, Variable) {
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let pool_stride_0 = Variable::GlobalScalar {
id: 0,
elem: Elem::UInt,
};
let pool_stride_1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let dilation_0 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let [kernel_size_0, kernel_size_1] = self.kernel_size;
@ -321,8 +351,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AvgPool2dBackwardEagerKernel<R, E>
let mut scope = Scope::root();
let item = E::cube_elem().into();
let grad = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let grad = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -21,7 +21,10 @@ impl<E: JitElement> PoolStrategy for MaxPool<E> {
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
let max_val = scope.create_local(item);
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
let max_initial = Variable::ConstantScalar {
value: E::minimum_value().to_f64(),
elem: item.elem(),
};
cpa!(scope, max_val = max_initial);
max_val
}
@ -67,7 +70,10 @@ impl<E: JitElement> PoolStrategy for MaxPoolWithIndices<E> {
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
let max_val = scope.create_local(item);
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
let max_initial = Variable::ConstantScalar {
value: E::minimum_value().to_f64(),
elem: item.elem(),
};
cpa!(scope, max_val = max_initial);
let max_index = scope.create_local(Elem::UInt);
(max_val, max_index)

View File

@ -161,12 +161,30 @@ impl MaxPool2dBackwardComputeShader {
output_stride_2: Variable,
output_stride_3: Variable,
) -> (Variable, Variable, Variable, Variable) {
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let pool_stride_0 = Variable::GlobalScalar {
id: 0,
elem: Elem::UInt,
};
let pool_stride_1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let dilation_0 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let [kernel_size_0, kernel_size_1] = self.kernel_size;
@ -267,9 +285,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for MaxPool2dWithIndicesBackwardEagerK
let mut scope = Scope::root();
let item = E::cube_elem().into();
let indices = Variable::GlobalInputArray(0, Item::new(Elem::Int(IntKind::I32)));
let grad = Variable::GlobalInputArray(1, item);
let output = Variable::GlobalOutputArray(0, item);
let indices = Variable::GlobalInputArray {
id: 0,
item: Item::new(Elem::Int(IntKind::I32)),
};
let grad = Variable::GlobalInputArray { id: 1, item };
let output = Variable::GlobalOutputArray { id: 0, item };
scope.write_global_custom(output);

View File

@ -65,12 +65,30 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> Pool2dComputeShader<P, R, E>
cpa!(scope, output_shape_2 = shape(output, 2u32));
cpa!(scope, output_shape_3 = shape(output, 3u32));
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
let pool_stride_0 = Variable::GlobalScalar {
id: 0,
elem: Elem::UInt,
};
let pool_stride_1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let dilation_0 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let dilation_1 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let padding_0 = Variable::GlobalScalar {
id: 4,
elem: Elem::UInt,
};
let padding_1 = Variable::GlobalScalar {
id: 5,
elem: Elem::UInt,
};
let b = scope.create_local(Elem::UInt);
let c = scope.create_local(Elem::UInt);
@ -177,13 +195,13 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> Kernel for Pool2dEagerKernel
let mut scope = Scope::root();
let item = E::cube_elem().into();
let input = Variable::GlobalInputArray(0, item);
let output = Variable::GlobalOutputArray(0, item);
let input = Variable::GlobalInputArray { id: 0, item };
let output = Variable::GlobalOutputArray { id: 0, item };
let indices = if P::with_indices() {
Some(Variable::GlobalOutputArray(
1,
Item::new(Elem::Int(IntKind::I32)),
))
Some(Variable::GlobalOutputArray {
id: 1,
item: Item::new(Elem::Int(IntKind::I32)),
})
} else {
None
};

View File

@ -66,17 +66,32 @@ impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R,
let mut scope = Scope::root();
let item = E::cube_elem().into();
let output = Variable::GlobalOutputArray(0, item);
let output = Variable::GlobalOutputArray { id: 0, item };
let seed0 = Variable::GlobalScalar(0, Elem::UInt);
let seed1 = Variable::GlobalScalar(1, Elem::UInt);
let seed2 = Variable::GlobalScalar(2, Elem::UInt);
let seed3 = Variable::GlobalScalar(3, Elem::UInt);
let seed0 = Variable::GlobalScalar {
id: 0,
elem: Elem::UInt,
};
let seed1 = Variable::GlobalScalar {
id: 1,
elem: Elem::UInt,
};
let seed2 = Variable::GlobalScalar {
id: 2,
elem: Elem::UInt,
};
let seed3 = Variable::GlobalScalar {
id: 3,
elem: Elem::UInt,
};
let seeds = [seed0, seed1, seed2, seed3];
let mut args = Vec::<Variable>::new();
for i in 0..P::args_length() {
args.push(Variable::GlobalScalar(i as u16, item.elem()));
args.push(Variable::GlobalScalar {
id: i as u16,
elem: item.elem(),
});
}
PrngShader::<P, E>::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope);

View File

@ -16,7 +16,10 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmax {
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let max = scope.create_local(input_item);
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
let max_initial = Variable::ConstantScalar {
value: E::minimum_value().to_f64(),
elem: input_item.elem(),
};
cpa!(scope, max = max_initial);
(max, index)

View File

@ -17,7 +17,10 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmin {
) -> Self::Accumulator {
let index = scope.create_local(Elem::UInt);
let min = scope.create_local(input_item);
let min_initial = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
let min_initial = Variable::ConstantScalar {
value: E::maximum_value().to_f64(),
elem: input_item.elem(),
};
cpa!(scope, min = min_initial);
(min, index)

View File

@ -41,8 +41,14 @@ impl<RD: ReduceDimNaive<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Kern
let item_input = EI::cube_elem().into();
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray(0, item_input);
let output = Variable::GlobalOutputArray(0, item_output);
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
NaiveReduceDimComputeShader {
tensor,

View File

@ -18,7 +18,10 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
let max = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
let max = Variable::ConstantScalar {
value: E::minimum_value().to_f64(),
elem: input_item.elem(),
};
cpa!(scope, value_shared_memory[write_position] = max);
(value_shared_memory, index_shared_memory)
}

View File

@ -19,7 +19,10 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
let min = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
let min = Variable::ConstantScalar {
value: E::maximum_value().to_f64(),
elem: input_item.elem(),
};
cpa!(scope, value_shared_memory[write_position] = min);
(value_shared_memory, index_shared_memory)
}

View File

@ -51,8 +51,14 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
let item_input = EI::cube_elem().into();
let item_output = EO::cube_elem().into();
let tensor = Variable::GlobalInputArray(0, item_input);
let output = Variable::GlobalOutputArray(0, item_output);
let tensor = Variable::GlobalInputArray {
id: 0,
item: item_input,
};
let output = Variable::GlobalOutputArray {
id: 0,
item: item_output,
};
// Reduce groups are elements that are aligned along the reduce dim
SharedReduceDimComputeShader {

View File

@ -9,14 +9,24 @@ pub enum Variable {
GlobalScalar(u16, Elem, cube::Elem),
ConstantScalar(f64, Elem),
Local {
index: u16,
id: u16,
item: Item,
scope_depth: u8,
depth: u8,
},
Named {
name: String,
item: Item,
is_array: bool,
},
Slice {
id: u16,
item: Item,
depth: u8,
},
LocalScalar {
index: u16,
id: u16,
elem: Elem,
scope_depth: u8,
depth: u8,
},
SharedMemory(u16, Item, u32),
LocalArray(u16, Item, u8, u32),
@ -71,9 +81,9 @@ impl Variable {
Variable::GlobalScalar(_, _, _) => true,
Variable::ConstantScalar(_, _) => true,
Variable::LocalScalar {
index: _,
id: _,
elem: _,
scope_depth: _,
depth: _,
} => true,
Variable::Id => true,
Variable::LocalInvocationIndex => true,
@ -85,11 +95,9 @@ impl Variable {
Variable::GlobalOutputArray(_, _) => false,
Variable::SharedMemory(_, _, _) => false,
Variable::LocalArray(_, _, _, _) => false,
Variable::Local {
index: _,
item: _,
scope_depth: _,
} => false,
Variable::Local { .. } => false,
Variable::Named { .. } => false,
Variable::Slice { .. } => false,
Variable::WorkgroupIdX => true,
Variable::WorkgroupIdY => true,
Variable::WorkgroupIdZ => true,
@ -121,11 +129,9 @@ impl Variable {
Self::GlobalOutputArray(_, e) => *e,
Self::SharedMemory(_, e, _) => *e,
Self::LocalArray(_, e, _, _) => *e,
Self::Local {
index: _,
item,
scope_depth: _,
} => *item,
Self::Local { item, .. } => *item,
Self::Slice { item, .. } => *item,
Self::Named { item, .. } => *item,
Self::ConstantScalar(_, e) => Item::Scalar(*e),
Self::GlobalScalar(_, e, _) => Item::Scalar(*e),
Self::Id => Item::Scalar(Elem::U32),
@ -134,11 +140,7 @@ impl Variable {
Self::LocalInvocationIdY => Item::Scalar(Elem::U32),
Self::LocalInvocationIdZ => Item::Scalar(Elem::U32),
Self::Rank => Item::Scalar(Elem::U32),
Self::LocalScalar {
index: _,
elem,
scope_depth: _,
} => Item::Scalar(*elem),
Self::LocalScalar { elem, .. } => Item::Scalar(*elem),
Self::WorkgroupId => Item::Scalar(Elem::U32),
Self::WorkgroupIdX => Item::Scalar(Elem::U32),
Self::WorkgroupIdY => Item::Scalar(Elem::U32),
@ -222,15 +224,21 @@ impl Display for Variable {
f.write_fmt(format_args!("input_{number}_global"))
}
Variable::LocalScalar {
index,
id: index,
elem: _,
scope_depth,
depth: scope_depth,
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
Variable::Local {
index,
id: index,
item: _,
scope_depth,
depth: scope_depth,
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
Variable::Named { name, .. } => f.write_str(name),
Variable::Slice {
id: index,
item: _,
depth: scope_depth,
} => f.write_fmt(format_args!("slice_{scope_depth}_{index}")),
Variable::GlobalOutputArray(number, _) => {
f.write_fmt(format_args!("output_{number}_global"))
}

View File

@ -127,43 +127,53 @@ impl WgslCompiler {
fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
match value {
cube::Variable::GlobalInputArray(index, item) => {
wgsl::Variable::GlobalInputArray(index, Self::compile_item(item))
cube::Variable::GlobalInputArray { id, item } => {
wgsl::Variable::GlobalInputArray(id, Self::compile_item(item))
}
cube::Variable::GlobalScalar(index, elem) => {
wgsl::Variable::GlobalScalar(index, Self::compile_elem(elem), elem)
cube::Variable::GlobalScalar { id, elem } => {
wgsl::Variable::GlobalScalar(id, Self::compile_elem(elem), elem)
}
cube::Variable::Local(index, item, scope_depth) => wgsl::Variable::Local {
index,
cube::Variable::Local { id, item, depth } => wgsl::Variable::Local {
id,
item: Self::compile_item(item),
scope_depth,
depth,
},
cube::Variable::LocalScalar(index, elem, scope_depth) => wgsl::Variable::LocalScalar {
index,
cube::Variable::Slice { id, item, depth } => wgsl::Variable::Slice {
id,
item: Self::compile_item(item),
depth,
},
cube::Variable::LocalScalar { id, elem, depth } => wgsl::Variable::LocalScalar {
id,
elem: Self::compile_elem(elem),
scope_depth,
depth,
},
cube::Variable::GlobalOutputArray(index, item) => {
wgsl::Variable::GlobalOutputArray(index, Self::compile_item(item))
cube::Variable::GlobalOutputArray { id, item } => {
wgsl::Variable::GlobalOutputArray(id, Self::compile_item(item))
}
cube::Variable::ConstantScalar(index, elem) => {
wgsl::Variable::ConstantScalar(index, Self::compile_elem(elem))
cube::Variable::ConstantScalar { value, elem } => {
wgsl::Variable::ConstantScalar(value, Self::compile_elem(elem))
}
cube::Variable::SharedMemory(index, item, size) => {
cube::Variable::SharedMemory { id, item, length } => {
let item = Self::compile_item(item);
if !self.shared_memories.iter().any(|s| s.index == index) {
if !self.shared_memories.iter().any(|s| s.index == id) {
self.shared_memories
.push(SharedMemory::new(index, item, size));
.push(SharedMemory::new(id, item, length));
}
wgsl::Variable::SharedMemory(index, item, size)
wgsl::Variable::SharedMemory(id, item, length)
}
cube::Variable::LocalArray(index, item, scope_depth, size) => {
cube::Variable::LocalArray {
id,
item,
depth,
length,
} => {
let item = Self::compile_item(item);
if !self.local_arrays.iter().any(|s| s.index == index) {
if !self.local_arrays.iter().any(|s| s.index == id) {
self.local_arrays
.push(LocalArray::new(index, item, scope_depth, size));
.push(LocalArray::new(id, item, depth, length));
}
wgsl::Variable::LocalArray(index, item, scope_depth, size)
wgsl::Variable::LocalArray(id, item, depth, length)
}
cube::Variable::AbsolutePos => {
self.id = true;
@ -241,7 +251,7 @@ impl WgslCompiler {
wgsl::Variable::NumWorkgroups
}
cube::Variable::SubcubeDim => wgsl::Variable::SubgroupSize,
cube::Variable::Matrix(_, _) => {
cube::Variable::Matrix { .. } => {
panic!("Cooperative matrix-multiply and accumulate not supported.")
}
}
@ -252,6 +262,11 @@ impl WgslCompiler {
let processing = value.process();
for var in processing.variables {
// We don't declare slices.
if let cube::Variable::Slice { .. } = var {
continue;
}
instructions.push(wgsl::Instruction::DeclareVariable {
var: self.compile_variable(var),
});
@ -427,8 +442,8 @@ impl WgslCompiler {
cube::Metadata::Stride { dim, var, out } => {
self.stride = true;
let position = match var {
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
cube::Variable::GlobalInputArray { id, .. } => id as usize,
cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
};
wgsl::Instruction::Stride {
@ -440,8 +455,8 @@ impl WgslCompiler {
cube::Metadata::Shape { dim, var, out } => {
self.shape = true;
let position = match var {
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
cube::Variable::GlobalInputArray { id, .. } => id as usize,
cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
_ => panic!("Only Input and Output have a shape, got {:?}", var),
};
wgsl::Instruction::Shape {
@ -450,7 +465,7 @@ impl WgslCompiler {
out: self.compile_variable(out),
}
}
cube::Metadata::ArrayLength { var, out } => wgsl::Instruction::ArrayLength {
cube::Metadata::Length { var, out } => wgsl::Instruction::Length {
out: self.compile_variable(out),
var: self.compile_variable(var),
},
@ -652,6 +667,12 @@ impl WgslCompiler {
rhs: self.compile_variable(op.rhs),
out: self.compile_variable(op.out),
},
cube::Operator::Slice(op) => wgsl::Instruction::Slice {
input: self.compile_variable(op.input),
start: self.compile_variable(op.start),
end: self.compile_variable(op.end),
out: self.compile_variable(op.out),
},
}
}

View File

@ -1,6 +1,6 @@
use super::{
base::{Item, Variable},
IndexedVariable, Subgroup,
Elem, IndexedVariable, Subgroup,
};
use std::fmt::Display;
@ -167,7 +167,7 @@ pub enum Instruction {
position: usize,
out: Variable,
},
ArrayLength {
Length {
var: Variable,
out: Variable,
},
@ -232,6 +232,12 @@ pub enum Instruction {
rhs: Variable,
out: Variable,
},
Slice {
input: Variable,
start: Variable,
end: Variable,
out: Variable,
},
Subgroup(Subgroup),
}
@ -245,6 +251,16 @@ impl Display for Instruction {
Instruction::Add { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n"))
}
Instruction::Slice {
input,
start,
end,
out,
} => {
f.write_fmt(format_args!("let {out}_offset = {start};\n"))?;
f.write_fmt(format_args!("let {out}_length = {end} - {start};\n"))?;
f.write_fmt(format_args!("let {out}_ptr = &{input};\n"))
}
Instruction::Fma { a, b, c, out } => {
f.write_fmt(format_args!("{out} = fma({a}, {b}, {c});\n"))
}
@ -261,10 +277,22 @@ impl Display for Instruction {
f.write_fmt(format_args!("{out} = {lhs} || {rhs};\n"))
}
Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")),
Instruction::Index { lhs, rhs, out } => {
let item = out.item();
f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n"))
}
Instruction::Index { lhs, rhs, out } => match lhs {
Variable::Slice { item, .. } => {
let offset = Variable::Named {
name: format!("{lhs}_offset"),
item: Item::Scalar(Elem::U32),
is_array: false,
};
let lhs = Variable::Named {
name: format!("(*{lhs}_ptr)"),
item: *item,
is_array: true,
};
index(f, &lhs, rhs, out, Some(offset))
}
_ => index(f, lhs, rhs, out, None),
},
Instruction::Modulo { lhs, rhs, out } => {
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
}
@ -406,88 +434,24 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
f.write_str("}\n")
}
Instruction::IndexAssign { lhs, rhs, out } => match lhs.item() {
Item::Vec4(elem) => {
let lhs0 = lhs.index(0);
let lhs1 = lhs.index(1);
let lhs2 = lhs.index(2);
let lhs3 = lhs.index(3);
Instruction::IndexAssign { lhs, rhs, out } => {
if let Variable::Slice { item, .. } = out {
let offset = Variable::Named {
name: format!("{out}_offset"),
item: Item::Scalar(Elem::U32),
is_array: false,
};
let out = Variable::Named {
name: format!("(*{out}_ptr)"),
item: *item,
is_array: true,
};
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
let rhs2 = rhs.index(2);
let rhs3 = rhs.index(3);
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
index_assign(f, lhs, rhs, &out, Some(offset))
} else {
index_assign(f, lhs, rhs, out, None)
}
Item::Vec3(elem) => {
let lhs0 = lhs.index(0);
let lhs1 = lhs.index(1);
let lhs2 = lhs.index(2);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
let rhs2 = rhs.index(2);
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
}
Item::Vec2(elem) => {
let lhs0 = lhs.index(0);
let lhs1 = lhs.index(1);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
}
Item::Scalar(_elem) => {
let is_array = matches!(
out,
Variable::GlobalInputArray(_, _)
| Variable::GlobalOutputArray(_, _)
| Variable::SharedMemory(_, _, _)
| Variable::LocalArray(_, _, _, _)
);
if !is_array {
let elem_out = out.elem();
let casting_type = match rhs.item() {
Item::Vec4(_) => Item::Vec4(elem_out),
Item::Vec3(_) => Item::Vec3(elem_out),
Item::Vec2(_) => Item::Vec2(elem_out),
Item::Scalar(_) => Item::Scalar(elem_out),
};
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
} else {
let item_rhs = rhs.item();
let item_out = out.item();
let vectorization_factor = item_out.vectorization_factor();
if vectorization_factor > item_rhs.vectorization_factor() {
let casting_type = item_out.elem();
f.write_fmt(format_args!("{out}[{lhs}] = vec{vectorization_factor}("))?;
for i in 0..vectorization_factor {
let value = rhs.index(i);
f.write_fmt(format_args!("{casting_type}({value})"))?;
if i < vectorization_factor - 1 {
f.write_str(",")?;
}
}
f.write_str(");\n")
} else {
let casting_type = item_out;
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
}
}
}
},
}
Instruction::If { cond, instructions } => {
f.write_fmt(format_args!("if {cond} {{\n"))?;
for i in instructions {
@ -513,9 +477,10 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
Instruction::Return => f.write_str("return;\n"),
Instruction::Break => f.write_str("break;\n"),
Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"),
Instruction::ArrayLength { var, out } => {
f.write_fmt(format_args!("{out} = arrayLength(&{var});\n"))
}
Instruction::Length { var, out } => match var {
Variable::Slice { .. } => f.write_fmt(format_args!("{out} = {var}_length;\n")),
_ => f.write_fmt(format_args!("{out} = arrayLength(&{var});\n")),
},
Instruction::Loop { instructions } => {
f.write_fmt(format_args!("loop {{\n"))?;
for i in instructions {
@ -623,3 +588,140 @@ fn unroll<
}
Ok(())
}
struct IndexOffset {
var: Variable,
offset: Option<Variable>,
index: usize,
}
impl IndexOffset {
fn new(var: &Variable, offset: &Option<Variable>, index: usize) -> Self {
Self {
var: var.clone(),
offset: offset.clone(),
index,
}
}
}
impl Display for IndexOffset {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let var = self.var.index(self.index);
match &self.offset {
Some(offset) => {
let offset = offset.index(self.index);
f.write_fmt(format_args!("{var} + {offset}"))
}
None => f.write_fmt(format_args!("{var}")),
}
}
}
fn index(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
offset: Option<Variable>,
) -> core::fmt::Result {
let item = out.item();
match offset {
Some(offset) => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs} + {offset}]);\n")),
None => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n")),
}
}
fn index_assign(
f: &mut std::fmt::Formatter<'_>,
lhs: &Variable,
rhs: &Variable,
out: &Variable,
offset: Option<Variable>,
) -> core::fmt::Result {
match lhs.item() {
Item::Vec4(elem) => {
let lhs0 = IndexOffset::new(lhs, &offset, 0);
let lhs1 = IndexOffset::new(lhs, &offset, 1);
let lhs2 = IndexOffset::new(lhs, &offset, 2);
let lhs3 = IndexOffset::new(lhs, &offset, 3);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
let rhs2 = rhs.index(2);
let rhs3 = rhs.index(3);
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
}
Item::Vec3(elem) => {
let lhs0 = IndexOffset::new(lhs, &offset, 0);
let lhs1 = IndexOffset::new(lhs, &offset, 1);
let lhs2 = IndexOffset::new(lhs, &offset, 2);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
let rhs2 = rhs.index(2);
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
}
Item::Vec2(elem) => {
let lhs0 = IndexOffset::new(lhs, &offset, 0);
let lhs1 = IndexOffset::new(lhs, &offset, 1);
let rhs0 = rhs.index(0);
let rhs1 = rhs.index(1);
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
}
Item::Scalar(_elem) => {
let is_array = match out {
Variable::GlobalInputArray(_, _)
| Variable::GlobalOutputArray(_, _)
| Variable::SharedMemory(_, _, _)
| Variable::Slice { .. }
| Variable::LocalArray(_, _, _, _) => true,
Variable::Named { is_array, .. } => *is_array,
_ => false,
};
if !is_array {
let elem_out = out.elem();
let casting_type = match rhs.item() {
Item::Vec4(_) => Item::Vec4(elem_out),
Item::Vec3(_) => Item::Vec3(elem_out),
Item::Vec2(_) => Item::Vec2(elem_out),
Item::Scalar(_) => Item::Scalar(elem_out),
};
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
} else {
let item_rhs = rhs.item();
let item_out = out.item();
let lhs = IndexOffset::new(lhs, &offset, 0);
let vectorization_factor = item_out.vectorization_factor();
if vectorization_factor > item_rhs.vectorization_factor() {
let casting_type = item_out.elem();
f.write_fmt(format_args!("{out}[{lhs}] = vec{vectorization_factor}("))?;
for i in 0..vectorization_factor {
let value = rhs.index(i);
f.write_fmt(format_args!("{casting_type}({value})"))?;
if i < vectorization_factor - 1 {
f.write_str(",")?;
}
}
f.write_str(");\n")
} else {
let casting_type = item_out;
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
}
}
}
}
}