mirror of https://github.com/tracel-ai/burn.git
Merge branch 'main' into refactor/cube/expand
This commit is contained in:
commit
01e5c39ae4
|
@ -423,17 +423,24 @@ impl KernelIntegrator {
|
|||
} else {
|
||||
item
|
||||
};
|
||||
let elem_adapted = bool_item(item);
|
||||
let item_adapted = bool_item(item);
|
||||
|
||||
self.output_bindings.push(Binding {
|
||||
item: elem_adapted,
|
||||
item: item_adapted,
|
||||
visibility: Visibility::ReadWrite,
|
||||
location: Location::Storage,
|
||||
size: None,
|
||||
});
|
||||
self.expansion.scope.write_global(
|
||||
Variable::Local(local, item, self.expansion.scope.depth),
|
||||
Variable::GlobalOutputArray(index, elem_adapted),
|
||||
Variable::Local {
|
||||
id: local,
|
||||
item,
|
||||
depth: self.expansion.scope.depth,
|
||||
},
|
||||
Variable::GlobalOutputArray {
|
||||
id: index,
|
||||
item: item_adapted,
|
||||
},
|
||||
position,
|
||||
);
|
||||
index += 1;
|
||||
|
@ -451,8 +458,15 @@ impl KernelIntegrator {
|
|||
};
|
||||
|
||||
self.expansion.scope.write_global(
|
||||
Variable::Local(local, item, self.expansion.scope.depth),
|
||||
Variable::GlobalInputArray(input, bool_item(item)),
|
||||
Variable::Local {
|
||||
id: local,
|
||||
item,
|
||||
depth: self.expansion.scope.depth,
|
||||
},
|
||||
Variable::GlobalInputArray {
|
||||
id: input,
|
||||
item: bool_item(item),
|
||||
},
|
||||
position,
|
||||
);
|
||||
}
|
||||
|
|
|
@ -27,11 +27,11 @@ where
|
|||
|
||||
if unroll {
|
||||
let start = match start.deref() {
|
||||
Variable::ConstantScalar(val, _) => *val as usize,
|
||||
Variable::ConstantScalar { value, .. } => *value as usize,
|
||||
_ => panic!("Only constant start can be unrolled."),
|
||||
};
|
||||
let end = match end.deref() {
|
||||
Variable::ConstantScalar(val, _) => *val as usize,
|
||||
Variable::ConstantScalar { value, .. } => *value as usize,
|
||||
_ => panic!("Only constant end can be unrolled."),
|
||||
};
|
||||
|
||||
|
|
|
@ -15,14 +15,14 @@
|
|||
//! 16,
|
||||
//! 16,
|
||||
//! 16,
|
||||
//! cmma::MatrixLayout::ColMajor,
|
||||
//! cmma::MatrixLayout::RowMajor,
|
||||
//! );
|
||||
//! let b = cmma::Matrix::<F16>::new(
|
||||
//! cmma::MatrixIdent::B,
|
||||
//! 16,
|
||||
//! 16,
|
||||
//! 16,
|
||||
//! cmma::MatrixLayout::RowMajor,
|
||||
//! cmma::MatrixLayout::ColMajor,
|
||||
//! );
|
||||
//! let c = cmma::Matrix::<F32>::new(
|
||||
//! cmma::MatrixIdent::Accumulator,
|
||||
|
@ -32,12 +32,17 @@
|
|||
//! cmma::MatrixLayout::Undefined,
|
||||
//! );
|
||||
//! cmma::fill::<F32>(&c, F32::new(0.0));
|
||||
//! cmma::load::<F16>(&a, lhs, UInt::new(16));
|
||||
//! cmma::load::<F16>(&b, rhs, UInt::new(16));
|
||||
//! cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
|
||||
//! cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
|
||||
//!
|
||||
//! cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
|
||||
//!
|
||||
//! cmma::store::<F32>(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor);
|
||||
//! cmma::store::<F32>(
|
||||
//! out.as_slice_mut(),
|
||||
//! &c,
|
||||
//! UInt::new(16),
|
||||
//! cmma::MatrixLayout::RowMajor,
|
||||
//! );
|
||||
//! }
|
||||
//! ```
|
||||
|
||||
|
@ -49,7 +54,8 @@ use crate::{
|
|||
};
|
||||
|
||||
use super::{
|
||||
Array, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, UInt,
|
||||
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut,
|
||||
UInt,
|
||||
};
|
||||
|
||||
pub use ir::{MatrixIdent, MatrixLayout};
|
||||
|
@ -142,7 +148,7 @@ pub mod fill {
|
|||
|
||||
/// Load the matrix with the provided array using the stride.
|
||||
#[allow(unused_variables)]
|
||||
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Array<C>, stride: UInt) {
|
||||
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
|
@ -155,7 +161,7 @@ pub mod load {
|
|||
pub fn __expand<C: CubeType>(
|
||||
context: &mut CubeContext,
|
||||
mat: MatrixExpand,
|
||||
value: ExpandElementTyped<Array<C>>,
|
||||
value: ExpandElementTyped<Slice<'static, C>>,
|
||||
stride: ExpandElement,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Load {
|
||||
|
@ -169,7 +175,7 @@ pub mod load {
|
|||
/// Store the matrix in the given array following the given stride and layout.
|
||||
#[allow(unused_variables)]
|
||||
pub fn store<C: CubePrimitive>(
|
||||
output: &Array<C>,
|
||||
output: &mut SliceMut<'_, C>,
|
||||
mat: &Matrix<C>,
|
||||
stride: UInt,
|
||||
layout: MatrixLayout,
|
||||
|
@ -185,7 +191,7 @@ pub mod store {
|
|||
#[allow(unused_variables)]
|
||||
pub fn __expand<C: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
output: ExpandElementTyped<Array<C>>,
|
||||
output: ExpandElementTyped<SliceMut<'static, C>>,
|
||||
mat: MatrixExpand,
|
||||
stride: ExpandElement,
|
||||
layout: MatrixLayout,
|
||||
|
|
|
@ -4,8 +4,6 @@ use alloc::rc::Rc;
|
|||
use core::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::{CubePrimitive, SharedMemoryExpand};
|
||||
|
||||
#[derive(Default, Clone)]
|
||||
pub struct VariablePool {
|
||||
map: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
|
||||
|
@ -114,14 +112,14 @@ impl CubeContext {
|
|||
ExpandElement::Plain(variable)
|
||||
}
|
||||
|
||||
pub fn create_shared<T: CubePrimitive>(
|
||||
&mut self,
|
||||
item: Item,
|
||||
size: u32,
|
||||
) -> SharedMemoryExpand<T> {
|
||||
SharedMemoryExpand {
|
||||
val: ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size)),
|
||||
}
|
||||
/// Create a new slice element.
|
||||
pub fn create_slice(&mut self, item: Item) -> ExpandElement {
|
||||
let variable = self.scope.borrow_mut().create_slice(item);
|
||||
ExpandElement::Plain(variable)
|
||||
}
|
||||
|
||||
pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement {
|
||||
ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size))
|
||||
}
|
||||
|
||||
pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
|
||||
|
@ -129,19 +127,19 @@ impl CubeContext {
|
|||
}
|
||||
|
||||
/// Obtain the index-th input
|
||||
pub fn input(&mut self, index: u16, item: Item) -> ExpandElement {
|
||||
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item))
|
||||
pub fn input(&mut self, id: u16, item: Item) -> ExpandElement {
|
||||
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray { id, item })
|
||||
}
|
||||
|
||||
/// Obtain the index-th output
|
||||
pub fn output(&mut self, index: u16, item: Item) -> ExpandElement {
|
||||
let var = crate::ir::Variable::GlobalOutputArray(index, item);
|
||||
pub fn output(&mut self, id: u16, item: Item) -> ExpandElement {
|
||||
let var = crate::ir::Variable::GlobalOutputArray { id, item };
|
||||
self.scope.borrow_mut().write_global_custom(var);
|
||||
ExpandElement::Plain(var)
|
||||
}
|
||||
|
||||
/// Obtain the index-th scalar
|
||||
pub fn scalar(&self, index: u16, elem: Elem) -> ExpandElement {
|
||||
ExpandElement::Plain(crate::ir::Variable::GlobalScalar(index, elem))
|
||||
pub fn scalar(&self, id: u16, elem: Elem) -> ExpandElement {
|
||||
ExpandElement::Plain(crate::ir::Variable::GlobalScalar { id, elem })
|
||||
}
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
|
|||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Array need constant initialization value"),
|
||||
};
|
||||
context
|
||||
|
@ -55,7 +55,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
|
|||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Shared memory need constant initialization value"),
|
||||
};
|
||||
context
|
||||
|
|
|
@ -160,7 +160,7 @@ impl ExpandElement {
|
|||
pub fn can_mut(&self) -> bool {
|
||||
match self {
|
||||
ExpandElement::Managed(var) => {
|
||||
if let Variable::Local(_, _, _) = var.as_ref() {
|
||||
if let Variable::Local { .. } = var.as_ref() {
|
||||
Rc::strong_count(var) <= 2
|
||||
} else {
|
||||
false
|
||||
|
@ -201,10 +201,10 @@ impl Init for ExpandElement {
|
|||
let mut init = |elem: Self| init_expand(context, elem, Operator::Assign);
|
||||
|
||||
match *self {
|
||||
Variable::GlobalScalar(_, _) => init(self),
|
||||
Variable::LocalScalar(_, _, _) => init(self),
|
||||
Variable::ConstantScalar(_, _) => init(self),
|
||||
Variable::Local(_, _, _) => init(self),
|
||||
Variable::GlobalScalar { .. } => init(self),
|
||||
Variable::LocalScalar { .. } => init(self),
|
||||
Variable::ConstantScalar { .. } => init(self),
|
||||
Variable::Local { .. } => init(self),
|
||||
// Constant should be initialized since the new variable can be mutated afterward.
|
||||
// And it is assumed those values are cloned.
|
||||
Variable::Rank
|
||||
|
@ -230,11 +230,12 @@ impl Init for ExpandElement {
|
|||
| Variable::AbsolutePosY
|
||||
| Variable::AbsolutePosZ => init(self),
|
||||
// Array types can't be copied, so we should simply return the same variable.
|
||||
Variable::SharedMemory(_, _, _)
|
||||
| Variable::GlobalInputArray(_, _)
|
||||
| Variable::GlobalOutputArray(_, _)
|
||||
| Variable::LocalArray(_, _, _, _)
|
||||
| Variable::Matrix(_, _) => self,
|
||||
Variable::SharedMemory { .. }
|
||||
| Variable::GlobalInputArray { .. }
|
||||
| Variable::GlobalOutputArray { .. }
|
||||
| Variable::LocalArray { .. }
|
||||
| Variable::Slice { .. }
|
||||
| Variable::Matrix { .. } => self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -41,9 +41,9 @@ impl_into_expand_element!(i64);
|
|||
/// Useful for Comptime
|
||||
impl From<UInt> for ExpandElement {
|
||||
fn from(value: UInt) -> Self {
|
||||
ExpandElement::Plain(crate::ir::Variable::ConstantScalar(
|
||||
value.val as f64,
|
||||
UInt::as_elem(),
|
||||
))
|
||||
ExpandElement::Plain(crate::ir::Variable::ConstantScalar {
|
||||
value: value.val as f64,
|
||||
elem: UInt::as_elem(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -93,7 +93,10 @@ macro_rules! impl_float {
|
|||
_context: &mut CubeContext,
|
||||
val: f32,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
|
|
|
@ -63,7 +63,10 @@ macro_rules! impl_int {
|
|||
_context: &mut CubeContext,
|
||||
val: i64,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ mod float;
|
|||
mod int;
|
||||
mod numeric;
|
||||
mod shared_memory;
|
||||
mod slice;
|
||||
mod tensor;
|
||||
mod uint;
|
||||
mod vectorized;
|
||||
|
@ -20,6 +21,7 @@ pub use float::*;
|
|||
pub use int::*;
|
||||
pub use numeric::*;
|
||||
pub use shared_memory::*;
|
||||
pub use slice::*;
|
||||
pub use tensor::*;
|
||||
pub use uint::*;
|
||||
pub use vectorized::*;
|
||||
|
|
|
@ -50,7 +50,10 @@ pub trait Numeric:
|
|||
}
|
||||
|
||||
fn __expand_from_int(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
|
|
|
@ -5,32 +5,21 @@ use crate::{
|
|||
ir::Item,
|
||||
};
|
||||
|
||||
use super::{ExpandElement, Init, UInt};
|
||||
use super::{ExpandElementTyped, Init, UInt};
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
pub struct SharedMemory<T: CubeType> {
|
||||
_val: PhantomData<T>,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct SharedMemoryExpand<T: CubePrimitive> {
|
||||
pub val: <T as CubeType>::ExpandType,
|
||||
}
|
||||
|
||||
impl<T: CubePrimitive> From<SharedMemoryExpand<T>> for ExpandElement {
|
||||
fn from(shared_memory_expand: SharedMemoryExpand<T>) -> Self {
|
||||
shared_memory_expand.val
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubePrimitive> Init for SharedMemoryExpand<T> {
|
||||
impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
|
||||
fn init(self, _context: &mut CubeContext) -> Self {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: CubePrimitive> CubeType for SharedMemory<T> {
|
||||
type ExpandType = SharedMemoryExpand<T>;
|
||||
type ExpandType = ExpandElementTyped<SharedMemory<T>>;
|
||||
}
|
||||
|
||||
impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
||||
|
@ -49,13 +38,14 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
|||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Shared memory need constant initialization value"),
|
||||
};
|
||||
context.create_shared(
|
||||
let var = context.create_shared(
|
||||
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
|
||||
size,
|
||||
)
|
||||
);
|
||||
ExpandElementTyped::new(var)
|
||||
}
|
||||
|
||||
pub fn __expand_new<S: Index>(
|
||||
|
@ -64,9 +54,10 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
|||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Shared memory need constant initialization value"),
|
||||
};
|
||||
context.create_shared(Item::new(T::as_elem()), size)
|
||||
let var = context.create_shared(Item::new(T::as_elem()), size);
|
||||
ExpandElementTyped::new(var)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,285 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use super::{
|
||||
Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory, Tensor,
|
||||
UInt,
|
||||
};
|
||||
use crate::{
|
||||
frontend::indexation::Index,
|
||||
ir::{self, Operator},
|
||||
prelude::CubeContext,
|
||||
unexpanded,
|
||||
};
|
||||
|
||||
/// A read-only contiguous list of elements
|
||||
pub struct Slice<'a, E> {
|
||||
_e: PhantomData<E>,
|
||||
_l: &'a (),
|
||||
}
|
||||
|
||||
/// A read-write contiguous list of elements.
|
||||
pub struct SliceMut<'a, E> {
|
||||
_e: PhantomData<E>,
|
||||
_l: &'a mut (),
|
||||
}
|
||||
|
||||
impl<'a, E> Slice<'a, E> {
|
||||
/// Get the length of the slice.
|
||||
pub fn len(&self) -> UInt {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E> SliceMut<'a, E> {
|
||||
/// Get the length of the slice.
|
||||
pub fn len(&self) -> UInt {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CubeType> CubeType for Slice<'a, E> {
|
||||
type ExpandType = ExpandElementTyped<Slice<'static, E>>;
|
||||
}
|
||||
|
||||
impl<'a, C: CubeType> Init for ExpandElementTyped<Slice<'a, C>> {
|
||||
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
|
||||
// The type can't be deeply cloned/copied.
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CubeType> CubeType for SliceMut<'a, E> {
|
||||
type ExpandType = ExpandElementTyped<SliceMut<'static, E>>;
|
||||
}
|
||||
|
||||
impl<'a, C: CubeType> Init for ExpandElementTyped<SliceMut<'a, C>> {
|
||||
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
|
||||
// The type can't be deeply cloned/copied.
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SliceOperator<E>: CubeType<ExpandType = Self::Expand> {
|
||||
type Expand: SliceOperatorExpand<E>;
|
||||
|
||||
/// Return a read-only view of all elements comprise between the start and end index.
|
||||
#[allow(unused_variables)]
|
||||
fn slice<Start: Index, End: Index>(&self, start: Start, end: End) -> &'_ Slice<'_, E> {
|
||||
unexpanded!()
|
||||
}
|
||||
/// Expand function of [SliceOperator::slice].
|
||||
fn slice_expand<Start: Index, End: Index>(
|
||||
context: &mut CubeContext,
|
||||
expand: Self::Expand,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> ExpandElementTyped<Slice<'static, E>> {
|
||||
expand.slice_expand(context, start, end)
|
||||
}
|
||||
|
||||
/// Return a read-write view of all elements comprise between the start and end index.
|
||||
#[allow(unused_variables)]
|
||||
fn slice_mut<Start: Index, End: Index>(
|
||||
&mut self,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> &'_ mut SliceMut<'_, E> {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand function of [SliceOperator::slice_mut].
|
||||
fn slice_mut_expand<Start: Index, End: Index>(
|
||||
context: &mut CubeContext,
|
||||
expand: Self::Expand,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||
expand.slice_mut_expand(context, start, end)
|
||||
}
|
||||
|
||||
/// Return a read-write view of all elements comprise between the start and end index.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// Ignore the multiple borrow rule.
|
||||
#[allow(unused_variables)]
|
||||
fn slice_mut_unsafe<Start: Index, End: Index>(
|
||||
&self,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> SliceMut<'static, E> {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand function of [SliceOperator::slice_mut_unsafe].
|
||||
fn slice_mut_unsafe_expand<Start: Index, End: Index>(
|
||||
context: &mut CubeContext,
|
||||
expand: Self::Expand,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||
expand.slice_mut_unsafe_expand(context, start, end)
|
||||
}
|
||||
|
||||
/// Reinterprete the current type as a read-only slice.
|
||||
#[allow(unused_variables)]
|
||||
fn as_slice(&self) -> &'_ Slice<'_, E> {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand function of [SliceOperator::as_slice].
|
||||
fn as_slice_expand(
|
||||
context: &mut CubeContext,
|
||||
expand: Self::Expand,
|
||||
) -> ExpandElementTyped<Slice<'static, E>> {
|
||||
expand.as_slice_expand(context)
|
||||
}
|
||||
|
||||
/// Reinterprete the current type as a read-write slice.
|
||||
#[allow(unused_variables)]
|
||||
fn as_slice_mut(&mut self) -> &'_ mut SliceMut<'_, E> {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand function of [SliceOperator::as_slice_mut].
|
||||
fn as_slice_mut_expand(
|
||||
context: &mut CubeContext,
|
||||
expand: Self::Expand,
|
||||
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||
expand.as_slice_mut_expand(context)
|
||||
}
|
||||
|
||||
/// Reinterprete the current type as a read-write slice.
|
||||
///
|
||||
/// # Warning
|
||||
///
|
||||
/// Ignore the multiple borrow rule.
|
||||
#[allow(unused_variables)]
|
||||
fn as_slice_mut_unsafe(&self) -> SliceMut<'static, E> {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand function of [SliceOperator::as_slice_mut_unsafe].
|
||||
fn as_slice_mut_unsafe_expand(
|
||||
context: &mut CubeContext,
|
||||
expand: Self::Expand,
|
||||
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||
expand.as_slice_mut_unsafe_expand(context)
|
||||
}
|
||||
}
|
||||
|
||||
pub trait SliceOperatorExpand<E>: Into<ExpandElement> + Clone {
|
||||
fn slice_base<Start: Index, End: Index>(
|
||||
&self,
|
||||
context: &mut CubeContext,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> ExpandElement;
|
||||
|
||||
fn slice_expand<Start: Index, End: Index>(
|
||||
&self,
|
||||
context: &mut CubeContext,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> ExpandElementTyped<Slice<'static, E>> {
|
||||
ExpandElementTyped::new(self.slice_base(context, start, end))
|
||||
}
|
||||
|
||||
fn slice_mut_expand<Start: Index, End: Index>(
|
||||
&self,
|
||||
context: &mut CubeContext,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||
ExpandElementTyped::new(self.slice_base(context, start, end))
|
||||
}
|
||||
|
||||
fn slice_mut_unsafe_expand<Start: Index, End: Index>(
|
||||
&self,
|
||||
context: &mut CubeContext,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||
ExpandElementTyped::new(self.slice_base(context, start, end))
|
||||
}
|
||||
|
||||
fn as_slice_expand(&self, _context: &mut CubeContext) -> ExpandElementTyped<Slice<'static, E>> {
|
||||
let expand = self.clone().into();
|
||||
ExpandElementTyped::new(expand)
|
||||
}
|
||||
|
||||
fn as_slice_mut_unsafe_expand(
|
||||
&self,
|
||||
context: &mut CubeContext,
|
||||
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||
self.as_slice_mut_expand(context)
|
||||
}
|
||||
|
||||
fn as_slice_mut_expand(
|
||||
&self,
|
||||
_context: &mut CubeContext,
|
||||
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||
let expand = self.clone().into();
|
||||
ExpandElementTyped::new(expand)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! slice_op {
|
||||
($type:ident) => {
|
||||
impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
|
||||
type Expand = ExpandElementTyped<$type<E>>;
|
||||
}
|
||||
|
||||
impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
|
||||
fn slice_base<Start: Index, End: Index>(
|
||||
&self,
|
||||
context: &mut CubeContext,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> ExpandElement {
|
||||
slice_expand(context, self.clone(), start, end)
|
||||
}
|
||||
}
|
||||
};
|
||||
(slice $type:ident) => {
|
||||
impl<'a, E: CubePrimitive> SliceOperator<E> for $type<'a, E> {
|
||||
type Expand = ExpandElementTyped<$type<'static, E>>;
|
||||
}
|
||||
|
||||
impl<'a, E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<'a, E>> {
|
||||
fn slice_base<Start: Index, End: Index>(
|
||||
&self,
|
||||
context: &mut CubeContext,
|
||||
start: Start,
|
||||
end: End,
|
||||
) -> ExpandElement {
|
||||
slice_expand(context, self.clone(), start, end)
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
slice_op!(Array);
|
||||
slice_op!(Tensor);
|
||||
slice_op!(SharedMemory);
|
||||
slice_op!(slice Slice);
|
||||
slice_op!(slice SliceMut);
|
||||
|
||||
pub fn slice_expand<I: Into<ExpandElement>, S1: Index, S2: Index>(
|
||||
context: &mut CubeContext,
|
||||
input: I,
|
||||
start: S1,
|
||||
end: S2, // Todo use it to get the length.
|
||||
) -> ExpandElement {
|
||||
let input = input.into();
|
||||
let out = context.create_slice(input.item());
|
||||
|
||||
context.register(Operator::Slice(ir::SliceOperator {
|
||||
input: *input,
|
||||
start: start.value(),
|
||||
end: end.value(),
|
||||
out: *out,
|
||||
}));
|
||||
|
||||
out
|
||||
}
|
|
@ -202,7 +202,7 @@ impl<T> ExpandElementTyped<T> {
|
|||
// Expanded version of len
|
||||
pub fn len_expand(self, context: &mut CubeContext) -> ExpandElement {
|
||||
let out = context.create_local(Item::new(Elem::UInt));
|
||||
context.register(Metadata::ArrayLength {
|
||||
context.register(Metadata::Length {
|
||||
var: self.expand.into(),
|
||||
out: out.clone().into(),
|
||||
});
|
||||
|
|
|
@ -49,7 +49,10 @@ impl UInt {
|
|||
}
|
||||
|
||||
pub fn __expand_new(_context: &mut CubeContext, val: u32) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
|
|
|
@ -7,31 +7,46 @@ pub trait Index {
|
|||
|
||||
impl Index for Comptime<u32> {
|
||||
fn value(self) -> Variable {
|
||||
Variable::ConstantScalar(self.inner as f64, Elem::UInt)
|
||||
Variable::ConstantScalar {
|
||||
value: self.inner as f64,
|
||||
elem: Elem::UInt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Index for Comptime<i32> {
|
||||
fn value(self) -> Variable {
|
||||
Variable::ConstantScalar(self.inner as f64, Elem::UInt)
|
||||
Variable::ConstantScalar {
|
||||
value: self.inner as f64,
|
||||
elem: Elem::UInt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Index for i32 {
|
||||
fn value(self) -> Variable {
|
||||
Variable::ConstantScalar(self as f64, Elem::UInt)
|
||||
Variable::ConstantScalar {
|
||||
value: self as f64,
|
||||
elem: Elem::UInt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Index for u32 {
|
||||
fn value(self) -> Variable {
|
||||
Variable::ConstantScalar(self as f64, Elem::UInt)
|
||||
Variable::ConstantScalar {
|
||||
value: self as f64,
|
||||
elem: Elem::UInt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Index for UInt {
|
||||
fn value(self) -> Variable {
|
||||
Variable::ConstantScalar(self.val as f64, Elem::UInt)
|
||||
Variable::ConstantScalar {
|
||||
value: self.val as f64,
|
||||
elem: Elem::UInt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ pub mod assign {
|
|||
}
|
||||
|
||||
pub mod index_assign {
|
||||
use crate::{frontend::CubeType, unexpanded};
|
||||
use crate::{frontend::CubeType, prelude::SliceMut, unexpanded};
|
||||
|
||||
use self::ir::{BinaryOperator, Operator, Variable};
|
||||
|
||||
|
@ -34,7 +34,10 @@ pub mod index_assign {
|
|||
let array = array.into();
|
||||
let index: Variable = *index.into();
|
||||
let index = match index {
|
||||
Variable::ConstantScalar(val, _) => Variable::ConstantScalar(val, ir::Elem::UInt),
|
||||
Variable::ConstantScalar { value, .. } => Variable::ConstantScalar {
|
||||
value,
|
||||
elem: ir::Elem::UInt,
|
||||
},
|
||||
_ => index,
|
||||
};
|
||||
context.register(Operator::IndexAssign(BinaryOperator {
|
||||
|
@ -58,6 +61,12 @@ pub mod index_assign {
|
|||
impl_index!(Array);
|
||||
impl_index!(Tensor);
|
||||
impl_index!(SharedMemory);
|
||||
|
||||
impl<'a, E: CubeType, I: Into<UInt>> core::ops::IndexMut<I> for SliceMut<'a, E> {
|
||||
fn index_mut(&mut self, _index: I) -> &mut Self::Output {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod index {
|
||||
|
@ -66,6 +75,7 @@ pub mod index {
|
|||
operation::base::{binary_expand, binary_expand_no_vec},
|
||||
CubeType,
|
||||
},
|
||||
prelude::{Slice, SliceMut},
|
||||
unexpanded,
|
||||
};
|
||||
|
||||
|
@ -78,20 +88,21 @@ pub mod index {
|
|||
array: L,
|
||||
index: R,
|
||||
) -> ExpandElement {
|
||||
let index = index.into();
|
||||
let index: ExpandElement = index.into();
|
||||
let index_var: Variable = *index;
|
||||
let index = match index_var {
|
||||
Variable::ConstantScalar(val, _) => {
|
||||
ExpandElement::Plain(Variable::ConstantScalar(val, ir::Elem::UInt))
|
||||
Variable::ConstantScalar { value, .. } => {
|
||||
ExpandElement::Plain(Variable::ConstantScalar {
|
||||
value,
|
||||
elem: ir::Elem::UInt,
|
||||
})
|
||||
}
|
||||
_ => index,
|
||||
};
|
||||
let array = array.into();
|
||||
let array: ExpandElement = array.into();
|
||||
let var: Variable = *array;
|
||||
match var {
|
||||
Variable::Local(_, _, _) => {
|
||||
binary_expand_no_vec(context, array, index, Operator::Index)
|
||||
}
|
||||
Variable::Local { .. } => binary_expand_no_vec(context, array, index, Operator::Index),
|
||||
_ => binary_expand(context, array, index, Operator::Index),
|
||||
}
|
||||
}
|
||||
|
@ -111,6 +122,20 @@ pub mod index {
|
|||
impl_index!(Array);
|
||||
impl_index!(Tensor);
|
||||
impl_index!(SharedMemory);
|
||||
|
||||
impl<'a, E: CubeType, I: Into<UInt>> core::ops::Index<I> for SliceMut<'a, E> {
|
||||
type Output = E;
|
||||
fn index(&self, _index: I) -> &Self::Output {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, E: CubeType, I: Into<UInt>> core::ops::Index<I> for Slice<'a, E> {
|
||||
type Output = E;
|
||||
fn index(&self, _index: I) -> &Self::Output {
|
||||
unexpanded!()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub mod add_assign_array_op {
|
||||
|
|
|
@ -352,7 +352,7 @@ macro_rules! cpa {
|
|||
};
|
||||
// out = len(array)
|
||||
($scope:expr, $out:ident = len($input:expr)) => {
|
||||
$scope.register($crate::ir::Metadata::ArrayLength {
|
||||
$scope.register($crate::ir::Metadata::Length {
|
||||
var: $input.into(),
|
||||
out: $out.into(),
|
||||
});
|
||||
|
@ -398,43 +398,64 @@ macro_rules! cpa {
|
|||
|
||||
impl From<bool> for Variable {
|
||||
fn from(value: bool) -> Self {
|
||||
Self::ConstantScalar(if value { 1.0 } else { 0.0 }, super::Elem::Bool)
|
||||
Self::ConstantScalar {
|
||||
value: if value { 1.0 } else { 0.0 },
|
||||
elem: super::Elem::Bool,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i32> for Variable {
|
||||
fn from(value: i32) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I32))
|
||||
Self::ConstantScalar {
|
||||
value: value as f64,
|
||||
elem: super::Elem::Int(super::IntKind::I32),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<i64> for Variable {
|
||||
fn from(value: i64) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I64))
|
||||
Self::ConstantScalar {
|
||||
value: value as f64,
|
||||
elem: super::Elem::Int(super::IntKind::I64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f32> for Variable {
|
||||
fn from(value: f32) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F32))
|
||||
Self::ConstantScalar {
|
||||
value: value as f64,
|
||||
elem: super::Elem::Float(super::FloatKind::F32),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<f64> for Variable {
|
||||
fn from(value: f64) -> Self {
|
||||
Self::ConstantScalar(value, super::Elem::Float(super::FloatKind::F64))
|
||||
Self::ConstantScalar {
|
||||
value,
|
||||
elem: super::Elem::Float(super::FloatKind::F64),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<u32> for Variable {
|
||||
fn from(value: u32) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::UInt)
|
||||
Self::ConstantScalar {
|
||||
value: value as f64,
|
||||
elem: super::Elem::UInt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<usize> for Variable {
|
||||
fn from(value: usize) -> Self {
|
||||
Self::ConstantScalar(value as f64, super::Elem::UInt)
|
||||
Self::ConstantScalar {
|
||||
value: value as f64,
|
||||
elem: super::Elem::UInt,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -53,6 +53,7 @@ pub enum Operator {
|
|||
Assign(UnaryOperator),
|
||||
Modulo(BinaryOperator),
|
||||
Index(BinaryOperator),
|
||||
Slice(SliceOperator),
|
||||
UncheckedIndex(BinaryOperator),
|
||||
IndexAssign(BinaryOperator),
|
||||
UncheckedIndexAssign(BinaryOperator),
|
||||
|
@ -84,7 +85,7 @@ pub enum Metadata {
|
|||
var: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
ArrayLength {
|
||||
Length {
|
||||
var: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
|
@ -120,6 +121,15 @@ pub struct ClampOperator {
|
|||
pub out: Variable,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct SliceOperator {
|
||||
pub input: Variable,
|
||||
pub start: Variable,
|
||||
pub end: Variable,
|
||||
pub out: Variable,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
#[allow(missing_docs)]
|
||||
pub struct ReadGlobalOperator {
|
||||
|
|
|
@ -19,6 +19,7 @@ pub struct Scope {
|
|||
pub operations: Vec<Operation>,
|
||||
locals: Vec<Variable>,
|
||||
matrices: Vec<Variable>,
|
||||
slices: Vec<Variable>,
|
||||
shared_memories: Vec<Variable>,
|
||||
local_arrays: Vec<Variable>,
|
||||
reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>,
|
||||
|
@ -49,6 +50,7 @@ impl Scope {
|
|||
operations: Vec::new(),
|
||||
locals: Vec::new(),
|
||||
matrices: Vec::new(),
|
||||
slices: Vec::new(),
|
||||
local_arrays: Vec::new(),
|
||||
shared_memories: Vec::new(),
|
||||
reads_global: Vec::new(),
|
||||
|
@ -75,7 +77,10 @@ impl Scope {
|
|||
I: Into<Item> + Copy,
|
||||
{
|
||||
let local = self.create_local(item);
|
||||
let value = Variable::ConstantScalar(value.into(), item.into().elem());
|
||||
let value = Variable::ConstantScalar {
|
||||
value: value.into(),
|
||||
elem: item.into().elem(),
|
||||
};
|
||||
cpa!(self, local = value);
|
||||
local
|
||||
}
|
||||
|
@ -83,16 +88,35 @@ impl Scope {
|
|||
/// Create a matrix variable
|
||||
pub fn create_matrix(&mut self, matrix: Matrix) -> Variable {
|
||||
let index = self.matrices.len() as u16;
|
||||
let variable = Variable::Matrix(index, matrix);
|
||||
let variable = Variable::Matrix {
|
||||
id: index,
|
||||
mat: matrix,
|
||||
};
|
||||
self.matrices.push(variable);
|
||||
variable
|
||||
}
|
||||
|
||||
/// Create a slice variable
|
||||
pub fn create_slice(&mut self, item: Item) -> Variable {
|
||||
let id = self.slices.len() as u16;
|
||||
let variable = Variable::Slice {
|
||||
id,
|
||||
item,
|
||||
depth: self.depth,
|
||||
};
|
||||
self.slices.push(variable);
|
||||
variable
|
||||
}
|
||||
|
||||
/// Create a local variable of the given [item type](Item).
|
||||
pub fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
|
||||
let item = item.into();
|
||||
let index = self.new_local_index();
|
||||
let local = Variable::Local(index, item, self.depth);
|
||||
let local = Variable::Local {
|
||||
id: index,
|
||||
item,
|
||||
depth: self.depth,
|
||||
};
|
||||
self.locals.push(local);
|
||||
local
|
||||
}
|
||||
|
@ -102,7 +126,11 @@ impl Scope {
|
|||
/// Useful for _for loops_ and other algorithms that require the control over initialization.
|
||||
pub fn create_local_undeclared(&mut self, item: Item) -> Variable {
|
||||
let index = self.new_local_index();
|
||||
let local = Variable::Local(index, item, self.depth);
|
||||
let local = Variable::Local {
|
||||
id: index,
|
||||
item,
|
||||
depth: self.depth,
|
||||
};
|
||||
self.undeclared += 1;
|
||||
local
|
||||
}
|
||||
|
@ -131,8 +159,12 @@ impl Scope {
|
|||
///
|
||||
/// The index refers to the scalar position for the same [element](Elem) type.
|
||||
pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable {
|
||||
let local = Variable::LocalScalar(self.new_local_scalar_index(), elem, self.depth);
|
||||
let scalar = Variable::GlobalScalar(index, elem);
|
||||
let local = Variable::LocalScalar {
|
||||
id: self.new_local_scalar_index(),
|
||||
elem,
|
||||
depth: self.depth,
|
||||
};
|
||||
let scalar = Variable::GlobalScalar { id: index, elem };
|
||||
|
||||
self.reads_scalar.push((local, scalar));
|
||||
|
||||
|
@ -215,7 +247,7 @@ impl Scope {
|
|||
self.reads_global
|
||||
.iter()
|
||||
.map(|(var, strategy, _, _)| match var {
|
||||
Variable::GlobalInputArray(id, _) => (*id, *strategy),
|
||||
Variable::GlobalInputArray { id, .. } => (*id, *strategy),
|
||||
_ => panic!("Can only read global input arrays."),
|
||||
})
|
||||
.collect()
|
||||
|
@ -233,6 +265,7 @@ impl Scope {
|
|||
operations: Vec::new(),
|
||||
locals: Vec::new(),
|
||||
matrices: Vec::new(),
|
||||
slices: Vec::new(),
|
||||
shared_memories: Vec::new(),
|
||||
local_arrays: Vec::new(),
|
||||
reads_global: Vec::new(),
|
||||
|
@ -259,6 +292,9 @@ impl Scope {
|
|||
for var in self.matrices.drain(..) {
|
||||
variables.push(var);
|
||||
}
|
||||
for var in self.slices.drain(..) {
|
||||
variables.push(var);
|
||||
}
|
||||
|
||||
for index in self.index_offset_with_output_layout_position.drain(..) {
|
||||
if let Some(Operation::Procedure(Procedure::IndexOffsetGlobalWithLayout(proc))) =
|
||||
|
@ -357,9 +393,16 @@ impl Scope {
|
|||
},
|
||||
_ => item,
|
||||
};
|
||||
let input = Variable::GlobalInputArray(index, item_global);
|
||||
let input = Variable::GlobalInputArray {
|
||||
id: index,
|
||||
item: item_global,
|
||||
};
|
||||
let index = self.new_local_index();
|
||||
let local = Variable::Local(index, item, self.depth);
|
||||
let local = Variable::Local {
|
||||
id: index,
|
||||
item,
|
||||
depth: self.depth,
|
||||
};
|
||||
self.reads_global.push((input, strategy, local, position));
|
||||
self.locals.push(local);
|
||||
local
|
||||
|
@ -369,7 +412,11 @@ impl Scope {
|
|||
pub fn create_shared<I: Into<Item>>(&mut self, item: I, shared_memory_size: u32) -> Variable {
|
||||
let item = item.into();
|
||||
let index = self.new_shared_index();
|
||||
let shared_memory = Variable::SharedMemory(index, item, shared_memory_size);
|
||||
let shared_memory = Variable::SharedMemory {
|
||||
id: index,
|
||||
item,
|
||||
length: shared_memory_size,
|
||||
};
|
||||
self.shared_memories.push(shared_memory);
|
||||
shared_memory
|
||||
}
|
||||
|
@ -378,7 +425,12 @@ impl Scope {
|
|||
pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> Variable {
|
||||
let item = item.into();
|
||||
let index = self.new_local_array_index();
|
||||
let local_array = Variable::LocalArray(index, item, self.depth, array_size);
|
||||
let local_array = Variable::LocalArray {
|
||||
id: index,
|
||||
item,
|
||||
depth: self.depth,
|
||||
length: array_size,
|
||||
};
|
||||
self.local_arrays.push(local_array);
|
||||
local_array
|
||||
}
|
||||
|
|
|
@ -5,14 +5,52 @@ use serde::{Deserialize, Serialize};
|
|||
#[allow(missing_docs)]
|
||||
pub enum Variable {
|
||||
Rank,
|
||||
GlobalInputArray(u16, Item),
|
||||
GlobalScalar(u16, Elem),
|
||||
GlobalOutputArray(u16, Item),
|
||||
Local(u16, Item, u8),
|
||||
LocalScalar(u16, Elem, u8),
|
||||
ConstantScalar(f64, Elem),
|
||||
SharedMemory(u16, Item, u32),
|
||||
LocalArray(u16, Item, u8, u32),
|
||||
GlobalInputArray {
|
||||
id: u16,
|
||||
item: Item,
|
||||
},
|
||||
GlobalScalar {
|
||||
id: u16,
|
||||
elem: Elem,
|
||||
},
|
||||
GlobalOutputArray {
|
||||
id: u16,
|
||||
item: Item,
|
||||
},
|
||||
Local {
|
||||
id: u16,
|
||||
item: Item,
|
||||
depth: u8,
|
||||
},
|
||||
LocalScalar {
|
||||
id: u16,
|
||||
elem: Elem,
|
||||
depth: u8,
|
||||
},
|
||||
ConstantScalar {
|
||||
value: f64,
|
||||
elem: Elem,
|
||||
},
|
||||
SharedMemory {
|
||||
id: u16,
|
||||
item: Item,
|
||||
length: u32,
|
||||
},
|
||||
LocalArray {
|
||||
id: u16,
|
||||
item: Item,
|
||||
depth: u8,
|
||||
length: u32,
|
||||
},
|
||||
Matrix {
|
||||
id: u16,
|
||||
mat: Matrix,
|
||||
},
|
||||
Slice {
|
||||
id: u16,
|
||||
item: Item,
|
||||
depth: u8,
|
||||
},
|
||||
UnitPos,
|
||||
UnitPosX,
|
||||
UnitPosY,
|
||||
|
@ -34,20 +72,21 @@ pub enum Variable {
|
|||
AbsolutePosX,
|
||||
AbsolutePosY,
|
||||
AbsolutePosZ,
|
||||
Matrix(u16, Matrix),
|
||||
}
|
||||
|
||||
impl Variable {
|
||||
pub fn index(&self) -> Option<u16> {
|
||||
match self {
|
||||
Variable::GlobalInputArray(idx, _) => Some(*idx),
|
||||
Variable::GlobalScalar(idx, _) => Some(*idx),
|
||||
Variable::Local(idx, _, _) => Some(*idx),
|
||||
Variable::LocalScalar(idx, _, _) => Some(*idx),
|
||||
Variable::GlobalOutputArray(idx, _) => Some(*idx),
|
||||
Variable::ConstantScalar(_, _) => None,
|
||||
Variable::SharedMemory(idx, _, _) => Some(*idx),
|
||||
Variable::LocalArray(idx, _, _, _) => Some(*idx),
|
||||
Variable::GlobalInputArray { id, .. } => Some(*id),
|
||||
Variable::GlobalScalar { id, .. } => Some(*id),
|
||||
Variable::Local { id, .. } => Some(*id),
|
||||
Variable::Slice { id, .. } => Some(*id),
|
||||
Variable::LocalScalar { id, .. } => Some(*id),
|
||||
Variable::GlobalOutputArray { id, .. } => Some(*id),
|
||||
Variable::ConstantScalar { .. } => None,
|
||||
Variable::SharedMemory { id, .. } => Some(*id),
|
||||
Variable::LocalArray { id, .. } => Some(*id),
|
||||
Variable::Matrix { id, .. } => Some(*id),
|
||||
Variable::AbsolutePos => None,
|
||||
Variable::UnitPos => None,
|
||||
Variable::UnitPosX => None,
|
||||
|
@ -70,21 +109,22 @@ impl Variable {
|
|||
Variable::CubeCount => None,
|
||||
Variable::CubeDim => None,
|
||||
Variable::SubcubeDim => None,
|
||||
Variable::Matrix(idx, _) => Some(*idx),
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch the item of the variable.
|
||||
pub fn item(&self) -> Item {
|
||||
match self {
|
||||
Variable::GlobalInputArray(_, item) => *item,
|
||||
Variable::GlobalOutputArray(_, item) => *item,
|
||||
Variable::GlobalScalar(_, elem) => Item::new(*elem),
|
||||
Variable::Local(_, item, _) => *item,
|
||||
Variable::LocalScalar(_, elem, _) => Item::new(*elem),
|
||||
Variable::ConstantScalar(_, elem) => Item::new(*elem),
|
||||
Variable::SharedMemory(_, item, _) => *item,
|
||||
Variable::LocalArray(_, item, _, _) => *item,
|
||||
Variable::GlobalInputArray { item, .. } => *item,
|
||||
Variable::GlobalOutputArray { item, .. } => *item,
|
||||
Variable::GlobalScalar { elem, .. } => Item::new(*elem),
|
||||
Variable::Local { item, .. } => *item,
|
||||
Variable::LocalScalar { elem, .. } => Item::new(*elem),
|
||||
Variable::ConstantScalar { elem, .. } => Item::new(*elem),
|
||||
Variable::SharedMemory { item, .. } => *item,
|
||||
Variable::LocalArray { item, .. } => *item,
|
||||
Variable::Slice { item, .. } => *item,
|
||||
Variable::Matrix { mat, .. } => Item::new(mat.elem),
|
||||
Variable::AbsolutePos => Item::new(Elem::UInt),
|
||||
Variable::Rank => Item::new(Elem::UInt),
|
||||
Variable::UnitPos => Item::new(Elem::UInt),
|
||||
|
@ -107,7 +147,6 @@ impl Variable {
|
|||
Variable::CubeCount => Item::new(Elem::UInt),
|
||||
Variable::CubeDim => Item::new(Elem::UInt),
|
||||
Variable::SubcubeDim => Item::new(Elem::UInt),
|
||||
Variable::Matrix(_, matrix) => Item::new(matrix.elem),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::{
|
||||
BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator, Subcube,
|
||||
UnaryOperator, Variable,
|
||||
BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator,
|
||||
SliceOperator, Subcube, UnaryOperator, Variable,
|
||||
};
|
||||
|
||||
pub type Vectorization = u8;
|
||||
|
@ -60,7 +60,7 @@ impl Operator {
|
|||
Operator::LowerEqual(op) => Operator::LowerEqual(op.vectorize(vectorization)),
|
||||
Operator::GreaterEqual(op) => Operator::GreaterEqual(op.vectorize(vectorization)),
|
||||
Operator::Assign(op) => {
|
||||
if let Variable::GlobalScalar(_, _) = op.input {
|
||||
if let Variable::GlobalScalar { .. } = op.input {
|
||||
// Assign will not change the type of the output if the input can't be
|
||||
// vectorized.
|
||||
return Operator::Assign(op.clone());
|
||||
|
@ -81,6 +81,7 @@ impl Operator {
|
|||
Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)),
|
||||
Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)),
|
||||
Operator::Remainder(op) => Operator::Remainder(op.vectorize(vectorization)),
|
||||
Operator::Slice(op) => Operator::Slice(op.vectorize(vectorization)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -104,6 +105,22 @@ impl UnaryOperator {
|
|||
}
|
||||
}
|
||||
|
||||
impl SliceOperator {
|
||||
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
let input = self.input.vectorize(vectorization);
|
||||
let start = self.start.vectorize(vectorization);
|
||||
let end = self.end.vectorize(vectorization);
|
||||
let out = self.out.vectorize(vectorization);
|
||||
|
||||
Self {
|
||||
input,
|
||||
start,
|
||||
end,
|
||||
out,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl InitOperator {
|
||||
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||
let out = self.out.vectorize(vectorization);
|
||||
|
@ -155,31 +172,46 @@ impl FmaOperator {
|
|||
impl Variable {
|
||||
pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self {
|
||||
match self {
|
||||
Variable::GlobalInputArray(index, item) => {
|
||||
Variable::GlobalInputArray(*index, item.vectorize(vectorize))
|
||||
}
|
||||
Variable::Local(index, item, name) => {
|
||||
Variable::Local(*index, item.vectorize(vectorize), *name)
|
||||
}
|
||||
Variable::GlobalOutputArray(index, item) => {
|
||||
Variable::GlobalOutputArray(*index, item.vectorize(vectorize))
|
||||
}
|
||||
Variable::SharedMemory(index, item, size) => Variable::SharedMemory(
|
||||
*index,
|
||||
item.vectorize(vectorize),
|
||||
item.vectorized_size(vectorize, *size),
|
||||
),
|
||||
Variable::LocalArray(index, item, name, size) => Variable::LocalArray(
|
||||
*index,
|
||||
item.vectorize(vectorize),
|
||||
*name,
|
||||
item.vectorized_size(vectorize, *size),
|
||||
),
|
||||
Variable::ConstantScalar(_, _) => *self,
|
||||
Variable::GlobalScalar(_, _) => *self,
|
||||
Variable::GlobalInputArray { id, item } => Variable::GlobalInputArray {
|
||||
id: *id,
|
||||
item: item.vectorize(vectorize),
|
||||
},
|
||||
Variable::Local { id, item, depth } => Variable::Local {
|
||||
id: *id,
|
||||
item: item.vectorize(vectorize),
|
||||
depth: *depth,
|
||||
},
|
||||
Variable::Slice { id, item, depth } => Variable::Slice {
|
||||
id: *id,
|
||||
item: item.vectorize(vectorize),
|
||||
depth: *depth,
|
||||
},
|
||||
Variable::GlobalOutputArray { id, item } => Variable::GlobalOutputArray {
|
||||
id: *id,
|
||||
item: item.vectorize(vectorize),
|
||||
},
|
||||
Variable::SharedMemory { id, item, length } => Variable::SharedMemory {
|
||||
id: *id,
|
||||
item: item.vectorize(vectorize),
|
||||
length: item.vectorized_size(vectorize, *length),
|
||||
},
|
||||
Variable::LocalArray {
|
||||
id,
|
||||
item,
|
||||
depth,
|
||||
length,
|
||||
} => Variable::LocalArray {
|
||||
id: *id,
|
||||
item: item.vectorize(vectorize),
|
||||
depth: *depth,
|
||||
length: item.vectorized_size(vectorize, *length),
|
||||
},
|
||||
Variable::ConstantScalar { .. } => *self,
|
||||
Variable::GlobalScalar { .. } => *self,
|
||||
Variable::AbsolutePos => *self,
|
||||
Variable::Rank => *self,
|
||||
Variable::LocalScalar(_, _, _) => *self,
|
||||
Variable::LocalScalar { .. } => *self,
|
||||
Variable::Matrix { .. } => *self,
|
||||
Variable::UnitPos => *self,
|
||||
Variable::UnitPosX => *self,
|
||||
Variable::UnitPosY => *self,
|
||||
|
@ -200,7 +232,6 @@ impl Variable {
|
|||
Variable::CubeCount => *self,
|
||||
Variable::CubeDim => *self,
|
||||
Variable::SubcubeDim => *self,
|
||||
Variable::Matrix(_, _) => *self,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,8 @@ pub use crate::runtime::Runtime;
|
|||
|
||||
/// Elements
|
||||
pub use crate::frontend::{
|
||||
Array, ArrayHandle, Bool, Float, LaunchArg, Tensor, TensorArg, UInt, F16, F32, F64, I32, I64,
|
||||
Array, ArrayHandle, Bool, Float, LaunchArg, Slice, SliceMut, Tensor, TensorArg, UInt, F16, F32,
|
||||
F64, I32, I64,
|
||||
};
|
||||
pub use crate::pod::CubeElement;
|
||||
|
||||
|
|
|
@ -31,12 +31,17 @@ pub fn kernel_simple_1(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>)
|
|||
cmma::MatrixLayout::Undefined,
|
||||
);
|
||||
cmma::fill::<F32>(&c, F32::new(0.0));
|
||||
cmma::load::<F16>(&a, lhs, UInt::new(16));
|
||||
cmma::load::<F16>(&b, rhs, UInt::new(16));
|
||||
cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
|
||||
cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
|
||||
|
||||
cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
|
||||
|
||||
cmma::store::<F32>(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor);
|
||||
cmma::store::<F32>(
|
||||
out.as_slice_mut(),
|
||||
&c,
|
||||
UInt::new(16),
|
||||
cmma::MatrixLayout::RowMajor,
|
||||
);
|
||||
}
|
||||
|
||||
pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
pub mod cmma;
|
||||
pub mod launch;
|
||||
pub mod slice;
|
||||
pub mod subcube;
|
||||
|
||||
#[allow(missing_docs)]
|
||||
|
@ -11,5 +12,6 @@ macro_rules! testgen_all {
|
|||
burn_cube::testgen_subcube!();
|
||||
burn_cube::testgen_launch!();
|
||||
burn_cube::testgen_cmma!();
|
||||
burn_cube::testgen_slice!();
|
||||
};
|
||||
}
|
||||
|
|
|
@ -0,0 +1,107 @@
|
|||
use crate as burn_cube;
|
||||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube(launch)]
|
||||
pub fn slice_select<F: Float>(input: &Array<F>, output: &mut Array<F>) {
|
||||
if UNIT_POS == UInt::new(0) {
|
||||
let slice = input.slice(2, 3);
|
||||
output[0] = slice[0u32];
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
pub fn slice_assign<F: Float>(input: &Array<F>, output: &mut Array<F>) {
|
||||
if UNIT_POS == UInt::new(0) {
|
||||
let slice_1 = output.slice_mut(2, 3);
|
||||
slice_1[0] = input[0u32];
|
||||
}
|
||||
}
|
||||
|
||||
#[cube(launch)]
|
||||
pub fn slice_len<F: Float>(input: &Array<F>, output: &mut Array<UInt>) {
|
||||
if UNIT_POS == UInt::new(0) {
|
||||
let slice = input.slice(2, 4);
|
||||
let _tmp = slice[0]; // It must be used at least once, otherwise wgpu isn't happy.
|
||||
output[0] = slice.len();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn test_slice_select<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
|
||||
let output = client.empty(core::mem::size_of::<f32>());
|
||||
|
||||
slice_select_launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::new(1, 1, 1),
|
||||
ArrayArg::new(&input, 5),
|
||||
ArrayArg::new(&output, 1),
|
||||
);
|
||||
|
||||
let actual = client.read(output.binding());
|
||||
let actual = f32::from_bytes(&actual);
|
||||
|
||||
assert_eq!(actual[0], 2.0);
|
||||
}
|
||||
|
||||
pub fn test_slice_len<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
|
||||
let output = client.empty(core::mem::size_of::<u32>());
|
||||
|
||||
slice_len_launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::new(1, 1, 1),
|
||||
ArrayArg::new(&input, 5),
|
||||
ArrayArg::new(&output, 1),
|
||||
);
|
||||
|
||||
let actual = client.read(output.binding());
|
||||
let actual = u32::from_bytes(&actual);
|
||||
|
||||
assert_eq!(actual, &[2]);
|
||||
}
|
||||
|
||||
pub fn test_slice_assign<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||
let input = client.create(f32::as_bytes(&[15.0]));
|
||||
let output = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
|
||||
|
||||
slice_assign_launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::new(1, 1, 1),
|
||||
ArrayArg::new(&input, 5),
|
||||
ArrayArg::new(&output, 1),
|
||||
);
|
||||
|
||||
let actual = client.read(output.binding());
|
||||
let actual = f32::from_bytes(&actual);
|
||||
|
||||
assert_eq!(actual, &[0.0, 1.0, 15.0, 3.0, 4.0]);
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
#[macro_export]
|
||||
macro_rules! testgen_slice {
|
||||
() => {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_slice_select() {
|
||||
let client = TestRuntime::client(&Default::default());
|
||||
burn_cube::runtime_tests::slice::test_slice_select::<TestRuntime>(client);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_assign() {
|
||||
let client = TestRuntime::client(&Default::default());
|
||||
burn_cube::runtime_tests::slice::test_slice_assign::<TestRuntime>(client);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_slice_len() {
|
||||
let client = TestRuntime::client(&Default::default());
|
||||
burn_cube::runtime_tests::slice::test_slice_len::<TestRuntime>(client);
|
||||
}
|
||||
};
|
||||
}
|
|
@ -36,7 +36,7 @@ mod tests {
|
|||
use super::*;
|
||||
use burn_cube::{
|
||||
cpa,
|
||||
ir::{Elem, Item, Variable},
|
||||
ir::{self, Elem, Item, Variable},
|
||||
};
|
||||
|
||||
type ElemType = F32;
|
||||
|
@ -47,7 +47,7 @@ mod tests {
|
|||
|
||||
array_read_write::__expand::<ElemType>(&mut context, 512);
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
context.into_scope().operations,
|
||||
inline_macro_ref_read_write()
|
||||
)
|
||||
}
|
||||
|
@ -60,10 +60,7 @@ mod tests {
|
|||
array_add_assign_simple::__expand(&mut context, array.into());
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", scope.operations),
|
||||
inline_macro_array_add_assign_simple()
|
||||
);
|
||||
assert_eq!(scope.operations, inline_macro_array_add_assign_simple());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -72,7 +69,7 @@ mod tests {
|
|||
|
||||
array_to_vectorized_variable::__expand::<ElemType>(&mut context);
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
context.into_scope().operations,
|
||||
inline_macro_ref_to_vectorized()
|
||||
);
|
||||
}
|
||||
|
@ -83,12 +80,12 @@ mod tests {
|
|||
|
||||
array_of_one_to_vectorized_variable::__expand::<ElemType>(&mut context);
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
context.into_scope().operations,
|
||||
inline_macro_ref_one_to_vectorized()
|
||||
);
|
||||
}
|
||||
|
||||
fn inline_macro_ref_read_write() -> String {
|
||||
fn inline_macro_ref_read_write() -> Vec<ir::Operation> {
|
||||
let context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
|
||||
|
@ -105,7 +102,7 @@ mod tests {
|
|||
// Read
|
||||
cpa!(scope, var = array[pos]);
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
scope.operations
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -116,30 +113,36 @@ mod tests {
|
|||
array_add_assign_expr::__expand(&mut context, array.into());
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", scope.operations),
|
||||
inline_macro_array_add_assign_expr()
|
||||
);
|
||||
assert_eq!(scope.operations, inline_macro_array_add_assign_expr());
|
||||
}
|
||||
|
||||
fn inline_macro_array_add_assign_simple() -> String {
|
||||
fn inline_macro_array_add_assign_simple() -> Vec<ir::Operation> {
|
||||
let context = CubeContext::root();
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let local = scope.create_local(Item::new(Elem::UInt));
|
||||
|
||||
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
|
||||
let index = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let value = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let array = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: Item::new(Elem::UInt),
|
||||
};
|
||||
let index = Variable::ConstantScalar {
|
||||
value: 1.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let value = Variable::ConstantScalar {
|
||||
value: 1.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
|
||||
cpa!(scope, local = array[index]);
|
||||
cpa!(scope, local += value);
|
||||
cpa!(scope, array[index] = local);
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
scope.operations
|
||||
}
|
||||
|
||||
fn inline_macro_ref_to_vectorized() -> String {
|
||||
fn inline_macro_ref_to_vectorized() -> Vec<ir::Operation> {
|
||||
let context = CubeContext::root();
|
||||
let scalar_item = Item::new(ElemType::as_elem());
|
||||
let vectorized_item = Item::vectorized(ElemType::as_elem(), 2);
|
||||
|
@ -158,10 +161,10 @@ mod tests {
|
|||
cpa!(scope, tmp = array[pos1]);
|
||||
cpa!(scope, vectorized_var[pos1] = tmp);
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
scope.operations
|
||||
}
|
||||
|
||||
fn inline_macro_ref_one_to_vectorized() -> String {
|
||||
fn inline_macro_ref_one_to_vectorized() -> Vec<ir::Operation> {
|
||||
let context = CubeContext::root();
|
||||
let scalar_item = Item::new(ElemType::as_elem());
|
||||
let unvectorized_item = Item::new(ElemType::as_elem());
|
||||
|
@ -176,26 +179,38 @@ mod tests {
|
|||
cpa!(scope, tmp = array[pos0]);
|
||||
cpa!(scope, unvectorized_var = tmp);
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
scope.operations
|
||||
}
|
||||
|
||||
fn inline_macro_array_add_assign_expr() -> String {
|
||||
fn inline_macro_array_add_assign_expr() -> Vec<ir::Operation> {
|
||||
let context = CubeContext::root();
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let index = scope.create_local(Item::new(Elem::UInt));
|
||||
let local = scope.create_local(Item::new(Elem::UInt));
|
||||
|
||||
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
|
||||
let const1 = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let const2 = Variable::ConstantScalar(5., Elem::UInt);
|
||||
let value = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let array = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: Item::new(Elem::UInt),
|
||||
};
|
||||
let const1 = Variable::ConstantScalar {
|
||||
value: 1.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let const2 = Variable::ConstantScalar {
|
||||
value: 5.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let value = Variable::ConstantScalar {
|
||||
value: 1.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
|
||||
cpa!(scope, index = const1 + const2);
|
||||
cpa!(scope, local = array[index]);
|
||||
cpa!(scope, local += value);
|
||||
cpa!(scope, array[index] = local);
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
scope.operations
|
||||
}
|
||||
}
|
||||
|
|
|
@ -98,8 +98,14 @@ mod tests {
|
|||
let mut scope = context.into_scope();
|
||||
let x = scope.create_local(Item::new(Elem::UInt));
|
||||
|
||||
let zero = Variable::ConstantScalar(0., Elem::UInt);
|
||||
let one = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let zero = Variable::ConstantScalar {
|
||||
value: 0.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let one = Variable::ConstantScalar {
|
||||
value: 1.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
cpa!(scope, x = zero);
|
||||
cpa!(scope, x = x + one);
|
||||
|
||||
|
@ -115,8 +121,14 @@ mod tests {
|
|||
let y: Variable = y.into();
|
||||
let x = scope.create_local(item);
|
||||
|
||||
let one = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let two = Variable::ConstantScalar(2., Elem::UInt);
|
||||
let one = Variable::ConstantScalar {
|
||||
value: 1.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let two = Variable::ConstantScalar {
|
||||
value: 2.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
cpa!(scope, x = y);
|
||||
cpa!(scope, x = x + one);
|
||||
cpa!(scope, y = y + two);
|
||||
|
@ -133,8 +145,14 @@ mod tests {
|
|||
let y: Variable = y.into();
|
||||
let x = scope.create_local(item);
|
||||
|
||||
let one = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let two = Variable::ConstantScalar(2., Elem::UInt);
|
||||
let one = Variable::ConstantScalar {
|
||||
value: 1.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let two = Variable::ConstantScalar {
|
||||
value: 2.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
cpa!(scope, x = y);
|
||||
cpa!(scope, y = y + one);
|
||||
cpa!(scope, x = x + two);
|
||||
|
@ -151,10 +169,22 @@ mod tests {
|
|||
let y: Variable = y.into();
|
||||
let x = scope.create_local(item);
|
||||
|
||||
let zero = Variable::ConstantScalar(0., Elem::UInt);
|
||||
let one = Variable::ConstantScalar(1., Elem::UInt);
|
||||
let two = Variable::ConstantScalar(2., Elem::UInt);
|
||||
let three = Variable::ConstantScalar(3., Elem::UInt);
|
||||
let zero = Variable::ConstantScalar {
|
||||
value: 0.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let one = Variable::ConstantScalar {
|
||||
value: 1.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let two = Variable::ConstantScalar {
|
||||
value: 2.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let three = Variable::ConstantScalar {
|
||||
value: 3.,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
cpa!(scope, x[zero] = one);
|
||||
cpa!(scope, x[one] = one);
|
||||
cpa!(scope, x[two] = one);
|
||||
|
|
|
@ -395,31 +395,31 @@ mod tests {
|
|||
let out_number = if in_type == out_type { 0 } else { 2 };
|
||||
format!(
|
||||
"[Operator({ops_name}(BinaryOperator {{ \
|
||||
lhs: Local(0, Item {{ \
|
||||
lhs: Local {{ id: 0, item: Item {{ \
|
||||
elem: {in_type}, \
|
||||
vectorization: 1 \
|
||||
}}, 0), \
|
||||
rhs: Local(1, Item {{ \
|
||||
}}, depth: 0 }}, \
|
||||
rhs: Local {{ id: 1, item: Item {{ \
|
||||
elem: {in_type}, \
|
||||
vectorization: 1 \
|
||||
}}, 0), \
|
||||
out: Local({out_number}, Item {{ \
|
||||
}}, depth: 0 }}, \
|
||||
out: Local {{ id: {out_number}, item: Item {{ \
|
||||
elem: {out_type}, \
|
||||
vectorization: 1 \
|
||||
}}, 0) \
|
||||
}}, depth: 0 }} \
|
||||
}}))]"
|
||||
)
|
||||
} else {
|
||||
format!(
|
||||
"[Operator({ops_name}(UnaryOperator {{ \
|
||||
input: Local(0, Item {{ \
|
||||
input: Local {{ id: 0, item: Item {{ \
|
||||
elem: {in_type}, \
|
||||
vectorization: 1 \
|
||||
}}, 0), \
|
||||
out: Local(0, Item {{ \
|
||||
}}, depth: 0 }}, \
|
||||
out: Local {{ id: 0, item: Item {{ \
|
||||
elem: {out_type}, \
|
||||
vectorization: 1 \
|
||||
}}, 0) \
|
||||
}}, depth: 0 }} \
|
||||
}}))]"
|
||||
)
|
||||
}
|
||||
|
|
|
@ -117,7 +117,10 @@ mod tests {
|
|||
|
||||
let i = scope.create_with_value(1, item);
|
||||
cpa!(scope, x += i);
|
||||
let value = Variable::ConstantScalar(2., item.elem());
|
||||
let value = Variable::ConstantScalar {
|
||||
value: 2.,
|
||||
elem: item.elem(),
|
||||
};
|
||||
cpa!(scope, i = value);
|
||||
cpa!(scope, x += i);
|
||||
|
||||
|
@ -156,7 +159,10 @@ mod tests {
|
|||
cpa!(
|
||||
&mut scope,
|
||||
range(0u32, end, false).for_each(|_, scope| {
|
||||
let value = Variable::ConstantScalar(2.into(), item.elem());
|
||||
let value = Variable::ConstantScalar {
|
||||
value: 2.into(),
|
||||
elem: item.elem(),
|
||||
};
|
||||
cpa!(scope, y = value);
|
||||
cpa!(scope, x += y);
|
||||
})
|
||||
|
|
|
@ -83,6 +83,9 @@ impl CudaCompiler {
|
|||
let processing = value.process();
|
||||
|
||||
for var in processing.variables {
|
||||
if let gpu::Variable::Slice { .. } = var {
|
||||
continue;
|
||||
}
|
||||
instructions.push(Instruction::DeclareVariable {
|
||||
var: self.compile_variable(var),
|
||||
});
|
||||
|
@ -192,8 +195,8 @@ impl CudaCompiler {
|
|||
gpu::Metadata::Stride { dim, var, out } => {
|
||||
self.stride = true;
|
||||
let position = match var {
|
||||
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
|
||||
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
||||
gpu::Variable::GlobalInputArray { id, .. } => id as usize,
|
||||
gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
|
||||
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
|
||||
};
|
||||
Instruction::Stride {
|
||||
|
@ -205,8 +208,8 @@ impl CudaCompiler {
|
|||
gpu::Metadata::Shape { dim, var, out } => {
|
||||
self.shape = true;
|
||||
let position = match var {
|
||||
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
|
||||
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
||||
gpu::Variable::GlobalInputArray { id, .. } => id as usize,
|
||||
gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
|
||||
_ => panic!("Only Input and Output have a shape, got {:?}", var),
|
||||
};
|
||||
Instruction::Shape {
|
||||
|
@ -215,12 +218,20 @@ impl CudaCompiler {
|
|||
out: self.compile_variable(out),
|
||||
}
|
||||
}
|
||||
gpu::Metadata::ArrayLength { var, out } => super::Instruction::ArrayLength {
|
||||
input: self.compile_variable(var),
|
||||
out: self.compile_variable(out),
|
||||
num_inputs: self.num_inputs,
|
||||
num_outputs: self.num_outputs,
|
||||
},
|
||||
gpu::Metadata::Length { var, out } => {
|
||||
let input = self.compile_variable(var);
|
||||
let out = self.compile_variable(out);
|
||||
|
||||
match input {
|
||||
super::Variable::Slice { .. } => super::Instruction::SliceLength { input, out },
|
||||
_ => super::Instruction::Length {
|
||||
input,
|
||||
out,
|
||||
num_inputs: self.num_inputs,
|
||||
num_outputs: self.num_outputs,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -297,6 +308,12 @@ impl CudaCompiler {
|
|||
gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)),
|
||||
gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)),
|
||||
gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)),
|
||||
gpu::Operator::Slice(op) => Instruction::Slice {
|
||||
input: self.compile_variable(op.input),
|
||||
start: self.compile_variable(op.start),
|
||||
end: self.compile_variable(op.end),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)),
|
||||
gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)),
|
||||
gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)),
|
||||
|
@ -372,35 +389,40 @@ impl CudaCompiler {
|
|||
|
||||
fn compile_variable(&mut self, value: gpu::Variable) -> super::Variable {
|
||||
match value {
|
||||
gpu::Variable::GlobalInputArray(index, item) => {
|
||||
super::Variable::GlobalInputArray(index, Self::compile_item(item))
|
||||
gpu::Variable::GlobalInputArray { id, item } => {
|
||||
super::Variable::GlobalInputArray(id, Self::compile_item(item))
|
||||
}
|
||||
gpu::Variable::GlobalScalar(index, elem) => {
|
||||
super::Variable::GlobalScalar(index, Self::compile_elem(elem), elem)
|
||||
gpu::Variable::GlobalScalar { id, elem } => {
|
||||
super::Variable::GlobalScalar(id, Self::compile_elem(elem), elem)
|
||||
}
|
||||
gpu::Variable::Local(index, item, scope_depth) => super::Variable::Local {
|
||||
index,
|
||||
gpu::Variable::Local { id, item, depth } => super::Variable::Local {
|
||||
id,
|
||||
item: Self::compile_item(item),
|
||||
scope_depth,
|
||||
depth,
|
||||
},
|
||||
gpu::Variable::LocalScalar(index, elem, scope_depth) => super::Variable::LocalScalar {
|
||||
index,
|
||||
gpu::Variable::Slice { id, item, depth } => super::Variable::Slice {
|
||||
id,
|
||||
item: Self::compile_item(item),
|
||||
depth,
|
||||
},
|
||||
gpu::Variable::LocalScalar { id, elem, depth } => super::Variable::LocalScalar {
|
||||
id,
|
||||
elem: Self::compile_elem(elem),
|
||||
scope_depth,
|
||||
depth,
|
||||
},
|
||||
gpu::Variable::GlobalOutputArray(index, item) => {
|
||||
super::Variable::GlobalOutputArray(index, Self::compile_item(item))
|
||||
gpu::Variable::GlobalOutputArray { id, item } => {
|
||||
super::Variable::GlobalOutputArray(id, Self::compile_item(item))
|
||||
}
|
||||
gpu::Variable::ConstantScalar(index, elem) => {
|
||||
super::Variable::ConstantScalar(index, Self::compile_elem(elem))
|
||||
gpu::Variable::ConstantScalar { value, elem } => {
|
||||
super::Variable::ConstantScalar(value, Self::compile_elem(elem))
|
||||
}
|
||||
gpu::Variable::SharedMemory(index, item, size) => {
|
||||
gpu::Variable::SharedMemory { id, item, length } => {
|
||||
let item = Self::compile_item(item);
|
||||
if !self.shared_memories.iter().any(|s| s.index == index) {
|
||||
if !self.shared_memories.iter().any(|s| s.index == id) {
|
||||
self.shared_memories
|
||||
.push(super::SharedMemory::new(index, item, size));
|
||||
.push(super::SharedMemory::new(id, item, length));
|
||||
}
|
||||
super::Variable::SharedMemory(index, item, size)
|
||||
super::Variable::SharedMemory(id, item, length)
|
||||
}
|
||||
gpu::Variable::AbsolutePos => {
|
||||
self.id = true;
|
||||
|
@ -438,7 +460,12 @@ impl CudaCompiler {
|
|||
gpu::Variable::CubeCountX => super::Variable::NumWorkgroupsX,
|
||||
gpu::Variable::CubeCountY => super::Variable::NumWorkgroupsY,
|
||||
gpu::Variable::CubeCountZ => super::Variable::NumWorkgroupsZ,
|
||||
gpu::Variable::LocalArray(id, item, depth, size) => {
|
||||
gpu::Variable::LocalArray {
|
||||
id,
|
||||
item,
|
||||
depth,
|
||||
length,
|
||||
} => {
|
||||
let item = Self::compile_item(item);
|
||||
if !self
|
||||
.local_arrays
|
||||
|
@ -446,19 +473,19 @@ impl CudaCompiler {
|
|||
.any(|s| s.index == id && s.depth == depth)
|
||||
{
|
||||
self.local_arrays
|
||||
.push(super::LocalArray::new(id, item, depth, size));
|
||||
.push(super::LocalArray::new(id, item, depth, length));
|
||||
}
|
||||
super::Variable::LocalArray(id, item, depth, size)
|
||||
super::Variable::LocalArray(id, item, depth, length)
|
||||
}
|
||||
gpu::Variable::CubePos => todo!(),
|
||||
gpu::Variable::CubeDim => todo!(),
|
||||
gpu::Variable::CubeCount => todo!(),
|
||||
gpu::Variable::SubcubeDim => todo!(),
|
||||
gpu::Variable::Matrix(index, matrix) => {
|
||||
gpu::Variable::Matrix { id, mat } => {
|
||||
self.wmma = true;
|
||||
super::Variable::WmmaFragment {
|
||||
index,
|
||||
frag: Self::compile_matrix(matrix),
|
||||
id,
|
||||
frag: Self::compile_matrix(mat),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -361,9 +361,9 @@ impl Binary for IndexAssign {
|
|||
out: &Variable,
|
||||
) -> std::fmt::Result {
|
||||
if let Variable::Local {
|
||||
index: _,
|
||||
id: _,
|
||||
item: _,
|
||||
scope_depth: _,
|
||||
depth: _,
|
||||
} = out
|
||||
{
|
||||
return IndexAssignVector::format(f, lhs, rhs, out);
|
||||
|
@ -388,9 +388,9 @@ impl Binary for Index {
|
|||
out: &Variable,
|
||||
) -> std::fmt::Result {
|
||||
if let Variable::Local {
|
||||
index: _,
|
||||
id: _,
|
||||
item: _,
|
||||
scope_depth: _,
|
||||
depth: _,
|
||||
} = lhs
|
||||
{
|
||||
return IndexVector::format(f, lhs, rhs, out);
|
||||
|
|
|
@ -86,9 +86,14 @@ impl Component for Variable {
|
|||
Variable::GlobalOutputArray(_, e) => *e,
|
||||
Variable::SharedMemory(_, e, _) => *e,
|
||||
Variable::Local {
|
||||
index: _,
|
||||
id: _,
|
||||
item,
|
||||
scope_depth: _,
|
||||
depth: _,
|
||||
} => *item,
|
||||
Variable::Slice {
|
||||
id: _,
|
||||
item,
|
||||
depth: _,
|
||||
} => *item,
|
||||
Variable::ConstantScalar(_, e) => Item::Scalar(*e),
|
||||
Variable::GlobalScalar(_, e, _) => Item::Scalar(*e),
|
||||
|
@ -99,9 +104,9 @@ impl Component for Variable {
|
|||
Variable::LocalInvocationIdZ => Item::Scalar(Elem::U32),
|
||||
Variable::Rank => Item::Scalar(Elem::U32),
|
||||
Variable::LocalScalar {
|
||||
index: _,
|
||||
id: _,
|
||||
elem,
|
||||
scope_depth: _,
|
||||
depth: _,
|
||||
} => Item::Scalar(*elem),
|
||||
Variable::WorkgroupIdX => Item::Scalar(Elem::U32),
|
||||
Variable::WorkgroupIdY => Item::Scalar(Elem::U32),
|
||||
|
@ -117,7 +122,7 @@ impl Component for Variable {
|
|||
Variable::NumWorkgroupsZ => Item::Scalar(Elem::U32),
|
||||
Variable::LocalArray(_, e, _, _) => *e,
|
||||
Variable::WarpSize => Item::Scalar(Elem::U32),
|
||||
Variable::WmmaFragment { index: _, frag } => Item::Scalar(frag.elem),
|
||||
Variable::WmmaFragment { id: _, frag } => Item::Scalar(frag.elem),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -129,16 +134,9 @@ pub enum Variable {
|
|||
GlobalOutputArray(u16, Item),
|
||||
GlobalScalar(u16, Elem, gpu::Elem),
|
||||
ConstantScalar(f64, Elem),
|
||||
Local {
|
||||
index: u16,
|
||||
item: Item,
|
||||
scope_depth: u8,
|
||||
},
|
||||
LocalScalar {
|
||||
index: u16,
|
||||
elem: Elem,
|
||||
scope_depth: u8,
|
||||
},
|
||||
Local { id: u16, item: Item, depth: u8 },
|
||||
Slice { id: u16, item: Item, depth: u8 },
|
||||
LocalScalar { id: u16, elem: Elem, depth: u8 },
|
||||
SharedMemory(u16, Item, u32),
|
||||
LocalArray(u16, Item, u8, u32),
|
||||
Id,
|
||||
|
@ -159,10 +157,7 @@ pub enum Variable {
|
|||
NumWorkgroupsX,
|
||||
NumWorkgroupsY,
|
||||
NumWorkgroupsZ,
|
||||
WmmaFragment {
|
||||
index: u16,
|
||||
frag: Fragment,
|
||||
},
|
||||
WmmaFragment { id: u16, frag: Fragment },
|
||||
}
|
||||
|
||||
impl Display for Variable {
|
||||
|
@ -170,15 +165,18 @@ impl Display for Variable {
|
|||
match self {
|
||||
Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")),
|
||||
Variable::LocalScalar {
|
||||
index,
|
||||
id: index,
|
||||
elem: _,
|
||||
scope_depth,
|
||||
depth: scope_depth,
|
||||
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
|
||||
Variable::Local {
|
||||
index,
|
||||
id: index,
|
||||
item: _,
|
||||
scope_depth,
|
||||
depth: scope_depth,
|
||||
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
|
||||
Variable::Slice { id, item: _, depth } => {
|
||||
f.write_fmt(format_args!("slice_{depth}_{id}"))
|
||||
}
|
||||
Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")),
|
||||
Variable::GlobalScalar(number, _, elem) => {
|
||||
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
|
||||
|
@ -209,7 +207,9 @@ impl Display for Variable {
|
|||
f.write_fmt(format_args!("l_arr_{}_{}", id, depth))
|
||||
}
|
||||
Variable::WarpSize => f.write_str("warpSize"),
|
||||
Variable::WmmaFragment { index, frag: _ } => f.write_fmt(format_args!("frag_{index}")),
|
||||
Variable::WmmaFragment { id: index, frag: _ } => {
|
||||
f.write_fmt(format_args!("frag_{index}"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -220,9 +220,9 @@ impl Variable {
|
|||
Variable::GlobalScalar(_, _, _) => true,
|
||||
Variable::ConstantScalar(_, _) => true,
|
||||
Variable::LocalScalar {
|
||||
index: _,
|
||||
id: _,
|
||||
elem: _,
|
||||
scope_depth: _,
|
||||
depth: _,
|
||||
} => true,
|
||||
Variable::Id => true,
|
||||
Variable::LocalInvocationIndex => true,
|
||||
|
@ -234,9 +234,14 @@ impl Variable {
|
|||
Variable::GlobalOutputArray(_, _) => false,
|
||||
Variable::SharedMemory(_, _, _) => false,
|
||||
Variable::Local {
|
||||
index: _,
|
||||
id: _,
|
||||
item: _,
|
||||
scope_depth: _,
|
||||
depth: _,
|
||||
} => false,
|
||||
Variable::Slice {
|
||||
id: _,
|
||||
item: _,
|
||||
depth: _,
|
||||
} => false,
|
||||
Variable::WorkgroupIdX => true,
|
||||
Variable::WorkgroupIdY => true,
|
||||
|
@ -252,7 +257,7 @@ impl Variable {
|
|||
Variable::NumWorkgroupsZ => true,
|
||||
Variable::LocalArray(_, _, _, _) => false,
|
||||
Variable::WarpSize => true,
|
||||
Variable::WmmaFragment { index: _, frag: _ } => false,
|
||||
Variable::WmmaFragment { id: _, frag: _ } => false,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -16,12 +16,16 @@ pub struct UnaryInstruction {
|
|||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Instruction {
|
||||
ArrayLength {
|
||||
Length {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
num_inputs: usize,
|
||||
num_outputs: usize,
|
||||
},
|
||||
SliceLength {
|
||||
input: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
DeclareVariable {
|
||||
var: Variable,
|
||||
},
|
||||
|
@ -58,6 +62,12 @@ pub enum Instruction {
|
|||
instructions_if: Vec<Self>,
|
||||
instructions_else: Vec<Self>,
|
||||
},
|
||||
Slice {
|
||||
input: Variable,
|
||||
start: Variable,
|
||||
end: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Return,
|
||||
Break,
|
||||
Stride {
|
||||
|
@ -114,7 +124,7 @@ impl Display for Instruction {
|
|||
Instruction::Return => f.write_str("return;"),
|
||||
Instruction::Break => f.write_str("break;"),
|
||||
Instruction::DeclareVariable { var } => match var {
|
||||
Variable::WmmaFragment { index: _, frag } => {
|
||||
Variable::WmmaFragment { id: _, frag } => {
|
||||
f.write_fmt(format_args!("{frag} {var};\n"))
|
||||
}
|
||||
_ => {
|
||||
|
@ -123,6 +133,16 @@ impl Display for Instruction {
|
|||
}
|
||||
},
|
||||
Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
|
||||
Instruction::Slice {
|
||||
input,
|
||||
start,
|
||||
end,
|
||||
out,
|
||||
} => {
|
||||
let item = out.item();
|
||||
f.write_fmt(format_args!("uint {out}_length = {end} - {start};\n"))?;
|
||||
f.write_fmt(format_args!("{item} *{out} = {input} + {start};\n"))
|
||||
}
|
||||
Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
|
||||
Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
|
||||
Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
|
||||
|
@ -225,7 +245,10 @@ for (uint {i} = {start}; {i} < {end}; {i}++) {{
|
|||
Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
|
||||
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
|
||||
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
|
||||
Instruction::ArrayLength {
|
||||
Instruction::SliceLength { input, out } => {
|
||||
f.write_fmt(format_args!("{out} = {input}_length;\n"))
|
||||
}
|
||||
Instruction::Length {
|
||||
input,
|
||||
out,
|
||||
num_inputs,
|
||||
|
|
|
@ -288,7 +288,10 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
|
|||
return false;
|
||||
}
|
||||
|
||||
let input = Variable::ConstantScalar(1.0, desc.dtype.into());
|
||||
let input = Variable::ConstantScalar {
|
||||
value: 1.0,
|
||||
elem: desc.dtype.into(),
|
||||
};
|
||||
let out = self.builder.output(desc, Variable::AbsolutePos);
|
||||
|
||||
self.builder
|
||||
|
@ -301,7 +304,10 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
|
|||
return false;
|
||||
}
|
||||
|
||||
let input = Variable::ConstantScalar(0.0, desc.dtype.into());
|
||||
let input = Variable::ConstantScalar {
|
||||
value: 0.0,
|
||||
elem: desc.dtype.into(),
|
||||
};
|
||||
let out = self.builder.output(desc, Variable::AbsolutePos);
|
||||
|
||||
self.builder
|
||||
|
|
|
@ -56,9 +56,11 @@ impl TraceBuilder {
|
|||
}
|
||||
true => match self.output_to_local.get(&tensor.id) {
|
||||
// Is a local variable.
|
||||
Some(local_index) => {
|
||||
Variable::Local(*local_index, Item::new(elem), self.scope.depth)
|
||||
}
|
||||
Some(local_index) => Variable::Local {
|
||||
id: *local_index,
|
||||
item: Item::new(elem),
|
||||
depth: self.scope.depth,
|
||||
},
|
||||
// Isn't an operation output variable, so must be an existing input.
|
||||
None => self
|
||||
.inputs
|
||||
|
@ -85,7 +87,11 @@ impl TraceBuilder {
|
|||
|
||||
// Output already registered as a local variable.
|
||||
if let Some(index) = self.output_to_local.get(&tensor.id) {
|
||||
return Variable::Local(*index, Item::new(elem), self.scope.depth);
|
||||
return Variable::Local {
|
||||
id: *index,
|
||||
item: Item::new(elem),
|
||||
depth: self.scope.depth,
|
||||
};
|
||||
}
|
||||
|
||||
let variable = self.scope.create_local(Item::new(elem));
|
||||
|
@ -159,11 +165,11 @@ impl TraceBuilder {
|
|||
//
|
||||
// Only local variables can become outputs.
|
||||
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
|
||||
if let Variable::Local(index, _, _) = var {
|
||||
if let Variable::Local { id: id_local, .. } = var {
|
||||
if let Some((id, _)) = self
|
||||
.output_to_local
|
||||
.iter()
|
||||
.find(|(_id, position)| *position == index)
|
||||
.find(|(_tensor_id, position)| *position == id_local)
|
||||
{
|
||||
if !list.contains(id) {
|
||||
list.push(*id);
|
||||
|
@ -201,6 +207,11 @@ impl TraceBuilder {
|
|||
mark(&op.c, &mut local_tensor_ids_input);
|
||||
mark(&op.out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Slice(op) => {
|
||||
mark(&op.input, &mut local_tensor_ids_input);
|
||||
mark(&op.start, &mut local_tensor_ids_input);
|
||||
mark(&op.out, &mut local_tensor_ids_output);
|
||||
}
|
||||
Operator::Max(op) => mark_binary(
|
||||
op,
|
||||
&mut local_tensor_ids_input,
|
||||
|
|
|
@ -68,8 +68,14 @@ impl<R: JitRuntime, EI: JitElement, EO: JitElement> Kernel for CastEagerKernel<R
|
|||
let item_input = EI::cube_elem().into();
|
||||
let item_output = EO::cube_elem().into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray(0, item_input);
|
||||
let output = Variable::GlobalOutputArray(0, item_output);
|
||||
let tensor = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_input,
|
||||
};
|
||||
let output = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: item_output,
|
||||
};
|
||||
|
||||
CastShader { tensor, output }.expand(&mut scope);
|
||||
|
||||
|
|
|
@ -60,8 +60,14 @@ impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
|
|||
let item_input = Item::new(Elem::Bool);
|
||||
let item_output = EO::cube_elem().into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray(0, item_input);
|
||||
let output = Variable::GlobalOutputArray(0, item_output);
|
||||
let tensor = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_input,
|
||||
};
|
||||
let output = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: item_output,
|
||||
};
|
||||
|
||||
BoolCastShader { tensor, output }.expand(&mut scope);
|
||||
|
||||
|
|
|
@ -93,13 +93,34 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
|||
cpa!(scope, kernel_size_0 = shape(weight, 2u32));
|
||||
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||
|
||||
let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
let groups = Variable::GlobalScalar(6, Elem::UInt);
|
||||
let conv_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let conv_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_0 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_1 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let groups = Variable::GlobalScalar {
|
||||
id: 6,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
|
||||
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
@ -294,10 +315,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let weight = Variable::GlobalInputArray(1, item);
|
||||
let bias = Variable::GlobalInputArray(2, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let weight = Variable::GlobalInputArray { id: 1, item };
|
||||
let bias = Variable::GlobalInputArray { id: 2, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -105,16 +105,46 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
|||
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||
cpa!(scope, kernel_size_2 = shape(weight, 4u32));
|
||||
|
||||
let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let conv_stride_2 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let dilation_0 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let dilation_1 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let dilation_2 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(6, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(7, Elem::UInt);
|
||||
let padding_2 = Variable::GlobalScalar(8, Elem::UInt);
|
||||
let groups = Variable::GlobalScalar(9, Elem::UInt);
|
||||
let conv_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let conv_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let conv_stride_2 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_0 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_1 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_2 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 6,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 7,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_2 = Variable::GlobalScalar {
|
||||
id: 8,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let groups = Variable::GlobalScalar {
|
||||
id: 9,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
|
||||
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||
|
@ -362,10 +392,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv3dTransposeEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let weight = Variable::GlobalInputArray(1, item);
|
||||
let bias = Variable::GlobalInputArray(2, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let weight = Variable::GlobalInputArray { id: 1, item };
|
||||
let bias = Variable::GlobalInputArray { id: 2, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -43,7 +43,10 @@ impl FlipComputeShader {
|
|||
cpa!(scope, shape = shape(output, i));
|
||||
cpa!(
|
||||
scope,
|
||||
flip = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
|
||||
flip = cast(Variable::GlobalScalar {
|
||||
id: i as u16,
|
||||
elem: Elem::UInt
|
||||
})
|
||||
);
|
||||
cpa!(scope, flip_bool = flip == 1u32);
|
||||
|
||||
|
@ -70,8 +73,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for FlipEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -27,8 +27,8 @@ struct GatherComputeShader {
|
|||
impl GatherComputeShader {
|
||||
pub fn expand(self, scope: &mut Scope) {
|
||||
match self.tensor {
|
||||
Variable::GlobalInputArray(_, _) => (),
|
||||
Variable::GlobalOutputArray(_, _) => (),
|
||||
Variable::GlobalInputArray { .. } => (),
|
||||
Variable::GlobalOutputArray { .. } => (),
|
||||
_ => panic!("Tensor variable must be an global array."),
|
||||
};
|
||||
|
||||
|
@ -78,10 +78,16 @@ impl<R: JitRuntime, E: JitElement> Kernel for GatherEagerKernel<R, E> {
|
|||
let item_tensor = E::cube_elem().into();
|
||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray(0, item_tensor);
|
||||
let tensor = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_tensor,
|
||||
};
|
||||
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
|
||||
|
||||
let output_array = Variable::GlobalOutputArray(0, item_tensor);
|
||||
let output_array = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: item_tensor,
|
||||
};
|
||||
let output_local = scope.create_local(item_tensor);
|
||||
|
||||
GatherComputeShader {
|
||||
|
|
|
@ -61,8 +61,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for RepeatEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -32,13 +32,13 @@ struct ScatterComputeShader {
|
|||
impl ScatterComputeShader {
|
||||
pub fn expand(self, scope: &mut Scope) {
|
||||
match self.input {
|
||||
Variable::GlobalInputArray(_, _) => (),
|
||||
Variable::GlobalOutputArray(_, _) => (),
|
||||
Variable::GlobalInputArray { .. } => (),
|
||||
Variable::GlobalOutputArray { .. } => (),
|
||||
_ => panic!("Input variable must be an global array."),
|
||||
};
|
||||
match self.value {
|
||||
Variable::GlobalInputArray(_, _) => (),
|
||||
Variable::GlobalOutputArray(_, _) => (),
|
||||
Variable::GlobalInputArray { .. } => (),
|
||||
Variable::GlobalOutputArray { .. } => (),
|
||||
_ => panic!("Value variable must be an global array."),
|
||||
};
|
||||
|
||||
|
@ -136,9 +136,18 @@ impl<R: JitRuntime, E: JitElement> Kernel for ScatterEagerKernel<R, E> {
|
|||
let item_value = E::cube_elem().into();
|
||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||
|
||||
let input_output = Variable::GlobalInputArray(0, item_value);
|
||||
let indices = Variable::GlobalInputArray(1, Elem::Int(IntKind::I32).into());
|
||||
let value = Variable::GlobalInputArray(2, item_value);
|
||||
let input_output = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_value,
|
||||
};
|
||||
let indices = Variable::GlobalInputArray {
|
||||
id: 1,
|
||||
item: Elem::Int(IntKind::I32).into(),
|
||||
};
|
||||
let value = Variable::GlobalInputArray {
|
||||
id: 2,
|
||||
item: item_value,
|
||||
};
|
||||
|
||||
scope.write_global_custom(input_output);
|
||||
|
||||
|
|
|
@ -73,9 +73,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for SelectEagerKernel<R, E> {
|
|||
let item = E::cube_elem().into();
|
||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let indices = Variable::GlobalInputArray(1, item_indices);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let indices = Variable::GlobalInputArray {
|
||||
id: 1,
|
||||
item: item_indices,
|
||||
};
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -134,9 +134,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for SelectAssignEagerKernel<R, E> {
|
|||
let item = E::cube_elem().into();
|
||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray(0, item);
|
||||
let value = Variable::GlobalInputArray(1, item);
|
||||
let indices = Variable::GlobalInputArray(2, item_indices);
|
||||
let tensor = Variable::GlobalInputArray { id: 0, item };
|
||||
let value = Variable::GlobalInputArray { id: 1, item };
|
||||
let indices = Variable::GlobalInputArray {
|
||||
id: 2,
|
||||
item: item_indices,
|
||||
};
|
||||
|
||||
scope.write_global_custom(tensor);
|
||||
|
||||
|
|
|
@ -44,7 +44,10 @@ impl SliceComputeShader {
|
|||
cpa!(scope, shape_output = shape(output, i));
|
||||
cpa!(
|
||||
scope,
|
||||
range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
|
||||
range_start = cast(Variable::GlobalScalar {
|
||||
id: i as u16,
|
||||
elem: Elem::UInt
|
||||
})
|
||||
);
|
||||
|
||||
cpa!(scope, offset_local = id / stride_output);
|
||||
|
@ -66,8 +69,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -47,7 +47,10 @@ impl SliceAssignComputeShader {
|
|||
cpa!(scope, shape_input = shape(input, i));
|
||||
cpa!(
|
||||
scope,
|
||||
range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
|
||||
range_start = cast(Variable::GlobalScalar {
|
||||
id: i as u16,
|
||||
elem: Elem::UInt
|
||||
})
|
||||
);
|
||||
|
||||
cpa!(scope, offset_local = id / stride_value);
|
||||
|
@ -75,8 +78,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceAssignEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let value = Variable::GlobalInputArray(1, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let value = Variable::GlobalInputArray { id: 1, item };
|
||||
|
||||
scope.write_global_custom(input);
|
||||
|
||||
|
|
|
@ -371,8 +371,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBicubicEagerKernel<R, E
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
InterpolateBicubicShader {
|
||||
input,
|
||||
|
|
|
@ -189,8 +189,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBilinearEagerKernel<R,
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
InterpolateBilinearShader { input, output }.expand(&mut scope);
|
||||
|
||||
|
|
|
@ -127,8 +127,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestEagerKernel<R, E
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
InterpolateNearestShader {
|
||||
input,
|
||||
|
|
|
@ -185,8 +185,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestBackwardEagerKer
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let out_grad = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let out_grad = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
InterpolateNearestBackwardShader {
|
||||
out_grad,
|
||||
|
|
|
@ -40,7 +40,10 @@ impl MaskStrategy for MaskFill {
|
|||
}
|
||||
|
||||
fn value_variable(value_item: Item) -> Variable {
|
||||
Variable::GlobalScalar(0, value_item.elem())
|
||||
Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: value_item.elem(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -65,7 +68,10 @@ impl MaskStrategy for MaskWhere {
|
|||
}
|
||||
|
||||
fn value_variable(value_item: Item) -> Variable {
|
||||
Variable::GlobalInputArray(2, value_item)
|
||||
Variable::GlobalInputArray {
|
||||
id: 2,
|
||||
item: value_item,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -102,10 +108,19 @@ impl<M: MaskStrategy, R: JitRuntime, EI: JitElement, EM: JitElement> Kernel
|
|||
let tensor_item = EI::cube_elem().into();
|
||||
let mask_item = EM::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, tensor_item);
|
||||
let mask = Variable::GlobalInputArray(1, mask_item);
|
||||
let input = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: tensor_item,
|
||||
};
|
||||
let mask = Variable::GlobalInputArray {
|
||||
id: 1,
|
||||
item: mask_item,
|
||||
};
|
||||
let value = M::value_variable(tensor_item);
|
||||
let output = Variable::GlobalOutputArray(0, tensor_item);
|
||||
let output = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: tensor_item,
|
||||
};
|
||||
|
||||
MaskShader::<EI, EM, M> {
|
||||
input,
|
||||
|
@ -176,8 +191,14 @@ impl<M: MaskStrategy, R: JitRuntime, EI: JitElement, EM: JitElement> Kernel
|
|||
let tensor_item = EI::cube_elem().into();
|
||||
let mask_item = EM::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, tensor_item);
|
||||
let mask = Variable::GlobalInputArray(1, mask_item);
|
||||
let input = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: tensor_item,
|
||||
};
|
||||
let mask = Variable::GlobalInputArray {
|
||||
id: 1,
|
||||
item: mask_item,
|
||||
};
|
||||
let value = M::value_variable(tensor_item);
|
||||
|
||||
MaskShader::<EI, EM, M> {
|
||||
|
|
|
@ -40,9 +40,9 @@ impl<R: JitRuntime, E: JitElement> Kernel for MatmulTiling2dEagerKernel<R, E> {
|
|||
);
|
||||
let item = elem.into();
|
||||
|
||||
let lhs = Variable::GlobalInputArray(0, item);
|
||||
let rhs = Variable::GlobalInputArray(1, item);
|
||||
let out = Variable::GlobalOutputArray(0, item);
|
||||
let lhs = Variable::GlobalInputArray { id: 0, item };
|
||||
let rhs = Variable::GlobalInputArray { id: 1, item };
|
||||
let out = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(out);
|
||||
|
||||
|
|
|
@ -207,8 +207,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AdaptiveAvgPool2dBackwardEagerKern
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let grad = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let grad = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -193,8 +193,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AdaptivePool2dEagerKernel<R, E> {
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -71,10 +71,22 @@ impl AvgPool2dBackwardComputeShader {
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
let pool_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
|
@ -215,12 +227,30 @@ impl AvgPool2dBackwardComputeShader {
|
|||
output_stride_2: Variable,
|
||||
output_stride_3: Variable,
|
||||
) -> (Variable, Variable, Variable, Variable) {
|
||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
let pool_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_0 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_1 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
|
@ -321,8 +351,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AvgPool2dBackwardEagerKernel<R, E>
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let grad = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let grad = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -21,7 +21,10 @@ impl<E: JitElement> PoolStrategy for MaxPool<E> {
|
|||
|
||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
||||
let max_val = scope.create_local(item);
|
||||
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
|
||||
let max_initial = Variable::ConstantScalar {
|
||||
value: E::minimum_value().to_f64(),
|
||||
elem: item.elem(),
|
||||
};
|
||||
cpa!(scope, max_val = max_initial);
|
||||
max_val
|
||||
}
|
||||
|
@ -67,7 +70,10 @@ impl<E: JitElement> PoolStrategy for MaxPoolWithIndices<E> {
|
|||
|
||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
||||
let max_val = scope.create_local(item);
|
||||
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
|
||||
let max_initial = Variable::ConstantScalar {
|
||||
value: E::minimum_value().to_f64(),
|
||||
elem: item.elem(),
|
||||
};
|
||||
cpa!(scope, max_val = max_initial);
|
||||
let max_index = scope.create_local(Elem::UInt);
|
||||
(max_val, max_index)
|
||||
|
|
|
@ -161,12 +161,30 @@ impl MaxPool2dBackwardComputeShader {
|
|||
output_stride_2: Variable,
|
||||
output_stride_3: Variable,
|
||||
) -> (Variable, Variable, Variable, Variable) {
|
||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
let pool_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_0 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_1 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
|
||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||
|
||||
|
@ -267,9 +285,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for MaxPool2dWithIndicesBackwardEagerK
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let indices = Variable::GlobalInputArray(0, Item::new(Elem::Int(IntKind::I32)));
|
||||
let grad = Variable::GlobalInputArray(1, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let indices = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: Item::new(Elem::Int(IntKind::I32)),
|
||||
};
|
||||
let grad = Variable::GlobalInputArray { id: 1, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
scope.write_global_custom(output);
|
||||
|
||||
|
|
|
@ -65,12 +65,30 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> Pool2dComputeShader<P, R, E>
|
|||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||
|
||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
||||
let pool_stride_0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let pool_stride_1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_0 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let dilation_1 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_0 = Variable::GlobalScalar {
|
||||
id: 4,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let padding_1 = Variable::GlobalScalar {
|
||||
id: 5,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
|
||||
let b = scope.create_local(Elem::UInt);
|
||||
let c = scope.create_local(Elem::UInt);
|
||||
|
@ -177,13 +195,13 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> Kernel for Pool2dEagerKernel
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let input = Variable::GlobalInputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let input = Variable::GlobalInputArray { id: 0, item };
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
let indices = if P::with_indices() {
|
||||
Some(Variable::GlobalOutputArray(
|
||||
1,
|
||||
Item::new(Elem::Int(IntKind::I32)),
|
||||
))
|
||||
Some(Variable::GlobalOutputArray {
|
||||
id: 1,
|
||||
item: Item::new(Elem::Int(IntKind::I32)),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
|
|
@ -66,17 +66,32 @@ impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R,
|
|||
let mut scope = Scope::root();
|
||||
let item = E::cube_elem().into();
|
||||
|
||||
let output = Variable::GlobalOutputArray(0, item);
|
||||
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||
|
||||
let seed0 = Variable::GlobalScalar(0, Elem::UInt);
|
||||
let seed1 = Variable::GlobalScalar(1, Elem::UInt);
|
||||
let seed2 = Variable::GlobalScalar(2, Elem::UInt);
|
||||
let seed3 = Variable::GlobalScalar(3, Elem::UInt);
|
||||
let seed0 = Variable::GlobalScalar {
|
||||
id: 0,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let seed1 = Variable::GlobalScalar {
|
||||
id: 1,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let seed2 = Variable::GlobalScalar {
|
||||
id: 2,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let seed3 = Variable::GlobalScalar {
|
||||
id: 3,
|
||||
elem: Elem::UInt,
|
||||
};
|
||||
let seeds = [seed0, seed1, seed2, seed3];
|
||||
|
||||
let mut args = Vec::<Variable>::new();
|
||||
for i in 0..P::args_length() {
|
||||
args.push(Variable::GlobalScalar(i as u16, item.elem()));
|
||||
args.push(Variable::GlobalScalar {
|
||||
id: i as u16,
|
||||
elem: item.elem(),
|
||||
});
|
||||
}
|
||||
|
||||
PrngShader::<P, E>::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope);
|
||||
|
|
|
@ -16,7 +16,10 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmax {
|
|||
) -> Self::Accumulator {
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let max = scope.create_local(input_item);
|
||||
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
|
||||
let max_initial = Variable::ConstantScalar {
|
||||
value: E::minimum_value().to_f64(),
|
||||
elem: input_item.elem(),
|
||||
};
|
||||
cpa!(scope, max = max_initial);
|
||||
|
||||
(max, index)
|
||||
|
|
|
@ -17,7 +17,10 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmin {
|
|||
) -> Self::Accumulator {
|
||||
let index = scope.create_local(Elem::UInt);
|
||||
let min = scope.create_local(input_item);
|
||||
let min_initial = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
|
||||
let min_initial = Variable::ConstantScalar {
|
||||
value: E::maximum_value().to_f64(),
|
||||
elem: input_item.elem(),
|
||||
};
|
||||
cpa!(scope, min = min_initial);
|
||||
|
||||
(min, index)
|
||||
|
|
|
@ -41,8 +41,14 @@ impl<RD: ReduceDimNaive<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Kern
|
|||
let item_input = EI::cube_elem().into();
|
||||
let item_output = EO::cube_elem().into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray(0, item_input);
|
||||
let output = Variable::GlobalOutputArray(0, item_output);
|
||||
let tensor = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_input,
|
||||
};
|
||||
let output = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: item_output,
|
||||
};
|
||||
|
||||
NaiveReduceDimComputeShader {
|
||||
tensor,
|
||||
|
|
|
@ -18,7 +18,10 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
|
|||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||
|
||||
let max = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
|
||||
let max = Variable::ConstantScalar {
|
||||
value: E::minimum_value().to_f64(),
|
||||
elem: input_item.elem(),
|
||||
};
|
||||
cpa!(scope, value_shared_memory[write_position] = max);
|
||||
(value_shared_memory, index_shared_memory)
|
||||
}
|
||||
|
|
|
@ -19,7 +19,10 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
|
|||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||
|
||||
let min = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
|
||||
let min = Variable::ConstantScalar {
|
||||
value: E::maximum_value().to_f64(),
|
||||
elem: input_item.elem(),
|
||||
};
|
||||
cpa!(scope, value_shared_memory[write_position] = min);
|
||||
(value_shared_memory, index_shared_memory)
|
||||
}
|
||||
|
|
|
@ -51,8 +51,14 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
|
|||
let item_input = EI::cube_elem().into();
|
||||
let item_output = EO::cube_elem().into();
|
||||
|
||||
let tensor = Variable::GlobalInputArray(0, item_input);
|
||||
let output = Variable::GlobalOutputArray(0, item_output);
|
||||
let tensor = Variable::GlobalInputArray {
|
||||
id: 0,
|
||||
item: item_input,
|
||||
};
|
||||
let output = Variable::GlobalOutputArray {
|
||||
id: 0,
|
||||
item: item_output,
|
||||
};
|
||||
|
||||
// Reduce groups are elements that are aligned along the reduce dim
|
||||
SharedReduceDimComputeShader {
|
||||
|
|
|
@ -9,14 +9,24 @@ pub enum Variable {
|
|||
GlobalScalar(u16, Elem, cube::Elem),
|
||||
ConstantScalar(f64, Elem),
|
||||
Local {
|
||||
index: u16,
|
||||
id: u16,
|
||||
item: Item,
|
||||
scope_depth: u8,
|
||||
depth: u8,
|
||||
},
|
||||
Named {
|
||||
name: String,
|
||||
item: Item,
|
||||
is_array: bool,
|
||||
},
|
||||
Slice {
|
||||
id: u16,
|
||||
item: Item,
|
||||
depth: u8,
|
||||
},
|
||||
LocalScalar {
|
||||
index: u16,
|
||||
id: u16,
|
||||
elem: Elem,
|
||||
scope_depth: u8,
|
||||
depth: u8,
|
||||
},
|
||||
SharedMemory(u16, Item, u32),
|
||||
LocalArray(u16, Item, u8, u32),
|
||||
|
@ -71,9 +81,9 @@ impl Variable {
|
|||
Variable::GlobalScalar(_, _, _) => true,
|
||||
Variable::ConstantScalar(_, _) => true,
|
||||
Variable::LocalScalar {
|
||||
index: _,
|
||||
id: _,
|
||||
elem: _,
|
||||
scope_depth: _,
|
||||
depth: _,
|
||||
} => true,
|
||||
Variable::Id => true,
|
||||
Variable::LocalInvocationIndex => true,
|
||||
|
@ -85,11 +95,9 @@ impl Variable {
|
|||
Variable::GlobalOutputArray(_, _) => false,
|
||||
Variable::SharedMemory(_, _, _) => false,
|
||||
Variable::LocalArray(_, _, _, _) => false,
|
||||
Variable::Local {
|
||||
index: _,
|
||||
item: _,
|
||||
scope_depth: _,
|
||||
} => false,
|
||||
Variable::Local { .. } => false,
|
||||
Variable::Named { .. } => false,
|
||||
Variable::Slice { .. } => false,
|
||||
Variable::WorkgroupIdX => true,
|
||||
Variable::WorkgroupIdY => true,
|
||||
Variable::WorkgroupIdZ => true,
|
||||
|
@ -121,11 +129,9 @@ impl Variable {
|
|||
Self::GlobalOutputArray(_, e) => *e,
|
||||
Self::SharedMemory(_, e, _) => *e,
|
||||
Self::LocalArray(_, e, _, _) => *e,
|
||||
Self::Local {
|
||||
index: _,
|
||||
item,
|
||||
scope_depth: _,
|
||||
} => *item,
|
||||
Self::Local { item, .. } => *item,
|
||||
Self::Slice { item, .. } => *item,
|
||||
Self::Named { item, .. } => *item,
|
||||
Self::ConstantScalar(_, e) => Item::Scalar(*e),
|
||||
Self::GlobalScalar(_, e, _) => Item::Scalar(*e),
|
||||
Self::Id => Item::Scalar(Elem::U32),
|
||||
|
@ -134,11 +140,7 @@ impl Variable {
|
|||
Self::LocalInvocationIdY => Item::Scalar(Elem::U32),
|
||||
Self::LocalInvocationIdZ => Item::Scalar(Elem::U32),
|
||||
Self::Rank => Item::Scalar(Elem::U32),
|
||||
Self::LocalScalar {
|
||||
index: _,
|
||||
elem,
|
||||
scope_depth: _,
|
||||
} => Item::Scalar(*elem),
|
||||
Self::LocalScalar { elem, .. } => Item::Scalar(*elem),
|
||||
Self::WorkgroupId => Item::Scalar(Elem::U32),
|
||||
Self::WorkgroupIdX => Item::Scalar(Elem::U32),
|
||||
Self::WorkgroupIdY => Item::Scalar(Elem::U32),
|
||||
|
@ -222,15 +224,21 @@ impl Display for Variable {
|
|||
f.write_fmt(format_args!("input_{number}_global"))
|
||||
}
|
||||
Variable::LocalScalar {
|
||||
index,
|
||||
id: index,
|
||||
elem: _,
|
||||
scope_depth,
|
||||
depth: scope_depth,
|
||||
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
|
||||
Variable::Local {
|
||||
index,
|
||||
id: index,
|
||||
item: _,
|
||||
scope_depth,
|
||||
depth: scope_depth,
|
||||
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
|
||||
Variable::Named { name, .. } => f.write_str(name),
|
||||
Variable::Slice {
|
||||
id: index,
|
||||
item: _,
|
||||
depth: scope_depth,
|
||||
} => f.write_fmt(format_args!("slice_{scope_depth}_{index}")),
|
||||
Variable::GlobalOutputArray(number, _) => {
|
||||
f.write_fmt(format_args!("output_{number}_global"))
|
||||
}
|
||||
|
|
|
@ -127,43 +127,53 @@ impl WgslCompiler {
|
|||
|
||||
fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
|
||||
match value {
|
||||
cube::Variable::GlobalInputArray(index, item) => {
|
||||
wgsl::Variable::GlobalInputArray(index, Self::compile_item(item))
|
||||
cube::Variable::GlobalInputArray { id, item } => {
|
||||
wgsl::Variable::GlobalInputArray(id, Self::compile_item(item))
|
||||
}
|
||||
cube::Variable::GlobalScalar(index, elem) => {
|
||||
wgsl::Variable::GlobalScalar(index, Self::compile_elem(elem), elem)
|
||||
cube::Variable::GlobalScalar { id, elem } => {
|
||||
wgsl::Variable::GlobalScalar(id, Self::compile_elem(elem), elem)
|
||||
}
|
||||
cube::Variable::Local(index, item, scope_depth) => wgsl::Variable::Local {
|
||||
index,
|
||||
cube::Variable::Local { id, item, depth } => wgsl::Variable::Local {
|
||||
id,
|
||||
item: Self::compile_item(item),
|
||||
scope_depth,
|
||||
depth,
|
||||
},
|
||||
cube::Variable::LocalScalar(index, elem, scope_depth) => wgsl::Variable::LocalScalar {
|
||||
index,
|
||||
cube::Variable::Slice { id, item, depth } => wgsl::Variable::Slice {
|
||||
id,
|
||||
item: Self::compile_item(item),
|
||||
depth,
|
||||
},
|
||||
cube::Variable::LocalScalar { id, elem, depth } => wgsl::Variable::LocalScalar {
|
||||
id,
|
||||
elem: Self::compile_elem(elem),
|
||||
scope_depth,
|
||||
depth,
|
||||
},
|
||||
cube::Variable::GlobalOutputArray(index, item) => {
|
||||
wgsl::Variable::GlobalOutputArray(index, Self::compile_item(item))
|
||||
cube::Variable::GlobalOutputArray { id, item } => {
|
||||
wgsl::Variable::GlobalOutputArray(id, Self::compile_item(item))
|
||||
}
|
||||
cube::Variable::ConstantScalar(index, elem) => {
|
||||
wgsl::Variable::ConstantScalar(index, Self::compile_elem(elem))
|
||||
cube::Variable::ConstantScalar { value, elem } => {
|
||||
wgsl::Variable::ConstantScalar(value, Self::compile_elem(elem))
|
||||
}
|
||||
cube::Variable::SharedMemory(index, item, size) => {
|
||||
cube::Variable::SharedMemory { id, item, length } => {
|
||||
let item = Self::compile_item(item);
|
||||
if !self.shared_memories.iter().any(|s| s.index == index) {
|
||||
if !self.shared_memories.iter().any(|s| s.index == id) {
|
||||
self.shared_memories
|
||||
.push(SharedMemory::new(index, item, size));
|
||||
.push(SharedMemory::new(id, item, length));
|
||||
}
|
||||
wgsl::Variable::SharedMemory(index, item, size)
|
||||
wgsl::Variable::SharedMemory(id, item, length)
|
||||
}
|
||||
cube::Variable::LocalArray(index, item, scope_depth, size) => {
|
||||
cube::Variable::LocalArray {
|
||||
id,
|
||||
item,
|
||||
depth,
|
||||
length,
|
||||
} => {
|
||||
let item = Self::compile_item(item);
|
||||
if !self.local_arrays.iter().any(|s| s.index == index) {
|
||||
if !self.local_arrays.iter().any(|s| s.index == id) {
|
||||
self.local_arrays
|
||||
.push(LocalArray::new(index, item, scope_depth, size));
|
||||
.push(LocalArray::new(id, item, depth, length));
|
||||
}
|
||||
wgsl::Variable::LocalArray(index, item, scope_depth, size)
|
||||
wgsl::Variable::LocalArray(id, item, depth, length)
|
||||
}
|
||||
cube::Variable::AbsolutePos => {
|
||||
self.id = true;
|
||||
|
@ -241,7 +251,7 @@ impl WgslCompiler {
|
|||
wgsl::Variable::NumWorkgroups
|
||||
}
|
||||
cube::Variable::SubcubeDim => wgsl::Variable::SubgroupSize,
|
||||
cube::Variable::Matrix(_, _) => {
|
||||
cube::Variable::Matrix { .. } => {
|
||||
panic!("Cooperative matrix-multiply and accumulate not supported.")
|
||||
}
|
||||
}
|
||||
|
@ -252,6 +262,11 @@ impl WgslCompiler {
|
|||
let processing = value.process();
|
||||
|
||||
for var in processing.variables {
|
||||
// We don't declare slices.
|
||||
if let cube::Variable::Slice { .. } = var {
|
||||
continue;
|
||||
}
|
||||
|
||||
instructions.push(wgsl::Instruction::DeclareVariable {
|
||||
var: self.compile_variable(var),
|
||||
});
|
||||
|
@ -427,8 +442,8 @@ impl WgslCompiler {
|
|||
cube::Metadata::Stride { dim, var, out } => {
|
||||
self.stride = true;
|
||||
let position = match var {
|
||||
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
|
||||
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
||||
cube::Variable::GlobalInputArray { id, .. } => id as usize,
|
||||
cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
|
||||
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
|
||||
};
|
||||
wgsl::Instruction::Stride {
|
||||
|
@ -440,8 +455,8 @@ impl WgslCompiler {
|
|||
cube::Metadata::Shape { dim, var, out } => {
|
||||
self.shape = true;
|
||||
let position = match var {
|
||||
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
|
||||
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
||||
cube::Variable::GlobalInputArray { id, .. } => id as usize,
|
||||
cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
|
||||
_ => panic!("Only Input and Output have a shape, got {:?}", var),
|
||||
};
|
||||
wgsl::Instruction::Shape {
|
||||
|
@ -450,7 +465,7 @@ impl WgslCompiler {
|
|||
out: self.compile_variable(out),
|
||||
}
|
||||
}
|
||||
cube::Metadata::ArrayLength { var, out } => wgsl::Instruction::ArrayLength {
|
||||
cube::Metadata::Length { var, out } => wgsl::Instruction::Length {
|
||||
out: self.compile_variable(out),
|
||||
var: self.compile_variable(var),
|
||||
},
|
||||
|
@ -652,6 +667,12 @@ impl WgslCompiler {
|
|||
rhs: self.compile_variable(op.rhs),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
cube::Operator::Slice(op) => wgsl::Instruction::Slice {
|
||||
input: self.compile_variable(op.input),
|
||||
start: self.compile_variable(op.start),
|
||||
end: self.compile_variable(op.end),
|
||||
out: self.compile_variable(op.out),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use super::{
|
||||
base::{Item, Variable},
|
||||
IndexedVariable, Subgroup,
|
||||
Elem, IndexedVariable, Subgroup,
|
||||
};
|
||||
use std::fmt::Display;
|
||||
|
||||
|
@ -167,7 +167,7 @@ pub enum Instruction {
|
|||
position: usize,
|
||||
out: Variable,
|
||||
},
|
||||
ArrayLength {
|
||||
Length {
|
||||
var: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
|
@ -232,6 +232,12 @@ pub enum Instruction {
|
|||
rhs: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Slice {
|
||||
input: Variable,
|
||||
start: Variable,
|
||||
end: Variable,
|
||||
out: Variable,
|
||||
},
|
||||
Subgroup(Subgroup),
|
||||
}
|
||||
|
||||
|
@ -245,6 +251,16 @@ impl Display for Instruction {
|
|||
Instruction::Add { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n"))
|
||||
}
|
||||
Instruction::Slice {
|
||||
input,
|
||||
start,
|
||||
end,
|
||||
out,
|
||||
} => {
|
||||
f.write_fmt(format_args!("let {out}_offset = {start};\n"))?;
|
||||
f.write_fmt(format_args!("let {out}_length = {end} - {start};\n"))?;
|
||||
f.write_fmt(format_args!("let {out}_ptr = &{input};\n"))
|
||||
}
|
||||
Instruction::Fma { a, b, c, out } => {
|
||||
f.write_fmt(format_args!("{out} = fma({a}, {b}, {c});\n"))
|
||||
}
|
||||
|
@ -261,10 +277,22 @@ impl Display for Instruction {
|
|||
f.write_fmt(format_args!("{out} = {lhs} || {rhs};\n"))
|
||||
}
|
||||
Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")),
|
||||
Instruction::Index { lhs, rhs, out } => {
|
||||
let item = out.item();
|
||||
f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n"))
|
||||
}
|
||||
Instruction::Index { lhs, rhs, out } => match lhs {
|
||||
Variable::Slice { item, .. } => {
|
||||
let offset = Variable::Named {
|
||||
name: format!("{lhs}_offset"),
|
||||
item: Item::Scalar(Elem::U32),
|
||||
is_array: false,
|
||||
};
|
||||
let lhs = Variable::Named {
|
||||
name: format!("(*{lhs}_ptr)"),
|
||||
item: *item,
|
||||
is_array: true,
|
||||
};
|
||||
index(f, &lhs, rhs, out, Some(offset))
|
||||
}
|
||||
_ => index(f, lhs, rhs, out, None),
|
||||
},
|
||||
Instruction::Modulo { lhs, rhs, out } => {
|
||||
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
|
||||
}
|
||||
|
@ -406,88 +434,24 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
|||
|
||||
f.write_str("}\n")
|
||||
}
|
||||
Instruction::IndexAssign { lhs, rhs, out } => match lhs.item() {
|
||||
Item::Vec4(elem) => {
|
||||
let lhs0 = lhs.index(0);
|
||||
let lhs1 = lhs.index(1);
|
||||
let lhs2 = lhs.index(2);
|
||||
let lhs3 = lhs.index(3);
|
||||
Instruction::IndexAssign { lhs, rhs, out } => {
|
||||
if let Variable::Slice { item, .. } = out {
|
||||
let offset = Variable::Named {
|
||||
name: format!("{out}_offset"),
|
||||
item: Item::Scalar(Elem::U32),
|
||||
is_array: false,
|
||||
};
|
||||
let out = Variable::Named {
|
||||
name: format!("(*{out}_ptr)"),
|
||||
item: *item,
|
||||
is_array: true,
|
||||
};
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
let rhs2 = rhs.index(2);
|
||||
let rhs3 = rhs.index(3);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
|
||||
index_assign(f, lhs, rhs, &out, Some(offset))
|
||||
} else {
|
||||
index_assign(f, lhs, rhs, out, None)
|
||||
}
|
||||
Item::Vec3(elem) => {
|
||||
let lhs0 = lhs.index(0);
|
||||
let lhs1 = lhs.index(1);
|
||||
let lhs2 = lhs.index(2);
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
let rhs2 = rhs.index(2);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
|
||||
}
|
||||
Item::Vec2(elem) => {
|
||||
let lhs0 = lhs.index(0);
|
||||
let lhs1 = lhs.index(1);
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
|
||||
}
|
||||
Item::Scalar(_elem) => {
|
||||
let is_array = matches!(
|
||||
out,
|
||||
Variable::GlobalInputArray(_, _)
|
||||
| Variable::GlobalOutputArray(_, _)
|
||||
| Variable::SharedMemory(_, _, _)
|
||||
| Variable::LocalArray(_, _, _, _)
|
||||
);
|
||||
|
||||
if !is_array {
|
||||
let elem_out = out.elem();
|
||||
let casting_type = match rhs.item() {
|
||||
Item::Vec4(_) => Item::Vec4(elem_out),
|
||||
Item::Vec3(_) => Item::Vec3(elem_out),
|
||||
Item::Vec2(_) => Item::Vec2(elem_out),
|
||||
Item::Scalar(_) => Item::Scalar(elem_out),
|
||||
};
|
||||
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
||||
} else {
|
||||
let item_rhs = rhs.item();
|
||||
let item_out = out.item();
|
||||
|
||||
let vectorization_factor = item_out.vectorization_factor();
|
||||
if vectorization_factor > item_rhs.vectorization_factor() {
|
||||
let casting_type = item_out.elem();
|
||||
f.write_fmt(format_args!("{out}[{lhs}] = vec{vectorization_factor}("))?;
|
||||
for i in 0..vectorization_factor {
|
||||
let value = rhs.index(i);
|
||||
f.write_fmt(format_args!("{casting_type}({value})"))?;
|
||||
|
||||
if i < vectorization_factor - 1 {
|
||||
f.write_str(",")?;
|
||||
}
|
||||
}
|
||||
f.write_str(");\n")
|
||||
} else {
|
||||
let casting_type = item_out;
|
||||
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
Instruction::If { cond, instructions } => {
|
||||
f.write_fmt(format_args!("if {cond} {{\n"))?;
|
||||
for i in instructions {
|
||||
|
@ -513,9 +477,10 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
|||
Instruction::Return => f.write_str("return;\n"),
|
||||
Instruction::Break => f.write_str("break;\n"),
|
||||
Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"),
|
||||
Instruction::ArrayLength { var, out } => {
|
||||
f.write_fmt(format_args!("{out} = arrayLength(&{var});\n"))
|
||||
}
|
||||
Instruction::Length { var, out } => match var {
|
||||
Variable::Slice { .. } => f.write_fmt(format_args!("{out} = {var}_length;\n")),
|
||||
_ => f.write_fmt(format_args!("{out} = arrayLength(&{var});\n")),
|
||||
},
|
||||
Instruction::Loop { instructions } => {
|
||||
f.write_fmt(format_args!("loop {{\n"))?;
|
||||
for i in instructions {
|
||||
|
@ -623,3 +588,140 @@ fn unroll<
|
|||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct IndexOffset {
|
||||
var: Variable,
|
||||
offset: Option<Variable>,
|
||||
index: usize,
|
||||
}
|
||||
impl IndexOffset {
|
||||
fn new(var: &Variable, offset: &Option<Variable>, index: usize) -> Self {
|
||||
Self {
|
||||
var: var.clone(),
|
||||
offset: offset.clone(),
|
||||
index,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for IndexOffset {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let var = self.var.index(self.index);
|
||||
|
||||
match &self.offset {
|
||||
Some(offset) => {
|
||||
let offset = offset.index(self.index);
|
||||
f.write_fmt(format_args!("{var} + {offset}"))
|
||||
}
|
||||
None => f.write_fmt(format_args!("{var}")),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn index(
|
||||
f: &mut std::fmt::Formatter<'_>,
|
||||
lhs: &Variable,
|
||||
rhs: &Variable,
|
||||
out: &Variable,
|
||||
offset: Option<Variable>,
|
||||
) -> core::fmt::Result {
|
||||
let item = out.item();
|
||||
match offset {
|
||||
Some(offset) => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs} + {offset}]);\n")),
|
||||
None => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n")),
|
||||
}
|
||||
}
|
||||
|
||||
fn index_assign(
|
||||
f: &mut std::fmt::Formatter<'_>,
|
||||
lhs: &Variable,
|
||||
rhs: &Variable,
|
||||
out: &Variable,
|
||||
offset: Option<Variable>,
|
||||
) -> core::fmt::Result {
|
||||
match lhs.item() {
|
||||
Item::Vec4(elem) => {
|
||||
let lhs0 = IndexOffset::new(lhs, &offset, 0);
|
||||
let lhs1 = IndexOffset::new(lhs, &offset, 1);
|
||||
let lhs2 = IndexOffset::new(lhs, &offset, 2);
|
||||
let lhs3 = IndexOffset::new(lhs, &offset, 3);
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
let rhs2 = rhs.index(2);
|
||||
let rhs3 = rhs.index(3);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
|
||||
}
|
||||
Item::Vec3(elem) => {
|
||||
let lhs0 = IndexOffset::new(lhs, &offset, 0);
|
||||
let lhs1 = IndexOffset::new(lhs, &offset, 1);
|
||||
let lhs2 = IndexOffset::new(lhs, &offset, 2);
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
let rhs2 = rhs.index(2);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
|
||||
}
|
||||
Item::Vec2(elem) => {
|
||||
let lhs0 = IndexOffset::new(lhs, &offset, 0);
|
||||
let lhs1 = IndexOffset::new(lhs, &offset, 1);
|
||||
|
||||
let rhs0 = rhs.index(0);
|
||||
let rhs1 = rhs.index(1);
|
||||
|
||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
|
||||
}
|
||||
Item::Scalar(_elem) => {
|
||||
let is_array = match out {
|
||||
Variable::GlobalInputArray(_, _)
|
||||
| Variable::GlobalOutputArray(_, _)
|
||||
| Variable::SharedMemory(_, _, _)
|
||||
| Variable::Slice { .. }
|
||||
| Variable::LocalArray(_, _, _, _) => true,
|
||||
Variable::Named { is_array, .. } => *is_array,
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if !is_array {
|
||||
let elem_out = out.elem();
|
||||
let casting_type = match rhs.item() {
|
||||
Item::Vec4(_) => Item::Vec4(elem_out),
|
||||
Item::Vec3(_) => Item::Vec3(elem_out),
|
||||
Item::Vec2(_) => Item::Vec2(elem_out),
|
||||
Item::Scalar(_) => Item::Scalar(elem_out),
|
||||
};
|
||||
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
||||
} else {
|
||||
let item_rhs = rhs.item();
|
||||
let item_out = out.item();
|
||||
let lhs = IndexOffset::new(lhs, &offset, 0);
|
||||
|
||||
let vectorization_factor = item_out.vectorization_factor();
|
||||
if vectorization_factor > item_rhs.vectorization_factor() {
|
||||
let casting_type = item_out.elem();
|
||||
f.write_fmt(format_args!("{out}[{lhs}] = vec{vectorization_factor}("))?;
|
||||
for i in 0..vectorization_factor {
|
||||
let value = rhs.index(i);
|
||||
f.write_fmt(format_args!("{casting_type}({value})"))?;
|
||||
|
||||
if i < vectorization_factor - 1 {
|
||||
f.write_str(",")?;
|
||||
}
|
||||
}
|
||||
f.write_str(");\n")
|
||||
} else {
|
||||
let casting_type = item_out;
|
||||
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue