mirror of https://github.com/tracel-ai/burn.git
Feat/cube/slice (#2004)
* Refactor Variable types * Sice * Implement slice wgsl * handle lifetime correctly * Add cuda impl * Update cmma * Cleanup * Fix tests * Fix slice signature
This commit is contained in:
parent
c30ffcf6ac
commit
35345de62a
|
@ -423,17 +423,24 @@ impl KernelIntegrator {
|
||||||
} else {
|
} else {
|
||||||
item
|
item
|
||||||
};
|
};
|
||||||
let elem_adapted = bool_item(item);
|
let item_adapted = bool_item(item);
|
||||||
|
|
||||||
self.output_bindings.push(Binding {
|
self.output_bindings.push(Binding {
|
||||||
item: elem_adapted,
|
item: item_adapted,
|
||||||
visibility: Visibility::ReadWrite,
|
visibility: Visibility::ReadWrite,
|
||||||
location: Location::Storage,
|
location: Location::Storage,
|
||||||
size: None,
|
size: None,
|
||||||
});
|
});
|
||||||
self.expansion.scope.write_global(
|
self.expansion.scope.write_global(
|
||||||
Variable::Local(local, item, self.expansion.scope.depth),
|
Variable::Local {
|
||||||
Variable::GlobalOutputArray(index, elem_adapted),
|
id: local,
|
||||||
|
item,
|
||||||
|
depth: self.expansion.scope.depth,
|
||||||
|
},
|
||||||
|
Variable::GlobalOutputArray {
|
||||||
|
id: index,
|
||||||
|
item: item_adapted,
|
||||||
|
},
|
||||||
position,
|
position,
|
||||||
);
|
);
|
||||||
index += 1;
|
index += 1;
|
||||||
|
@ -451,8 +458,15 @@ impl KernelIntegrator {
|
||||||
};
|
};
|
||||||
|
|
||||||
self.expansion.scope.write_global(
|
self.expansion.scope.write_global(
|
||||||
Variable::Local(local, item, self.expansion.scope.depth),
|
Variable::Local {
|
||||||
Variable::GlobalInputArray(input, bool_item(item)),
|
id: local,
|
||||||
|
item,
|
||||||
|
depth: self.expansion.scope.depth,
|
||||||
|
},
|
||||||
|
Variable::GlobalInputArray {
|
||||||
|
id: input,
|
||||||
|
item: bool_item(item),
|
||||||
|
},
|
||||||
position,
|
position,
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,11 +27,11 @@ where
|
||||||
|
|
||||||
if unroll {
|
if unroll {
|
||||||
let start = match start.deref() {
|
let start = match start.deref() {
|
||||||
Variable::ConstantScalar(val, _) => *val as usize,
|
Variable::ConstantScalar { value, .. } => *value as usize,
|
||||||
_ => panic!("Only constant start can be unrolled."),
|
_ => panic!("Only constant start can be unrolled."),
|
||||||
};
|
};
|
||||||
let end = match end.deref() {
|
let end = match end.deref() {
|
||||||
Variable::ConstantScalar(val, _) => *val as usize,
|
Variable::ConstantScalar { value, .. } => *value as usize,
|
||||||
_ => panic!("Only constant end can be unrolled."),
|
_ => panic!("Only constant end can be unrolled."),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -15,14 +15,14 @@
|
||||||
//! 16,
|
//! 16,
|
||||||
//! 16,
|
//! 16,
|
||||||
//! 16,
|
//! 16,
|
||||||
//! cmma::MatrixLayout::ColMajor,
|
//! cmma::MatrixLayout::RowMajor,
|
||||||
//! );
|
//! );
|
||||||
//! let b = cmma::Matrix::<F16>::new(
|
//! let b = cmma::Matrix::<F16>::new(
|
||||||
//! cmma::MatrixIdent::B,
|
//! cmma::MatrixIdent::B,
|
||||||
//! 16,
|
//! 16,
|
||||||
//! 16,
|
//! 16,
|
||||||
//! 16,
|
//! 16,
|
||||||
//! cmma::MatrixLayout::RowMajor,
|
//! cmma::MatrixLayout::ColMajor,
|
||||||
//! );
|
//! );
|
||||||
//! let c = cmma::Matrix::<F32>::new(
|
//! let c = cmma::Matrix::<F32>::new(
|
||||||
//! cmma::MatrixIdent::Accumulator,
|
//! cmma::MatrixIdent::Accumulator,
|
||||||
|
@ -32,12 +32,17 @@
|
||||||
//! cmma::MatrixLayout::Undefined,
|
//! cmma::MatrixLayout::Undefined,
|
||||||
//! );
|
//! );
|
||||||
//! cmma::fill::<F32>(&c, F32::new(0.0));
|
//! cmma::fill::<F32>(&c, F32::new(0.0));
|
||||||
//! cmma::load::<F16>(&a, lhs, UInt::new(16));
|
//! cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
|
||||||
//! cmma::load::<F16>(&b, rhs, UInt::new(16));
|
//! cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
|
||||||
//!
|
//!
|
||||||
//! cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
|
//! cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
|
||||||
//!
|
//!
|
||||||
//! cmma::store::<F32>(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor);
|
//! cmma::store::<F32>(
|
||||||
|
//! out.as_slice_mut(),
|
||||||
|
//! &c,
|
||||||
|
//! UInt::new(16),
|
||||||
|
//! cmma::MatrixLayout::RowMajor,
|
||||||
|
//! );
|
||||||
//! }
|
//! }
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
|
@ -49,7 +54,8 @@ use crate::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
Array, CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, UInt,
|
CubeContext, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, Slice, SliceMut,
|
||||||
|
UInt,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub use ir::{MatrixIdent, MatrixLayout};
|
pub use ir::{MatrixIdent, MatrixLayout};
|
||||||
|
@ -137,7 +143,7 @@ pub fn fill_expand<C: CubeType>(
|
||||||
|
|
||||||
/// Load the matrix with the provided array using the stride.
|
/// Load the matrix with the provided array using the stride.
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Array<C>, stride: UInt) {
|
pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
|
||||||
unexpanded!()
|
unexpanded!()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -146,7 +152,7 @@ pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Array<C>, stride: UInt) {
|
||||||
pub fn load_expand<C: CubeType>(
|
pub fn load_expand<C: CubeType>(
|
||||||
context: &mut CubeContext,
|
context: &mut CubeContext,
|
||||||
mat: MatrixExpand,
|
mat: MatrixExpand,
|
||||||
value: ExpandElementTyped<Array<C>>,
|
value: ExpandElementTyped<Slice<'static, C>>,
|
||||||
stride: ExpandElement,
|
stride: ExpandElement,
|
||||||
) {
|
) {
|
||||||
context.register(Operation::CoopMma(ir::CoopMma::Load {
|
context.register(Operation::CoopMma(ir::CoopMma::Load {
|
||||||
|
@ -159,7 +165,7 @@ pub fn load_expand<C: CubeType>(
|
||||||
/// Store the matrix in the given array following the given stride and layout.
|
/// Store the matrix in the given array following the given stride and layout.
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
pub fn store<C: CubePrimitive>(
|
pub fn store<C: CubePrimitive>(
|
||||||
output: &Array<C>,
|
output: &mut SliceMut<'_, C>,
|
||||||
mat: &Matrix<C>,
|
mat: &Matrix<C>,
|
||||||
stride: UInt,
|
stride: UInt,
|
||||||
layout: MatrixLayout,
|
layout: MatrixLayout,
|
||||||
|
@ -171,7 +177,7 @@ pub fn store<C: CubePrimitive>(
|
||||||
#[allow(unused_variables)]
|
#[allow(unused_variables)]
|
||||||
pub fn store_expand<C: CubePrimitive>(
|
pub fn store_expand<C: CubePrimitive>(
|
||||||
context: &mut CubeContext,
|
context: &mut CubeContext,
|
||||||
output: ExpandElementTyped<Array<C>>,
|
output: ExpandElementTyped<SliceMut<'static, C>>,
|
||||||
mat: MatrixExpand,
|
mat: MatrixExpand,
|
||||||
stride: ExpandElement,
|
stride: ExpandElement,
|
||||||
layout: MatrixLayout,
|
layout: MatrixLayout,
|
||||||
|
|
|
@ -4,8 +4,6 @@ use alloc::rc::Rc;
|
||||||
use core::cell::RefCell;
|
use core::cell::RefCell;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use super::{CubePrimitive, SharedMemoryExpand};
|
|
||||||
|
|
||||||
#[derive(Default, Clone)]
|
#[derive(Default, Clone)]
|
||||||
pub struct VariablePool {
|
pub struct VariablePool {
|
||||||
map: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
|
map: Rc<RefCell<HashMap<Item, Vec<ExpandElement>>>>,
|
||||||
|
@ -114,14 +112,14 @@ impl CubeContext {
|
||||||
ExpandElement::Plain(variable)
|
ExpandElement::Plain(variable)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_shared<T: CubePrimitive>(
|
/// Create a new slice element.
|
||||||
&mut self,
|
pub fn create_slice(&mut self, item: Item) -> ExpandElement {
|
||||||
item: Item,
|
let variable = self.scope.borrow_mut().create_slice(item);
|
||||||
size: u32,
|
ExpandElement::Plain(variable)
|
||||||
) -> SharedMemoryExpand<T> {
|
}
|
||||||
SharedMemoryExpand {
|
|
||||||
val: ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size)),
|
pub fn create_shared(&mut self, item: Item, size: u32) -> ExpandElement {
|
||||||
}
|
ExpandElement::Plain(self.root.borrow_mut().create_shared(item, size))
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
|
pub fn create_local_array(&mut self, item: Item, size: u32) -> ExpandElement {
|
||||||
|
@ -129,19 +127,19 @@ impl CubeContext {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Obtain the index-th input
|
/// Obtain the index-th input
|
||||||
pub fn input(&mut self, index: u16, item: Item) -> ExpandElement {
|
pub fn input(&mut self, id: u16, item: Item) -> ExpandElement {
|
||||||
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray(index, item))
|
ExpandElement::Plain(crate::ir::Variable::GlobalInputArray { id, item })
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Obtain the index-th output
|
/// Obtain the index-th output
|
||||||
pub fn output(&mut self, index: u16, item: Item) -> ExpandElement {
|
pub fn output(&mut self, id: u16, item: Item) -> ExpandElement {
|
||||||
let var = crate::ir::Variable::GlobalOutputArray(index, item);
|
let var = crate::ir::Variable::GlobalOutputArray { id, item };
|
||||||
self.scope.borrow_mut().write_global_custom(var);
|
self.scope.borrow_mut().write_global_custom(var);
|
||||||
ExpandElement::Plain(var)
|
ExpandElement::Plain(var)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Obtain the index-th scalar
|
/// Obtain the index-th scalar
|
||||||
pub fn scalar(&self, index: u16, elem: Elem) -> ExpandElement {
|
pub fn scalar(&self, id: u16, elem: Elem) -> ExpandElement {
|
||||||
ExpandElement::Plain(crate::ir::Variable::GlobalScalar(index, elem))
|
ExpandElement::Plain(crate::ir::Variable::GlobalScalar { id, elem })
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
|
||||||
) -> <Self as CubeType>::ExpandType {
|
) -> <Self as CubeType>::ExpandType {
|
||||||
let size = size.value();
|
let size = size.value();
|
||||||
let size = match size {
|
let size = match size {
|
||||||
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
|
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||||
_ => panic!("Array need constant initialization value"),
|
_ => panic!("Array need constant initialization value"),
|
||||||
};
|
};
|
||||||
context
|
context
|
||||||
|
@ -55,7 +55,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
|
||||||
) -> <Self as CubeType>::ExpandType {
|
) -> <Self as CubeType>::ExpandType {
|
||||||
let size = size.value();
|
let size = size.value();
|
||||||
let size = match size {
|
let size = match size {
|
||||||
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
|
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||||
_ => panic!("Shared memory need constant initialization value"),
|
_ => panic!("Shared memory need constant initialization value"),
|
||||||
};
|
};
|
||||||
context
|
context
|
||||||
|
|
|
@ -160,7 +160,7 @@ impl ExpandElement {
|
||||||
pub fn can_mut(&self) -> bool {
|
pub fn can_mut(&self) -> bool {
|
||||||
match self {
|
match self {
|
||||||
ExpandElement::Managed(var) => {
|
ExpandElement::Managed(var) => {
|
||||||
if let Variable::Local(_, _, _) = var.as_ref() {
|
if let Variable::Local { .. } = var.as_ref() {
|
||||||
Rc::strong_count(var) <= 2
|
Rc::strong_count(var) <= 2
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
|
@ -201,10 +201,10 @@ impl Init for ExpandElement {
|
||||||
let mut init = |elem: Self| init_expand(context, elem, Operator::Assign);
|
let mut init = |elem: Self| init_expand(context, elem, Operator::Assign);
|
||||||
|
|
||||||
match *self {
|
match *self {
|
||||||
Variable::GlobalScalar(_, _) => init(self),
|
Variable::GlobalScalar { .. } => init(self),
|
||||||
Variable::LocalScalar(_, _, _) => init(self),
|
Variable::LocalScalar { .. } => init(self),
|
||||||
Variable::ConstantScalar(_, _) => init(self),
|
Variable::ConstantScalar { .. } => init(self),
|
||||||
Variable::Local(_, _, _) => init(self),
|
Variable::Local { .. } => init(self),
|
||||||
// Constant should be initialized since the new variable can be mutated afterward.
|
// Constant should be initialized since the new variable can be mutated afterward.
|
||||||
// And it is assumed those values are cloned.
|
// And it is assumed those values are cloned.
|
||||||
Variable::Rank
|
Variable::Rank
|
||||||
|
@ -230,11 +230,12 @@ impl Init for ExpandElement {
|
||||||
| Variable::AbsolutePosY
|
| Variable::AbsolutePosY
|
||||||
| Variable::AbsolutePosZ => init(self),
|
| Variable::AbsolutePosZ => init(self),
|
||||||
// Array types can't be copied, so we should simply return the same variable.
|
// Array types can't be copied, so we should simply return the same variable.
|
||||||
Variable::SharedMemory(_, _, _)
|
Variable::SharedMemory { .. }
|
||||||
| Variable::GlobalInputArray(_, _)
|
| Variable::GlobalInputArray { .. }
|
||||||
| Variable::GlobalOutputArray(_, _)
|
| Variable::GlobalOutputArray { .. }
|
||||||
| Variable::LocalArray(_, _, _, _)
|
| Variable::LocalArray { .. }
|
||||||
| Variable::Matrix(_, _) => self,
|
| Variable::Slice { .. }
|
||||||
|
| Variable::Matrix { .. } => self,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -41,9 +41,9 @@ impl_into_expand_element!(i64);
|
||||||
/// Useful for Comptime
|
/// Useful for Comptime
|
||||||
impl From<UInt> for ExpandElement {
|
impl From<UInt> for ExpandElement {
|
||||||
fn from(value: UInt) -> Self {
|
fn from(value: UInt) -> Self {
|
||||||
ExpandElement::Plain(crate::ir::Variable::ConstantScalar(
|
ExpandElement::Plain(crate::ir::Variable::ConstantScalar {
|
||||||
value.val as f64,
|
value: value.val as f64,
|
||||||
UInt::as_elem(),
|
elem: UInt::as_elem(),
|
||||||
))
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,7 +75,10 @@ macro_rules! impl_float {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_expand(_context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType {
|
fn new_expand(_context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType {
|
||||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
let new_var = Variable::ConstantScalar {
|
||||||
|
value: val as f64,
|
||||||
|
elem: Self::as_elem(),
|
||||||
|
};
|
||||||
ExpandElement::Plain(new_var)
|
ExpandElement::Plain(new_var)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,10 @@ macro_rules! impl_int {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn new_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
fn new_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
||||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
let new_var = Variable::ConstantScalar {
|
||||||
|
value: val as f64,
|
||||||
|
elem: Self::as_elem(),
|
||||||
|
};
|
||||||
ExpandElement::Plain(new_var)
|
ExpandElement::Plain(new_var)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,6 +7,7 @@ mod float;
|
||||||
mod int;
|
mod int;
|
||||||
mod numeric;
|
mod numeric;
|
||||||
mod shared_memory;
|
mod shared_memory;
|
||||||
|
mod slice;
|
||||||
mod tensor;
|
mod tensor;
|
||||||
mod uint;
|
mod uint;
|
||||||
mod vectorized;
|
mod vectorized;
|
||||||
|
@ -19,6 +20,7 @@ pub use float::*;
|
||||||
pub use int::*;
|
pub use int::*;
|
||||||
pub use numeric::*;
|
pub use numeric::*;
|
||||||
pub use shared_memory::*;
|
pub use shared_memory::*;
|
||||||
|
pub use slice::*;
|
||||||
pub use tensor::*;
|
pub use tensor::*;
|
||||||
pub use uint::*;
|
pub use uint::*;
|
||||||
pub use vectorized::*;
|
pub use vectorized::*;
|
||||||
|
|
|
@ -47,7 +47,10 @@ pub trait Numeric:
|
||||||
|
|
||||||
/// Expand version of from_int
|
/// Expand version of from_int
|
||||||
fn from_int_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
fn from_int_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
||||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
let new_var = Variable::ConstantScalar {
|
||||||
|
value: val as f64,
|
||||||
|
elem: Self::as_elem(),
|
||||||
|
};
|
||||||
ExpandElement::Plain(new_var)
|
ExpandElement::Plain(new_var)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -5,32 +5,21 @@ use crate::{
|
||||||
ir::Item,
|
ir::Item,
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::{ExpandElement, Init, UInt};
|
use super::{ExpandElementTyped, Init, UInt};
|
||||||
|
|
||||||
#[derive(Clone, Copy)]
|
#[derive(Clone, Copy)]
|
||||||
pub struct SharedMemory<T: CubeType> {
|
pub struct SharedMemory<T: CubeType> {
|
||||||
_val: PhantomData<T>,
|
_val: PhantomData<T>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
impl<T: CubePrimitive> Init for ExpandElementTyped<SharedMemory<T>> {
|
||||||
pub struct SharedMemoryExpand<T: CubePrimitive> {
|
|
||||||
pub val: <T as CubeType>::ExpandType,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: CubePrimitive> From<SharedMemoryExpand<T>> for ExpandElement {
|
|
||||||
fn from(shared_memory_expand: SharedMemoryExpand<T>) -> Self {
|
|
||||||
shared_memory_expand.val
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<T: CubePrimitive> Init for SharedMemoryExpand<T> {
|
|
||||||
fn init(self, _context: &mut CubeContext) -> Self {
|
fn init(self, _context: &mut CubeContext) -> Self {
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: CubePrimitive> CubeType for SharedMemory<T> {
|
impl<T: CubePrimitive> CubeType for SharedMemory<T> {
|
||||||
type ExpandType = SharedMemoryExpand<T>;
|
type ExpandType = ExpandElementTyped<SharedMemory<T>>;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
||||||
|
@ -44,10 +33,11 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
||||||
) -> <Self as CubeType>::ExpandType {
|
) -> <Self as CubeType>::ExpandType {
|
||||||
let size = size.value();
|
let size = size.value();
|
||||||
let size = match size {
|
let size = match size {
|
||||||
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
|
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||||
_ => panic!("Shared memory need constant initialization value"),
|
_ => panic!("Shared memory need constant initialization value"),
|
||||||
};
|
};
|
||||||
context.create_shared(Item::new(T::as_elem()), size)
|
let var = context.create_shared(Item::new(T::as_elem()), size);
|
||||||
|
ExpandElementTyped::new(var)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
|
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
|
||||||
|
@ -61,12 +51,13 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
||||||
) -> <Self as CubeType>::ExpandType {
|
) -> <Self as CubeType>::ExpandType {
|
||||||
let size = size.value();
|
let size = size.value();
|
||||||
let size = match size {
|
let size = match size {
|
||||||
crate::ir::Variable::ConstantScalar(val, _) => val as u32,
|
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||||
_ => panic!("Shared memory need constant initialization value"),
|
_ => panic!("Shared memory need constant initialization value"),
|
||||||
};
|
};
|
||||||
context.create_shared(
|
let var = context.create_shared(
|
||||||
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
|
Item::vectorized(T::as_elem(), vectorization_factor.val as u8),
|
||||||
size,
|
size,
|
||||||
)
|
);
|
||||||
|
ExpandElementTyped::new(var)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,285 @@
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
|
use super::{
|
||||||
|
Array, CubePrimitive, CubeType, ExpandElement, ExpandElementTyped, Init, SharedMemory, Tensor,
|
||||||
|
UInt,
|
||||||
|
};
|
||||||
|
use crate::{
|
||||||
|
frontend::indexation::Index,
|
||||||
|
ir::{self, Operator},
|
||||||
|
prelude::CubeContext,
|
||||||
|
unexpanded,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// A read-only contiguous list of elements
|
||||||
|
pub struct Slice<'a, E> {
|
||||||
|
_e: PhantomData<E>,
|
||||||
|
_l: &'a (),
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A read-write contiguous list of elements.
|
||||||
|
pub struct SliceMut<'a, E> {
|
||||||
|
_e: PhantomData<E>,
|
||||||
|
_l: &'a mut (),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, E> Slice<'a, E> {
|
||||||
|
/// Get the length of the slice.
|
||||||
|
pub fn len(&self) -> UInt {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, E> SliceMut<'a, E> {
|
||||||
|
/// Get the length of the slice.
|
||||||
|
pub fn len(&self) -> UInt {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, E: CubeType> CubeType for Slice<'a, E> {
|
||||||
|
type ExpandType = ExpandElementTyped<Slice<'static, E>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, C: CubeType> Init for ExpandElementTyped<Slice<'a, C>> {
|
||||||
|
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
|
||||||
|
// The type can't be deeply cloned/copied.
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, E: CubeType> CubeType for SliceMut<'a, E> {
|
||||||
|
type ExpandType = ExpandElementTyped<SliceMut<'static, E>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, C: CubeType> Init for ExpandElementTyped<SliceMut<'a, C>> {
|
||||||
|
fn init(self, _context: &mut crate::prelude::CubeContext) -> Self {
|
||||||
|
// The type can't be deeply cloned/copied.
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait SliceOperator<E>: CubeType<ExpandType = Self::Expand> {
|
||||||
|
type Expand: SliceOperatorExpand<E>;
|
||||||
|
|
||||||
|
/// Return a read-only view of all elements comprise between the start and end index.
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn slice<Start: Index, End: Index>(&self, start: Start, end: End) -> &'_ Slice<'_, E> {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
/// Expand function of [SliceOperator::slice].
|
||||||
|
fn slice_expand<Start: Index, End: Index>(
|
||||||
|
context: &mut CubeContext,
|
||||||
|
expand: Self::Expand,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> ExpandElementTyped<Slice<'static, E>> {
|
||||||
|
expand.slice_expand(context, start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a read-write view of all elements comprise between the start and end index.
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn slice_mut<Start: Index, End: Index>(
|
||||||
|
&mut self,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> &'_ mut SliceMut<'_, E> {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expand function of [SliceOperator::slice_mut].
|
||||||
|
fn slice_mut_expand<Start: Index, End: Index>(
|
||||||
|
context: &mut CubeContext,
|
||||||
|
expand: Self::Expand,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||||
|
expand.slice_mut_expand(context, start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Return a read-write view of all elements comprise between the start and end index.
|
||||||
|
///
|
||||||
|
/// # Warning
|
||||||
|
///
|
||||||
|
/// Ignore the multiple borrow rule.
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn slice_mut_unsafe<Start: Index, End: Index>(
|
||||||
|
&self,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> SliceMut<'static, E> {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expand function of [SliceOperator::slice_mut_unsafe].
|
||||||
|
fn slice_mut_unsafe_expand<Start: Index, End: Index>(
|
||||||
|
context: &mut CubeContext,
|
||||||
|
expand: Self::Expand,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||||
|
expand.slice_mut_unsafe_expand(context, start, end)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reinterprete the current type as a read-only slice.
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn as_slice(&self) -> &'_ Slice<'_, E> {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expand function of [SliceOperator::as_slice].
|
||||||
|
fn as_slice_expand(
|
||||||
|
context: &mut CubeContext,
|
||||||
|
expand: Self::Expand,
|
||||||
|
) -> ExpandElementTyped<Slice<'static, E>> {
|
||||||
|
expand.as_slice_expand(context)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reinterprete the current type as a read-write slice.
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn as_slice_mut(&mut self) -> &'_ mut SliceMut<'_, E> {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expand function of [SliceOperator::as_slice_mut].
|
||||||
|
fn as_slice_mut_expand(
|
||||||
|
context: &mut CubeContext,
|
||||||
|
expand: Self::Expand,
|
||||||
|
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||||
|
expand.as_slice_mut_expand(context)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Reinterprete the current type as a read-write slice.
|
||||||
|
///
|
||||||
|
/// # Warning
|
||||||
|
///
|
||||||
|
/// Ignore the multiple borrow rule.
|
||||||
|
#[allow(unused_variables)]
|
||||||
|
fn as_slice_mut_unsafe(&self) -> SliceMut<'static, E> {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Expand function of [SliceOperator::as_slice_mut_unsafe].
|
||||||
|
fn as_slice_mut_unsafe_expand(
|
||||||
|
context: &mut CubeContext,
|
||||||
|
expand: Self::Expand,
|
||||||
|
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||||
|
expand.as_slice_mut_unsafe_expand(context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub trait SliceOperatorExpand<E>: Into<ExpandElement> + Clone {
|
||||||
|
fn slice_base<Start: Index, End: Index>(
|
||||||
|
&self,
|
||||||
|
context: &mut CubeContext,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> ExpandElement;
|
||||||
|
|
||||||
|
fn slice_expand<Start: Index, End: Index>(
|
||||||
|
&self,
|
||||||
|
context: &mut CubeContext,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> ExpandElementTyped<Slice<'static, E>> {
|
||||||
|
ExpandElementTyped::new(self.slice_base(context, start, end))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn slice_mut_expand<Start: Index, End: Index>(
|
||||||
|
&self,
|
||||||
|
context: &mut CubeContext,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||||
|
ExpandElementTyped::new(self.slice_base(context, start, end))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn slice_mut_unsafe_expand<Start: Index, End: Index>(
|
||||||
|
&self,
|
||||||
|
context: &mut CubeContext,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||||
|
ExpandElementTyped::new(self.slice_base(context, start, end))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_slice_expand(&self, _context: &mut CubeContext) -> ExpandElementTyped<Slice<'static, E>> {
|
||||||
|
let expand = self.clone().into();
|
||||||
|
ExpandElementTyped::new(expand)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_slice_mut_unsafe_expand(
|
||||||
|
&self,
|
||||||
|
context: &mut CubeContext,
|
||||||
|
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||||
|
self.as_slice_mut_expand(context)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn as_slice_mut_expand(
|
||||||
|
&self,
|
||||||
|
_context: &mut CubeContext,
|
||||||
|
) -> ExpandElementTyped<SliceMut<'static, E>> {
|
||||||
|
let expand = self.clone().into();
|
||||||
|
ExpandElementTyped::new(expand)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! slice_op {
|
||||||
|
($type:ident) => {
|
||||||
|
impl<E: CubePrimitive> SliceOperator<E> for $type<E> {
|
||||||
|
type Expand = ExpandElementTyped<$type<E>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<E>> {
|
||||||
|
fn slice_base<Start: Index, End: Index>(
|
||||||
|
&self,
|
||||||
|
context: &mut CubeContext,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> ExpandElement {
|
||||||
|
slice_expand(context, self.clone(), start, end)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
(slice $type:ident) => {
|
||||||
|
impl<'a, E: CubePrimitive> SliceOperator<E> for $type<'a, E> {
|
||||||
|
type Expand = ExpandElementTyped<$type<'static, E>>;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, E: CubePrimitive> SliceOperatorExpand<E> for ExpandElementTyped<$type<'a, E>> {
|
||||||
|
fn slice_base<Start: Index, End: Index>(
|
||||||
|
&self,
|
||||||
|
context: &mut CubeContext,
|
||||||
|
start: Start,
|
||||||
|
end: End,
|
||||||
|
) -> ExpandElement {
|
||||||
|
slice_expand(context, self.clone(), start, end)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
slice_op!(Array);
|
||||||
|
slice_op!(Tensor);
|
||||||
|
slice_op!(SharedMemory);
|
||||||
|
slice_op!(slice Slice);
|
||||||
|
slice_op!(slice SliceMut);
|
||||||
|
|
||||||
|
pub fn slice_expand<I: Into<ExpandElement>, S1: Index, S2: Index>(
|
||||||
|
context: &mut CubeContext,
|
||||||
|
input: I,
|
||||||
|
start: S1,
|
||||||
|
end: S2, // Todo use it to get the length.
|
||||||
|
) -> ExpandElement {
|
||||||
|
let input = input.into();
|
||||||
|
let out = context.create_slice(input.item());
|
||||||
|
|
||||||
|
context.register(Operator::Slice(ir::SliceOperator {
|
||||||
|
input: *input,
|
||||||
|
start: start.value(),
|
||||||
|
end: end.value(),
|
||||||
|
out: *out,
|
||||||
|
}));
|
||||||
|
|
||||||
|
out
|
||||||
|
}
|
|
@ -202,7 +202,7 @@ impl<T> ExpandElementTyped<T> {
|
||||||
// Expanded version of len
|
// Expanded version of len
|
||||||
pub fn len_expand(self, context: &mut CubeContext) -> ExpandElement {
|
pub fn len_expand(self, context: &mut CubeContext) -> ExpandElement {
|
||||||
let out = context.create_local(Item::new(Elem::UInt));
|
let out = context.create_local(Item::new(Elem::UInt));
|
||||||
context.register(Metadata::ArrayLength {
|
context.register(Metadata::Length {
|
||||||
var: self.expand.into(),
|
var: self.expand.into(),
|
||||||
out: out.clone().into(),
|
out: out.clone().into(),
|
||||||
});
|
});
|
||||||
|
|
|
@ -49,7 +49,10 @@ impl UInt {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_expand(_context: &mut CubeContext, val: u32) -> <Self as CubeType>::ExpandType {
|
pub fn new_expand(_context: &mut CubeContext, val: u32) -> <Self as CubeType>::ExpandType {
|
||||||
let new_var = Variable::ConstantScalar(val as f64, Self::as_elem());
|
let new_var = Variable::ConstantScalar {
|
||||||
|
value: val as f64,
|
||||||
|
elem: Self::as_elem(),
|
||||||
|
};
|
||||||
ExpandElement::Plain(new_var)
|
ExpandElement::Plain(new_var)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -7,31 +7,46 @@ pub trait Index {
|
||||||
|
|
||||||
impl Index for Comptime<u32> {
|
impl Index for Comptime<u32> {
|
||||||
fn value(self) -> Variable {
|
fn value(self) -> Variable {
|
||||||
Variable::ConstantScalar(self.inner as f64, Elem::UInt)
|
Variable::ConstantScalar {
|
||||||
|
value: self.inner as f64,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Index for Comptime<i32> {
|
impl Index for Comptime<i32> {
|
||||||
fn value(self) -> Variable {
|
fn value(self) -> Variable {
|
||||||
Variable::ConstantScalar(self.inner as f64, Elem::UInt)
|
Variable::ConstantScalar {
|
||||||
|
value: self.inner as f64,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Index for i32 {
|
impl Index for i32 {
|
||||||
fn value(self) -> Variable {
|
fn value(self) -> Variable {
|
||||||
Variable::ConstantScalar(self as f64, Elem::UInt)
|
Variable::ConstantScalar {
|
||||||
|
value: self as f64,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Index for u32 {
|
impl Index for u32 {
|
||||||
fn value(self) -> Variable {
|
fn value(self) -> Variable {
|
||||||
Variable::ConstantScalar(self as f64, Elem::UInt)
|
Variable::ConstantScalar {
|
||||||
|
value: self as f64,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Index for UInt {
|
impl Index for UInt {
|
||||||
fn value(self) -> Variable {
|
fn value(self) -> Variable {
|
||||||
Variable::ConstantScalar(self.val as f64, Elem::UInt)
|
Variable::ConstantScalar {
|
||||||
|
value: self.val as f64,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ pub mod assign {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod index_assign {
|
pub mod index_assign {
|
||||||
use crate::{frontend::CubeType, unexpanded};
|
use crate::{frontend::CubeType, prelude::SliceMut, unexpanded};
|
||||||
|
|
||||||
use self::ir::{BinaryOperator, Operator, Variable};
|
use self::ir::{BinaryOperator, Operator, Variable};
|
||||||
|
|
||||||
|
@ -34,7 +34,10 @@ pub mod index_assign {
|
||||||
let array = array.into();
|
let array = array.into();
|
||||||
let index: Variable = *index.into();
|
let index: Variable = *index.into();
|
||||||
let index = match index {
|
let index = match index {
|
||||||
Variable::ConstantScalar(val, _) => Variable::ConstantScalar(val, ir::Elem::UInt),
|
Variable::ConstantScalar { value, .. } => Variable::ConstantScalar {
|
||||||
|
value,
|
||||||
|
elem: ir::Elem::UInt,
|
||||||
|
},
|
||||||
_ => index,
|
_ => index,
|
||||||
};
|
};
|
||||||
context.register(Operator::IndexAssign(BinaryOperator {
|
context.register(Operator::IndexAssign(BinaryOperator {
|
||||||
|
@ -58,6 +61,12 @@ pub mod index_assign {
|
||||||
impl_index!(Array);
|
impl_index!(Array);
|
||||||
impl_index!(Tensor);
|
impl_index!(Tensor);
|
||||||
impl_index!(SharedMemory);
|
impl_index!(SharedMemory);
|
||||||
|
|
||||||
|
impl<'a, E: CubeType, I: Into<UInt>> core::ops::IndexMut<I> for SliceMut<'a, E> {
|
||||||
|
fn index_mut(&mut self, _index: I) -> &mut Self::Output {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod index {
|
pub mod index {
|
||||||
|
@ -66,6 +75,7 @@ pub mod index {
|
||||||
operation::base::{binary_expand, binary_expand_no_vec},
|
operation::base::{binary_expand, binary_expand_no_vec},
|
||||||
CubeType,
|
CubeType,
|
||||||
},
|
},
|
||||||
|
prelude::{Slice, SliceMut},
|
||||||
unexpanded,
|
unexpanded,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -78,20 +88,21 @@ pub mod index {
|
||||||
array: L,
|
array: L,
|
||||||
index: R,
|
index: R,
|
||||||
) -> ExpandElement {
|
) -> ExpandElement {
|
||||||
let index = index.into();
|
let index: ExpandElement = index.into();
|
||||||
let index_var: Variable = *index;
|
let index_var: Variable = *index;
|
||||||
let index = match index_var {
|
let index = match index_var {
|
||||||
Variable::ConstantScalar(val, _) => {
|
Variable::ConstantScalar { value, .. } => {
|
||||||
ExpandElement::Plain(Variable::ConstantScalar(val, ir::Elem::UInt))
|
ExpandElement::Plain(Variable::ConstantScalar {
|
||||||
|
value,
|
||||||
|
elem: ir::Elem::UInt,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
_ => index,
|
_ => index,
|
||||||
};
|
};
|
||||||
let array = array.into();
|
let array: ExpandElement = array.into();
|
||||||
let var: Variable = *array;
|
let var: Variable = *array;
|
||||||
match var {
|
match var {
|
||||||
Variable::Local(_, _, _) => {
|
Variable::Local { .. } => binary_expand_no_vec(context, array, index, Operator::Index),
|
||||||
binary_expand_no_vec(context, array, index, Operator::Index)
|
|
||||||
}
|
|
||||||
_ => binary_expand(context, array, index, Operator::Index),
|
_ => binary_expand(context, array, index, Operator::Index),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -111,6 +122,20 @@ pub mod index {
|
||||||
impl_index!(Array);
|
impl_index!(Array);
|
||||||
impl_index!(Tensor);
|
impl_index!(Tensor);
|
||||||
impl_index!(SharedMemory);
|
impl_index!(SharedMemory);
|
||||||
|
|
||||||
|
impl<'a, E: CubeType, I: Into<UInt>> core::ops::Index<I> for SliceMut<'a, E> {
|
||||||
|
type Output = E;
|
||||||
|
fn index(&self, _index: I) -> &Self::Output {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, E: CubeType, I: Into<UInt>> core::ops::Index<I> for Slice<'a, E> {
|
||||||
|
type Output = E;
|
||||||
|
fn index(&self, _index: I) -> &Self::Output {
|
||||||
|
unexpanded!()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub mod add_assign_array_op {
|
pub mod add_assign_array_op {
|
||||||
|
|
|
@ -352,7 +352,7 @@ macro_rules! cpa {
|
||||||
};
|
};
|
||||||
// out = len(array)
|
// out = len(array)
|
||||||
($scope:expr, $out:ident = len($input:expr)) => {
|
($scope:expr, $out:ident = len($input:expr)) => {
|
||||||
$scope.register($crate::ir::Metadata::ArrayLength {
|
$scope.register($crate::ir::Metadata::Length {
|
||||||
var: $input.into(),
|
var: $input.into(),
|
||||||
out: $out.into(),
|
out: $out.into(),
|
||||||
});
|
});
|
||||||
|
@ -398,43 +398,64 @@ macro_rules! cpa {
|
||||||
|
|
||||||
impl From<bool> for Variable {
|
impl From<bool> for Variable {
|
||||||
fn from(value: bool) -> Self {
|
fn from(value: bool) -> Self {
|
||||||
Self::ConstantScalar(if value { 1.0 } else { 0.0 }, super::Elem::Bool)
|
Self::ConstantScalar {
|
||||||
|
value: if value { 1.0 } else { 0.0 },
|
||||||
|
elem: super::Elem::Bool,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<i32> for Variable {
|
impl From<i32> for Variable {
|
||||||
fn from(value: i32) -> Self {
|
fn from(value: i32) -> Self {
|
||||||
Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I32))
|
Self::ConstantScalar {
|
||||||
|
value: value as f64,
|
||||||
|
elem: super::Elem::Int(super::IntKind::I32),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<i64> for Variable {
|
impl From<i64> for Variable {
|
||||||
fn from(value: i64) -> Self {
|
fn from(value: i64) -> Self {
|
||||||
Self::ConstantScalar(value as f64, super::Elem::Int(super::IntKind::I64))
|
Self::ConstantScalar {
|
||||||
|
value: value as f64,
|
||||||
|
elem: super::Elem::Int(super::IntKind::I64),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<f32> for Variable {
|
impl From<f32> for Variable {
|
||||||
fn from(value: f32) -> Self {
|
fn from(value: f32) -> Self {
|
||||||
Self::ConstantScalar(value as f64, super::Elem::Float(super::FloatKind::F32))
|
Self::ConstantScalar {
|
||||||
|
value: value as f64,
|
||||||
|
elem: super::Elem::Float(super::FloatKind::F32),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<f64> for Variable {
|
impl From<f64> for Variable {
|
||||||
fn from(value: f64) -> Self {
|
fn from(value: f64) -> Self {
|
||||||
Self::ConstantScalar(value, super::Elem::Float(super::FloatKind::F64))
|
Self::ConstantScalar {
|
||||||
|
value,
|
||||||
|
elem: super::Elem::Float(super::FloatKind::F64),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<u32> for Variable {
|
impl From<u32> for Variable {
|
||||||
fn from(value: u32) -> Self {
|
fn from(value: u32) -> Self {
|
||||||
Self::ConstantScalar(value as f64, super::Elem::UInt)
|
Self::ConstantScalar {
|
||||||
|
value: value as f64,
|
||||||
|
elem: super::Elem::UInt,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<usize> for Variable {
|
impl From<usize> for Variable {
|
||||||
fn from(value: usize) -> Self {
|
fn from(value: usize) -> Self {
|
||||||
Self::ConstantScalar(value as f64, super::Elem::UInt)
|
Self::ConstantScalar {
|
||||||
|
value: value as f64,
|
||||||
|
elem: super::Elem::UInt,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -53,6 +53,7 @@ pub enum Operator {
|
||||||
Assign(UnaryOperator),
|
Assign(UnaryOperator),
|
||||||
Modulo(BinaryOperator),
|
Modulo(BinaryOperator),
|
||||||
Index(BinaryOperator),
|
Index(BinaryOperator),
|
||||||
|
Slice(SliceOperator),
|
||||||
UncheckedIndex(BinaryOperator),
|
UncheckedIndex(BinaryOperator),
|
||||||
IndexAssign(BinaryOperator),
|
IndexAssign(BinaryOperator),
|
||||||
UncheckedIndexAssign(BinaryOperator),
|
UncheckedIndexAssign(BinaryOperator),
|
||||||
|
@ -84,7 +85,7 @@ pub enum Metadata {
|
||||||
var: Variable,
|
var: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
},
|
},
|
||||||
ArrayLength {
|
Length {
|
||||||
var: Variable,
|
var: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
},
|
},
|
||||||
|
@ -120,6 +121,15 @@ pub struct ClampOperator {
|
||||||
pub out: Variable,
|
pub out: Variable,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
pub struct SliceOperator {
|
||||||
|
pub input: Variable,
|
||||||
|
pub start: Variable,
|
||||||
|
pub end: Variable,
|
||||||
|
pub out: Variable,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||||
#[allow(missing_docs)]
|
#[allow(missing_docs)]
|
||||||
pub struct ReadGlobalOperator {
|
pub struct ReadGlobalOperator {
|
||||||
|
|
|
@ -19,6 +19,7 @@ pub struct Scope {
|
||||||
pub operations: Vec<Operation>,
|
pub operations: Vec<Operation>,
|
||||||
locals: Vec<Variable>,
|
locals: Vec<Variable>,
|
||||||
matrices: Vec<Variable>,
|
matrices: Vec<Variable>,
|
||||||
|
slices: Vec<Variable>,
|
||||||
shared_memories: Vec<Variable>,
|
shared_memories: Vec<Variable>,
|
||||||
local_arrays: Vec<Variable>,
|
local_arrays: Vec<Variable>,
|
||||||
reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>,
|
reads_global: Vec<(Variable, ReadingStrategy, Variable, Variable)>,
|
||||||
|
@ -49,6 +50,7 @@ impl Scope {
|
||||||
operations: Vec::new(),
|
operations: Vec::new(),
|
||||||
locals: Vec::new(),
|
locals: Vec::new(),
|
||||||
matrices: Vec::new(),
|
matrices: Vec::new(),
|
||||||
|
slices: Vec::new(),
|
||||||
local_arrays: Vec::new(),
|
local_arrays: Vec::new(),
|
||||||
shared_memories: Vec::new(),
|
shared_memories: Vec::new(),
|
||||||
reads_global: Vec::new(),
|
reads_global: Vec::new(),
|
||||||
|
@ -75,7 +77,10 @@ impl Scope {
|
||||||
I: Into<Item> + Copy,
|
I: Into<Item> + Copy,
|
||||||
{
|
{
|
||||||
let local = self.create_local(item);
|
let local = self.create_local(item);
|
||||||
let value = Variable::ConstantScalar(value.into(), item.into().elem());
|
let value = Variable::ConstantScalar {
|
||||||
|
value: value.into(),
|
||||||
|
elem: item.into().elem(),
|
||||||
|
};
|
||||||
cpa!(self, local = value);
|
cpa!(self, local = value);
|
||||||
local
|
local
|
||||||
}
|
}
|
||||||
|
@ -83,16 +88,35 @@ impl Scope {
|
||||||
/// Create a matrix variable
|
/// Create a matrix variable
|
||||||
pub fn create_matrix(&mut self, matrix: Matrix) -> Variable {
|
pub fn create_matrix(&mut self, matrix: Matrix) -> Variable {
|
||||||
let index = self.matrices.len() as u16;
|
let index = self.matrices.len() as u16;
|
||||||
let variable = Variable::Matrix(index, matrix);
|
let variable = Variable::Matrix {
|
||||||
|
id: index,
|
||||||
|
mat: matrix,
|
||||||
|
};
|
||||||
self.matrices.push(variable);
|
self.matrices.push(variable);
|
||||||
variable
|
variable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a slice variable
|
||||||
|
pub fn create_slice(&mut self, item: Item) -> Variable {
|
||||||
|
let id = self.slices.len() as u16;
|
||||||
|
let variable = Variable::Slice {
|
||||||
|
id,
|
||||||
|
item,
|
||||||
|
depth: self.depth,
|
||||||
|
};
|
||||||
|
self.slices.push(variable);
|
||||||
|
variable
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a local variable of the given [item type](Item).
|
/// Create a local variable of the given [item type](Item).
|
||||||
pub fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
|
pub fn create_local<I: Into<Item>>(&mut self, item: I) -> Variable {
|
||||||
let item = item.into();
|
let item = item.into();
|
||||||
let index = self.new_local_index();
|
let index = self.new_local_index();
|
||||||
let local = Variable::Local(index, item, self.depth);
|
let local = Variable::Local {
|
||||||
|
id: index,
|
||||||
|
item,
|
||||||
|
depth: self.depth,
|
||||||
|
};
|
||||||
self.locals.push(local);
|
self.locals.push(local);
|
||||||
local
|
local
|
||||||
}
|
}
|
||||||
|
@ -102,7 +126,11 @@ impl Scope {
|
||||||
/// Useful for _for loops_ and other algorithms that require the control over initialization.
|
/// Useful for _for loops_ and other algorithms that require the control over initialization.
|
||||||
pub fn create_local_undeclared(&mut self, item: Item) -> Variable {
|
pub fn create_local_undeclared(&mut self, item: Item) -> Variable {
|
||||||
let index = self.new_local_index();
|
let index = self.new_local_index();
|
||||||
let local = Variable::Local(index, item, self.depth);
|
let local = Variable::Local {
|
||||||
|
id: index,
|
||||||
|
item,
|
||||||
|
depth: self.depth,
|
||||||
|
};
|
||||||
self.undeclared += 1;
|
self.undeclared += 1;
|
||||||
local
|
local
|
||||||
}
|
}
|
||||||
|
@ -131,8 +159,12 @@ impl Scope {
|
||||||
///
|
///
|
||||||
/// The index refers to the scalar position for the same [element](Elem) type.
|
/// The index refers to the scalar position for the same [element](Elem) type.
|
||||||
pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable {
|
pub fn read_scalar(&mut self, index: u16, elem: Elem) -> Variable {
|
||||||
let local = Variable::LocalScalar(self.new_local_scalar_index(), elem, self.depth);
|
let local = Variable::LocalScalar {
|
||||||
let scalar = Variable::GlobalScalar(index, elem);
|
id: self.new_local_scalar_index(),
|
||||||
|
elem,
|
||||||
|
depth: self.depth,
|
||||||
|
};
|
||||||
|
let scalar = Variable::GlobalScalar { id: index, elem };
|
||||||
|
|
||||||
self.reads_scalar.push((local, scalar));
|
self.reads_scalar.push((local, scalar));
|
||||||
|
|
||||||
|
@ -215,7 +247,7 @@ impl Scope {
|
||||||
self.reads_global
|
self.reads_global
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(var, strategy, _, _)| match var {
|
.map(|(var, strategy, _, _)| match var {
|
||||||
Variable::GlobalInputArray(id, _) => (*id, *strategy),
|
Variable::GlobalInputArray { id, .. } => (*id, *strategy),
|
||||||
_ => panic!("Can only read global input arrays."),
|
_ => panic!("Can only read global input arrays."),
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
|
@ -233,6 +265,7 @@ impl Scope {
|
||||||
operations: Vec::new(),
|
operations: Vec::new(),
|
||||||
locals: Vec::new(),
|
locals: Vec::new(),
|
||||||
matrices: Vec::new(),
|
matrices: Vec::new(),
|
||||||
|
slices: Vec::new(),
|
||||||
shared_memories: Vec::new(),
|
shared_memories: Vec::new(),
|
||||||
local_arrays: Vec::new(),
|
local_arrays: Vec::new(),
|
||||||
reads_global: Vec::new(),
|
reads_global: Vec::new(),
|
||||||
|
@ -259,6 +292,9 @@ impl Scope {
|
||||||
for var in self.matrices.drain(..) {
|
for var in self.matrices.drain(..) {
|
||||||
variables.push(var);
|
variables.push(var);
|
||||||
}
|
}
|
||||||
|
for var in self.slices.drain(..) {
|
||||||
|
variables.push(var);
|
||||||
|
}
|
||||||
|
|
||||||
for index in self.index_offset_with_output_layout_position.drain(..) {
|
for index in self.index_offset_with_output_layout_position.drain(..) {
|
||||||
if let Some(Operation::Procedure(Procedure::IndexOffsetGlobalWithLayout(proc))) =
|
if let Some(Operation::Procedure(Procedure::IndexOffsetGlobalWithLayout(proc))) =
|
||||||
|
@ -357,9 +393,16 @@ impl Scope {
|
||||||
},
|
},
|
||||||
_ => item,
|
_ => item,
|
||||||
};
|
};
|
||||||
let input = Variable::GlobalInputArray(index, item_global);
|
let input = Variable::GlobalInputArray {
|
||||||
|
id: index,
|
||||||
|
item: item_global,
|
||||||
|
};
|
||||||
let index = self.new_local_index();
|
let index = self.new_local_index();
|
||||||
let local = Variable::Local(index, item, self.depth);
|
let local = Variable::Local {
|
||||||
|
id: index,
|
||||||
|
item,
|
||||||
|
depth: self.depth,
|
||||||
|
};
|
||||||
self.reads_global.push((input, strategy, local, position));
|
self.reads_global.push((input, strategy, local, position));
|
||||||
self.locals.push(local);
|
self.locals.push(local);
|
||||||
local
|
local
|
||||||
|
@ -369,7 +412,11 @@ impl Scope {
|
||||||
pub fn create_shared<I: Into<Item>>(&mut self, item: I, shared_memory_size: u32) -> Variable {
|
pub fn create_shared<I: Into<Item>>(&mut self, item: I, shared_memory_size: u32) -> Variable {
|
||||||
let item = item.into();
|
let item = item.into();
|
||||||
let index = self.new_shared_index();
|
let index = self.new_shared_index();
|
||||||
let shared_memory = Variable::SharedMemory(index, item, shared_memory_size);
|
let shared_memory = Variable::SharedMemory {
|
||||||
|
id: index,
|
||||||
|
item,
|
||||||
|
length: shared_memory_size,
|
||||||
|
};
|
||||||
self.shared_memories.push(shared_memory);
|
self.shared_memories.push(shared_memory);
|
||||||
shared_memory
|
shared_memory
|
||||||
}
|
}
|
||||||
|
@ -378,7 +425,12 @@ impl Scope {
|
||||||
pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> Variable {
|
pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> Variable {
|
||||||
let item = item.into();
|
let item = item.into();
|
||||||
let index = self.new_local_array_index();
|
let index = self.new_local_array_index();
|
||||||
let local_array = Variable::LocalArray(index, item, self.depth, array_size);
|
let local_array = Variable::LocalArray {
|
||||||
|
id: index,
|
||||||
|
item,
|
||||||
|
depth: self.depth,
|
||||||
|
length: array_size,
|
||||||
|
};
|
||||||
self.local_arrays.push(local_array);
|
self.local_arrays.push(local_array);
|
||||||
local_array
|
local_array
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,14 +5,52 @@ use serde::{Deserialize, Serialize};
|
||||||
#[allow(missing_docs)]
|
#[allow(missing_docs)]
|
||||||
pub enum Variable {
|
pub enum Variable {
|
||||||
Rank,
|
Rank,
|
||||||
GlobalInputArray(u16, Item),
|
GlobalInputArray {
|
||||||
GlobalScalar(u16, Elem),
|
id: u16,
|
||||||
GlobalOutputArray(u16, Item),
|
item: Item,
|
||||||
Local(u16, Item, u8),
|
},
|
||||||
LocalScalar(u16, Elem, u8),
|
GlobalScalar {
|
||||||
ConstantScalar(f64, Elem),
|
id: u16,
|
||||||
SharedMemory(u16, Item, u32),
|
elem: Elem,
|
||||||
LocalArray(u16, Item, u8, u32),
|
},
|
||||||
|
GlobalOutputArray {
|
||||||
|
id: u16,
|
||||||
|
item: Item,
|
||||||
|
},
|
||||||
|
Local {
|
||||||
|
id: u16,
|
||||||
|
item: Item,
|
||||||
|
depth: u8,
|
||||||
|
},
|
||||||
|
LocalScalar {
|
||||||
|
id: u16,
|
||||||
|
elem: Elem,
|
||||||
|
depth: u8,
|
||||||
|
},
|
||||||
|
ConstantScalar {
|
||||||
|
value: f64,
|
||||||
|
elem: Elem,
|
||||||
|
},
|
||||||
|
SharedMemory {
|
||||||
|
id: u16,
|
||||||
|
item: Item,
|
||||||
|
length: u32,
|
||||||
|
},
|
||||||
|
LocalArray {
|
||||||
|
id: u16,
|
||||||
|
item: Item,
|
||||||
|
depth: u8,
|
||||||
|
length: u32,
|
||||||
|
},
|
||||||
|
Matrix {
|
||||||
|
id: u16,
|
||||||
|
mat: Matrix,
|
||||||
|
},
|
||||||
|
Slice {
|
||||||
|
id: u16,
|
||||||
|
item: Item,
|
||||||
|
depth: u8,
|
||||||
|
},
|
||||||
UnitPos,
|
UnitPos,
|
||||||
UnitPosX,
|
UnitPosX,
|
||||||
UnitPosY,
|
UnitPosY,
|
||||||
|
@ -34,20 +72,21 @@ pub enum Variable {
|
||||||
AbsolutePosX,
|
AbsolutePosX,
|
||||||
AbsolutePosY,
|
AbsolutePosY,
|
||||||
AbsolutePosZ,
|
AbsolutePosZ,
|
||||||
Matrix(u16, Matrix),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Variable {
|
impl Variable {
|
||||||
pub fn index(&self) -> Option<u16> {
|
pub fn index(&self) -> Option<u16> {
|
||||||
match self {
|
match self {
|
||||||
Variable::GlobalInputArray(idx, _) => Some(*idx),
|
Variable::GlobalInputArray { id, .. } => Some(*id),
|
||||||
Variable::GlobalScalar(idx, _) => Some(*idx),
|
Variable::GlobalScalar { id, .. } => Some(*id),
|
||||||
Variable::Local(idx, _, _) => Some(*idx),
|
Variable::Local { id, .. } => Some(*id),
|
||||||
Variable::LocalScalar(idx, _, _) => Some(*idx),
|
Variable::Slice { id, .. } => Some(*id),
|
||||||
Variable::GlobalOutputArray(idx, _) => Some(*idx),
|
Variable::LocalScalar { id, .. } => Some(*id),
|
||||||
Variable::ConstantScalar(_, _) => None,
|
Variable::GlobalOutputArray { id, .. } => Some(*id),
|
||||||
Variable::SharedMemory(idx, _, _) => Some(*idx),
|
Variable::ConstantScalar { .. } => None,
|
||||||
Variable::LocalArray(idx, _, _, _) => Some(*idx),
|
Variable::SharedMemory { id, .. } => Some(*id),
|
||||||
|
Variable::LocalArray { id, .. } => Some(*id),
|
||||||
|
Variable::Matrix { id, .. } => Some(*id),
|
||||||
Variable::AbsolutePos => None,
|
Variable::AbsolutePos => None,
|
||||||
Variable::UnitPos => None,
|
Variable::UnitPos => None,
|
||||||
Variable::UnitPosX => None,
|
Variable::UnitPosX => None,
|
||||||
|
@ -70,21 +109,22 @@ impl Variable {
|
||||||
Variable::CubeCount => None,
|
Variable::CubeCount => None,
|
||||||
Variable::CubeDim => None,
|
Variable::CubeDim => None,
|
||||||
Variable::SubcubeDim => None,
|
Variable::SubcubeDim => None,
|
||||||
Variable::Matrix(idx, _) => Some(*idx),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Fetch the item of the variable.
|
/// Fetch the item of the variable.
|
||||||
pub fn item(&self) -> Item {
|
pub fn item(&self) -> Item {
|
||||||
match self {
|
match self {
|
||||||
Variable::GlobalInputArray(_, item) => *item,
|
Variable::GlobalInputArray { item, .. } => *item,
|
||||||
Variable::GlobalOutputArray(_, item) => *item,
|
Variable::GlobalOutputArray { item, .. } => *item,
|
||||||
Variable::GlobalScalar(_, elem) => Item::new(*elem),
|
Variable::GlobalScalar { elem, .. } => Item::new(*elem),
|
||||||
Variable::Local(_, item, _) => *item,
|
Variable::Local { item, .. } => *item,
|
||||||
Variable::LocalScalar(_, elem, _) => Item::new(*elem),
|
Variable::LocalScalar { elem, .. } => Item::new(*elem),
|
||||||
Variable::ConstantScalar(_, elem) => Item::new(*elem),
|
Variable::ConstantScalar { elem, .. } => Item::new(*elem),
|
||||||
Variable::SharedMemory(_, item, _) => *item,
|
Variable::SharedMemory { item, .. } => *item,
|
||||||
Variable::LocalArray(_, item, _, _) => *item,
|
Variable::LocalArray { item, .. } => *item,
|
||||||
|
Variable::Slice { item, .. } => *item,
|
||||||
|
Variable::Matrix { mat, .. } => Item::new(mat.elem),
|
||||||
Variable::AbsolutePos => Item::new(Elem::UInt),
|
Variable::AbsolutePos => Item::new(Elem::UInt),
|
||||||
Variable::Rank => Item::new(Elem::UInt),
|
Variable::Rank => Item::new(Elem::UInt),
|
||||||
Variable::UnitPos => Item::new(Elem::UInt),
|
Variable::UnitPos => Item::new(Elem::UInt),
|
||||||
|
@ -107,7 +147,6 @@ impl Variable {
|
||||||
Variable::CubeCount => Item::new(Elem::UInt),
|
Variable::CubeCount => Item::new(Elem::UInt),
|
||||||
Variable::CubeDim => Item::new(Elem::UInt),
|
Variable::CubeDim => Item::new(Elem::UInt),
|
||||||
Variable::SubcubeDim => Item::new(Elem::UInt),
|
Variable::SubcubeDim => Item::new(Elem::UInt),
|
||||||
Variable::Matrix(_, matrix) => Item::new(matrix.elem),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use super::{
|
use super::{
|
||||||
BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator, Subcube,
|
BinaryOperator, ClampOperator, FmaOperator, InitOperator, Item, Operation, Operator,
|
||||||
UnaryOperator, Variable,
|
SliceOperator, Subcube, UnaryOperator, Variable,
|
||||||
};
|
};
|
||||||
|
|
||||||
pub type Vectorization = u8;
|
pub type Vectorization = u8;
|
||||||
|
@ -60,7 +60,7 @@ impl Operator {
|
||||||
Operator::LowerEqual(op) => Operator::LowerEqual(op.vectorize(vectorization)),
|
Operator::LowerEqual(op) => Operator::LowerEqual(op.vectorize(vectorization)),
|
||||||
Operator::GreaterEqual(op) => Operator::GreaterEqual(op.vectorize(vectorization)),
|
Operator::GreaterEqual(op) => Operator::GreaterEqual(op.vectorize(vectorization)),
|
||||||
Operator::Assign(op) => {
|
Operator::Assign(op) => {
|
||||||
if let Variable::GlobalScalar(_, _) = op.input {
|
if let Variable::GlobalScalar { .. } = op.input {
|
||||||
// Assign will not change the type of the output if the input can't be
|
// Assign will not change the type of the output if the input can't be
|
||||||
// vectorized.
|
// vectorized.
|
||||||
return Operator::Assign(op.clone());
|
return Operator::Assign(op.clone());
|
||||||
|
@ -81,6 +81,7 @@ impl Operator {
|
||||||
Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)),
|
Operator::ShiftLeft(op) => Operator::ShiftLeft(op.vectorize(vectorization)),
|
||||||
Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)),
|
Operator::ShiftRight(op) => Operator::ShiftRight(op.vectorize(vectorization)),
|
||||||
Operator::Remainder(op) => Operator::Remainder(op.vectorize(vectorization)),
|
Operator::Remainder(op) => Operator::Remainder(op.vectorize(vectorization)),
|
||||||
|
Operator::Slice(op) => Operator::Slice(op.vectorize(vectorization)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -104,6 +105,22 @@ impl UnaryOperator {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl SliceOperator {
|
||||||
|
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||||
|
let input = self.input.vectorize(vectorization);
|
||||||
|
let start = self.start.vectorize(vectorization);
|
||||||
|
let end = self.end.vectorize(vectorization);
|
||||||
|
let out = self.out.vectorize(vectorization);
|
||||||
|
|
||||||
|
Self {
|
||||||
|
input,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
out,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl InitOperator {
|
impl InitOperator {
|
||||||
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
|
pub(crate) fn vectorize(&self, vectorization: Vectorization) -> Self {
|
||||||
let out = self.out.vectorize(vectorization);
|
let out = self.out.vectorize(vectorization);
|
||||||
|
@ -155,31 +172,46 @@ impl FmaOperator {
|
||||||
impl Variable {
|
impl Variable {
|
||||||
pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self {
|
pub(crate) fn vectorize(&self, vectorize: Vectorization) -> Self {
|
||||||
match self {
|
match self {
|
||||||
Variable::GlobalInputArray(index, item) => {
|
Variable::GlobalInputArray { id, item } => Variable::GlobalInputArray {
|
||||||
Variable::GlobalInputArray(*index, item.vectorize(vectorize))
|
id: *id,
|
||||||
}
|
item: item.vectorize(vectorize),
|
||||||
Variable::Local(index, item, name) => {
|
},
|
||||||
Variable::Local(*index, item.vectorize(vectorize), *name)
|
Variable::Local { id, item, depth } => Variable::Local {
|
||||||
}
|
id: *id,
|
||||||
Variable::GlobalOutputArray(index, item) => {
|
item: item.vectorize(vectorize),
|
||||||
Variable::GlobalOutputArray(*index, item.vectorize(vectorize))
|
depth: *depth,
|
||||||
}
|
},
|
||||||
Variable::SharedMemory(index, item, size) => Variable::SharedMemory(
|
Variable::Slice { id, item, depth } => Variable::Slice {
|
||||||
*index,
|
id: *id,
|
||||||
item.vectorize(vectorize),
|
item: item.vectorize(vectorize),
|
||||||
item.vectorized_size(vectorize, *size),
|
depth: *depth,
|
||||||
),
|
},
|
||||||
Variable::LocalArray(index, item, name, size) => Variable::LocalArray(
|
Variable::GlobalOutputArray { id, item } => Variable::GlobalOutputArray {
|
||||||
*index,
|
id: *id,
|
||||||
item.vectorize(vectorize),
|
item: item.vectorize(vectorize),
|
||||||
*name,
|
},
|
||||||
item.vectorized_size(vectorize, *size),
|
Variable::SharedMemory { id, item, length } => Variable::SharedMemory {
|
||||||
),
|
id: *id,
|
||||||
Variable::ConstantScalar(_, _) => *self,
|
item: item.vectorize(vectorize),
|
||||||
Variable::GlobalScalar(_, _) => *self,
|
length: item.vectorized_size(vectorize, *length),
|
||||||
|
},
|
||||||
|
Variable::LocalArray {
|
||||||
|
id,
|
||||||
|
item,
|
||||||
|
depth,
|
||||||
|
length,
|
||||||
|
} => Variable::LocalArray {
|
||||||
|
id: *id,
|
||||||
|
item: item.vectorize(vectorize),
|
||||||
|
depth: *depth,
|
||||||
|
length: item.vectorized_size(vectorize, *length),
|
||||||
|
},
|
||||||
|
Variable::ConstantScalar { .. } => *self,
|
||||||
|
Variable::GlobalScalar { .. } => *self,
|
||||||
Variable::AbsolutePos => *self,
|
Variable::AbsolutePos => *self,
|
||||||
Variable::Rank => *self,
|
Variable::Rank => *self,
|
||||||
Variable::LocalScalar(_, _, _) => *self,
|
Variable::LocalScalar { .. } => *self,
|
||||||
|
Variable::Matrix { .. } => *self,
|
||||||
Variable::UnitPos => *self,
|
Variable::UnitPos => *self,
|
||||||
Variable::UnitPosX => *self,
|
Variable::UnitPosX => *self,
|
||||||
Variable::UnitPosY => *self,
|
Variable::UnitPosY => *self,
|
||||||
|
@ -200,7 +232,6 @@ impl Variable {
|
||||||
Variable::CubeCount => *self,
|
Variable::CubeCount => *self,
|
||||||
Variable::CubeDim => *self,
|
Variable::CubeDim => *self,
|
||||||
Variable::SubcubeDim => *self,
|
Variable::SubcubeDim => *self,
|
||||||
Variable::Matrix(_, _) => *self,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,7 +11,7 @@ pub use crate::runtime::Runtime;
|
||||||
|
|
||||||
/// Elements
|
/// Elements
|
||||||
pub use crate::frontend::{
|
pub use crate::frontend::{
|
||||||
Array, ArrayHandle, Float, LaunchArg, Tensor, TensorArg, UInt, F16, F32, F64, I32, I64,
|
Array, ArrayHandle, Float, LaunchArg, Slice, Tensor, TensorArg, UInt, F16, F32, F64, I32, I64,
|
||||||
};
|
};
|
||||||
pub use crate::pod::CubeElement;
|
pub use crate::pod::CubeElement;
|
||||||
|
|
||||||
|
|
|
@ -31,12 +31,17 @@ pub fn kernel_simple_1(lhs: &Array<F16>, rhs: &Array<F16>, out: &mut Array<F32>)
|
||||||
cmma::MatrixLayout::Undefined,
|
cmma::MatrixLayout::Undefined,
|
||||||
);
|
);
|
||||||
cmma::fill::<F32>(&c, F32::new(0.0));
|
cmma::fill::<F32>(&c, F32::new(0.0));
|
||||||
cmma::load::<F16>(&a, lhs, UInt::new(16));
|
cmma::load::<F16>(&a, lhs.as_slice(), UInt::new(16));
|
||||||
cmma::load::<F16>(&b, rhs, UInt::new(16));
|
cmma::load::<F16>(&b, rhs.as_slice(), UInt::new(16));
|
||||||
|
|
||||||
cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
|
cmma::execute::<F16, F16, F32, F32>(&a, &b, &c, &c);
|
||||||
|
|
||||||
cmma::store::<F32>(out, &c, UInt::new(16), cmma::MatrixLayout::RowMajor);
|
cmma::store::<F32>(
|
||||||
|
out.as_slice_mut(),
|
||||||
|
&c,
|
||||||
|
UInt::new(16),
|
||||||
|
cmma::MatrixLayout::RowMajor,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod cmma;
|
pub mod cmma;
|
||||||
pub mod launch;
|
pub mod launch;
|
||||||
|
pub mod slice;
|
||||||
pub mod subcube;
|
pub mod subcube;
|
||||||
|
|
||||||
#[allow(missing_docs)]
|
#[allow(missing_docs)]
|
||||||
|
@ -11,5 +12,6 @@ macro_rules! testgen_all {
|
||||||
burn_cube::testgen_subcube!();
|
burn_cube::testgen_subcube!();
|
||||||
burn_cube::testgen_launch!();
|
burn_cube::testgen_launch!();
|
||||||
burn_cube::testgen_cmma!();
|
burn_cube::testgen_cmma!();
|
||||||
|
burn_cube::testgen_slice!();
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,107 @@
|
||||||
|
use crate as burn_cube;
|
||||||
|
use burn_cube::prelude::*;
|
||||||
|
|
||||||
|
#[cube(launch)]
|
||||||
|
pub fn slice_select<F: Float>(input: &Array<F>, output: &mut Array<F>) {
|
||||||
|
if UNIT_POS == UInt::new(0) {
|
||||||
|
let slice = input.slice(2, 3);
|
||||||
|
output[0] = slice[0u32];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cube(launch)]
|
||||||
|
pub fn slice_assign<F: Float>(input: &Array<F>, output: &mut Array<F>) {
|
||||||
|
if UNIT_POS == UInt::new(0) {
|
||||||
|
let slice_1 = output.slice_mut(2, 3);
|
||||||
|
slice_1[0] = input[0u32];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cube(launch)]
|
||||||
|
pub fn slice_len<F: Float>(input: &Array<F>, output: &mut Array<UInt>) {
|
||||||
|
if UNIT_POS == UInt::new(0) {
|
||||||
|
let slice = input.slice(2, 4);
|
||||||
|
let _tmp = slice[0]; // It must be used at least once, otherwise wgpu isn't happy.
|
||||||
|
output[0] = slice.len();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn test_slice_select<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||||
|
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
|
||||||
|
let output = client.empty(core::mem::size_of::<f32>());
|
||||||
|
|
||||||
|
slice_select_launch::<F32, R>(
|
||||||
|
client.clone(),
|
||||||
|
CubeCount::Static(1, 1, 1),
|
||||||
|
CubeDim::new(1, 1, 1),
|
||||||
|
ArrayArg::new(&input, 5),
|
||||||
|
ArrayArg::new(&output, 1),
|
||||||
|
);
|
||||||
|
|
||||||
|
let actual = client.read(output.binding());
|
||||||
|
let actual = f32::from_bytes(&actual);
|
||||||
|
|
||||||
|
assert_eq!(actual[0], 2.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn test_slice_len<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||||
|
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
|
||||||
|
let output = client.empty(core::mem::size_of::<u32>());
|
||||||
|
|
||||||
|
slice_len_launch::<F32, R>(
|
||||||
|
client.clone(),
|
||||||
|
CubeCount::Static(1, 1, 1),
|
||||||
|
CubeDim::new(1, 1, 1),
|
||||||
|
ArrayArg::new(&input, 5),
|
||||||
|
ArrayArg::new(&output, 1),
|
||||||
|
);
|
||||||
|
|
||||||
|
let actual = client.read(output.binding());
|
||||||
|
let actual = u32::from_bytes(&actual);
|
||||||
|
|
||||||
|
assert_eq!(actual, &[2]);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn test_slice_assign<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||||
|
let input = client.create(f32::as_bytes(&[15.0]));
|
||||||
|
let output = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
|
||||||
|
|
||||||
|
slice_assign_launch::<F32, R>(
|
||||||
|
client.clone(),
|
||||||
|
CubeCount::Static(1, 1, 1),
|
||||||
|
CubeDim::new(1, 1, 1),
|
||||||
|
ArrayArg::new(&input, 5),
|
||||||
|
ArrayArg::new(&output, 1),
|
||||||
|
);
|
||||||
|
|
||||||
|
let actual = client.read(output.binding());
|
||||||
|
let actual = f32::from_bytes(&actual);
|
||||||
|
|
||||||
|
assert_eq!(actual, &[0.0, 1.0, 15.0, 3.0, 4.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(missing_docs)]
|
||||||
|
#[macro_export]
|
||||||
|
macro_rules! testgen_slice {
|
||||||
|
() => {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_slice_select() {
|
||||||
|
let client = TestRuntime::client(&Default::default());
|
||||||
|
burn_cube::runtime_tests::slice::test_slice_select::<TestRuntime>(client);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_slice_assign() {
|
||||||
|
let client = TestRuntime::client(&Default::default());
|
||||||
|
burn_cube::runtime_tests::slice::test_slice_assign::<TestRuntime>(client);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_slice_len() {
|
||||||
|
let client = TestRuntime::client(&Default::default());
|
||||||
|
burn_cube::runtime_tests::slice::test_slice_len::<TestRuntime>(client);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
|
@ -36,7 +36,7 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use burn_cube::{
|
use burn_cube::{
|
||||||
cpa,
|
cpa,
|
||||||
ir::{Elem, Item, Variable},
|
ir::{self, Elem, Item, Variable},
|
||||||
};
|
};
|
||||||
|
|
||||||
type ElemType = F32;
|
type ElemType = F32;
|
||||||
|
@ -47,7 +47,7 @@ mod tests {
|
||||||
|
|
||||||
array_read_write_expand::<ElemType>(&mut context, 512);
|
array_read_write_expand::<ElemType>(&mut context, 512);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
format!("{:?}", context.into_scope().operations),
|
context.into_scope().operations,
|
||||||
inline_macro_ref_read_write()
|
inline_macro_ref_read_write()
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -60,10 +60,7 @@ mod tests {
|
||||||
array_add_assign_simple_expand(&mut context, array.into());
|
array_add_assign_simple_expand(&mut context, array.into());
|
||||||
let scope = context.into_scope();
|
let scope = context.into_scope();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(scope.operations, inline_macro_array_add_assign_simple());
|
||||||
format!("{:?}", scope.operations),
|
|
||||||
inline_macro_array_add_assign_simple()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -72,7 +69,7 @@ mod tests {
|
||||||
|
|
||||||
array_to_vectorized_variable_expand::<ElemType>(&mut context);
|
array_to_vectorized_variable_expand::<ElemType>(&mut context);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
format!("{:?}", context.into_scope().operations),
|
context.into_scope().operations,
|
||||||
inline_macro_ref_to_vectorized()
|
inline_macro_ref_to_vectorized()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
@ -83,12 +80,12 @@ mod tests {
|
||||||
|
|
||||||
array_of_one_to_vectorized_variable_expand::<ElemType>(&mut context);
|
array_of_one_to_vectorized_variable_expand::<ElemType>(&mut context);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
format!("{:?}", context.into_scope().operations),
|
context.into_scope().operations,
|
||||||
inline_macro_ref_one_to_vectorized()
|
inline_macro_ref_one_to_vectorized()
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inline_macro_ref_read_write() -> String {
|
fn inline_macro_ref_read_write() -> Vec<ir::Operation> {
|
||||||
let context = CubeContext::root();
|
let context = CubeContext::root();
|
||||||
let item = Item::new(ElemType::as_elem());
|
let item = Item::new(ElemType::as_elem());
|
||||||
|
|
||||||
|
@ -105,7 +102,7 @@ mod tests {
|
||||||
// Read
|
// Read
|
||||||
cpa!(scope, var = array[pos]);
|
cpa!(scope, var = array[pos]);
|
||||||
|
|
||||||
format!("{:?}", scope.operations)
|
scope.operations
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -116,30 +113,36 @@ mod tests {
|
||||||
array_add_assign_expr_expand(&mut context, array.into());
|
array_add_assign_expr_expand(&mut context, array.into());
|
||||||
let scope = context.into_scope();
|
let scope = context.into_scope();
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(scope.operations, inline_macro_array_add_assign_expr());
|
||||||
format!("{:?}", scope.operations),
|
|
||||||
inline_macro_array_add_assign_expr()
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inline_macro_array_add_assign_simple() -> String {
|
fn inline_macro_array_add_assign_simple() -> Vec<ir::Operation> {
|
||||||
let context = CubeContext::root();
|
let context = CubeContext::root();
|
||||||
|
|
||||||
let mut scope = context.into_scope();
|
let mut scope = context.into_scope();
|
||||||
let local = scope.create_local(Item::new(Elem::UInt));
|
let local = scope.create_local(Item::new(Elem::UInt));
|
||||||
|
|
||||||
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
|
let array = Variable::GlobalInputArray {
|
||||||
let index = Variable::ConstantScalar(1., Elem::UInt);
|
id: 0,
|
||||||
let value = Variable::ConstantScalar(1., Elem::UInt);
|
item: Item::new(Elem::UInt),
|
||||||
|
};
|
||||||
|
let index = Variable::ConstantScalar {
|
||||||
|
value: 1.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let value = Variable::ConstantScalar {
|
||||||
|
value: 1.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
|
||||||
cpa!(scope, local = array[index]);
|
cpa!(scope, local = array[index]);
|
||||||
cpa!(scope, local += value);
|
cpa!(scope, local += value);
|
||||||
cpa!(scope, array[index] = local);
|
cpa!(scope, array[index] = local);
|
||||||
|
|
||||||
format!("{:?}", scope.operations)
|
scope.operations
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inline_macro_ref_to_vectorized() -> String {
|
fn inline_macro_ref_to_vectorized() -> Vec<ir::Operation> {
|
||||||
let context = CubeContext::root();
|
let context = CubeContext::root();
|
||||||
let scalar_item = Item::new(ElemType::as_elem());
|
let scalar_item = Item::new(ElemType::as_elem());
|
||||||
let vectorized_item = Item::vectorized(ElemType::as_elem(), 2);
|
let vectorized_item = Item::vectorized(ElemType::as_elem(), 2);
|
||||||
|
@ -158,10 +161,10 @@ mod tests {
|
||||||
cpa!(scope, tmp = array[pos1]);
|
cpa!(scope, tmp = array[pos1]);
|
||||||
cpa!(scope, vectorized_var[pos1] = tmp);
|
cpa!(scope, vectorized_var[pos1] = tmp);
|
||||||
|
|
||||||
format!("{:?}", scope.operations)
|
scope.operations
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inline_macro_ref_one_to_vectorized() -> String {
|
fn inline_macro_ref_one_to_vectorized() -> Vec<ir::Operation> {
|
||||||
let context = CubeContext::root();
|
let context = CubeContext::root();
|
||||||
let scalar_item = Item::new(ElemType::as_elem());
|
let scalar_item = Item::new(ElemType::as_elem());
|
||||||
let unvectorized_item = Item::new(ElemType::as_elem());
|
let unvectorized_item = Item::new(ElemType::as_elem());
|
||||||
|
@ -176,26 +179,38 @@ mod tests {
|
||||||
cpa!(scope, tmp = array[pos0]);
|
cpa!(scope, tmp = array[pos0]);
|
||||||
cpa!(scope, unvectorized_var = tmp);
|
cpa!(scope, unvectorized_var = tmp);
|
||||||
|
|
||||||
format!("{:?}", scope.operations)
|
scope.operations
|
||||||
}
|
}
|
||||||
|
|
||||||
fn inline_macro_array_add_assign_expr() -> String {
|
fn inline_macro_array_add_assign_expr() -> Vec<ir::Operation> {
|
||||||
let context = CubeContext::root();
|
let context = CubeContext::root();
|
||||||
|
|
||||||
let mut scope = context.into_scope();
|
let mut scope = context.into_scope();
|
||||||
let index = scope.create_local(Item::new(Elem::UInt));
|
let index = scope.create_local(Item::new(Elem::UInt));
|
||||||
let local = scope.create_local(Item::new(Elem::UInt));
|
let local = scope.create_local(Item::new(Elem::UInt));
|
||||||
|
|
||||||
let array = Variable::GlobalInputArray(0, Item::new(Elem::UInt));
|
let array = Variable::GlobalInputArray {
|
||||||
let const1 = Variable::ConstantScalar(1., Elem::UInt);
|
id: 0,
|
||||||
let const2 = Variable::ConstantScalar(5., Elem::UInt);
|
item: Item::new(Elem::UInt),
|
||||||
let value = Variable::ConstantScalar(1., Elem::UInt);
|
};
|
||||||
|
let const1 = Variable::ConstantScalar {
|
||||||
|
value: 1.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let const2 = Variable::ConstantScalar {
|
||||||
|
value: 5.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let value = Variable::ConstantScalar {
|
||||||
|
value: 1.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
|
||||||
cpa!(scope, index = const1 + const2);
|
cpa!(scope, index = const1 + const2);
|
||||||
cpa!(scope, local = array[index]);
|
cpa!(scope, local = array[index]);
|
||||||
cpa!(scope, local += value);
|
cpa!(scope, local += value);
|
||||||
cpa!(scope, array[index] = local);
|
cpa!(scope, array[index] = local);
|
||||||
|
|
||||||
format!("{:?}", scope.operations)
|
scope.operations
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -98,8 +98,14 @@ mod tests {
|
||||||
let mut scope = context.into_scope();
|
let mut scope = context.into_scope();
|
||||||
let x = scope.create_local(Item::new(Elem::UInt));
|
let x = scope.create_local(Item::new(Elem::UInt));
|
||||||
|
|
||||||
let zero = Variable::ConstantScalar(0., Elem::UInt);
|
let zero = Variable::ConstantScalar {
|
||||||
let one = Variable::ConstantScalar(1., Elem::UInt);
|
value: 0.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let one = Variable::ConstantScalar {
|
||||||
|
value: 1.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
cpa!(scope, x = zero);
|
cpa!(scope, x = zero);
|
||||||
cpa!(scope, x = x + one);
|
cpa!(scope, x = x + one);
|
||||||
|
|
||||||
|
@ -115,8 +121,14 @@ mod tests {
|
||||||
let y: Variable = y.into();
|
let y: Variable = y.into();
|
||||||
let x = scope.create_local(item);
|
let x = scope.create_local(item);
|
||||||
|
|
||||||
let one = Variable::ConstantScalar(1., Elem::UInt);
|
let one = Variable::ConstantScalar {
|
||||||
let two = Variable::ConstantScalar(2., Elem::UInt);
|
value: 1.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let two = Variable::ConstantScalar {
|
||||||
|
value: 2.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
cpa!(scope, x = y);
|
cpa!(scope, x = y);
|
||||||
cpa!(scope, x = x + one);
|
cpa!(scope, x = x + one);
|
||||||
cpa!(scope, y = y + two);
|
cpa!(scope, y = y + two);
|
||||||
|
@ -133,8 +145,14 @@ mod tests {
|
||||||
let y: Variable = y.into();
|
let y: Variable = y.into();
|
||||||
let x = scope.create_local(item);
|
let x = scope.create_local(item);
|
||||||
|
|
||||||
let one = Variable::ConstantScalar(1., Elem::UInt);
|
let one = Variable::ConstantScalar {
|
||||||
let two = Variable::ConstantScalar(2., Elem::UInt);
|
value: 1.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let two = Variable::ConstantScalar {
|
||||||
|
value: 2.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
cpa!(scope, x = y);
|
cpa!(scope, x = y);
|
||||||
cpa!(scope, y = y + one);
|
cpa!(scope, y = y + one);
|
||||||
cpa!(scope, x = x + two);
|
cpa!(scope, x = x + two);
|
||||||
|
@ -151,10 +169,22 @@ mod tests {
|
||||||
let y: Variable = y.into();
|
let y: Variable = y.into();
|
||||||
let x = scope.create_local(item);
|
let x = scope.create_local(item);
|
||||||
|
|
||||||
let zero = Variable::ConstantScalar(0., Elem::UInt);
|
let zero = Variable::ConstantScalar {
|
||||||
let one = Variable::ConstantScalar(1., Elem::UInt);
|
value: 0.,
|
||||||
let two = Variable::ConstantScalar(2., Elem::UInt);
|
elem: Elem::UInt,
|
||||||
let three = Variable::ConstantScalar(3., Elem::UInt);
|
};
|
||||||
|
let one = Variable::ConstantScalar {
|
||||||
|
value: 1.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let two = Variable::ConstantScalar {
|
||||||
|
value: 2.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let three = Variable::ConstantScalar {
|
||||||
|
value: 3.,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
cpa!(scope, x[zero] = one);
|
cpa!(scope, x[zero] = one);
|
||||||
cpa!(scope, x[one] = one);
|
cpa!(scope, x[one] = one);
|
||||||
cpa!(scope, x[two] = one);
|
cpa!(scope, x[two] = one);
|
||||||
|
|
|
@ -380,31 +380,31 @@ mod tests {
|
||||||
let out_number = if in_type == out_type { 0 } else { 2 };
|
let out_number = if in_type == out_type { 0 } else { 2 };
|
||||||
format!(
|
format!(
|
||||||
"[Operator({ops_name}(BinaryOperator {{ \
|
"[Operator({ops_name}(BinaryOperator {{ \
|
||||||
lhs: Local(0, Item {{ \
|
lhs: Local {{ id: 0, item: Item {{ \
|
||||||
elem: {in_type}, \
|
elem: {in_type}, \
|
||||||
vectorization: 1 \
|
vectorization: 1 \
|
||||||
}}, 0), \
|
}}, depth: 0 }}, \
|
||||||
rhs: Local(1, Item {{ \
|
rhs: Local {{ id: 1, item: Item {{ \
|
||||||
elem: {in_type}, \
|
elem: {in_type}, \
|
||||||
vectorization: 1 \
|
vectorization: 1 \
|
||||||
}}, 0), \
|
}}, depth: 0 }}, \
|
||||||
out: Local({out_number}, Item {{ \
|
out: Local {{ id: {out_number}, item: Item {{ \
|
||||||
elem: {out_type}, \
|
elem: {out_type}, \
|
||||||
vectorization: 1 \
|
vectorization: 1 \
|
||||||
}}, 0) \
|
}}, depth: 0 }} \
|
||||||
}}))]"
|
}}))]"
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
format!(
|
format!(
|
||||||
"[Operator({ops_name}(UnaryOperator {{ \
|
"[Operator({ops_name}(UnaryOperator {{ \
|
||||||
input: Local(0, Item {{ \
|
input: Local {{ id: 0, item: Item {{ \
|
||||||
elem: {in_type}, \
|
elem: {in_type}, \
|
||||||
vectorization: 1 \
|
vectorization: 1 \
|
||||||
}}, 0), \
|
}}, depth: 0 }}, \
|
||||||
out: Local(0, Item {{ \
|
out: Local {{ id: 0, item: Item {{ \
|
||||||
elem: {out_type}, \
|
elem: {out_type}, \
|
||||||
vectorization: 1 \
|
vectorization: 1 \
|
||||||
}}, 0) \
|
}}, depth: 0 }} \
|
||||||
}}))]"
|
}}))]"
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
|
@ -117,7 +117,10 @@ mod tests {
|
||||||
|
|
||||||
let i = scope.create_with_value(1, item);
|
let i = scope.create_with_value(1, item);
|
||||||
cpa!(scope, x += i);
|
cpa!(scope, x += i);
|
||||||
let value = Variable::ConstantScalar(2., item.elem());
|
let value = Variable::ConstantScalar {
|
||||||
|
value: 2.,
|
||||||
|
elem: item.elem(),
|
||||||
|
};
|
||||||
cpa!(scope, i = value);
|
cpa!(scope, i = value);
|
||||||
cpa!(scope, x += i);
|
cpa!(scope, x += i);
|
||||||
|
|
||||||
|
@ -156,7 +159,10 @@ mod tests {
|
||||||
cpa!(
|
cpa!(
|
||||||
&mut scope,
|
&mut scope,
|
||||||
range(0u32, end, false).for_each(|_, scope| {
|
range(0u32, end, false).for_each(|_, scope| {
|
||||||
let value = Variable::ConstantScalar(2.into(), item.elem());
|
let value = Variable::ConstantScalar {
|
||||||
|
value: 2.into(),
|
||||||
|
elem: item.elem(),
|
||||||
|
};
|
||||||
cpa!(scope, y = value);
|
cpa!(scope, y = value);
|
||||||
cpa!(scope, x += y);
|
cpa!(scope, x += y);
|
||||||
})
|
})
|
||||||
|
|
|
@ -83,6 +83,9 @@ impl CudaCompiler {
|
||||||
let processing = value.process();
|
let processing = value.process();
|
||||||
|
|
||||||
for var in processing.variables {
|
for var in processing.variables {
|
||||||
|
if let gpu::Variable::Slice { .. } = var {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
instructions.push(Instruction::DeclareVariable {
|
instructions.push(Instruction::DeclareVariable {
|
||||||
var: self.compile_variable(var),
|
var: self.compile_variable(var),
|
||||||
});
|
});
|
||||||
|
@ -192,8 +195,8 @@ impl CudaCompiler {
|
||||||
gpu::Metadata::Stride { dim, var, out } => {
|
gpu::Metadata::Stride { dim, var, out } => {
|
||||||
self.stride = true;
|
self.stride = true;
|
||||||
let position = match var {
|
let position = match var {
|
||||||
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
|
gpu::Variable::GlobalInputArray { id, .. } => id as usize,
|
||||||
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
|
||||||
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
|
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
|
||||||
};
|
};
|
||||||
Instruction::Stride {
|
Instruction::Stride {
|
||||||
|
@ -205,8 +208,8 @@ impl CudaCompiler {
|
||||||
gpu::Metadata::Shape { dim, var, out } => {
|
gpu::Metadata::Shape { dim, var, out } => {
|
||||||
self.shape = true;
|
self.shape = true;
|
||||||
let position = match var {
|
let position = match var {
|
||||||
gpu::Variable::GlobalInputArray(idx, _) => idx as usize,
|
gpu::Variable::GlobalInputArray { id, .. } => id as usize,
|
||||||
gpu::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
gpu::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
|
||||||
_ => panic!("Only Input and Output have a shape, got {:?}", var),
|
_ => panic!("Only Input and Output have a shape, got {:?}", var),
|
||||||
};
|
};
|
||||||
Instruction::Shape {
|
Instruction::Shape {
|
||||||
|
@ -215,12 +218,20 @@ impl CudaCompiler {
|
||||||
out: self.compile_variable(out),
|
out: self.compile_variable(out),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
gpu::Metadata::ArrayLength { var, out } => super::Instruction::ArrayLength {
|
gpu::Metadata::Length { var, out } => {
|
||||||
input: self.compile_variable(var),
|
let input = self.compile_variable(var);
|
||||||
out: self.compile_variable(out),
|
let out = self.compile_variable(out);
|
||||||
num_inputs: self.num_inputs,
|
|
||||||
num_outputs: self.num_outputs,
|
match input {
|
||||||
},
|
super::Variable::Slice { .. } => super::Instruction::SliceLength { input, out },
|
||||||
|
_ => super::Instruction::Length {
|
||||||
|
input,
|
||||||
|
out,
|
||||||
|
num_inputs: self.num_inputs,
|
||||||
|
num_outputs: self.num_outputs,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -297,6 +308,12 @@ impl CudaCompiler {
|
||||||
gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)),
|
gpu::Operator::Div(op) => Instruction::Div(self.compile_binary(op)),
|
||||||
gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)),
|
gpu::Operator::Sub(op) => Instruction::Sub(self.compile_binary(op)),
|
||||||
gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)),
|
gpu::Operator::Assign(op) => Instruction::Assign(self.compile_unary(op)),
|
||||||
|
gpu::Operator::Slice(op) => Instruction::Slice {
|
||||||
|
input: self.compile_variable(op.input),
|
||||||
|
start: self.compile_variable(op.start),
|
||||||
|
end: self.compile_variable(op.end),
|
||||||
|
out: self.compile_variable(op.out),
|
||||||
|
},
|
||||||
gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)),
|
gpu::Operator::Index(op) => Instruction::Index(self.compile_binary(op)),
|
||||||
gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)),
|
gpu::Operator::UncheckedIndex(op) => Instruction::Index(self.compile_binary(op)),
|
||||||
gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)),
|
gpu::Operator::IndexAssign(op) => Instruction::IndexAssign(self.compile_binary(op)),
|
||||||
|
@ -372,35 +389,40 @@ impl CudaCompiler {
|
||||||
|
|
||||||
fn compile_variable(&mut self, value: gpu::Variable) -> super::Variable {
|
fn compile_variable(&mut self, value: gpu::Variable) -> super::Variable {
|
||||||
match value {
|
match value {
|
||||||
gpu::Variable::GlobalInputArray(index, item) => {
|
gpu::Variable::GlobalInputArray { id, item } => {
|
||||||
super::Variable::GlobalInputArray(index, Self::compile_item(item))
|
super::Variable::GlobalInputArray(id, Self::compile_item(item))
|
||||||
}
|
}
|
||||||
gpu::Variable::GlobalScalar(index, elem) => {
|
gpu::Variable::GlobalScalar { id, elem } => {
|
||||||
super::Variable::GlobalScalar(index, Self::compile_elem(elem), elem)
|
super::Variable::GlobalScalar(id, Self::compile_elem(elem), elem)
|
||||||
}
|
}
|
||||||
gpu::Variable::Local(index, item, scope_depth) => super::Variable::Local {
|
gpu::Variable::Local { id, item, depth } => super::Variable::Local {
|
||||||
index,
|
id,
|
||||||
item: Self::compile_item(item),
|
item: Self::compile_item(item),
|
||||||
scope_depth,
|
depth,
|
||||||
},
|
},
|
||||||
gpu::Variable::LocalScalar(index, elem, scope_depth) => super::Variable::LocalScalar {
|
gpu::Variable::Slice { id, item, depth } => super::Variable::Slice {
|
||||||
index,
|
id,
|
||||||
|
item: Self::compile_item(item),
|
||||||
|
depth,
|
||||||
|
},
|
||||||
|
gpu::Variable::LocalScalar { id, elem, depth } => super::Variable::LocalScalar {
|
||||||
|
id,
|
||||||
elem: Self::compile_elem(elem),
|
elem: Self::compile_elem(elem),
|
||||||
scope_depth,
|
depth,
|
||||||
},
|
},
|
||||||
gpu::Variable::GlobalOutputArray(index, item) => {
|
gpu::Variable::GlobalOutputArray { id, item } => {
|
||||||
super::Variable::GlobalOutputArray(index, Self::compile_item(item))
|
super::Variable::GlobalOutputArray(id, Self::compile_item(item))
|
||||||
}
|
}
|
||||||
gpu::Variable::ConstantScalar(index, elem) => {
|
gpu::Variable::ConstantScalar { value, elem } => {
|
||||||
super::Variable::ConstantScalar(index, Self::compile_elem(elem))
|
super::Variable::ConstantScalar(value, Self::compile_elem(elem))
|
||||||
}
|
}
|
||||||
gpu::Variable::SharedMemory(index, item, size) => {
|
gpu::Variable::SharedMemory { id, item, length } => {
|
||||||
let item = Self::compile_item(item);
|
let item = Self::compile_item(item);
|
||||||
if !self.shared_memories.iter().any(|s| s.index == index) {
|
if !self.shared_memories.iter().any(|s| s.index == id) {
|
||||||
self.shared_memories
|
self.shared_memories
|
||||||
.push(super::SharedMemory::new(index, item, size));
|
.push(super::SharedMemory::new(id, item, length));
|
||||||
}
|
}
|
||||||
super::Variable::SharedMemory(index, item, size)
|
super::Variable::SharedMemory(id, item, length)
|
||||||
}
|
}
|
||||||
gpu::Variable::AbsolutePos => {
|
gpu::Variable::AbsolutePos => {
|
||||||
self.id = true;
|
self.id = true;
|
||||||
|
@ -438,7 +460,12 @@ impl CudaCompiler {
|
||||||
gpu::Variable::CubeCountX => super::Variable::NumWorkgroupsX,
|
gpu::Variable::CubeCountX => super::Variable::NumWorkgroupsX,
|
||||||
gpu::Variable::CubeCountY => super::Variable::NumWorkgroupsY,
|
gpu::Variable::CubeCountY => super::Variable::NumWorkgroupsY,
|
||||||
gpu::Variable::CubeCountZ => super::Variable::NumWorkgroupsZ,
|
gpu::Variable::CubeCountZ => super::Variable::NumWorkgroupsZ,
|
||||||
gpu::Variable::LocalArray(id, item, depth, size) => {
|
gpu::Variable::LocalArray {
|
||||||
|
id,
|
||||||
|
item,
|
||||||
|
depth,
|
||||||
|
length,
|
||||||
|
} => {
|
||||||
let item = Self::compile_item(item);
|
let item = Self::compile_item(item);
|
||||||
if !self
|
if !self
|
||||||
.local_arrays
|
.local_arrays
|
||||||
|
@ -446,19 +473,19 @@ impl CudaCompiler {
|
||||||
.any(|s| s.index == id && s.depth == depth)
|
.any(|s| s.index == id && s.depth == depth)
|
||||||
{
|
{
|
||||||
self.local_arrays
|
self.local_arrays
|
||||||
.push(super::LocalArray::new(id, item, depth, size));
|
.push(super::LocalArray::new(id, item, depth, length));
|
||||||
}
|
}
|
||||||
super::Variable::LocalArray(id, item, depth, size)
|
super::Variable::LocalArray(id, item, depth, length)
|
||||||
}
|
}
|
||||||
gpu::Variable::CubePos => todo!(),
|
gpu::Variable::CubePos => todo!(),
|
||||||
gpu::Variable::CubeDim => todo!(),
|
gpu::Variable::CubeDim => todo!(),
|
||||||
gpu::Variable::CubeCount => todo!(),
|
gpu::Variable::CubeCount => todo!(),
|
||||||
gpu::Variable::SubcubeDim => todo!(),
|
gpu::Variable::SubcubeDim => todo!(),
|
||||||
gpu::Variable::Matrix(index, matrix) => {
|
gpu::Variable::Matrix { id, mat } => {
|
||||||
self.wmma = true;
|
self.wmma = true;
|
||||||
super::Variable::WmmaFragment {
|
super::Variable::WmmaFragment {
|
||||||
index,
|
id,
|
||||||
frag: Self::compile_matrix(matrix),
|
frag: Self::compile_matrix(mat),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -361,9 +361,9 @@ impl Binary for IndexAssign {
|
||||||
out: &Variable,
|
out: &Variable,
|
||||||
) -> std::fmt::Result {
|
) -> std::fmt::Result {
|
||||||
if let Variable::Local {
|
if let Variable::Local {
|
||||||
index: _,
|
id: _,
|
||||||
item: _,
|
item: _,
|
||||||
scope_depth: _,
|
depth: _,
|
||||||
} = out
|
} = out
|
||||||
{
|
{
|
||||||
return IndexAssignVector::format(f, lhs, rhs, out);
|
return IndexAssignVector::format(f, lhs, rhs, out);
|
||||||
|
@ -388,9 +388,9 @@ impl Binary for Index {
|
||||||
out: &Variable,
|
out: &Variable,
|
||||||
) -> std::fmt::Result {
|
) -> std::fmt::Result {
|
||||||
if let Variable::Local {
|
if let Variable::Local {
|
||||||
index: _,
|
id: _,
|
||||||
item: _,
|
item: _,
|
||||||
scope_depth: _,
|
depth: _,
|
||||||
} = lhs
|
} = lhs
|
||||||
{
|
{
|
||||||
return IndexVector::format(f, lhs, rhs, out);
|
return IndexVector::format(f, lhs, rhs, out);
|
||||||
|
|
|
@ -86,9 +86,14 @@ impl Component for Variable {
|
||||||
Variable::GlobalOutputArray(_, e) => *e,
|
Variable::GlobalOutputArray(_, e) => *e,
|
||||||
Variable::SharedMemory(_, e, _) => *e,
|
Variable::SharedMemory(_, e, _) => *e,
|
||||||
Variable::Local {
|
Variable::Local {
|
||||||
index: _,
|
id: _,
|
||||||
item,
|
item,
|
||||||
scope_depth: _,
|
depth: _,
|
||||||
|
} => *item,
|
||||||
|
Variable::Slice {
|
||||||
|
id: _,
|
||||||
|
item,
|
||||||
|
depth: _,
|
||||||
} => *item,
|
} => *item,
|
||||||
Variable::ConstantScalar(_, e) => Item::Scalar(*e),
|
Variable::ConstantScalar(_, e) => Item::Scalar(*e),
|
||||||
Variable::GlobalScalar(_, e, _) => Item::Scalar(*e),
|
Variable::GlobalScalar(_, e, _) => Item::Scalar(*e),
|
||||||
|
@ -99,9 +104,9 @@ impl Component for Variable {
|
||||||
Variable::LocalInvocationIdZ => Item::Scalar(Elem::U32),
|
Variable::LocalInvocationIdZ => Item::Scalar(Elem::U32),
|
||||||
Variable::Rank => Item::Scalar(Elem::U32),
|
Variable::Rank => Item::Scalar(Elem::U32),
|
||||||
Variable::LocalScalar {
|
Variable::LocalScalar {
|
||||||
index: _,
|
id: _,
|
||||||
elem,
|
elem,
|
||||||
scope_depth: _,
|
depth: _,
|
||||||
} => Item::Scalar(*elem),
|
} => Item::Scalar(*elem),
|
||||||
Variable::WorkgroupIdX => Item::Scalar(Elem::U32),
|
Variable::WorkgroupIdX => Item::Scalar(Elem::U32),
|
||||||
Variable::WorkgroupIdY => Item::Scalar(Elem::U32),
|
Variable::WorkgroupIdY => Item::Scalar(Elem::U32),
|
||||||
|
@ -117,7 +122,7 @@ impl Component for Variable {
|
||||||
Variable::NumWorkgroupsZ => Item::Scalar(Elem::U32),
|
Variable::NumWorkgroupsZ => Item::Scalar(Elem::U32),
|
||||||
Variable::LocalArray(_, e, _, _) => *e,
|
Variable::LocalArray(_, e, _, _) => *e,
|
||||||
Variable::WarpSize => Item::Scalar(Elem::U32),
|
Variable::WarpSize => Item::Scalar(Elem::U32),
|
||||||
Variable::WmmaFragment { index: _, frag } => Item::Scalar(frag.elem),
|
Variable::WmmaFragment { id: _, frag } => Item::Scalar(frag.elem),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -129,16 +134,9 @@ pub enum Variable {
|
||||||
GlobalOutputArray(u16, Item),
|
GlobalOutputArray(u16, Item),
|
||||||
GlobalScalar(u16, Elem, gpu::Elem),
|
GlobalScalar(u16, Elem, gpu::Elem),
|
||||||
ConstantScalar(f64, Elem),
|
ConstantScalar(f64, Elem),
|
||||||
Local {
|
Local { id: u16, item: Item, depth: u8 },
|
||||||
index: u16,
|
Slice { id: u16, item: Item, depth: u8 },
|
||||||
item: Item,
|
LocalScalar { id: u16, elem: Elem, depth: u8 },
|
||||||
scope_depth: u8,
|
|
||||||
},
|
|
||||||
LocalScalar {
|
|
||||||
index: u16,
|
|
||||||
elem: Elem,
|
|
||||||
scope_depth: u8,
|
|
||||||
},
|
|
||||||
SharedMemory(u16, Item, u32),
|
SharedMemory(u16, Item, u32),
|
||||||
LocalArray(u16, Item, u8, u32),
|
LocalArray(u16, Item, u8, u32),
|
||||||
Id,
|
Id,
|
||||||
|
@ -159,10 +157,7 @@ pub enum Variable {
|
||||||
NumWorkgroupsX,
|
NumWorkgroupsX,
|
||||||
NumWorkgroupsY,
|
NumWorkgroupsY,
|
||||||
NumWorkgroupsZ,
|
NumWorkgroupsZ,
|
||||||
WmmaFragment {
|
WmmaFragment { id: u16, frag: Fragment },
|
||||||
index: u16,
|
|
||||||
frag: Fragment,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Display for Variable {
|
impl Display for Variable {
|
||||||
|
@ -170,15 +165,18 @@ impl Display for Variable {
|
||||||
match self {
|
match self {
|
||||||
Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")),
|
Variable::GlobalInputArray(number, _) => f.write_fmt(format_args!("input_{number}")),
|
||||||
Variable::LocalScalar {
|
Variable::LocalScalar {
|
||||||
index,
|
id: index,
|
||||||
elem: _,
|
elem: _,
|
||||||
scope_depth,
|
depth: scope_depth,
|
||||||
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
|
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
|
||||||
Variable::Local {
|
Variable::Local {
|
||||||
index,
|
id: index,
|
||||||
item: _,
|
item: _,
|
||||||
scope_depth,
|
depth: scope_depth,
|
||||||
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
|
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
|
||||||
|
Variable::Slice { id, item: _, depth } => {
|
||||||
|
f.write_fmt(format_args!("slice_{depth}_{id}"))
|
||||||
|
}
|
||||||
Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")),
|
Variable::GlobalOutputArray(number, _) => f.write_fmt(format_args!("output_{number}")),
|
||||||
Variable::GlobalScalar(number, _, elem) => {
|
Variable::GlobalScalar(number, _, elem) => {
|
||||||
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
|
f.write_fmt(format_args!("scalars_{elem}[{number}]"))
|
||||||
|
@ -209,7 +207,9 @@ impl Display for Variable {
|
||||||
f.write_fmt(format_args!("l_arr_{}_{}", id, depth))
|
f.write_fmt(format_args!("l_arr_{}_{}", id, depth))
|
||||||
}
|
}
|
||||||
Variable::WarpSize => f.write_str("warpSize"),
|
Variable::WarpSize => f.write_str("warpSize"),
|
||||||
Variable::WmmaFragment { index, frag: _ } => f.write_fmt(format_args!("frag_{index}")),
|
Variable::WmmaFragment { id: index, frag: _ } => {
|
||||||
|
f.write_fmt(format_args!("frag_{index}"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -220,9 +220,9 @@ impl Variable {
|
||||||
Variable::GlobalScalar(_, _, _) => true,
|
Variable::GlobalScalar(_, _, _) => true,
|
||||||
Variable::ConstantScalar(_, _) => true,
|
Variable::ConstantScalar(_, _) => true,
|
||||||
Variable::LocalScalar {
|
Variable::LocalScalar {
|
||||||
index: _,
|
id: _,
|
||||||
elem: _,
|
elem: _,
|
||||||
scope_depth: _,
|
depth: _,
|
||||||
} => true,
|
} => true,
|
||||||
Variable::Id => true,
|
Variable::Id => true,
|
||||||
Variable::LocalInvocationIndex => true,
|
Variable::LocalInvocationIndex => true,
|
||||||
|
@ -234,9 +234,14 @@ impl Variable {
|
||||||
Variable::GlobalOutputArray(_, _) => false,
|
Variable::GlobalOutputArray(_, _) => false,
|
||||||
Variable::SharedMemory(_, _, _) => false,
|
Variable::SharedMemory(_, _, _) => false,
|
||||||
Variable::Local {
|
Variable::Local {
|
||||||
index: _,
|
id: _,
|
||||||
item: _,
|
item: _,
|
||||||
scope_depth: _,
|
depth: _,
|
||||||
|
} => false,
|
||||||
|
Variable::Slice {
|
||||||
|
id: _,
|
||||||
|
item: _,
|
||||||
|
depth: _,
|
||||||
} => false,
|
} => false,
|
||||||
Variable::WorkgroupIdX => true,
|
Variable::WorkgroupIdX => true,
|
||||||
Variable::WorkgroupIdY => true,
|
Variable::WorkgroupIdY => true,
|
||||||
|
@ -252,7 +257,7 @@ impl Variable {
|
||||||
Variable::NumWorkgroupsZ => true,
|
Variable::NumWorkgroupsZ => true,
|
||||||
Variable::LocalArray(_, _, _, _) => false,
|
Variable::LocalArray(_, _, _, _) => false,
|
||||||
Variable::WarpSize => true,
|
Variable::WarpSize => true,
|
||||||
Variable::WmmaFragment { index: _, frag: _ } => false,
|
Variable::WmmaFragment { id: _, frag: _ } => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -16,12 +16,16 @@ pub struct UnaryInstruction {
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub enum Instruction {
|
pub enum Instruction {
|
||||||
ArrayLength {
|
Length {
|
||||||
input: Variable,
|
input: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
num_inputs: usize,
|
num_inputs: usize,
|
||||||
num_outputs: usize,
|
num_outputs: usize,
|
||||||
},
|
},
|
||||||
|
SliceLength {
|
||||||
|
input: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
DeclareVariable {
|
DeclareVariable {
|
||||||
var: Variable,
|
var: Variable,
|
||||||
},
|
},
|
||||||
|
@ -58,6 +62,12 @@ pub enum Instruction {
|
||||||
instructions_if: Vec<Self>,
|
instructions_if: Vec<Self>,
|
||||||
instructions_else: Vec<Self>,
|
instructions_else: Vec<Self>,
|
||||||
},
|
},
|
||||||
|
Slice {
|
||||||
|
input: Variable,
|
||||||
|
start: Variable,
|
||||||
|
end: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
Return,
|
Return,
|
||||||
Break,
|
Break,
|
||||||
Stride {
|
Stride {
|
||||||
|
@ -114,7 +124,7 @@ impl Display for Instruction {
|
||||||
Instruction::Return => f.write_str("return;"),
|
Instruction::Return => f.write_str("return;"),
|
||||||
Instruction::Break => f.write_str("break;"),
|
Instruction::Break => f.write_str("break;"),
|
||||||
Instruction::DeclareVariable { var } => match var {
|
Instruction::DeclareVariable { var } => match var {
|
||||||
Variable::WmmaFragment { index: _, frag } => {
|
Variable::WmmaFragment { id: _, frag } => {
|
||||||
f.write_fmt(format_args!("{frag} {var};\n"))
|
f.write_fmt(format_args!("{frag} {var};\n"))
|
||||||
}
|
}
|
||||||
_ => {
|
_ => {
|
||||||
|
@ -123,6 +133,16 @@ impl Display for Instruction {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
|
Instruction::Add(it) => Add::format(f, &it.lhs, &it.rhs, &it.out),
|
||||||
|
Instruction::Slice {
|
||||||
|
input,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
out,
|
||||||
|
} => {
|
||||||
|
let item = out.item();
|
||||||
|
f.write_fmt(format_args!("uint {out}_length = {end} - {start};\n"))?;
|
||||||
|
f.write_fmt(format_args!("{item} *{out} = {input} + {start};\n"))
|
||||||
|
}
|
||||||
Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
|
Instruction::Mul(it) => Mul::format(f, &it.lhs, &it.rhs, &it.out),
|
||||||
Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
|
Instruction::Div(it) => Div::format(f, &it.lhs, &it.rhs, &it.out),
|
||||||
Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
|
Instruction::Sub(it) => Sub::format(f, &it.lhs, &it.rhs, &it.out),
|
||||||
|
@ -225,7 +245,10 @@ for (uint {i} = {start}; {i} < {end}; {i}++) {{
|
||||||
Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
|
Instruction::SyncThreads => f.write_str("__syncthreads();\n"),
|
||||||
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
|
Instruction::Ceil(it) => Ceil::format(f, &it.input, &it.out),
|
||||||
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
|
Instruction::Floor(it) => Floor::format(f, &it.input, &it.out),
|
||||||
Instruction::ArrayLength {
|
Instruction::SliceLength { input, out } => {
|
||||||
|
f.write_fmt(format_args!("{out} = {input}_length;\n"))
|
||||||
|
}
|
||||||
|
Instruction::Length {
|
||||||
input,
|
input,
|
||||||
out,
|
out,
|
||||||
num_inputs,
|
num_inputs,
|
||||||
|
|
|
@ -288,7 +288,10 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
let input = Variable::ConstantScalar(1.0, desc.dtype.into());
|
let input = Variable::ConstantScalar {
|
||||||
|
value: 1.0,
|
||||||
|
elem: desc.dtype.into(),
|
||||||
|
};
|
||||||
let out = self.builder.output(desc, Variable::AbsolutePos);
|
let out = self.builder.output(desc, Variable::AbsolutePos);
|
||||||
|
|
||||||
self.builder
|
self.builder
|
||||||
|
@ -301,7 +304,10 @@ impl<R: JitRuntime> ElementWiseBuilder<R> {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
let input = Variable::ConstantScalar(0.0, desc.dtype.into());
|
let input = Variable::ConstantScalar {
|
||||||
|
value: 0.0,
|
||||||
|
elem: desc.dtype.into(),
|
||||||
|
};
|
||||||
let out = self.builder.output(desc, Variable::AbsolutePos);
|
let out = self.builder.output(desc, Variable::AbsolutePos);
|
||||||
|
|
||||||
self.builder
|
self.builder
|
||||||
|
|
|
@ -56,9 +56,11 @@ impl TraceBuilder {
|
||||||
}
|
}
|
||||||
true => match self.output_to_local.get(&tensor.id) {
|
true => match self.output_to_local.get(&tensor.id) {
|
||||||
// Is a local variable.
|
// Is a local variable.
|
||||||
Some(local_index) => {
|
Some(local_index) => Variable::Local {
|
||||||
Variable::Local(*local_index, Item::new(elem), self.scope.depth)
|
id: *local_index,
|
||||||
}
|
item: Item::new(elem),
|
||||||
|
depth: self.scope.depth,
|
||||||
|
},
|
||||||
// Isn't an operation output variable, so must be an existing input.
|
// Isn't an operation output variable, so must be an existing input.
|
||||||
None => self
|
None => self
|
||||||
.inputs
|
.inputs
|
||||||
|
@ -85,7 +87,11 @@ impl TraceBuilder {
|
||||||
|
|
||||||
// Output already registered as a local variable.
|
// Output already registered as a local variable.
|
||||||
if let Some(index) = self.output_to_local.get(&tensor.id) {
|
if let Some(index) = self.output_to_local.get(&tensor.id) {
|
||||||
return Variable::Local(*index, Item::new(elem), self.scope.depth);
|
return Variable::Local {
|
||||||
|
id: *index,
|
||||||
|
item: Item::new(elem),
|
||||||
|
depth: self.scope.depth,
|
||||||
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
let variable = self.scope.create_local(Item::new(elem));
|
let variable = self.scope.create_local(Item::new(elem));
|
||||||
|
@ -159,11 +165,11 @@ impl TraceBuilder {
|
||||||
//
|
//
|
||||||
// Only local variables can become outputs.
|
// Only local variables can become outputs.
|
||||||
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
|
let mark = |var: &Variable, list: &mut Vec<TensorId>| {
|
||||||
if let Variable::Local(index, _, _) = var {
|
if let Variable::Local { id: id_local, .. } = var {
|
||||||
if let Some((id, _)) = self
|
if let Some((id, _)) = self
|
||||||
.output_to_local
|
.output_to_local
|
||||||
.iter()
|
.iter()
|
||||||
.find(|(_id, position)| *position == index)
|
.find(|(_tensor_id, position)| *position == id_local)
|
||||||
{
|
{
|
||||||
if !list.contains(id) {
|
if !list.contains(id) {
|
||||||
list.push(*id);
|
list.push(*id);
|
||||||
|
@ -201,6 +207,11 @@ impl TraceBuilder {
|
||||||
mark(&op.c, &mut local_tensor_ids_input);
|
mark(&op.c, &mut local_tensor_ids_input);
|
||||||
mark(&op.out, &mut local_tensor_ids_output);
|
mark(&op.out, &mut local_tensor_ids_output);
|
||||||
}
|
}
|
||||||
|
Operator::Slice(op) => {
|
||||||
|
mark(&op.input, &mut local_tensor_ids_input);
|
||||||
|
mark(&op.start, &mut local_tensor_ids_input);
|
||||||
|
mark(&op.out, &mut local_tensor_ids_output);
|
||||||
|
}
|
||||||
Operator::Max(op) => mark_binary(
|
Operator::Max(op) => mark_binary(
|
||||||
op,
|
op,
|
||||||
&mut local_tensor_ids_input,
|
&mut local_tensor_ids_input,
|
||||||
|
|
|
@ -68,8 +68,14 @@ impl<R: JitRuntime, EI: JitElement, EO: JitElement> Kernel for CastEagerKernel<R
|
||||||
let item_input = EI::cube_elem().into();
|
let item_input = EI::cube_elem().into();
|
||||||
let item_output = EO::cube_elem().into();
|
let item_output = EO::cube_elem().into();
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray(0, item_input);
|
let tensor = Variable::GlobalInputArray {
|
||||||
let output = Variable::GlobalOutputArray(0, item_output);
|
id: 0,
|
||||||
|
item: item_input,
|
||||||
|
};
|
||||||
|
let output = Variable::GlobalOutputArray {
|
||||||
|
id: 0,
|
||||||
|
item: item_output,
|
||||||
|
};
|
||||||
|
|
||||||
CastShader { tensor, output }.expand(&mut scope);
|
CastShader { tensor, output }.expand(&mut scope);
|
||||||
|
|
||||||
|
|
|
@ -60,8 +60,14 @@ impl<R: JitRuntime, EO: JitElement> Kernel for BoolCastEagerKernel<R, EO> {
|
||||||
let item_input = Item::new(Elem::Bool);
|
let item_input = Item::new(Elem::Bool);
|
||||||
let item_output = EO::cube_elem().into();
|
let item_output = EO::cube_elem().into();
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray(0, item_input);
|
let tensor = Variable::GlobalInputArray {
|
||||||
let output = Variable::GlobalOutputArray(0, item_output);
|
id: 0,
|
||||||
|
item: item_input,
|
||||||
|
};
|
||||||
|
let output = Variable::GlobalOutputArray {
|
||||||
|
id: 0,
|
||||||
|
item: item_output,
|
||||||
|
};
|
||||||
|
|
||||||
BoolCastShader { tensor, output }.expand(&mut scope);
|
BoolCastShader { tensor, output }.expand(&mut scope);
|
||||||
|
|
||||||
|
|
|
@ -93,13 +93,34 @@ impl<E: JitElement> Conv2dTransposeComputeShader<E> {
|
||||||
cpa!(scope, kernel_size_0 = shape(weight, 2u32));
|
cpa!(scope, kernel_size_0 = shape(weight, 2u32));
|
||||||
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||||
|
|
||||||
let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
let conv_stride_0 = Variable::GlobalScalar {
|
||||||
let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
id: 0,
|
||||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
elem: Elem::UInt,
|
||||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
};
|
||||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
let conv_stride_1 = Variable::GlobalScalar {
|
||||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
id: 1,
|
||||||
let groups = Variable::GlobalScalar(6, Elem::UInt);
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_0 = Variable::GlobalScalar {
|
||||||
|
id: 2,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_1 = Variable::GlobalScalar {
|
||||||
|
id: 3,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_0 = Variable::GlobalScalar {
|
||||||
|
id: 4,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_1 = Variable::GlobalScalar {
|
||||||
|
id: 5,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let groups = Variable::GlobalScalar {
|
||||||
|
id: 6,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
|
||||||
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||||
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||||
|
@ -294,10 +315,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv2dTransposeEagerKernel<R, E> {
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let weight = Variable::GlobalInputArray(1, item);
|
let weight = Variable::GlobalInputArray { id: 1, item };
|
||||||
let bias = Variable::GlobalInputArray(2, item);
|
let bias = Variable::GlobalInputArray { id: 2, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -105,16 +105,46 @@ impl<E: JitElement> Conv3dTransposeComputeShader<E> {
|
||||||
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
cpa!(scope, kernel_size_1 = shape(weight, 3u32));
|
||||||
cpa!(scope, kernel_size_2 = shape(weight, 4u32));
|
cpa!(scope, kernel_size_2 = shape(weight, 4u32));
|
||||||
|
|
||||||
let conv_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
let conv_stride_0 = Variable::GlobalScalar {
|
||||||
let conv_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
id: 0,
|
||||||
let conv_stride_2 = Variable::GlobalScalar(2, Elem::UInt);
|
elem: Elem::UInt,
|
||||||
let dilation_0 = Variable::GlobalScalar(3, Elem::UInt);
|
};
|
||||||
let dilation_1 = Variable::GlobalScalar(4, Elem::UInt);
|
let conv_stride_1 = Variable::GlobalScalar {
|
||||||
let dilation_2 = Variable::GlobalScalar(5, Elem::UInt);
|
id: 1,
|
||||||
let padding_0 = Variable::GlobalScalar(6, Elem::UInt);
|
elem: Elem::UInt,
|
||||||
let padding_1 = Variable::GlobalScalar(7, Elem::UInt);
|
};
|
||||||
let padding_2 = Variable::GlobalScalar(8, Elem::UInt);
|
let conv_stride_2 = Variable::GlobalScalar {
|
||||||
let groups = Variable::GlobalScalar(9, Elem::UInt);
|
id: 2,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_0 = Variable::GlobalScalar {
|
||||||
|
id: 3,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_1 = Variable::GlobalScalar {
|
||||||
|
id: 4,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_2 = Variable::GlobalScalar {
|
||||||
|
id: 5,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_0 = Variable::GlobalScalar {
|
||||||
|
id: 6,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_1 = Variable::GlobalScalar {
|
||||||
|
id: 7,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_2 = Variable::GlobalScalar {
|
||||||
|
id: 8,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let groups = Variable::GlobalScalar {
|
||||||
|
id: 9,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
|
||||||
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
let stride_0_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||||
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
let stride_1_i = scope.create_local(Elem::Int(IntKind::I32));
|
||||||
|
@ -362,10 +392,10 @@ impl<R: JitRuntime, E: JitElement> Kernel for Conv3dTransposeEagerKernel<R, E> {
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let weight = Variable::GlobalInputArray(1, item);
|
let weight = Variable::GlobalInputArray { id: 1, item };
|
||||||
let bias = Variable::GlobalInputArray(2, item);
|
let bias = Variable::GlobalInputArray { id: 2, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -43,7 +43,10 @@ impl FlipComputeShader {
|
||||||
cpa!(scope, shape = shape(output, i));
|
cpa!(scope, shape = shape(output, i));
|
||||||
cpa!(
|
cpa!(
|
||||||
scope,
|
scope,
|
||||||
flip = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
|
flip = cast(Variable::GlobalScalar {
|
||||||
|
id: i as u16,
|
||||||
|
elem: Elem::UInt
|
||||||
|
})
|
||||||
);
|
);
|
||||||
cpa!(scope, flip_bool = flip == 1u32);
|
cpa!(scope, flip_bool = flip == 1u32);
|
||||||
|
|
||||||
|
@ -70,8 +73,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for FlipEagerKernel<R, E> {
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -27,8 +27,8 @@ struct GatherComputeShader {
|
||||||
impl GatherComputeShader {
|
impl GatherComputeShader {
|
||||||
pub fn expand(self, scope: &mut Scope) {
|
pub fn expand(self, scope: &mut Scope) {
|
||||||
match self.tensor {
|
match self.tensor {
|
||||||
Variable::GlobalInputArray(_, _) => (),
|
Variable::GlobalInputArray { .. } => (),
|
||||||
Variable::GlobalOutputArray(_, _) => (),
|
Variable::GlobalOutputArray { .. } => (),
|
||||||
_ => panic!("Tensor variable must be an global array."),
|
_ => panic!("Tensor variable must be an global array."),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -78,10 +78,16 @@ impl<R: JitRuntime, E: JitElement> Kernel for GatherEagerKernel<R, E> {
|
||||||
let item_tensor = E::cube_elem().into();
|
let item_tensor = E::cube_elem().into();
|
||||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray(0, item_tensor);
|
let tensor = Variable::GlobalInputArray {
|
||||||
|
id: 0,
|
||||||
|
item: item_tensor,
|
||||||
|
};
|
||||||
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
|
let indices = scope.read_array(1, item_indices, Variable::AbsolutePos);
|
||||||
|
|
||||||
let output_array = Variable::GlobalOutputArray(0, item_tensor);
|
let output_array = Variable::GlobalOutputArray {
|
||||||
|
id: 0,
|
||||||
|
item: item_tensor,
|
||||||
|
};
|
||||||
let output_local = scope.create_local(item_tensor);
|
let output_local = scope.create_local(item_tensor);
|
||||||
|
|
||||||
GatherComputeShader {
|
GatherComputeShader {
|
||||||
|
|
|
@ -61,8 +61,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for RepeatEagerKernel<R, E> {
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -32,13 +32,13 @@ struct ScatterComputeShader {
|
||||||
impl ScatterComputeShader {
|
impl ScatterComputeShader {
|
||||||
pub fn expand(self, scope: &mut Scope) {
|
pub fn expand(self, scope: &mut Scope) {
|
||||||
match self.input {
|
match self.input {
|
||||||
Variable::GlobalInputArray(_, _) => (),
|
Variable::GlobalInputArray { .. } => (),
|
||||||
Variable::GlobalOutputArray(_, _) => (),
|
Variable::GlobalOutputArray { .. } => (),
|
||||||
_ => panic!("Input variable must be an global array."),
|
_ => panic!("Input variable must be an global array."),
|
||||||
};
|
};
|
||||||
match self.value {
|
match self.value {
|
||||||
Variable::GlobalInputArray(_, _) => (),
|
Variable::GlobalInputArray { .. } => (),
|
||||||
Variable::GlobalOutputArray(_, _) => (),
|
Variable::GlobalOutputArray { .. } => (),
|
||||||
_ => panic!("Value variable must be an global array."),
|
_ => panic!("Value variable must be an global array."),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -136,9 +136,18 @@ impl<R: JitRuntime, E: JitElement> Kernel for ScatterEagerKernel<R, E> {
|
||||||
let item_value = E::cube_elem().into();
|
let item_value = E::cube_elem().into();
|
||||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||||
|
|
||||||
let input_output = Variable::GlobalInputArray(0, item_value);
|
let input_output = Variable::GlobalInputArray {
|
||||||
let indices = Variable::GlobalInputArray(1, Elem::Int(IntKind::I32).into());
|
id: 0,
|
||||||
let value = Variable::GlobalInputArray(2, item_value);
|
item: item_value,
|
||||||
|
};
|
||||||
|
let indices = Variable::GlobalInputArray {
|
||||||
|
id: 1,
|
||||||
|
item: Elem::Int(IntKind::I32).into(),
|
||||||
|
};
|
||||||
|
let value = Variable::GlobalInputArray {
|
||||||
|
id: 2,
|
||||||
|
item: item_value,
|
||||||
|
};
|
||||||
|
|
||||||
scope.write_global_custom(input_output);
|
scope.write_global_custom(input_output);
|
||||||
|
|
||||||
|
|
|
@ -73,9 +73,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for SelectEagerKernel<R, E> {
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let indices = Variable::GlobalInputArray(1, item_indices);
|
let indices = Variable::GlobalInputArray {
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
id: 1,
|
||||||
|
item: item_indices,
|
||||||
|
};
|
||||||
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -134,9 +134,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for SelectAssignEagerKernel<R, E> {
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
let item_indices: Item = Elem::Int(IntKind::I32).into();
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray(0, item);
|
let tensor = Variable::GlobalInputArray { id: 0, item };
|
||||||
let value = Variable::GlobalInputArray(1, item);
|
let value = Variable::GlobalInputArray { id: 1, item };
|
||||||
let indices = Variable::GlobalInputArray(2, item_indices);
|
let indices = Variable::GlobalInputArray {
|
||||||
|
id: 2,
|
||||||
|
item: item_indices,
|
||||||
|
};
|
||||||
|
|
||||||
scope.write_global_custom(tensor);
|
scope.write_global_custom(tensor);
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,10 @@ impl SliceComputeShader {
|
||||||
cpa!(scope, shape_output = shape(output, i));
|
cpa!(scope, shape_output = shape(output, i));
|
||||||
cpa!(
|
cpa!(
|
||||||
scope,
|
scope,
|
||||||
range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
|
range_start = cast(Variable::GlobalScalar {
|
||||||
|
id: i as u16,
|
||||||
|
elem: Elem::UInt
|
||||||
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
cpa!(scope, offset_local = id / stride_output);
|
cpa!(scope, offset_local = id / stride_output);
|
||||||
|
@ -66,8 +69,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceEagerKernel<R, E> {
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -47,7 +47,10 @@ impl SliceAssignComputeShader {
|
||||||
cpa!(scope, shape_input = shape(input, i));
|
cpa!(scope, shape_input = shape(input, i));
|
||||||
cpa!(
|
cpa!(
|
||||||
scope,
|
scope,
|
||||||
range_start = cast(Variable::GlobalScalar(i as u16, Elem::UInt))
|
range_start = cast(Variable::GlobalScalar {
|
||||||
|
id: i as u16,
|
||||||
|
elem: Elem::UInt
|
||||||
|
})
|
||||||
);
|
);
|
||||||
|
|
||||||
cpa!(scope, offset_local = id / stride_value);
|
cpa!(scope, offset_local = id / stride_value);
|
||||||
|
@ -75,8 +78,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for SliceAssignEagerKernel<R, E> {
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let value = Variable::GlobalInputArray(1, item);
|
let value = Variable::GlobalInputArray { id: 1, item };
|
||||||
|
|
||||||
scope.write_global_custom(input);
|
scope.write_global_custom(input);
|
||||||
|
|
||||||
|
|
|
@ -371,8 +371,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBicubicEagerKernel<R, E
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
InterpolateBicubicShader {
|
InterpolateBicubicShader {
|
||||||
input,
|
input,
|
||||||
|
|
|
@ -189,8 +189,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateBilinearEagerKernel<R,
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
InterpolateBilinearShader { input, output }.expand(&mut scope);
|
InterpolateBilinearShader { input, output }.expand(&mut scope);
|
||||||
|
|
||||||
|
|
|
@ -127,8 +127,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestEagerKernel<R, E
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
InterpolateNearestShader {
|
InterpolateNearestShader {
|
||||||
input,
|
input,
|
||||||
|
|
|
@ -185,8 +185,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for InterpolateNearestBackwardEagerKer
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let out_grad = Variable::GlobalInputArray(0, item);
|
let out_grad = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
InterpolateNearestBackwardShader {
|
InterpolateNearestBackwardShader {
|
||||||
out_grad,
|
out_grad,
|
||||||
|
|
|
@ -40,7 +40,10 @@ impl MaskStrategy for MaskFill {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn value_variable(value_item: Item) -> Variable {
|
fn value_variable(value_item: Item) -> Variable {
|
||||||
Variable::GlobalScalar(0, value_item.elem())
|
Variable::GlobalScalar {
|
||||||
|
id: 0,
|
||||||
|
elem: value_item.elem(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -65,7 +68,10 @@ impl MaskStrategy for MaskWhere {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn value_variable(value_item: Item) -> Variable {
|
fn value_variable(value_item: Item) -> Variable {
|
||||||
Variable::GlobalInputArray(2, value_item)
|
Variable::GlobalInputArray {
|
||||||
|
id: 2,
|
||||||
|
item: value_item,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -102,10 +108,19 @@ impl<M: MaskStrategy, R: JitRuntime, EI: JitElement, EM: JitElement> Kernel
|
||||||
let tensor_item = EI::cube_elem().into();
|
let tensor_item = EI::cube_elem().into();
|
||||||
let mask_item = EM::cube_elem().into();
|
let mask_item = EM::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, tensor_item);
|
let input = Variable::GlobalInputArray {
|
||||||
let mask = Variable::GlobalInputArray(1, mask_item);
|
id: 0,
|
||||||
|
item: tensor_item,
|
||||||
|
};
|
||||||
|
let mask = Variable::GlobalInputArray {
|
||||||
|
id: 1,
|
||||||
|
item: mask_item,
|
||||||
|
};
|
||||||
let value = M::value_variable(tensor_item);
|
let value = M::value_variable(tensor_item);
|
||||||
let output = Variable::GlobalOutputArray(0, tensor_item);
|
let output = Variable::GlobalOutputArray {
|
||||||
|
id: 0,
|
||||||
|
item: tensor_item,
|
||||||
|
};
|
||||||
|
|
||||||
MaskShader::<EI, EM, M> {
|
MaskShader::<EI, EM, M> {
|
||||||
input,
|
input,
|
||||||
|
@ -176,8 +191,14 @@ impl<M: MaskStrategy, R: JitRuntime, EI: JitElement, EM: JitElement> Kernel
|
||||||
let tensor_item = EI::cube_elem().into();
|
let tensor_item = EI::cube_elem().into();
|
||||||
let mask_item = EM::cube_elem().into();
|
let mask_item = EM::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, tensor_item);
|
let input = Variable::GlobalInputArray {
|
||||||
let mask = Variable::GlobalInputArray(1, mask_item);
|
id: 0,
|
||||||
|
item: tensor_item,
|
||||||
|
};
|
||||||
|
let mask = Variable::GlobalInputArray {
|
||||||
|
id: 1,
|
||||||
|
item: mask_item,
|
||||||
|
};
|
||||||
let value = M::value_variable(tensor_item);
|
let value = M::value_variable(tensor_item);
|
||||||
|
|
||||||
MaskShader::<EI, EM, M> {
|
MaskShader::<EI, EM, M> {
|
||||||
|
|
|
@ -40,9 +40,9 @@ impl<R: JitRuntime, E: JitElement> Kernel for MatmulTiling2dEagerKernel<R, E> {
|
||||||
);
|
);
|
||||||
let item = elem.into();
|
let item = elem.into();
|
||||||
|
|
||||||
let lhs = Variable::GlobalInputArray(0, item);
|
let lhs = Variable::GlobalInputArray { id: 0, item };
|
||||||
let rhs = Variable::GlobalInputArray(1, item);
|
let rhs = Variable::GlobalInputArray { id: 1, item };
|
||||||
let out = Variable::GlobalOutputArray(0, item);
|
let out = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(out);
|
scope.write_global_custom(out);
|
||||||
|
|
||||||
|
|
|
@ -207,8 +207,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AdaptiveAvgPool2dBackwardEagerKern
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let grad = Variable::GlobalInputArray(0, item);
|
let grad = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -193,8 +193,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AdaptivePool2dEagerKernel<R, E> {
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -71,10 +71,22 @@ impl AvgPool2dBackwardComputeShader {
|
||||||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||||
|
|
||||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
let pool_stride_0 = Variable::GlobalScalar {
|
||||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
id: 0,
|
||||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
elem: Elem::UInt,
|
||||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
};
|
||||||
|
let pool_stride_1 = Variable::GlobalScalar {
|
||||||
|
id: 1,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_0 = Variable::GlobalScalar {
|
||||||
|
id: 4,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_1 = Variable::GlobalScalar {
|
||||||
|
id: 5,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||||
|
|
||||||
let b = scope.create_local(Elem::UInt);
|
let b = scope.create_local(Elem::UInt);
|
||||||
|
@ -215,12 +227,30 @@ impl AvgPool2dBackwardComputeShader {
|
||||||
output_stride_2: Variable,
|
output_stride_2: Variable,
|
||||||
output_stride_3: Variable,
|
output_stride_3: Variable,
|
||||||
) -> (Variable, Variable, Variable, Variable) {
|
) -> (Variable, Variable, Variable, Variable) {
|
||||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
let pool_stride_0 = Variable::GlobalScalar {
|
||||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
id: 0,
|
||||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
elem: Elem::UInt,
|
||||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
};
|
||||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
let pool_stride_1 = Variable::GlobalScalar {
|
||||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
id: 1,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_0 = Variable::GlobalScalar {
|
||||||
|
id: 2,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_1 = Variable::GlobalScalar {
|
||||||
|
id: 3,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_0 = Variable::GlobalScalar {
|
||||||
|
id: 4,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_1 = Variable::GlobalScalar {
|
||||||
|
id: 5,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
|
||||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||||
|
|
||||||
|
@ -321,8 +351,8 @@ impl<R: JitRuntime, E: JitElement> Kernel for AvgPool2dBackwardEagerKernel<R, E>
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let grad = Variable::GlobalInputArray(0, item);
|
let grad = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,10 @@ impl<E: JitElement> PoolStrategy for MaxPool<E> {
|
||||||
|
|
||||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
||||||
let max_val = scope.create_local(item);
|
let max_val = scope.create_local(item);
|
||||||
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
|
let max_initial = Variable::ConstantScalar {
|
||||||
|
value: E::minimum_value().to_f64(),
|
||||||
|
elem: item.elem(),
|
||||||
|
};
|
||||||
cpa!(scope, max_val = max_initial);
|
cpa!(scope, max_val = max_initial);
|
||||||
max_val
|
max_val
|
||||||
}
|
}
|
||||||
|
@ -67,7 +70,10 @@ impl<E: JitElement> PoolStrategy for MaxPoolWithIndices<E> {
|
||||||
|
|
||||||
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
fn initialize(&self, scope: &mut Scope, item: Item) -> Self::Accumulator {
|
||||||
let max_val = scope.create_local(item);
|
let max_val = scope.create_local(item);
|
||||||
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), item.elem());
|
let max_initial = Variable::ConstantScalar {
|
||||||
|
value: E::minimum_value().to_f64(),
|
||||||
|
elem: item.elem(),
|
||||||
|
};
|
||||||
cpa!(scope, max_val = max_initial);
|
cpa!(scope, max_val = max_initial);
|
||||||
let max_index = scope.create_local(Elem::UInt);
|
let max_index = scope.create_local(Elem::UInt);
|
||||||
(max_val, max_index)
|
(max_val, max_index)
|
||||||
|
|
|
@ -161,12 +161,30 @@ impl MaxPool2dBackwardComputeShader {
|
||||||
output_stride_2: Variable,
|
output_stride_2: Variable,
|
||||||
output_stride_3: Variable,
|
output_stride_3: Variable,
|
||||||
) -> (Variable, Variable, Variable, Variable) {
|
) -> (Variable, Variable, Variable, Variable) {
|
||||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
let pool_stride_0 = Variable::GlobalScalar {
|
||||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
id: 0,
|
||||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
elem: Elem::UInt,
|
||||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
};
|
||||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
let pool_stride_1 = Variable::GlobalScalar {
|
||||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
id: 1,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_0 = Variable::GlobalScalar {
|
||||||
|
id: 2,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_1 = Variable::GlobalScalar {
|
||||||
|
id: 3,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_0 = Variable::GlobalScalar {
|
||||||
|
id: 4,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_1 = Variable::GlobalScalar {
|
||||||
|
id: 5,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
|
||||||
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
let [kernel_size_0, kernel_size_1] = self.kernel_size;
|
||||||
|
|
||||||
|
@ -267,9 +285,12 @@ impl<R: JitRuntime, E: JitElement> Kernel for MaxPool2dWithIndicesBackwardEagerK
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let indices = Variable::GlobalInputArray(0, Item::new(Elem::Int(IntKind::I32)));
|
let indices = Variable::GlobalInputArray {
|
||||||
let grad = Variable::GlobalInputArray(1, item);
|
id: 0,
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
item: Item::new(Elem::Int(IntKind::I32)),
|
||||||
|
};
|
||||||
|
let grad = Variable::GlobalInputArray { id: 1, item };
|
||||||
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
scope.write_global_custom(output);
|
scope.write_global_custom(output);
|
||||||
|
|
||||||
|
|
|
@ -65,12 +65,30 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> Pool2dComputeShader<P, R, E>
|
||||||
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
cpa!(scope, output_shape_2 = shape(output, 2u32));
|
||||||
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
cpa!(scope, output_shape_3 = shape(output, 3u32));
|
||||||
|
|
||||||
let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
|
let pool_stride_0 = Variable::GlobalScalar {
|
||||||
let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
|
id: 0,
|
||||||
let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
|
elem: Elem::UInt,
|
||||||
let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
|
};
|
||||||
let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
|
let pool_stride_1 = Variable::GlobalScalar {
|
||||||
let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
|
id: 1,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_0 = Variable::GlobalScalar {
|
||||||
|
id: 2,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let dilation_1 = Variable::GlobalScalar {
|
||||||
|
id: 3,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_0 = Variable::GlobalScalar {
|
||||||
|
id: 4,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let padding_1 = Variable::GlobalScalar {
|
||||||
|
id: 5,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
|
||||||
let b = scope.create_local(Elem::UInt);
|
let b = scope.create_local(Elem::UInt);
|
||||||
let c = scope.create_local(Elem::UInt);
|
let c = scope.create_local(Elem::UInt);
|
||||||
|
@ -177,13 +195,13 @@ impl<P: PoolStrategy, R: JitRuntime, E: JitElement> Kernel for Pool2dEagerKernel
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let input = Variable::GlobalInputArray(0, item);
|
let input = Variable::GlobalInputArray { id: 0, item };
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
let indices = if P::with_indices() {
|
let indices = if P::with_indices() {
|
||||||
Some(Variable::GlobalOutputArray(
|
Some(Variable::GlobalOutputArray {
|
||||||
1,
|
id: 1,
|
||||||
Item::new(Elem::Int(IntKind::I32)),
|
item: Item::new(Elem::Int(IntKind::I32)),
|
||||||
))
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
|
@ -66,17 +66,32 @@ impl<P: Prng<E>, R: JitRuntime, E: JitElement> Kernel for PrngEagerKernel<P, R,
|
||||||
let mut scope = Scope::root();
|
let mut scope = Scope::root();
|
||||||
let item = E::cube_elem().into();
|
let item = E::cube_elem().into();
|
||||||
|
|
||||||
let output = Variable::GlobalOutputArray(0, item);
|
let output = Variable::GlobalOutputArray { id: 0, item };
|
||||||
|
|
||||||
let seed0 = Variable::GlobalScalar(0, Elem::UInt);
|
let seed0 = Variable::GlobalScalar {
|
||||||
let seed1 = Variable::GlobalScalar(1, Elem::UInt);
|
id: 0,
|
||||||
let seed2 = Variable::GlobalScalar(2, Elem::UInt);
|
elem: Elem::UInt,
|
||||||
let seed3 = Variable::GlobalScalar(3, Elem::UInt);
|
};
|
||||||
|
let seed1 = Variable::GlobalScalar {
|
||||||
|
id: 1,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let seed2 = Variable::GlobalScalar {
|
||||||
|
id: 2,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
|
let seed3 = Variable::GlobalScalar {
|
||||||
|
id: 3,
|
||||||
|
elem: Elem::UInt,
|
||||||
|
};
|
||||||
let seeds = [seed0, seed1, seed2, seed3];
|
let seeds = [seed0, seed1, seed2, seed3];
|
||||||
|
|
||||||
let mut args = Vec::<Variable>::new();
|
let mut args = Vec::<Variable>::new();
|
||||||
for i in 0..P::args_length() {
|
for i in 0..P::args_length() {
|
||||||
args.push(Variable::GlobalScalar(i as u16, item.elem()));
|
args.push(Variable::GlobalScalar {
|
||||||
|
id: i as u16,
|
||||||
|
elem: item.elem(),
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
PrngShader::<P, E>::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope);
|
PrngShader::<P, E>::new(output, N_VALUES_PER_THREAD, seeds, args).expand(&mut scope);
|
||||||
|
|
|
@ -16,7 +16,10 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmax {
|
||||||
) -> Self::Accumulator {
|
) -> Self::Accumulator {
|
||||||
let index = scope.create_local(Elem::UInt);
|
let index = scope.create_local(Elem::UInt);
|
||||||
let max = scope.create_local(input_item);
|
let max = scope.create_local(input_item);
|
||||||
let max_initial = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
|
let max_initial = Variable::ConstantScalar {
|
||||||
|
value: E::minimum_value().to_f64(),
|
||||||
|
elem: input_item.elem(),
|
||||||
|
};
|
||||||
cpa!(scope, max = max_initial);
|
cpa!(scope, max = max_initial);
|
||||||
|
|
||||||
(max, index)
|
(max, index)
|
||||||
|
|
|
@ -17,7 +17,10 @@ impl<E: JitElement> ReduceDimNaive<E> for Argmin {
|
||||||
) -> Self::Accumulator {
|
) -> Self::Accumulator {
|
||||||
let index = scope.create_local(Elem::UInt);
|
let index = scope.create_local(Elem::UInt);
|
||||||
let min = scope.create_local(input_item);
|
let min = scope.create_local(input_item);
|
||||||
let min_initial = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
|
let min_initial = Variable::ConstantScalar {
|
||||||
|
value: E::maximum_value().to_f64(),
|
||||||
|
elem: input_item.elem(),
|
||||||
|
};
|
||||||
cpa!(scope, min = min_initial);
|
cpa!(scope, min = min_initial);
|
||||||
|
|
||||||
(min, index)
|
(min, index)
|
||||||
|
|
|
@ -41,8 +41,14 @@ impl<RD: ReduceDimNaive<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Kern
|
||||||
let item_input = EI::cube_elem().into();
|
let item_input = EI::cube_elem().into();
|
||||||
let item_output = EO::cube_elem().into();
|
let item_output = EO::cube_elem().into();
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray(0, item_input);
|
let tensor = Variable::GlobalInputArray {
|
||||||
let output = Variable::GlobalOutputArray(0, item_output);
|
id: 0,
|
||||||
|
item: item_input,
|
||||||
|
};
|
||||||
|
let output = Variable::GlobalOutputArray {
|
||||||
|
id: 0,
|
||||||
|
item: item_output,
|
||||||
|
};
|
||||||
|
|
||||||
NaiveReduceDimComputeShader {
|
NaiveReduceDimComputeShader {
|
||||||
tensor,
|
tensor,
|
||||||
|
|
|
@ -18,7 +18,10 @@ impl<E: JitElement> ReduceDimShared<E> for Argmax {
|
||||||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||||
|
|
||||||
let max = Variable::ConstantScalar(E::minimum_value().to_f64(), input_item.elem());
|
let max = Variable::ConstantScalar {
|
||||||
|
value: E::minimum_value().to_f64(),
|
||||||
|
elem: input_item.elem(),
|
||||||
|
};
|
||||||
cpa!(scope, value_shared_memory[write_position] = max);
|
cpa!(scope, value_shared_memory[write_position] = max);
|
||||||
(value_shared_memory, index_shared_memory)
|
(value_shared_memory, index_shared_memory)
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,10 @@ impl<E: JitElement> ReduceDimShared<E> for Argmin {
|
||||||
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
let value_shared_memory = scope.create_shared(input_item, shared_memory_size);
|
||||||
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
let index_shared_memory = scope.create_shared(Elem::UInt, shared_memory_size);
|
||||||
|
|
||||||
let min = Variable::ConstantScalar(E::maximum_value().to_f64(), input_item.elem());
|
let min = Variable::ConstantScalar {
|
||||||
|
value: E::maximum_value().to_f64(),
|
||||||
|
elem: input_item.elem(),
|
||||||
|
};
|
||||||
cpa!(scope, value_shared_memory[write_position] = min);
|
cpa!(scope, value_shared_memory[write_position] = min);
|
||||||
(value_shared_memory, index_shared_memory)
|
(value_shared_memory, index_shared_memory)
|
||||||
}
|
}
|
||||||
|
|
|
@ -51,8 +51,14 @@ impl<RD: ReduceDimShared<EI>, R: JitRuntime, EI: JitElement, EO: JitElement> Ker
|
||||||
let item_input = EI::cube_elem().into();
|
let item_input = EI::cube_elem().into();
|
||||||
let item_output = EO::cube_elem().into();
|
let item_output = EO::cube_elem().into();
|
||||||
|
|
||||||
let tensor = Variable::GlobalInputArray(0, item_input);
|
let tensor = Variable::GlobalInputArray {
|
||||||
let output = Variable::GlobalOutputArray(0, item_output);
|
id: 0,
|
||||||
|
item: item_input,
|
||||||
|
};
|
||||||
|
let output = Variable::GlobalOutputArray {
|
||||||
|
id: 0,
|
||||||
|
item: item_output,
|
||||||
|
};
|
||||||
|
|
||||||
// Reduce groups are elements that are aligned along the reduce dim
|
// Reduce groups are elements that are aligned along the reduce dim
|
||||||
SharedReduceDimComputeShader {
|
SharedReduceDimComputeShader {
|
||||||
|
|
|
@ -9,14 +9,24 @@ pub enum Variable {
|
||||||
GlobalScalar(u16, Elem, cube::Elem),
|
GlobalScalar(u16, Elem, cube::Elem),
|
||||||
ConstantScalar(f64, Elem),
|
ConstantScalar(f64, Elem),
|
||||||
Local {
|
Local {
|
||||||
index: u16,
|
id: u16,
|
||||||
item: Item,
|
item: Item,
|
||||||
scope_depth: u8,
|
depth: u8,
|
||||||
|
},
|
||||||
|
Named {
|
||||||
|
name: String,
|
||||||
|
item: Item,
|
||||||
|
is_array: bool,
|
||||||
|
},
|
||||||
|
Slice {
|
||||||
|
id: u16,
|
||||||
|
item: Item,
|
||||||
|
depth: u8,
|
||||||
},
|
},
|
||||||
LocalScalar {
|
LocalScalar {
|
||||||
index: u16,
|
id: u16,
|
||||||
elem: Elem,
|
elem: Elem,
|
||||||
scope_depth: u8,
|
depth: u8,
|
||||||
},
|
},
|
||||||
SharedMemory(u16, Item, u32),
|
SharedMemory(u16, Item, u32),
|
||||||
LocalArray(u16, Item, u8, u32),
|
LocalArray(u16, Item, u8, u32),
|
||||||
|
@ -71,9 +81,9 @@ impl Variable {
|
||||||
Variable::GlobalScalar(_, _, _) => true,
|
Variable::GlobalScalar(_, _, _) => true,
|
||||||
Variable::ConstantScalar(_, _) => true,
|
Variable::ConstantScalar(_, _) => true,
|
||||||
Variable::LocalScalar {
|
Variable::LocalScalar {
|
||||||
index: _,
|
id: _,
|
||||||
elem: _,
|
elem: _,
|
||||||
scope_depth: _,
|
depth: _,
|
||||||
} => true,
|
} => true,
|
||||||
Variable::Id => true,
|
Variable::Id => true,
|
||||||
Variable::LocalInvocationIndex => true,
|
Variable::LocalInvocationIndex => true,
|
||||||
|
@ -85,11 +95,9 @@ impl Variable {
|
||||||
Variable::GlobalOutputArray(_, _) => false,
|
Variable::GlobalOutputArray(_, _) => false,
|
||||||
Variable::SharedMemory(_, _, _) => false,
|
Variable::SharedMemory(_, _, _) => false,
|
||||||
Variable::LocalArray(_, _, _, _) => false,
|
Variable::LocalArray(_, _, _, _) => false,
|
||||||
Variable::Local {
|
Variable::Local { .. } => false,
|
||||||
index: _,
|
Variable::Named { .. } => false,
|
||||||
item: _,
|
Variable::Slice { .. } => false,
|
||||||
scope_depth: _,
|
|
||||||
} => false,
|
|
||||||
Variable::WorkgroupIdX => true,
|
Variable::WorkgroupIdX => true,
|
||||||
Variable::WorkgroupIdY => true,
|
Variable::WorkgroupIdY => true,
|
||||||
Variable::WorkgroupIdZ => true,
|
Variable::WorkgroupIdZ => true,
|
||||||
|
@ -121,11 +129,9 @@ impl Variable {
|
||||||
Self::GlobalOutputArray(_, e) => *e,
|
Self::GlobalOutputArray(_, e) => *e,
|
||||||
Self::SharedMemory(_, e, _) => *e,
|
Self::SharedMemory(_, e, _) => *e,
|
||||||
Self::LocalArray(_, e, _, _) => *e,
|
Self::LocalArray(_, e, _, _) => *e,
|
||||||
Self::Local {
|
Self::Local { item, .. } => *item,
|
||||||
index: _,
|
Self::Slice { item, .. } => *item,
|
||||||
item,
|
Self::Named { item, .. } => *item,
|
||||||
scope_depth: _,
|
|
||||||
} => *item,
|
|
||||||
Self::ConstantScalar(_, e) => Item::Scalar(*e),
|
Self::ConstantScalar(_, e) => Item::Scalar(*e),
|
||||||
Self::GlobalScalar(_, e, _) => Item::Scalar(*e),
|
Self::GlobalScalar(_, e, _) => Item::Scalar(*e),
|
||||||
Self::Id => Item::Scalar(Elem::U32),
|
Self::Id => Item::Scalar(Elem::U32),
|
||||||
|
@ -134,11 +140,7 @@ impl Variable {
|
||||||
Self::LocalInvocationIdY => Item::Scalar(Elem::U32),
|
Self::LocalInvocationIdY => Item::Scalar(Elem::U32),
|
||||||
Self::LocalInvocationIdZ => Item::Scalar(Elem::U32),
|
Self::LocalInvocationIdZ => Item::Scalar(Elem::U32),
|
||||||
Self::Rank => Item::Scalar(Elem::U32),
|
Self::Rank => Item::Scalar(Elem::U32),
|
||||||
Self::LocalScalar {
|
Self::LocalScalar { elem, .. } => Item::Scalar(*elem),
|
||||||
index: _,
|
|
||||||
elem,
|
|
||||||
scope_depth: _,
|
|
||||||
} => Item::Scalar(*elem),
|
|
||||||
Self::WorkgroupId => Item::Scalar(Elem::U32),
|
Self::WorkgroupId => Item::Scalar(Elem::U32),
|
||||||
Self::WorkgroupIdX => Item::Scalar(Elem::U32),
|
Self::WorkgroupIdX => Item::Scalar(Elem::U32),
|
||||||
Self::WorkgroupIdY => Item::Scalar(Elem::U32),
|
Self::WorkgroupIdY => Item::Scalar(Elem::U32),
|
||||||
|
@ -222,15 +224,21 @@ impl Display for Variable {
|
||||||
f.write_fmt(format_args!("input_{number}_global"))
|
f.write_fmt(format_args!("input_{number}_global"))
|
||||||
}
|
}
|
||||||
Variable::LocalScalar {
|
Variable::LocalScalar {
|
||||||
index,
|
id: index,
|
||||||
elem: _,
|
elem: _,
|
||||||
scope_depth,
|
depth: scope_depth,
|
||||||
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
|
} => f.write_fmt(format_args!("s_{scope_depth}_{index}")),
|
||||||
Variable::Local {
|
Variable::Local {
|
||||||
index,
|
id: index,
|
||||||
item: _,
|
item: _,
|
||||||
scope_depth,
|
depth: scope_depth,
|
||||||
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
|
} => f.write_fmt(format_args!("l_{scope_depth}_{index}")),
|
||||||
|
Variable::Named { name, .. } => f.write_str(name),
|
||||||
|
Variable::Slice {
|
||||||
|
id: index,
|
||||||
|
item: _,
|
||||||
|
depth: scope_depth,
|
||||||
|
} => f.write_fmt(format_args!("slice_{scope_depth}_{index}")),
|
||||||
Variable::GlobalOutputArray(number, _) => {
|
Variable::GlobalOutputArray(number, _) => {
|
||||||
f.write_fmt(format_args!("output_{number}_global"))
|
f.write_fmt(format_args!("output_{number}_global"))
|
||||||
}
|
}
|
||||||
|
|
|
@ -127,43 +127,53 @@ impl WgslCompiler {
|
||||||
|
|
||||||
fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
|
fn compile_variable(&mut self, value: cube::Variable) -> wgsl::Variable {
|
||||||
match value {
|
match value {
|
||||||
cube::Variable::GlobalInputArray(index, item) => {
|
cube::Variable::GlobalInputArray { id, item } => {
|
||||||
wgsl::Variable::GlobalInputArray(index, Self::compile_item(item))
|
wgsl::Variable::GlobalInputArray(id, Self::compile_item(item))
|
||||||
}
|
}
|
||||||
cube::Variable::GlobalScalar(index, elem) => {
|
cube::Variable::GlobalScalar { id, elem } => {
|
||||||
wgsl::Variable::GlobalScalar(index, Self::compile_elem(elem), elem)
|
wgsl::Variable::GlobalScalar(id, Self::compile_elem(elem), elem)
|
||||||
}
|
}
|
||||||
cube::Variable::Local(index, item, scope_depth) => wgsl::Variable::Local {
|
cube::Variable::Local { id, item, depth } => wgsl::Variable::Local {
|
||||||
index,
|
id,
|
||||||
item: Self::compile_item(item),
|
item: Self::compile_item(item),
|
||||||
scope_depth,
|
depth,
|
||||||
},
|
},
|
||||||
cube::Variable::LocalScalar(index, elem, scope_depth) => wgsl::Variable::LocalScalar {
|
cube::Variable::Slice { id, item, depth } => wgsl::Variable::Slice {
|
||||||
index,
|
id,
|
||||||
|
item: Self::compile_item(item),
|
||||||
|
depth,
|
||||||
|
},
|
||||||
|
cube::Variable::LocalScalar { id, elem, depth } => wgsl::Variable::LocalScalar {
|
||||||
|
id,
|
||||||
elem: Self::compile_elem(elem),
|
elem: Self::compile_elem(elem),
|
||||||
scope_depth,
|
depth,
|
||||||
},
|
},
|
||||||
cube::Variable::GlobalOutputArray(index, item) => {
|
cube::Variable::GlobalOutputArray { id, item } => {
|
||||||
wgsl::Variable::GlobalOutputArray(index, Self::compile_item(item))
|
wgsl::Variable::GlobalOutputArray(id, Self::compile_item(item))
|
||||||
}
|
}
|
||||||
cube::Variable::ConstantScalar(index, elem) => {
|
cube::Variable::ConstantScalar { value, elem } => {
|
||||||
wgsl::Variable::ConstantScalar(index, Self::compile_elem(elem))
|
wgsl::Variable::ConstantScalar(value, Self::compile_elem(elem))
|
||||||
}
|
}
|
||||||
cube::Variable::SharedMemory(index, item, size) => {
|
cube::Variable::SharedMemory { id, item, length } => {
|
||||||
let item = Self::compile_item(item);
|
let item = Self::compile_item(item);
|
||||||
if !self.shared_memories.iter().any(|s| s.index == index) {
|
if !self.shared_memories.iter().any(|s| s.index == id) {
|
||||||
self.shared_memories
|
self.shared_memories
|
||||||
.push(SharedMemory::new(index, item, size));
|
.push(SharedMemory::new(id, item, length));
|
||||||
}
|
}
|
||||||
wgsl::Variable::SharedMemory(index, item, size)
|
wgsl::Variable::SharedMemory(id, item, length)
|
||||||
}
|
}
|
||||||
cube::Variable::LocalArray(index, item, scope_depth, size) => {
|
cube::Variable::LocalArray {
|
||||||
|
id,
|
||||||
|
item,
|
||||||
|
depth,
|
||||||
|
length,
|
||||||
|
} => {
|
||||||
let item = Self::compile_item(item);
|
let item = Self::compile_item(item);
|
||||||
if !self.local_arrays.iter().any(|s| s.index == index) {
|
if !self.local_arrays.iter().any(|s| s.index == id) {
|
||||||
self.local_arrays
|
self.local_arrays
|
||||||
.push(LocalArray::new(index, item, scope_depth, size));
|
.push(LocalArray::new(id, item, depth, length));
|
||||||
}
|
}
|
||||||
wgsl::Variable::LocalArray(index, item, scope_depth, size)
|
wgsl::Variable::LocalArray(id, item, depth, length)
|
||||||
}
|
}
|
||||||
cube::Variable::AbsolutePos => {
|
cube::Variable::AbsolutePos => {
|
||||||
self.id = true;
|
self.id = true;
|
||||||
|
@ -241,7 +251,7 @@ impl WgslCompiler {
|
||||||
wgsl::Variable::NumWorkgroups
|
wgsl::Variable::NumWorkgroups
|
||||||
}
|
}
|
||||||
cube::Variable::SubcubeDim => wgsl::Variable::SubgroupSize,
|
cube::Variable::SubcubeDim => wgsl::Variable::SubgroupSize,
|
||||||
cube::Variable::Matrix(_, _) => {
|
cube::Variable::Matrix { .. } => {
|
||||||
panic!("Cooperative matrix-multiply and accumulate not supported.")
|
panic!("Cooperative matrix-multiply and accumulate not supported.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -252,6 +262,11 @@ impl WgslCompiler {
|
||||||
let processing = value.process();
|
let processing = value.process();
|
||||||
|
|
||||||
for var in processing.variables {
|
for var in processing.variables {
|
||||||
|
// We don't declare slices.
|
||||||
|
if let cube::Variable::Slice { .. } = var {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
instructions.push(wgsl::Instruction::DeclareVariable {
|
instructions.push(wgsl::Instruction::DeclareVariable {
|
||||||
var: self.compile_variable(var),
|
var: self.compile_variable(var),
|
||||||
});
|
});
|
||||||
|
@ -427,8 +442,8 @@ impl WgslCompiler {
|
||||||
cube::Metadata::Stride { dim, var, out } => {
|
cube::Metadata::Stride { dim, var, out } => {
|
||||||
self.stride = true;
|
self.stride = true;
|
||||||
let position = match var {
|
let position = match var {
|
||||||
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
|
cube::Variable::GlobalInputArray { id, .. } => id as usize,
|
||||||
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
|
||||||
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
|
_ => panic!("Only Input and Output have a stride, got: {:?}", var),
|
||||||
};
|
};
|
||||||
wgsl::Instruction::Stride {
|
wgsl::Instruction::Stride {
|
||||||
|
@ -440,8 +455,8 @@ impl WgslCompiler {
|
||||||
cube::Metadata::Shape { dim, var, out } => {
|
cube::Metadata::Shape { dim, var, out } => {
|
||||||
self.shape = true;
|
self.shape = true;
|
||||||
let position = match var {
|
let position = match var {
|
||||||
cube::Variable::GlobalInputArray(idx, _) => idx as usize,
|
cube::Variable::GlobalInputArray { id, .. } => id as usize,
|
||||||
cube::Variable::GlobalOutputArray(idx, _) => self.num_inputs + idx as usize,
|
cube::Variable::GlobalOutputArray { id, .. } => self.num_inputs + id as usize,
|
||||||
_ => panic!("Only Input and Output have a shape, got {:?}", var),
|
_ => panic!("Only Input and Output have a shape, got {:?}", var),
|
||||||
};
|
};
|
||||||
wgsl::Instruction::Shape {
|
wgsl::Instruction::Shape {
|
||||||
|
@ -450,7 +465,7 @@ impl WgslCompiler {
|
||||||
out: self.compile_variable(out),
|
out: self.compile_variable(out),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
cube::Metadata::ArrayLength { var, out } => wgsl::Instruction::ArrayLength {
|
cube::Metadata::Length { var, out } => wgsl::Instruction::Length {
|
||||||
out: self.compile_variable(out),
|
out: self.compile_variable(out),
|
||||||
var: self.compile_variable(var),
|
var: self.compile_variable(var),
|
||||||
},
|
},
|
||||||
|
@ -652,6 +667,12 @@ impl WgslCompiler {
|
||||||
rhs: self.compile_variable(op.rhs),
|
rhs: self.compile_variable(op.rhs),
|
||||||
out: self.compile_variable(op.out),
|
out: self.compile_variable(op.out),
|
||||||
},
|
},
|
||||||
|
cube::Operator::Slice(op) => wgsl::Instruction::Slice {
|
||||||
|
input: self.compile_variable(op.input),
|
||||||
|
start: self.compile_variable(op.start),
|
||||||
|
end: self.compile_variable(op.end),
|
||||||
|
out: self.compile_variable(op.out),
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
use super::{
|
use super::{
|
||||||
base::{Item, Variable},
|
base::{Item, Variable},
|
||||||
IndexedVariable, Subgroup,
|
Elem, IndexedVariable, Subgroup,
|
||||||
};
|
};
|
||||||
use std::fmt::Display;
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
@ -167,7 +167,7 @@ pub enum Instruction {
|
||||||
position: usize,
|
position: usize,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
},
|
},
|
||||||
ArrayLength {
|
Length {
|
||||||
var: Variable,
|
var: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
},
|
},
|
||||||
|
@ -232,6 +232,12 @@ pub enum Instruction {
|
||||||
rhs: Variable,
|
rhs: Variable,
|
||||||
out: Variable,
|
out: Variable,
|
||||||
},
|
},
|
||||||
|
Slice {
|
||||||
|
input: Variable,
|
||||||
|
start: Variable,
|
||||||
|
end: Variable,
|
||||||
|
out: Variable,
|
||||||
|
},
|
||||||
Subgroup(Subgroup),
|
Subgroup(Subgroup),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -245,6 +251,16 @@ impl Display for Instruction {
|
||||||
Instruction::Add { lhs, rhs, out } => {
|
Instruction::Add { lhs, rhs, out } => {
|
||||||
f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n"))
|
f.write_fmt(format_args!("{out} = {lhs} + {rhs};\n"))
|
||||||
}
|
}
|
||||||
|
Instruction::Slice {
|
||||||
|
input,
|
||||||
|
start,
|
||||||
|
end,
|
||||||
|
out,
|
||||||
|
} => {
|
||||||
|
f.write_fmt(format_args!("let {out}_offset = {start};\n"))?;
|
||||||
|
f.write_fmt(format_args!("let {out}_length = {end} - {start};\n"))?;
|
||||||
|
f.write_fmt(format_args!("let {out}_ptr = &{input};\n"))
|
||||||
|
}
|
||||||
Instruction::Fma { a, b, c, out } => {
|
Instruction::Fma { a, b, c, out } => {
|
||||||
f.write_fmt(format_args!("{out} = fma({a}, {b}, {c});\n"))
|
f.write_fmt(format_args!("{out} = fma({a}, {b}, {c});\n"))
|
||||||
}
|
}
|
||||||
|
@ -261,10 +277,22 @@ impl Display for Instruction {
|
||||||
f.write_fmt(format_args!("{out} = {lhs} || {rhs};\n"))
|
f.write_fmt(format_args!("{out} = {lhs} || {rhs};\n"))
|
||||||
}
|
}
|
||||||
Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")),
|
Instruction::Not { input, out } => f.write_fmt(format_args!("{out} = !{input};\n")),
|
||||||
Instruction::Index { lhs, rhs, out } => {
|
Instruction::Index { lhs, rhs, out } => match lhs {
|
||||||
let item = out.item();
|
Variable::Slice { item, .. } => {
|
||||||
f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n"))
|
let offset = Variable::Named {
|
||||||
}
|
name: format!("{lhs}_offset"),
|
||||||
|
item: Item::Scalar(Elem::U32),
|
||||||
|
is_array: false,
|
||||||
|
};
|
||||||
|
let lhs = Variable::Named {
|
||||||
|
name: format!("(*{lhs}_ptr)"),
|
||||||
|
item: *item,
|
||||||
|
is_array: true,
|
||||||
|
};
|
||||||
|
index(f, &lhs, rhs, out, Some(offset))
|
||||||
|
}
|
||||||
|
_ => index(f, lhs, rhs, out, None),
|
||||||
|
},
|
||||||
Instruction::Modulo { lhs, rhs, out } => {
|
Instruction::Modulo { lhs, rhs, out } => {
|
||||||
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
|
f.write_fmt(format_args!("{out} = {lhs} % {rhs};\n"))
|
||||||
}
|
}
|
||||||
|
@ -406,88 +434,24 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
||||||
|
|
||||||
f.write_str("}\n")
|
f.write_str("}\n")
|
||||||
}
|
}
|
||||||
Instruction::IndexAssign { lhs, rhs, out } => match lhs.item() {
|
Instruction::IndexAssign { lhs, rhs, out } => {
|
||||||
Item::Vec4(elem) => {
|
if let Variable::Slice { item, .. } = out {
|
||||||
let lhs0 = lhs.index(0);
|
let offset = Variable::Named {
|
||||||
let lhs1 = lhs.index(1);
|
name: format!("{out}_offset"),
|
||||||
let lhs2 = lhs.index(2);
|
item: Item::Scalar(Elem::U32),
|
||||||
let lhs3 = lhs.index(3);
|
is_array: false,
|
||||||
|
};
|
||||||
|
let out = Variable::Named {
|
||||||
|
name: format!("(*{out}_ptr)"),
|
||||||
|
item: *item,
|
||||||
|
is_array: true,
|
||||||
|
};
|
||||||
|
|
||||||
let rhs0 = rhs.index(0);
|
index_assign(f, lhs, rhs, &out, Some(offset))
|
||||||
let rhs1 = rhs.index(1);
|
} else {
|
||||||
let rhs2 = rhs.index(2);
|
index_assign(f, lhs, rhs, out, None)
|
||||||
let rhs3 = rhs.index(3);
|
|
||||||
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
|
|
||||||
}
|
}
|
||||||
Item::Vec3(elem) => {
|
}
|
||||||
let lhs0 = lhs.index(0);
|
|
||||||
let lhs1 = lhs.index(1);
|
|
||||||
let lhs2 = lhs.index(2);
|
|
||||||
|
|
||||||
let rhs0 = rhs.index(0);
|
|
||||||
let rhs1 = rhs.index(1);
|
|
||||||
let rhs2 = rhs.index(2);
|
|
||||||
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
|
|
||||||
}
|
|
||||||
Item::Vec2(elem) => {
|
|
||||||
let lhs0 = lhs.index(0);
|
|
||||||
let lhs1 = lhs.index(1);
|
|
||||||
|
|
||||||
let rhs0 = rhs.index(0);
|
|
||||||
let rhs1 = rhs.index(1);
|
|
||||||
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
|
|
||||||
}
|
|
||||||
Item::Scalar(_elem) => {
|
|
||||||
let is_array = matches!(
|
|
||||||
out,
|
|
||||||
Variable::GlobalInputArray(_, _)
|
|
||||||
| Variable::GlobalOutputArray(_, _)
|
|
||||||
| Variable::SharedMemory(_, _, _)
|
|
||||||
| Variable::LocalArray(_, _, _, _)
|
|
||||||
);
|
|
||||||
|
|
||||||
if !is_array {
|
|
||||||
let elem_out = out.elem();
|
|
||||||
let casting_type = match rhs.item() {
|
|
||||||
Item::Vec4(_) => Item::Vec4(elem_out),
|
|
||||||
Item::Vec3(_) => Item::Vec3(elem_out),
|
|
||||||
Item::Vec2(_) => Item::Vec2(elem_out),
|
|
||||||
Item::Scalar(_) => Item::Scalar(elem_out),
|
|
||||||
};
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
|
||||||
} else {
|
|
||||||
let item_rhs = rhs.item();
|
|
||||||
let item_out = out.item();
|
|
||||||
|
|
||||||
let vectorization_factor = item_out.vectorization_factor();
|
|
||||||
if vectorization_factor > item_rhs.vectorization_factor() {
|
|
||||||
let casting_type = item_out.elem();
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs}] = vec{vectorization_factor}("))?;
|
|
||||||
for i in 0..vectorization_factor {
|
|
||||||
let value = rhs.index(i);
|
|
||||||
f.write_fmt(format_args!("{casting_type}({value})"))?;
|
|
||||||
|
|
||||||
if i < vectorization_factor - 1 {
|
|
||||||
f.write_str(",")?;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
f.write_str(");\n")
|
|
||||||
} else {
|
|
||||||
let casting_type = item_out;
|
|
||||||
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
Instruction::If { cond, instructions } => {
|
Instruction::If { cond, instructions } => {
|
||||||
f.write_fmt(format_args!("if {cond} {{\n"))?;
|
f.write_fmt(format_args!("if {cond} {{\n"))?;
|
||||||
for i in instructions {
|
for i in instructions {
|
||||||
|
@ -513,9 +477,10 @@ for (var {i}: u32 = {start}; {i} < {end}; {i}++) {{
|
||||||
Instruction::Return => f.write_str("return;\n"),
|
Instruction::Return => f.write_str("return;\n"),
|
||||||
Instruction::Break => f.write_str("break;\n"),
|
Instruction::Break => f.write_str("break;\n"),
|
||||||
Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"),
|
Instruction::WorkgroupBarrier => f.write_str("workgroupBarrier();\n"),
|
||||||
Instruction::ArrayLength { var, out } => {
|
Instruction::Length { var, out } => match var {
|
||||||
f.write_fmt(format_args!("{out} = arrayLength(&{var});\n"))
|
Variable::Slice { .. } => f.write_fmt(format_args!("{out} = {var}_length;\n")),
|
||||||
}
|
_ => f.write_fmt(format_args!("{out} = arrayLength(&{var});\n")),
|
||||||
|
},
|
||||||
Instruction::Loop { instructions } => {
|
Instruction::Loop { instructions } => {
|
||||||
f.write_fmt(format_args!("loop {{\n"))?;
|
f.write_fmt(format_args!("loop {{\n"))?;
|
||||||
for i in instructions {
|
for i in instructions {
|
||||||
|
@ -623,3 +588,140 @@ fn unroll<
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct IndexOffset {
|
||||||
|
var: Variable,
|
||||||
|
offset: Option<Variable>,
|
||||||
|
index: usize,
|
||||||
|
}
|
||||||
|
impl IndexOffset {
|
||||||
|
fn new(var: &Variable, offset: &Option<Variable>, index: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
var: var.clone(),
|
||||||
|
offset: offset.clone(),
|
||||||
|
index,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for IndexOffset {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
let var = self.var.index(self.index);
|
||||||
|
|
||||||
|
match &self.offset {
|
||||||
|
Some(offset) => {
|
||||||
|
let offset = offset.index(self.index);
|
||||||
|
f.write_fmt(format_args!("{var} + {offset}"))
|
||||||
|
}
|
||||||
|
None => f.write_fmt(format_args!("{var}")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index(
|
||||||
|
f: &mut std::fmt::Formatter<'_>,
|
||||||
|
lhs: &Variable,
|
||||||
|
rhs: &Variable,
|
||||||
|
out: &Variable,
|
||||||
|
offset: Option<Variable>,
|
||||||
|
) -> core::fmt::Result {
|
||||||
|
let item = out.item();
|
||||||
|
match offset {
|
||||||
|
Some(offset) => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs} + {offset}]);\n")),
|
||||||
|
None => f.write_fmt(format_args!("{out} = {item}({lhs}[{rhs}]);\n")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn index_assign(
|
||||||
|
f: &mut std::fmt::Formatter<'_>,
|
||||||
|
lhs: &Variable,
|
||||||
|
rhs: &Variable,
|
||||||
|
out: &Variable,
|
||||||
|
offset: Option<Variable>,
|
||||||
|
) -> core::fmt::Result {
|
||||||
|
match lhs.item() {
|
||||||
|
Item::Vec4(elem) => {
|
||||||
|
let lhs0 = IndexOffset::new(lhs, &offset, 0);
|
||||||
|
let lhs1 = IndexOffset::new(lhs, &offset, 1);
|
||||||
|
let lhs2 = IndexOffset::new(lhs, &offset, 2);
|
||||||
|
let lhs3 = IndexOffset::new(lhs, &offset, 3);
|
||||||
|
|
||||||
|
let rhs0 = rhs.index(0);
|
||||||
|
let rhs1 = rhs.index(1);
|
||||||
|
let rhs2 = rhs.index(2);
|
||||||
|
let rhs3 = rhs.index(3);
|
||||||
|
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs3}] = {elem}({rhs3});\n"))
|
||||||
|
}
|
||||||
|
Item::Vec3(elem) => {
|
||||||
|
let lhs0 = IndexOffset::new(lhs, &offset, 0);
|
||||||
|
let lhs1 = IndexOffset::new(lhs, &offset, 1);
|
||||||
|
let lhs2 = IndexOffset::new(lhs, &offset, 2);
|
||||||
|
|
||||||
|
let rhs0 = rhs.index(0);
|
||||||
|
let rhs1 = rhs.index(1);
|
||||||
|
let rhs2 = rhs.index(2);
|
||||||
|
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs2}] = {elem}({rhs2});\n"))
|
||||||
|
}
|
||||||
|
Item::Vec2(elem) => {
|
||||||
|
let lhs0 = IndexOffset::new(lhs, &offset, 0);
|
||||||
|
let lhs1 = IndexOffset::new(lhs, &offset, 1);
|
||||||
|
|
||||||
|
let rhs0 = rhs.index(0);
|
||||||
|
let rhs1 = rhs.index(1);
|
||||||
|
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs0}] = {elem}({rhs0});\n"))?;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs1}] = {elem}({rhs1});\n"))
|
||||||
|
}
|
||||||
|
Item::Scalar(_elem) => {
|
||||||
|
let is_array = match out {
|
||||||
|
Variable::GlobalInputArray(_, _)
|
||||||
|
| Variable::GlobalOutputArray(_, _)
|
||||||
|
| Variable::SharedMemory(_, _, _)
|
||||||
|
| Variable::Slice { .. }
|
||||||
|
| Variable::LocalArray(_, _, _, _) => true,
|
||||||
|
Variable::Named { is_array, .. } => *is_array,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
|
||||||
|
if !is_array {
|
||||||
|
let elem_out = out.elem();
|
||||||
|
let casting_type = match rhs.item() {
|
||||||
|
Item::Vec4(_) => Item::Vec4(elem_out),
|
||||||
|
Item::Vec3(_) => Item::Vec3(elem_out),
|
||||||
|
Item::Vec2(_) => Item::Vec2(elem_out),
|
||||||
|
Item::Scalar(_) => Item::Scalar(elem_out),
|
||||||
|
};
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
||||||
|
} else {
|
||||||
|
let item_rhs = rhs.item();
|
||||||
|
let item_out = out.item();
|
||||||
|
let lhs = IndexOffset::new(lhs, &offset, 0);
|
||||||
|
|
||||||
|
let vectorization_factor = item_out.vectorization_factor();
|
||||||
|
if vectorization_factor > item_rhs.vectorization_factor() {
|
||||||
|
let casting_type = item_out.elem();
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs}] = vec{vectorization_factor}("))?;
|
||||||
|
for i in 0..vectorization_factor {
|
||||||
|
let value = rhs.index(i);
|
||||||
|
f.write_fmt(format_args!("{casting_type}({value})"))?;
|
||||||
|
|
||||||
|
if i < vectorization_factor - 1 {
|
||||||
|
f.write_str(",")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.write_str(");\n")
|
||||||
|
} else {
|
||||||
|
let casting_type = item_out;
|
||||||
|
f.write_fmt(format_args!("{out}[{lhs}] = {casting_type}({rhs});\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue