Refactor/cube/expand & fix double imports (#2009)

* Refactored function

* WIP

* Basic stuff done

* Fix traits

* Cleanup

* Cleanup

* Cleanup
This commit is contained in:
Nathaniel Simard 2024-07-12 09:18:38 -04:00 committed by GitHub
parent 35345de62a
commit 19f5ad7be5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
71 changed files with 714 additions and 542 deletions

View File

@ -2,10 +2,17 @@ use quote::ToTokens;
use crate::tracker::VariableTracker;
#[derive(Copy, Clone, Debug)]
pub enum ExpandMode {
FuncImpl,
MethodImpl,
}
pub fn expand_sig(
sig: &syn::Signature,
visibility: &syn::Visibility,
mut variable_tracker: Option<&mut VariableTracker>,
mode: ExpandMode,
) -> proc_macro2::TokenStream {
let mut inputs = quote::quote!();
@ -42,7 +49,10 @@ pub fn expand_sig(
}
let ident = &sig.ident;
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
let ident = match mode {
ExpandMode::FuncImpl => syn::Ident::new("__expand".to_string().as_str(), ident.span()),
_ => syn::Ident::new(format!("__expand_{ident}").as_str(), ident.span()),
};
let generics = sig.generics.clone().into_token_stream();

View File

@ -106,22 +106,42 @@ pub(crate) fn codegen_call(
// Path
let mut path_tokens = TokenStream::new();
let mut is_comptime = false;
let mut is_plain_func = true;
let mut comptime_func: Option<String> = None;
for (i, (ident, generics)) in path.iter().enumerate() {
if *ident == "Comptime" {
let name = ident.to_string();
if name == "Comptime" {
is_comptime = true;
continue;
}
if let Some(first_char) = name.chars().next() {
if first_char.is_uppercase() {
is_plain_func = false;
}
}
if i == path.len() - 1 {
if is_comptime {
comptime_func = Some(ident.to_string());
break;
}
let func_name_expand = syn::Ident::new(
format!("{ident}_expand").as_str(),
proc_macro2::Span::call_site(),
);
let func_name_expand = if is_plain_func {
quote::quote! {
#ident::__expand
}
} else {
let ident = syn::Ident::new(
format!("__expand_{ident}").as_str(),
proc_macro2::Span::call_site(),
);
quote::quote! {
#ident
}
};
path_tokens.extend(quote_spanned! {func_name_expand.span() => #func_name_expand });
if let Some(generics) = generics {
path_tokens.extend(quote_spanned! {generics.span() => #generics });

View File

@ -211,7 +211,7 @@ impl Codegen {
}
}
fn gen_define_impl(&self, expand: &Ident) -> TokenStream {
fn gen_define_impl(&self, expand: &TokenStream) -> TokenStream {
let mut expand_args = quote::quote! { &mut builder.context, };
let mut variables = quote::quote! {};
@ -340,7 +340,7 @@ impl Codegen {
tokens
}
fn gen_compile_impl(&self, expand: &Ident) -> TokenStream {
fn gen_compile_impl(&self, expand: &TokenStream) -> TokenStream {
let ident = Ident::new(&self.name, Span::call_site());
let generics = add_runtime(self.generics.clone());
let (impl_gen, ty_gen, where_gen) = generics.split_for_impl();
@ -453,22 +453,27 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
let codegen = Codegen::from_sig(sig);
let ident = &sig.ident;
let ident_expand = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
let ident = syn::Ident::new(format!("{ident}_launch").as_str(), ident.span());
let ident_expand = quote::quote! {
__expand
};
let generics = add_runtime(add_lifetime(sig.generics.clone()));
let body = codegen.gen_launch_body();
let kernel = codegen.gen_kernel_struct();
let compile = codegen.gen_compile_impl(&ident_expand);
let (inputs, output) = (codegen.fn_inputs, codegen.fn_output);
let doc =
format!("Launch the kernel [{ident}] with the provided argument on the given runtime.");
quote::quote! {
#kernel
#compile
#[allow(clippy::too_many_arguments)]
/// Launch
pub fn #ident #generics (
#[doc = #doc]
/// Launch the kernel.
pub fn launch #generics (
client: ComputeClient<R::Server, R::Channel>,
cube_count: CubeCount<R::Server>,
cube_dim: CubeDim,

View File

@ -1,4 +1,6 @@
use crate::codegen_common::signature::expand_sig;
use proc_macro2::TokenStream;
use crate::codegen_common::signature::{expand_sig, ExpandMode};
pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
let mut expand_items = Vec::new();
@ -6,7 +8,12 @@ pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
for item in tr.items.iter() {
match item {
syn::TraitItem::Fn(func) => {
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
let expand = expand_sig(
&func.sig,
&syn::Visibility::Inherited,
None,
ExpandMode::MethodImpl,
);
expand_items.push(syn::parse_quote!(#expand;));
}
_ => continue,
@ -26,7 +33,7 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
match item {
syn::ImplItem::Fn(func) => {
let ident = &func.sig.ident;
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
let ident = quote::quote! {#ident::__expand};
let mut inputs = quote::quote!();
for input in &func.sig.inputs {
@ -41,7 +48,12 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
}
}
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
let expand = expand_sig(
&func.sig,
&syn::Visibility::Inherited,
None,
ExpandMode::MethodImpl,
);
let tokens = if !tr.generics.params.is_empty() {
let mut func = func.clone();
@ -67,7 +79,7 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
fn register_expand(
func: &syn::ImplItemFn,
name: &syn::Ident,
name: &TokenStream,
expand: proc_macro2::TokenStream,
inputs: proc_macro2::TokenStream,
) -> proc_macro2::TokenStream {
@ -91,7 +103,7 @@ fn register_expand(
quote::quote! (
#expand {
#[cube]
#func
pub #func
#func_expand
}
)

View File

@ -10,7 +10,7 @@ mod tracker;
pub(crate) mod codegen_common;
use analyzer::VariableAnalyzer;
use codegen_common::signature::expand_sig;
use codegen_common::signature::{expand_sig, ExpandMode};
use codegen_function::{codegen_launch, codegen_statement};
use codegen_trait::{expand_trait_def, expand_trait_impl};
use codegen_type::generate_cube_type;
@ -69,20 +69,8 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream {
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
match codegen_cube(&func, &mut variable_tracker) {
Ok(code) => {
if attrs.launch {
let launch = codegen_launch(&func.sig);
quote::quote! {
#code
#launch
}
.into()
} else {
code.into()
}
}
match codegen_cube(&func, &mut variable_tracker, attrs.launch) {
Ok(code) => code.into(),
Err(err) => err.into(),
}
}
@ -120,8 +108,15 @@ fn parse_attributes(args: &Punctuated<Meta, Comma>) -> SupportedAttributes {
fn codegen_cube(
func: &syn::ItemFn,
variable_tracker: &mut VariableTracker,
launch: bool,
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
let signature = expand_sig(&func.sig, &func.vis, Some(variable_tracker));
let signature = expand_sig(
&func.sig,
&syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import
// it from an outside module.
Some(variable_tracker),
ExpandMode::FuncImpl,
);
let mut body = quote::quote! {};
for statement in func.block.stmts.iter() {
@ -148,15 +143,36 @@ fn codegen_cube(
return Err(code);
}
let launch_doc = if launch { "and launch function " } else { "" };
let launch = if launch {
codegen_launch(&func.sig)
} else {
quote::quote! {}
};
let mod_name = &func.sig.ident;
let vis = &func.vis;
let doc = format!("Module containing the expand method {launch_doc}of {mod_name}.");
Ok(quote::quote! {
#[allow(dead_code)]
#[allow(clippy::too_many_arguments)]
#func
#[allow(unused_mut)]
#[allow(clippy::too_many_arguments)]
#signature {
#body
#[doc = #doc]
#vis mod #mod_name {
use super::*;
#launch
#[allow(unused_mut)]
#[allow(clippy::too_many_arguments)]
#signature {
#body
}
}
})
}

View File

@ -86,7 +86,7 @@ impl Init for MatrixExpand {
impl<C: CubePrimitive> Matrix<C> {
/// Create a new matrix that is going to be used in the
/// [matrix-multiply and accumulate](execute) function.
/// [matrix-multiply and accumulate](execute()) function.
///
/// You have to declare the shape used for the execution.
/// The shape of the current matrix is determined using the [MatrixIdent].
@ -103,7 +103,7 @@ impl<C: CubePrimitive> Matrix<C> {
Matrix { _c: PhantomData }
}
pub fn new_expand(
pub fn __expand_new(
context: &mut CubeContext,
ident: MatrixIdent,
m: u8,
@ -129,16 +129,21 @@ pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
unexpanded!()
}
/// Expand method of [fill].
pub fn fill_expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElement,
) {
context.register(Operation::CoopMma(ir::CoopMma::Fill {
mat: *mat.elem,
value: *value,
}));
/// Module containing the expand function for [fill()].
pub mod fill {
use super::*;
/// Expand method of [fill()].
pub fn __expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElement,
) {
context.register(Operation::CoopMma(ir::CoopMma::Fill {
mat: *mat.elem,
value: *value,
}));
}
}
/// Load the matrix with the provided array using the stride.
@ -147,19 +152,24 @@ pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
unexpanded!()
}
/// Expand method of [load].
#[allow(unused_variables)]
pub fn load_expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElementTyped<Slice<'static, C>>,
stride: ExpandElement,
) {
context.register(Operation::CoopMma(ir::CoopMma::Load {
mat: *mat.elem,
value: *value.expand,
stride: *stride,
}));
/// Module containing the expand function for [load()].
pub mod load {
use super::*;
/// Expand method of [load()].
#[allow(unused_variables)]
pub fn __expand<C: CubeType>(
context: &mut CubeContext,
mat: MatrixExpand,
value: ExpandElementTyped<Slice<'static, C>>,
stride: ExpandElement,
) {
context.register(Operation::CoopMma(ir::CoopMma::Load {
mat: *mat.elem,
value: *value.expand,
stride: *stride,
}));
}
}
/// Store the matrix in the given array following the given stride and layout.
@ -173,21 +183,26 @@ pub fn store<C: CubePrimitive>(
unexpanded!()
}
/// Expand method of [store].
#[allow(unused_variables)]
pub fn store_expand<C: CubePrimitive>(
context: &mut CubeContext,
output: ExpandElementTyped<SliceMut<'static, C>>,
mat: MatrixExpand,
stride: ExpandElement,
layout: MatrixLayout,
) {
context.register(Operation::CoopMma(ir::CoopMma::Store {
output: *output.expand,
mat: *mat.elem,
stride: *stride,
layout,
}));
/// Module containing the expand function for [store()].
pub mod store {
use super::*;
/// Expand method of [store()].
#[allow(unused_variables)]
pub fn __expand<C: CubePrimitive>(
context: &mut CubeContext,
output: ExpandElementTyped<SliceMut<'static, C>>,
mat: MatrixExpand,
stride: ExpandElement,
layout: MatrixLayout,
) {
context.register(Operation::CoopMma(ir::CoopMma::Store {
output: *output.expand,
mat: *mat.elem,
stride: *stride,
layout,
}));
}
}
/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
@ -201,18 +216,23 @@ pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrim
unexpanded!()
}
/// Expand method of [execute].
pub fn execute_expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
context: &mut CubeContext,
mat_a: MatrixExpand,
mat_b: MatrixExpand,
mat_c: MatrixExpand,
mat_d: MatrixExpand,
) {
context.register(Operation::CoopMma(ir::CoopMma::Execute {
mat_a: *mat_a.elem,
mat_b: *mat_b.elem,
mat_c: *mat_c.elem,
mat_d: *mat_d.elem,
}));
/// Module containing the expand function for [execute()].
pub mod execute {
use super::*;
/// Expand method of [execute()].
pub fn __expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
context: &mut CubeContext,
mat_a: MatrixExpand,
mat_b: MatrixExpand,
mat_c: MatrixExpand,
mat_d: MatrixExpand,
) {
context.register(Operation::CoopMma(ir::CoopMma::Execute {
mat_a: *mat_a.elem,
mat_b: *mat_b.elem,
mat_c: *mat_c.elem,
mat_d: *mat_d.elem,
}));
}
}

View File

@ -30,7 +30,11 @@ impl<T: CubePrimitive + Clone> Array<T> {
Array { _val: PhantomData }
}
pub fn new_expand<S: Index>(
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
Array { _val: PhantomData }
}
pub fn __expand_new<S: Index>(
context: &mut CubeContext,
size: S,
) -> <Self as CubeType>::ExpandType {
@ -44,11 +48,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
.into()
}
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
Array { _val: PhantomData }
}
pub fn vectorized_expand<S: Index>(
pub fn __expand_vectorized<S: Index>(
context: &mut CubeContext,
size: S,
vectorization_factor: UInt,

View File

@ -3,6 +3,10 @@ use crate::ir::Elem;
use super::Vectorized;
// To be consistent with other primitive type.
/// Boolean type.
pub type Bool = bool;
impl CubeType for bool {
type ExpandType = ExpandElement;
}

View File

@ -6,7 +6,7 @@ use crate::{frontend::ExpandElement, unexpanded};
pub trait Cast: CubePrimitive {
fn cast_from<From: CubePrimitive>(value: From) -> Self;
fn cast_from_expand<From>(
fn __expand_cast_from<From>(
context: &mut CubeContext,
value: From,
) -> <Self as CubeType>::ExpandType

View File

@ -29,15 +29,15 @@ pub trait Float:
+ core::ops::IndexMut<UInt, Output = Self>
{
fn new(val: f32) -> Self;
fn new_expand(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
fn vectorized(val: f32, vectorization: UInt) -> Self;
fn vectorized_expand(
fn vectorized_empty(vectorization: UInt) -> Self;
fn __expand_new(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
fn __expand_vectorized(
context: &mut CubeContext,
val: f32,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType;
fn vectorized_empty(vectorization: UInt) -> Self;
fn vectorized_empty_expand(
fn __expand_vectorized_empty(
context: &mut CubeContext,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType;
@ -74,14 +74,6 @@ macro_rules! impl_float {
}
}
fn new_expand(_context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}
fn vectorized(val: f32, vectorization: UInt) -> Self {
if vectorization.val == 1 {
Self::new(val)
@ -93,13 +85,28 @@ macro_rules! impl_float {
}
}
fn vectorized_expand(
fn vectorized_empty(vectorization: UInt) -> Self {
Self::vectorized(0., vectorization)
}
fn __expand_new(
_context: &mut CubeContext,
val: f32,
) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}
fn __expand_vectorized(
context: &mut CubeContext,
val: f32,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::new_expand(context, val)
Self::__expand_new(context, val)
} else {
let mut new_var = context
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));
@ -111,16 +118,12 @@ macro_rules! impl_float {
}
}
fn vectorized_empty(vectorization: UInt) -> Self {
Self::vectorized(0., vectorization)
}
fn vectorized_empty_expand(
fn __expand_vectorized_empty(
context: &mut CubeContext,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::new_expand(context, 0.)
Self::__expand_new(context, 0.)
} else {
context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8))
}

View File

@ -9,9 +9,9 @@ use super::{LaunchArgExpand, ScalarArgSettings, UInt, Vectorized};
/// Signed integer. Used as input in int kernels
pub trait Int: Numeric + std::ops::Rem<Output = Self> {
fn new(val: i64) -> Self;
fn new_expand(context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType;
fn vectorized(val: i64, vectorization: UInt) -> Self;
fn vectorized_expand(
fn __expand_new(context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType;
fn __expand_vectorized(
context: &mut CubeContext,
val: i64,
vectorization: UInt,
@ -48,14 +48,6 @@ macro_rules! impl_int {
}
}
fn new_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}
fn vectorized(val: i64, vectorization: UInt) -> Self {
if vectorization.val == 1 {
Self::new(val)
@ -67,13 +59,24 @@ macro_rules! impl_int {
}
}
fn vectorized_expand(
fn __expand_new(
_context: &mut CubeContext,
val: i64,
) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
};
ExpandElement::Plain(new_var)
}
fn __expand_vectorized(
context: &mut CubeContext,
val: i64,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::new_expand(context, val)
Self::__expand_new(context, val)
} else {
let mut new_var = context
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));

View File

@ -14,6 +14,7 @@ mod vectorized;
pub use array::*;
pub use base::*;
pub use bool::*;
pub use cast::*;
pub use cube_elem::*;
pub use float::*;

View File

@ -45,8 +45,11 @@ pub trait Numeric:
type Primitive: ScalarArgSettings;
/// Expand version of from_int
fn from_int_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
fn from_vec<const D: usize>(_vec: [i64; D]) -> Self {
unexpanded!()
}
fn __expand_from_int(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
@ -54,11 +57,7 @@ pub trait Numeric:
ExpandElement::Plain(new_var)
}
fn from_vec<const D: usize>(_vec: [i64; D]) -> Self {
unexpanded!()
}
fn from_vec_expand<const D: usize>(
fn __expand_from_vec<const D: usize>(
context: &mut CubeContext,
vec: [i64; D],
) -> <Self as CubeType>::ExpandType {

View File

@ -27,24 +27,11 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
SharedMemory { _val: PhantomData }
}
pub fn new_expand<S: Index>(
context: &mut CubeContext,
size: S,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
let var = context.create_shared(Item::new(T::as_elem()), size);
ExpandElementTyped::new(var)
}
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
SharedMemory { _val: PhantomData }
}
pub fn vectorized_expand<S: Index>(
pub fn __expand_vectorized<S: Index>(
context: &mut CubeContext,
size: S,
vectorization_factor: UInt,
@ -60,4 +47,17 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
);
ExpandElementTyped::new(var)
}
pub fn __expand_new<S: Index>(
context: &mut CubeContext,
size: S,
) -> <Self as CubeType>::ExpandType {
let size = size.value();
let size = match size {
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
_ => panic!("Shared memory need constant initialization value"),
};
let var = context.create_shared(Item::new(T::as_elem()), size);
ExpandElementTyped::new(var)
}
}

View File

@ -48,7 +48,7 @@ impl UInt {
}
}
pub fn new_expand(_context: &mut CubeContext, val: u32) -> <Self as CubeType>::ExpandType {
pub fn __expand_new(_context: &mut CubeContext, val: u32) -> <Self as CubeType>::ExpandType {
let new_var = Variable::ConstantScalar {
value: val as f64,
elem: Self::as_elem(),
@ -67,13 +67,13 @@ impl UInt {
}
}
pub fn vectorized_expand(
pub fn __expand_vectorized(
context: &mut CubeContext,
val: u32,
vectorization: UInt,
) -> <Self as CubeType>::ExpandType {
if vectorization.val == 1 {
Self::new_expand(context, val)
Self::__expand_new(context, val)
} else {
let mut new_var =
context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));

View File

@ -280,11 +280,20 @@ macro_rules! impl_binary_func {
}
}
impl_binary_func!(Powf, powf, powf_expand, Operator::Powf, F16, BF16, F32, F64);
impl_binary_func!(
Powf,
powf,
__expand_powf,
Operator::Powf,
F16,
BF16,
F32,
F64
);
impl_binary_func!(
Max,
max,
max_expand,
__expand_max,
Operator::Max,
F16,
BF16,
@ -297,7 +306,7 @@ impl_binary_func!(
impl_binary_func!(
Min,
min,
min_expand,
__expand_min,
Operator::Min,
F16,
BF16,
@ -310,7 +319,7 @@ impl_binary_func!(
impl_binary_func!(
Remainder,
rem,
rem_expand,
__expand_rem,
Operator::Remainder,
F16,
BF16,

View File

@ -12,7 +12,7 @@ pub trait Clamp: CubePrimitive + Sized {
fn clamp(input: Self, min_value: Self, max_value: Self) -> Self {
unexpanded!()
}
fn clamp_expand(
fn __expand_clamp(
context: &mut CubeContext,
input: Self::ExpandType,
min_value: Self::ExpandType,

View File

@ -33,7 +33,7 @@ macro_rules! impl_unary_func {
impl_unary_func!(
Abs,
abs,
abs_expand,
__expand_abs,
Operator::Abs,
F16,
BF16,
@ -43,38 +43,65 @@ impl_unary_func!(
I64,
UInt
);
impl_unary_func!(Exp, exp, exp_expand, Operator::Exp, F16, BF16, F32, F64);
impl_unary_func!(Log, log, log_expand, Operator::Log, F16, BF16, F32, F64);
impl_unary_func!(Exp, exp, __expand_exp, Operator::Exp, F16, BF16, F32, F64);
impl_unary_func!(Log, log, __expand_log, Operator::Log, F16, BF16, F32, F64);
impl_unary_func!(
Log1p,
log1p,
log1p_expand,
__expand_log1p,
Operator::Log1p,
F16,
BF16,
F32,
F64
);
impl_unary_func!(Cos, cos, cos_expand, Operator::Cos, F16, BF16, F32, F64);
impl_unary_func!(Sin, sin, sin_expand, Operator::Sin, F16, BF16, F32, F64);
impl_unary_func!(Tanh, tanh, tanh_expand, Operator::Tanh, F16, BF16, F32, F64);
impl_unary_func!(Sqrt, sqrt, sqrt_expand, Operator::Sqrt, F16, BF16, F32, F64);
impl_unary_func!(Cos, cos, __expand_cos, Operator::Cos, F16, BF16, F32, F64);
impl_unary_func!(Sin, sin, __expand_sin, Operator::Sin, F16, BF16, F32, F64);
impl_unary_func!(
Tanh,
tanh,
__expand_tanh,
Operator::Tanh,
F16,
BF16,
F32,
F64
);
impl_unary_func!(
Sqrt,
sqrt,
__expand_sqrt,
Operator::Sqrt,
F16,
BF16,
F32,
F64
);
impl_unary_func!(
Floor,
floor,
floor_expand,
__expand_floor,
Operator::Floor,
F16,
BF16,
F32,
F64
);
impl_unary_func!(Ceil, ceil, ceil_expand, Operator::Ceil, F16, BF16, F32, F64);
impl_unary_func!(Erf, erf, erf_expand, Operator::Erf, F16, BF16, F32, F64);
impl_unary_func!(
Ceil,
ceil,
__expand_ceil,
Operator::Ceil,
F16,
BF16,
F32,
F64
);
impl_unary_func!(Erf, erf, __expand_erf, Operator::Erf, F16, BF16, F32, F64);
impl_unary_func!(
Recip,
recip,
recip_expand,
__expand_recip,
Operator::Recip,
F16,
BF16,

View File

@ -19,107 +19,143 @@ pub fn subcube_elect_expand<E: CubePrimitive>(context: &mut CubeContext) -> Expa
output
}
pub fn subcube_sum<E: CubePrimitive>(_elem: E) -> E {
/// Perform a reduce sum operation across all units in a subcube.
#[allow(unused_variables)]
pub fn subcube_sum<E: CubePrimitive>(value: E) -> E {
unexpanded!()
}
pub fn subcube_sum_expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
/// Module containing the expand function for [subcube_sum()].
pub mod subcube_sum {
use super::*;
let out = *output;
let input = *elem;
/// Expand method of [subcube_sum()].
pub fn __expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
context.register(Operation::Subcube(Subcube::Sum(UnaryOperator {
input,
out,
})));
let out = *output;
let input = *elem;
output
context.register(Operation::Subcube(Subcube::Sum(UnaryOperator {
input,
out,
})));
output
}
}
/// Perform a reduce prod operation across all units in a subcube.
pub fn subcube_prod<E: CubePrimitive>(_elem: E) -> E {
unexpanded!()
}
pub fn subcube_prod_expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
/// Module containing the expand function for [subcube_prod()].
pub mod subcube_prod {
use super::*;
let out = *output;
let input = *elem;
/// Expand method of [subcube_prod()].
pub fn __expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
context.register(Operation::Subcube(Subcube::Prod(UnaryOperator {
input,
out,
})));
let out = *output;
let input = *elem;
output
context.register(Operation::Subcube(Subcube::Prod(UnaryOperator {
input,
out,
})));
output
}
}
/// Perform a reduce max operation across all units in a subcube.
pub fn subcube_max<E: CubePrimitive>(_elem: E) -> E {
unexpanded!()
}
pub fn subcube_max_expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
/// Module containing the expand function for [subcube_max()].
pub mod subcube_max {
use super::*;
let out = *output;
let input = *elem;
/// Expand method of [subcube_max()].
pub fn __expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
context.register(Operation::Subcube(Subcube::Max(UnaryOperator {
input,
out,
})));
let out = *output;
let input = *elem;
output
context.register(Operation::Subcube(Subcube::Max(UnaryOperator {
input,
out,
})));
output
}
}
/// Perform a reduce min operation across all units in a subcube.
pub fn subcube_min<E: CubePrimitive>(_elem: E) -> E {
unexpanded!()
}
pub fn subcube_min_expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
/// Module containing the expand function for [subcube_min()].
pub mod subcube_min {
use super::*;
let out = *output;
let input = *elem;
/// Expand method of [subcube_min()].
pub fn __expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
context.register(Operation::Subcube(Subcube::Min(UnaryOperator {
input,
out,
})));
let out = *output;
let input = *elem;
output
context.register(Operation::Subcube(Subcube::Min(UnaryOperator {
input,
out,
})));
output
}
}
/// Perform a reduce all operation across all units in a subcube.
pub fn subcube_all<E: CubePrimitive>(_elem: E) -> E {
unexpanded!()
}
pub fn subcube_all_expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
/// Module containing the expand function for [subcube_all()].
pub mod subcube_all {
use super::*;
let out = *output;
let input = *elem;
/// Expand method of [subcube_all()].
pub fn __expand<E: CubePrimitive>(
context: &mut CubeContext,
elem: ExpandElement,
) -> ExpandElement {
let output = context.create_local(elem.item());
context.register(Operation::Subcube(Subcube::All(UnaryOperator {
input,
out,
})));
let out = *output;
let input = *elem;
output
context.register(Operation::Subcube(Subcube::All(UnaryOperator {
input,
out,
})));
output
}
}

View File

@ -3,6 +3,10 @@ use crate::ir::Synchronization;
pub fn sync_units() {}
pub fn sync_units_expand(context: &mut CubeContext) {
context.register(Synchronization::SyncUnits)
pub mod sync_units {
use super::*;
pub fn __expand(context: &mut CubeContext) {
context.register(Synchronization::SyncUnits)
}
}

View File

@ -11,7 +11,8 @@ pub use crate::runtime::Runtime;
/// Elements
pub use crate::frontend::{
Array, ArrayHandle, Float, LaunchArg, Slice, Tensor, TensorArg, UInt, F16, F32, F64, I32, I64,
Array, ArrayHandle, Bool, Float, LaunchArg, Slice, SliceMut, Tensor, TensorArg, UInt, F16, F32,
F64, I32, I64,
};
pub use crate::pod::CubeElement;
@ -23,10 +24,7 @@ pub use crate::frontend::{
};
/// Export subcube operations.
pub use crate::frontend::{
subcube_all, subcube_all_expand, subcube_max, subcube_max_expand, subcube_min,
subcube_min_expand, subcube_prod, subcube_prod_expand, subcube_sum, subcube_sum_expand,
};
pub use crate::frontend::{subcube_all, subcube_max, subcube_min, subcube_prod, subcube_sum};
pub use burn_compute::client::ComputeClient;
pub use crate::frontend::*;

View File

@ -64,7 +64,7 @@ pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let rhs = client.create(f16::as_bytes(&rhs));
let out = client.empty(core::mem::size_of::<f32>() * 256);
kernel_simple_1_launch::<R>(
kernel_simple_1::launch::<R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(16, 16, 1),

View File

@ -18,7 +18,7 @@ pub fn kernel_without_generics(output: &mut Array<F32>) {
pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));
kernel_with_generics_launch::<F32, R>(
kernel_with_generics::launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::default(),
@ -34,7 +34,7 @@ pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R:
pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));
kernel_without_generics_launch::<R>(
kernel_without_generics::launch::<R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::default(),

View File

@ -30,7 +30,7 @@ pub fn test_slice_select<R: Runtime>(client: ComputeClient<R::Server, R::Channel
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
let output = client.empty(core::mem::size_of::<f32>());
slice_select_launch::<F32, R>(
slice_select::launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(1, 1, 1),
@ -48,7 +48,7 @@ pub fn test_slice_len<R: Runtime>(client: ComputeClient<R::Server, R::Channel>)
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
let output = client.empty(core::mem::size_of::<u32>());
slice_len_launch::<F32, R>(
slice_len::launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(1, 1, 1),
@ -66,7 +66,7 @@ pub fn test_slice_assign<R: Runtime>(client: ComputeClient<R::Server, R::Channel
let input = client.create(f32::as_bytes(&[15.0]));
let output = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
slice_assign_launch::<F32, R>(
slice_assign::launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::new(1, 1, 1),

View File

@ -49,7 +49,7 @@ pub fn test_subcube_sum<TestRuntime: Runtime>(
&[17.0, 5.0, 7.0, 1.0],
client.clone(),
|cube_count, cube_dim, handle| {
kernel_sum_launch::<F32, TestRuntime>(client.clone(), cube_count, cube_dim, handle)
kernel_sum::launch::<F32, TestRuntime>(client.clone(), cube_count, cube_dim, handle)
},
);
}
@ -62,7 +62,7 @@ pub fn test_subcube_prod<TestRuntime: Runtime>(
&[140.0, 5.0, 7.0, 1.0],
client.clone(),
|cube_dim, settings, handle| {
kernel_prod_launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
kernel_prod::launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
},
);
}
@ -74,7 +74,7 @@ pub fn test_subcube_max<TestRuntime: Runtime>(
&[7.0, 5.0, 7.0, 1.0],
client.clone(),
|cube_dim, settings, handle| {
kernel_max_launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
kernel_max::launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
},
);
}
@ -87,7 +87,7 @@ pub fn test_subcube_min<TestRuntime: Runtime>(
&[1.0, 5.0, 7.0, 1.0],
client.clone(),
|cube_dim, settings, handle| {
kernel_min_launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
kernel_min::launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
},
);
}

View File

@ -1,14 +1,14 @@
use burn_cube::prelude::*;
#[cube]
fn array_read_write<T: Numeric>(array_size: Comptime<u32>) {
pub fn array_read_write<T: Numeric>(array_size: Comptime<u32>) {
let mut array = Array::<T>::new(array_size);
array[0] = T::from_int(3);
let _ = array[0];
}
#[cube]
fn array_to_vectorized_variable<T: Numeric>() -> T {
pub fn array_to_vectorized_variable<T: Numeric>() -> T {
let mut array = Array::<T>::new(2);
array[0] = T::from_int(0);
array[1] = T::from_int(1);
@ -16,19 +16,19 @@ fn array_to_vectorized_variable<T: Numeric>() -> T {
}
#[cube]
fn array_of_one_to_vectorized_variable<T: Numeric>() -> T {
pub fn array_of_one_to_vectorized_variable<T: Numeric>() -> T {
let mut array = Array::<T>::new(1);
array[0] = T::from_int(3);
array.to_vectorized(Comptime::new(UInt::new(1)))
}
#[cube]
fn array_add_assign_simple(array: &mut Array<UInt>) {
pub fn array_add_assign_simple(array: &mut Array<UInt>) {
array[UInt::new(1)] += UInt::new(1);
}
#[cube]
fn array_add_assign_expr(array: &mut Array<UInt>) {
pub fn array_add_assign_expr(array: &mut Array<UInt>) {
array[UInt::new(1) + UInt::new(5)] += UInt::new(1);
}
@ -45,7 +45,7 @@ mod tests {
fn cube_support_array() {
let mut context = CubeContext::root();
array_read_write_expand::<ElemType>(&mut context, 512);
array_read_write::__expand::<ElemType>(&mut context, 512);
assert_eq!(
context.into_scope().operations,
inline_macro_ref_read_write()
@ -57,7 +57,7 @@ mod tests {
let mut context = CubeContext::root();
let array = context.input(0, Item::new(Elem::UInt));
array_add_assign_simple_expand(&mut context, array.into());
array_add_assign_simple::__expand(&mut context, array.into());
let scope = context.into_scope();
assert_eq!(scope.operations, inline_macro_array_add_assign_simple());
@ -67,7 +67,7 @@ mod tests {
fn cube_array_to_vectorized() {
let mut context = CubeContext::root();
array_to_vectorized_variable_expand::<ElemType>(&mut context);
array_to_vectorized_variable::__expand::<ElemType>(&mut context);
assert_eq!(
context.into_scope().operations,
inline_macro_ref_to_vectorized()
@ -78,7 +78,7 @@ mod tests {
fn cube_array_of_one_to_vectorized() {
let mut context = CubeContext::root();
array_of_one_to_vectorized_variable_expand::<ElemType>(&mut context);
array_of_one_to_vectorized_variable::__expand::<ElemType>(&mut context);
assert_eq!(
context.into_scope().operations,
inline_macro_ref_one_to_vectorized()
@ -110,7 +110,7 @@ mod tests {
let mut context = CubeContext::root();
let array = context.input(0, Item::new(Elem::UInt));
array_add_assign_expr_expand(&mut context, array.into());
array_add_assign_expr::__expand(&mut context, array.into());
let scope = context.into_scope();
assert_eq!(scope.operations, inline_macro_array_add_assign_expr());

View File

@ -1,27 +1,27 @@
use burn_cube::prelude::*;
#[cube]
fn mut_assign() {
pub fn mut_assign() {
let mut x = UInt::new(0);
x += UInt::new(1);
}
#[cube]
fn mut_assign_input(y: UInt) -> UInt {
pub fn mut_assign_input(y: UInt) -> UInt {
let mut x = y;
x += UInt::new(1);
y + UInt::new(2)
}
#[cube]
fn assign_mut_input(mut y: UInt) -> UInt {
pub fn assign_mut_input(mut y: UInt) -> UInt {
let x = y;
y += UInt::new(1);
x + UInt::new(2)
}
#[cube]
fn assign_vectorized(y: UInt) -> UInt {
pub fn assign_vectorized(y: UInt) -> UInt {
let vectorization_factor = Comptime::vectorization(&y);
let x = UInt::vectorized(1, Comptime::get(vectorization_factor));
x + y
@ -38,7 +38,7 @@ mod tests {
fn cube_mut_assign_test() {
let mut context = CubeContext::root();
mut_assign_expand(&mut context);
mut_assign::__expand(&mut context);
let scope = context.into_scope();
assert_eq!(
@ -53,7 +53,7 @@ mod tests {
let y = context.create_local(Item::new(UInt::as_elem()));
mut_assign_input_expand(&mut context, y);
mut_assign_input::__expand(&mut context, y);
let scope = context.into_scope();
assert_eq!(
@ -68,7 +68,7 @@ mod tests {
let y = context.create_local(Item::new(UInt::as_elem()));
assign_mut_input_expand(&mut context, y);
assign_mut_input::__expand(&mut context, y);
let scope = context.into_scope();
assert_eq!(
@ -83,7 +83,7 @@ mod tests {
let y = context.create_local(Item::vectorized(UInt::as_elem(), 4));
assign_vectorized_expand(&mut context, y);
assign_vectorized::__expand(&mut context, y);
let scope = context.into_scope();
assert_eq!(

View File

@ -1,6 +1,6 @@
use burn_cube::{
cube,
frontend::{Cast, Numeric, UInt, F32, I32},
frontend::{Bool, Cast, Numeric, UInt, F32, I32},
};
// From float
@ -26,7 +26,7 @@ pub fn float_to_uint(x: F32) {
#[allow(clippy::overly_complex_bool_expr)]
pub fn float_to_bool(x: F32) {
let y = x + F32::from_int(2);
let _ = bool::cast_from(y) || true;
let _ = Bool::cast_from(y) || true;
}
// From int
@ -53,7 +53,7 @@ pub fn int_to_uint(x: I32) {
#[allow(clippy::overly_complex_bool_expr)]
pub fn int_to_bool(x: I32) {
let y = x + I32::from_int(2);
let _ = bool::cast_from(y) || true;
let _ = Bool::cast_from(y) || true;
}
// // From uint
@ -80,27 +80,27 @@ pub fn uint_to_uint(x: UInt) {
#[allow(clippy::overly_complex_bool_expr)]
pub fn uint_to_bool(x: UInt) {
let y = x + UInt::from_int(2);
let _ = bool::cast_from(y) || true;
let _ = Bool::cast_from(y) || true;
}
// From bool
#[cube]
#[allow(clippy::overly_complex_bool_expr)]
pub fn bool_to_float(x: bool) {
pub fn bool_to_float(x: Bool) {
let y = x && false;
let _ = F32::cast_from(y) + F32::from_int(34);
}
#[cube]
#[allow(clippy::overly_complex_bool_expr)]
pub fn bool_to_int(x: bool) {
pub fn bool_to_int(x: Bool) {
let y = x && false;
let _ = I32::cast_from(y) + I32::from_int(34);
}
#[cube]
#[allow(clippy::overly_complex_bool_expr)]
pub fn bool_to_uint(x: bool) {
pub fn bool_to_uint(x: Bool) {
let y = x && false;
let _ = UInt::cast_from(y) + UInt::from_int(34);
}
@ -108,9 +108,9 @@ pub fn bool_to_uint(x: bool) {
#[cube]
#[allow(clippy::overly_complex_bool_expr)]
#[allow(clippy::useless_conversion)]
pub fn bool_to_bool(x: bool) {
pub fn bool_to_bool(x: Bool) {
let y = x && false;
let _ = bool::cast_from(y) || true;
let _ = Bool::cast_from(y) || true;
}
mod tests {
@ -122,7 +122,7 @@ mod tests {
};
macro_rules! cast_test {
($name:ident, $module:ident, $from:expr, $to:expr) => {
($name:ident, $module:expr, $from:expr, $to:expr) => {
#[test]
fn $name() {
let mut context = CubeContext::root();
@ -142,112 +142,112 @@ mod tests {
cast_test!(
cube_float_to_float_test,
float_to_float_expand,
float_to_float::__expand,
Item::new(F32::as_elem()),
Item::new(F32::as_elem())
);
cast_test!(
cube_float_to_int_test,
float_to_int_expand,
float_to_int::__expand,
Item::new(F32::as_elem()),
Item::new(I32::as_elem())
);
cast_test!(
cube_float_to_uint_test,
float_to_uint_expand,
float_to_uint::__expand,
Item::new(F32::as_elem()),
Item::new(Elem::UInt)
);
cast_test!(
cube_float_to_bool_test,
float_to_bool_expand,
float_to_bool::__expand,
Item::new(F32::as_elem()),
Item::new(Elem::Bool)
);
cast_test!(
cube_int_to_float_test,
int_to_float_expand,
int_to_float::__expand,
Item::new(I32::as_elem()),
Item::new(F32::as_elem())
);
cast_test!(
cube_int_to_int_test,
int_to_int_expand,
int_to_int::__expand,
Item::new(I32::as_elem()),
Item::new(I32::as_elem())
);
cast_test!(
cube_int_to_uint_test,
int_to_uint_expand,
int_to_uint::__expand,
Item::new(I32::as_elem()),
Item::new(Elem::UInt)
);
cast_test!(
cube_int_to_bool_test,
int_to_bool_expand,
int_to_bool::__expand,
Item::new(I32::as_elem()),
Item::new(Elem::Bool)
);
cast_test!(
cube_uint_to_float_test,
uint_to_float_expand,
uint_to_float::__expand,
Item::new(Elem::UInt),
Item::new(F32::as_elem())
);
cast_test!(
cube_uint_to_int_test,
uint_to_int_expand,
uint_to_int::__expand,
Item::new(Elem::UInt),
Item::new(I32::as_elem())
);
cast_test!(
cube_uint_to_uint_test,
uint_to_uint_expand,
uint_to_uint::__expand,
Item::new(Elem::UInt),
Item::new(Elem::UInt)
);
cast_test!(
cube_uint_to_bool_test,
uint_to_bool_expand,
uint_to_bool::__expand,
Item::new(Elem::UInt),
Item::new(Elem::Bool)
);
cast_test!(
cube_bool_to_float_test,
bool_to_float_expand,
bool_to_float::__expand,
Item::new(Elem::Bool),
Item::new(F32::as_elem())
);
cast_test!(
cube_bool_to_int_test,
bool_to_int_expand,
bool_to_int::__expand,
Item::new(Elem::Bool),
Item::new(I32::as_elem())
);
cast_test!(
cube_bool_to_uint_test,
bool_to_uint_expand,
bool_to_uint::__expand,
Item::new(Elem::Bool),
Item::new(Elem::UInt)
);
cast_test!(
cube_bool_to_bool_test,
bool_to_bool_expand,
bool_to_bool::__expand,
Item::new(Elem::Bool),
Item::new(Elem::Bool)
);

View File

@ -46,7 +46,7 @@ mod tests {
let input = context.create_local(item);
cast_float_kind_expand::<F64, F32>(&mut context, input);
cast_float_kind::__expand::<F64, F32>(&mut context, input);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float());
@ -59,7 +59,7 @@ mod tests {
let input = context.create_local(item);
cast_int_kind_expand::<I32, I64>(&mut context, input);
cast_int_kind::__expand::<I32, I64>(&mut context, input);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int());
@ -72,7 +72,7 @@ mod tests {
let input = context.create_local(item);
cast_numeric_to_kind_expand::<I32, I64>(&mut context, input);
cast_numeric_to_kind::__expand::<I32, I64>(&mut context, input);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int());
@ -85,7 +85,7 @@ mod tests {
let input = context.create_local(item);
cast_int_to_numeric_expand::<I32, I64>(&mut context, input);
cast_int_to_numeric::__expand::<I32, I64>(&mut context, input);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int());

View File

@ -122,7 +122,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
comptime_if_else_expand::<ElemType>(&mut context, lhs, true);
comptime_if_else::__expand::<ElemType>(&mut context, lhs, true);
let scope = context.into_scope();
assert_eq!(
@ -137,7 +137,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
comptime_if_expr_expand::<ElemType>(&mut context, lhs, UInt::new(4), UInt::new(5));
comptime_if_expr::__expand::<ElemType>(&mut context, lhs, UInt::new(4), UInt::new(5));
let scope = context.into_scope();
assert_eq!(
@ -152,7 +152,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
comptime_if_else_expand::<ElemType>(&mut context, lhs, false);
comptime_if_else::__expand::<ElemType>(&mut context, lhs, false);
let scope = context.into_scope();
assert_eq!(
@ -167,12 +167,12 @@ mod tests {
for cond2 in [false, true] {
let mut context1 = CubeContext::root();
let lhs = context1.create_local(Item::new(ElemType::as_elem()));
comptime_else_then_if_expand::<ElemType>(&mut context1, lhs, cond1, cond2);
comptime_else_then_if::__expand::<ElemType>(&mut context1, lhs, cond1, cond2);
let scope1 = context1.into_scope();
let mut context2 = CubeContext::root();
let lhs = context2.create_local(Item::new(ElemType::as_elem()));
comptime_elsif_expand::<ElemType>(&mut context2, lhs, cond1, cond2);
comptime_elsif::__expand::<ElemType>(&mut context2, lhs, cond1, cond2);
let scope2 = context2.into_scope();
assert_eq!(
@ -188,7 +188,7 @@ mod tests {
for cond in [false, true] {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
comptime_elsif_with_runtime1_expand::<ElemType>(&mut context, lhs, cond);
comptime_elsif_with_runtime1::__expand::<ElemType>(&mut context, lhs, cond);
let scope = context.into_scope();
assert_eq!(
@ -203,7 +203,7 @@ mod tests {
for cond in [false, true] {
let mut context = CubeContext::root();
let lhs = context.create_local(Item::new(ElemType::as_elem()));
comptime_elsif_with_runtime2_expand::<ElemType>(&mut context, lhs, cond);
comptime_elsif_with_runtime2::__expand::<ElemType>(&mut context, lhs, cond);
let scope = context.into_scope();
assert_eq!(
@ -227,8 +227,8 @@ mod tests {
bound: 4,
};
comptime_with_map_bool_expand::<ElemType>(&mut context1, comptime_state_true);
comptime_with_map_bool_expand::<ElemType>(&mut context2, comptime_state_false);
comptime_with_map_bool::__expand::<ElemType>(&mut context1, comptime_state_true);
comptime_with_map_bool::__expand::<ElemType>(&mut context2, comptime_state_false);
let scope1 = context1.into_scope();
let scope2 = context2.into_scope();
@ -248,7 +248,7 @@ mod tests {
bound: 4,
};
comptime_with_map_uint_expand::<ElemType>(&mut context, comptime_state);
comptime_with_map_uint::__expand::<ElemType>(&mut context, comptime_state);
let scope = context.into_scope();

View File

@ -42,12 +42,12 @@ impl<C: Float> CombinedTraitFunctionGeneric<C> for Test {
}
#[cube]
fn simple<C: Float>(lhs: C, rhs: C) -> C {
pub fn simple<C: Float>(lhs: C, rhs: C) -> C {
lhs + rhs
}
#[cube]
fn with_cast<C: Float, O: Numeric>(lhs: C, rhs: C) -> O {
pub fn with_cast<C: Float, O: Numeric>(lhs: C, rhs: C) -> O {
O::cast_from(lhs + rhs)
}
@ -62,7 +62,7 @@ mod tests {
let lhs = context.create_local(Item::new(F32::as_elem()));
let rhs = context.create_local(Item::new(F32::as_elem()));
<Test as FunctionGeneric>::test_expand::<F32>(&mut context, lhs, rhs);
<Test as FunctionGeneric>::__expand_test::<F32>(&mut context, lhs, rhs);
assert_eq!(simple_scope(), context.into_scope());
}
@ -73,7 +73,7 @@ mod tests {
let lhs = context.create_local(Item::new(F32::as_elem()));
let rhs = context.create_local(Item::new(F32::as_elem()));
<Test as TraitGeneric<F32>>::test_expand(&mut context, lhs, rhs);
<Test as TraitGeneric<F32>>::__expand_test(&mut context, lhs, rhs);
assert_eq!(simple_scope(), context.into_scope());
}
@ -84,7 +84,7 @@ mod tests {
let lhs = context.create_local(Item::new(F32::as_elem()));
let rhs = context.create_local(Item::new(F32::as_elem()));
<Test as CombinedTraitFunctionGeneric<F32>>::test_expand::<UInt>(&mut context, lhs, rhs);
<Test as CombinedTraitFunctionGeneric<F32>>::__expand_test::<UInt>(&mut context, lhs, rhs);
assert_eq!(with_cast_scope(), context.into_scope());
}
@ -94,7 +94,7 @@ mod tests {
let lhs = context_ref.create_local(Item::new(F32::as_elem()));
let rhs = context_ref.create_local(Item::new(F32::as_elem()));
simple_expand::<F32>(&mut context_ref, lhs, rhs);
simple::__expand::<F32>(&mut context_ref, lhs, rhs);
context_ref.into_scope()
}
@ -103,7 +103,7 @@ mod tests {
let lhs = context_ref.create_local(Item::new(F32::as_elem()));
let rhs = context_ref.create_local(Item::new(F32::as_elem()));
with_cast_expand::<F32, UInt>(&mut context_ref, lhs, rhs);
with_cast::__expand::<F32, UInt>(&mut context_ref, lhs, rhs);
context_ref.into_scope()
}
}

View File

@ -30,7 +30,7 @@ mod tests {
let rhs = context.create_local(Item::new(ElemType::as_elem()));
let end = 4u32.into();
for_loop_expand::<ElemType>(&mut context, lhs.into(), rhs, end, unroll);
for_loop::__expand::<ElemType>(&mut context, lhs.into(), rhs, end, unroll);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll));
@ -45,7 +45,7 @@ mod tests {
let rhs = context.create_local(Item::new(ElemType::as_elem()));
let end = 4u32.into();
for_loop_expand::<ElemType>(&mut context, lhs.into(), rhs, end, unroll);
for_loop::__expand::<ElemType>(&mut context, lhs.into(), rhs, end, unroll);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll));

View File

@ -59,12 +59,12 @@ mod tests {
fn cube_call_equivalent_to_no_call_no_arg_test() {
let mut caller_context = CubeContext::root();
let x = caller_context.create_local(Item::new(Elem::UInt));
caller_no_arg_expand(&mut caller_context, x);
caller_no_arg::__expand(&mut caller_context, x);
let caller_scope = caller_context.into_scope();
let mut no_call_context = CubeContext::root();
let x = no_call_context.create_local(Item::new(Elem::UInt));
no_call_no_arg_expand(&mut no_call_context, x);
no_call_no_arg::__expand(&mut no_call_context, x);
let no_call_scope = no_call_context.into_scope();
assert_eq!(
@ -78,12 +78,12 @@ mod tests {
let mut caller_context = CubeContext::root();
let x = caller_context.create_local(Item::new(Elem::UInt));
caller_with_arg_expand(&mut caller_context, x);
caller_with_arg::__expand(&mut caller_context, x);
let caller_scope = caller_context.into_scope();
let mut no_call_context = CubeContext::root();
let x = no_call_context.create_local(Item::new(Elem::UInt));
no_call_with_arg_expand(&mut no_call_context, x);
no_call_with_arg::__expand(&mut no_call_context, x);
let no_call_scope = no_call_context.into_scope();
assert_eq!(
@ -97,12 +97,12 @@ mod tests {
let mut caller_context = CubeContext::root();
type ElemType = I64;
let x = caller_context.create_local(Item::new(ElemType::as_elem()));
caller_with_generics_expand::<ElemType>(&mut caller_context, x);
caller_with_generics::__expand::<ElemType>(&mut caller_context, x);
let caller_scope = caller_context.into_scope();
let mut no_call_context = CubeContext::root();
let x = no_call_context.create_local(Item::new(ElemType::as_elem()));
no_call_with_generics_expand::<ElemType>(&mut no_call_context, x);
no_call_with_generics::__expand::<ElemType>(&mut no_call_context, x);
let no_call_scope = no_call_context.into_scope();
assert_eq!(

View File

@ -20,7 +20,7 @@ mod tests {
let lhs = context.create_local(Item::new(F32::as_elem()));
generic_kernel_expand::<F32>(&mut context, lhs);
generic_kernel::__expand::<F32>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float());
@ -32,7 +32,7 @@ mod tests {
let lhs = context.create_local(Item::new(I32::as_elem()));
generic_kernel_expand::<I32>(&mut context, lhs);
generic_kernel::__expand::<I32>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int());

View File

@ -52,7 +52,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
if_greater_expand::<ElemType>(&mut context, lhs);
if_greater::__expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_if());
@ -64,7 +64,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
if_then_else_expand::<ElemType>(&mut context, lhs);
if_then_else::__expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(
@ -79,7 +79,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
elsif_expand::<ElemType>(&mut context, lhs);
elsif::__expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif());

View File

@ -25,7 +25,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
literal_expand::<ElemType>(&mut context, lhs);
literal::__expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
@ -37,7 +37,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
literal_float_no_decimals_expand::<ElemType>(&mut context, lhs);
literal_float_no_decimals::__expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());

View File

@ -42,7 +42,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
while_not_expand::<ElemType>(&mut context, lhs);
while_not::__expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false));
@ -54,7 +54,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
manual_loop_break_expand::<ElemType>(&mut context, lhs);
manual_loop_break::__expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false));
@ -66,7 +66,7 @@ mod tests {
let lhs = context.create_local(Item::new(ElemType::as_elem()));
loop_with_return_expand::<ElemType>(&mut context, lhs);
loop_with_return::__expand::<ElemType>(&mut context, lhs);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(true));

View File

@ -33,12 +33,12 @@ mod tests {
fn cube_call_equivalent_to_no_call_no_arg_test() {
let mut caller_context = CubeContext::root();
let x = caller_context.create_local(Item::new(ElemType::as_elem()));
here::caller_expand::<ElemType>(&mut caller_context, x);
here::caller::__expand::<ElemType>(&mut caller_context, x);
let caller_scope = caller_context.into_scope();
let mut no_call_context = CubeContext::root();
let x = no_call_context.create_local(Item::new(ElemType::as_elem()));
here::no_call_ref_expand::<ElemType>(&mut no_call_context, x);
here::no_call_ref::__expand::<ElemType>(&mut no_call_context, x);
let no_call_scope = no_call_context.into_scope();
assert_eq!(

View File

@ -1,192 +1,192 @@
use burn_cube::prelude::*;
#[cube]
fn add_op<T: Numeric>(a: T, b: T) -> T {
pub fn add_op<T: Numeric>(a: T, b: T) -> T {
a + b
}
#[cube]
fn sub_op<T: Numeric>(a: T, b: T) -> T {
pub fn sub_op<T: Numeric>(a: T, b: T) -> T {
a - b
}
#[cube]
fn mul_op<T: Numeric>(a: T, b: T) -> T {
pub fn mul_op<T: Numeric>(a: T, b: T) -> T {
a * b
}
#[cube]
fn div_op<T: Numeric>(a: T, b: T) -> T {
pub fn div_op<T: Numeric>(a: T, b: T) -> T {
a / b
}
#[cube]
fn abs_op<T: Numeric>(a: T) -> T {
pub fn abs_op<T: Numeric>(a: T) -> T {
T::abs(a)
}
#[cube]
fn exp_op<F: Float>(a: F) -> F {
pub fn exp_op<F: Float>(a: F) -> F {
F::exp(a)
}
#[cube]
fn log_op<F: Float>(a: F) -> F {
pub fn log_op<F: Float>(a: F) -> F {
F::log(a)
}
#[cube]
fn log1p_op<F: Float>(a: F) -> F {
pub fn log1p_op<F: Float>(a: F) -> F {
F::log1p(a)
}
#[cube]
fn cos_op<F: Float>(a: F) -> F {
pub fn cos_op<F: Float>(a: F) -> F {
F::cos(a)
}
#[cube]
fn sin_op<F: Float>(a: F) -> F {
pub fn sin_op<F: Float>(a: F) -> F {
F::sin(a)
}
#[cube]
fn tanh_op<F: Float>(a: F) -> F {
pub fn tanh_op<F: Float>(a: F) -> F {
F::tanh(a)
}
#[cube]
fn powf_op<F: Float>(a: F, b: F) -> F {
pub fn powf_op<F: Float>(a: F, b: F) -> F {
F::powf(a, b)
}
#[cube]
fn sqrt_op<F: Float>(a: F) -> F {
pub fn sqrt_op<F: Float>(a: F) -> F {
F::sqrt(a)
}
#[cube]
fn floor_op<F: Float>(a: F) -> F {
pub fn floor_op<F: Float>(a: F) -> F {
F::floor(a)
}
#[cube]
fn ceil_op<F: Float>(a: F) -> F {
pub fn ceil_op<F: Float>(a: F) -> F {
F::ceil(a)
}
#[cube]
fn erf_op<F: Float>(a: F) -> F {
pub fn erf_op<F: Float>(a: F) -> F {
F::erf(a)
}
#[cube]
fn recip_op<F: Float>(a: F) -> F {
pub fn recip_op<F: Float>(a: F) -> F {
F::recip(a)
}
#[cube]
fn equal_op<T: CubePrimitive>(a: T, b: T) -> bool {
pub fn equal_op<T: CubePrimitive>(a: T, b: T) -> bool {
a == b
}
#[cube]
fn not_equal_op<T: CubePrimitive>(a: T, b: T) -> bool {
pub fn not_equal_op<T: CubePrimitive>(a: T, b: T) -> bool {
a != b
}
#[cube]
fn lower_op<T: Numeric>(a: T, b: T) -> bool {
pub fn lower_op<T: Numeric>(a: T, b: T) -> bool {
a < b
}
#[cube]
fn greater_op<T: Numeric>(a: T, b: T) -> bool {
pub fn greater_op<T: Numeric>(a: T, b: T) -> bool {
a > b
}
#[cube]
fn lower_equal_op<T: Numeric>(a: T, b: T) -> bool {
pub fn lower_equal_op<T: Numeric>(a: T, b: T) -> bool {
a <= b
}
#[cube]
fn greater_equal_op<T: Numeric>(a: T, b: T) -> bool {
pub fn greater_equal_op<T: Numeric>(a: T, b: T) -> bool {
a >= b
}
#[cube]
fn modulo_op(a: UInt, b: UInt) -> UInt {
pub fn modulo_op(a: UInt, b: UInt) -> UInt {
a % b
}
#[cube]
fn remainder_op<T: Numeric>(a: T, b: T) -> T {
pub fn remainder_op<T: Numeric>(a: T, b: T) -> T {
T::rem(a, b)
}
#[cube]
fn max_op<T: Numeric>(a: T, b: T) -> T {
pub fn max_op<T: Numeric>(a: T, b: T) -> T {
T::max(a, b)
}
#[cube]
fn min_op<T: Numeric>(a: T, b: T) -> T {
pub fn min_op<T: Numeric>(a: T, b: T) -> T {
T::min(a, b)
}
#[cube]
fn and_op(a: bool, b: bool) -> bool {
pub fn and_op(a: bool, b: bool) -> bool {
a && b
}
#[cube]
fn or_op(a: bool, b: bool) -> bool {
pub fn or_op(a: bool, b: bool) -> bool {
a || b
}
#[cube]
fn not_op(a: bool) -> bool {
pub fn not_op(a: bool) -> bool {
!a
}
#[cube]
fn bitand_op(a: UInt, b: UInt) -> UInt {
pub fn bitand_op(a: UInt, b: UInt) -> UInt {
a & b
}
#[cube]
fn bitxor_op(a: UInt, b: UInt) -> UInt {
pub fn bitxor_op(a: UInt, b: UInt) -> UInt {
a ^ b
}
#[cube]
fn shl_op(a: UInt, b: UInt) -> UInt {
pub fn shl_op(a: UInt, b: UInt) -> UInt {
a << b
}
#[cube]
fn shr_op(a: UInt, b: UInt) -> UInt {
pub fn shr_op(a: UInt, b: UInt) -> UInt {
a >> b
}
#[cube]
fn add_assign_op<T: Numeric>(mut a: T, b: T) {
pub fn add_assign_op<T: Numeric>(mut a: T, b: T) {
a += b;
}
#[cube]
fn sub_assign_op<T: Numeric>(mut a: T, b: T) {
pub fn sub_assign_op<T: Numeric>(mut a: T, b: T) {
a -= b;
}
#[cube]
fn mul_assign_op<T: Numeric>(mut a: T, b: T) {
pub fn mul_assign_op<T: Numeric>(mut a: T, b: T) {
a *= b;
}
#[cube]
fn div_assign_op<T: Numeric>(mut a: T, b: T) {
pub fn div_assign_op<T: Numeric>(mut a: T, b: T) {
a /= b;
}
@ -195,14 +195,14 @@ mod tests {
use burn_cube::ir::{Elem, FloatKind, Item};
macro_rules! binary_test {
($test_name:ident, $op_expand:ident, $op_name:expr, $func:ident) => {
($test_name:ident, $op_expand:expr, $op_name:expr, $func:ident) => {
#[test]
fn $test_name() {
let mut context = CubeContext::root();
let x = context.create_local(Item::new(Elem::Float(FloatKind::F32)));
let y = context.create_local(Item::new(Elem::Float(FloatKind::F32)));
$op_expand::<F32>(&mut context, x, y);
$op_expand(&mut context, x, y);
assert_eq!(
format!("{:?}", context.into_scope().operations),
@ -213,13 +213,13 @@ mod tests {
}
macro_rules! unary_test {
($test_name:ident, $op_expand:ident, $op_name:expr) => {
($test_name:ident, $op_expand:expr, $op_name:expr) => {
#[test]
fn $test_name() {
let mut context = CubeContext::root();
let x = context.create_local(Item::new(Elem::Float(FloatKind::F32)));
$op_expand::<F32>(&mut context, x);
$op_expand(&mut context, x);
assert_eq!(
format!("{:?}", context.into_scope().operations),
@ -230,7 +230,7 @@ mod tests {
}
macro_rules! binary_boolean_test {
($test_name:ident, $op_expand:ident, $op_name:expr) => {
($test_name:ident, $op_expand:expr, $op_name:expr) => {
#[test]
fn $test_name() {
let mut context = CubeContext::root();
@ -248,7 +248,7 @@ mod tests {
}
macro_rules! binary_uint_test {
($test_name:ident, $op_expand:ident, $op_name:expr) => {
($test_name:ident, $op_expand:expr, $op_name:expr) => {
#[test]
fn $test_name() {
let mut context = CubeContext::root();
@ -265,75 +265,90 @@ mod tests {
};
}
binary_test!(cube_can_add, add_op_expand, "Add", ref_ops_binary);
binary_test!(cube_can_sub, sub_op_expand, "Sub", ref_ops_binary);
binary_test!(cube_can_mul, mul_op_expand, "Mul", ref_ops_binary);
binary_test!(cube_can_div, div_op_expand, "Div", ref_ops_binary);
unary_test!(cube_can_abs, abs_op_expand, "Abs");
unary_test!(cube_can_exp, exp_op_expand, "Exp");
unary_test!(cube_can_log, log_op_expand, "Log");
unary_test!(cube_can_log1p, log1p_op_expand, "Log1p");
unary_test!(cube_can_cos, cos_op_expand, "Cos");
unary_test!(cube_can_sin, sin_op_expand, "Sin");
unary_test!(cube_can_tanh, tanh_op_expand, "Tanh");
binary_test!(cube_can_powf, powf_op_expand, "Powf", ref_ops_binary);
unary_test!(cube_can_sqrt, sqrt_op_expand, "Sqrt");
unary_test!(cube_can_erf, erf_op_expand, "Erf");
unary_test!(cube_can_recip, recip_op_expand, "Recip");
unary_test!(cube_can_floor, floor_op_expand, "Floor");
unary_test!(cube_can_ceil, ceil_op_expand, "Ceil");
binary_test!(cube_can_eq, equal_op_expand, "Equal", ref_ops_cmp);
binary_test!(cube_can_ne, not_equal_op_expand, "NotEqual", ref_ops_cmp);
binary_test!(cube_can_lt, lower_op_expand, "Lower", ref_ops_cmp);
binary_test!(cube_can_add, add_op::__expand::<F32>, "Add", ref_ops_binary);
binary_test!(cube_can_sub, sub_op::__expand::<F32>, "Sub", ref_ops_binary);
binary_test!(cube_can_mul, mul_op::__expand::<F32>, "Mul", ref_ops_binary);
binary_test!(cube_can_div, div_op::__expand::<F32>, "Div", ref_ops_binary);
unary_test!(cube_can_abs, abs_op::__expand::<F32>, "Abs");
unary_test!(cube_can_exp, exp_op::__expand::<F32>, "Exp");
unary_test!(cube_can_log, log_op::__expand::<F32>, "Log");
unary_test!(cube_can_log1p, log1p_op::__expand::<F32>, "Log1p");
unary_test!(cube_can_cos, cos_op::__expand::<F32>, "Cos");
unary_test!(cube_can_sin, sin_op::__expand::<F32>, "Sin");
unary_test!(cube_can_tanh, tanh_op::__expand::<F32>, "Tanh");
binary_test!(
cube_can_powf,
powf_op::__expand::<F32>,
"Powf",
ref_ops_binary
);
unary_test!(cube_can_sqrt, sqrt_op::__expand::<F32>, "Sqrt");
unary_test!(cube_can_erf, erf_op::__expand::<F32>, "Erf");
unary_test!(cube_can_recip, recip_op::__expand::<F32>, "Recip");
unary_test!(cube_can_floor, floor_op::__expand::<F32>, "Floor");
unary_test!(cube_can_ceil, ceil_op::__expand::<F32>, "Ceil");
binary_test!(cube_can_eq, equal_op::__expand::<F32>, "Equal", ref_ops_cmp);
binary_test!(
cube_can_ne,
not_equal_op::__expand::<F32>,
"NotEqual",
ref_ops_cmp
);
binary_test!(cube_can_lt, lower_op::__expand::<F32>, "Lower", ref_ops_cmp);
binary_test!(
cube_can_le,
lower_equal_op_expand,
lower_equal_op::__expand::<F32>,
"LowerEqual",
ref_ops_cmp
);
binary_test!(
cube_can_ge,
greater_equal_op_expand,
greater_equal_op::__expand::<F32>,
"GreaterEqual",
ref_ops_cmp
);
binary_test!(cube_can_gt, greater_op_expand, "Greater", ref_ops_cmp);
binary_test!(cube_can_max, max_op_expand, "Max", ref_ops_binary);
binary_test!(cube_can_min, min_op_expand, "Min", ref_ops_binary);
binary_test!(
cube_can_gt,
greater_op::__expand::<F32>,
"Greater",
ref_ops_cmp
);
binary_test!(cube_can_max, max_op::__expand::<F32>, "Max", ref_ops_binary);
binary_test!(cube_can_min, min_op::__expand::<F32>, "Min", ref_ops_binary);
binary_test!(
cube_can_add_assign,
add_assign_op_expand,
add_assign_op::__expand::<F32>,
"Add",
ref_ops_binary
);
binary_test!(
cube_can_sub_assign,
sub_assign_op_expand,
sub_assign_op::__expand::<F32>,
"Sub",
ref_ops_binary
);
binary_test!(
cube_can_mul_assign,
mul_assign_op_expand,
mul_assign_op::__expand::<F32>,
"Mul",
ref_ops_binary
);
binary_test!(
cube_can_div_assign,
div_assign_op_expand,
div_assign_op::__expand::<F32>,
"Div",
ref_ops_binary
);
binary_boolean_test!(cube_can_and, and_op_expand, "And");
binary_boolean_test!(cube_can_or, or_op_expand, "Or");
binary_uint_test!(cube_can_bitand, bitand_op_expand, "BitwiseAnd");
binary_uint_test!(cube_can_bitxor, bitxor_op_expand, "BitwiseXor");
binary_uint_test!(cube_can_shl, shl_op_expand, "ShiftLeft");
binary_uint_test!(cube_can_shr, shr_op_expand, "ShiftRight");
binary_uint_test!(cube_can_mod, modulo_op_expand, "Modulo");
binary_boolean_test!(cube_can_and, and_op::__expand, "And");
binary_boolean_test!(cube_can_or, or_op::__expand, "Or");
binary_uint_test!(cube_can_bitand, bitand_op::__expand, "BitwiseAnd");
binary_uint_test!(cube_can_bitxor, bitxor_op::__expand, "BitwiseXor");
binary_uint_test!(cube_can_shl, shl_op::__expand, "ShiftLeft");
binary_uint_test!(cube_can_shr, shr_op::__expand, "ShiftRight");
binary_uint_test!(cube_can_mod, modulo_op::__expand, "Modulo");
binary_test!(
cube_can_rem,
remainder_op_expand,
remainder_op::__expand::<F32>,
"Remainder",
ref_ops_binary
);
@ -343,7 +358,7 @@ mod tests {
let mut context = CubeContext::root();
let x = context.create_local(Item::new(Elem::Bool));
not_op_expand(&mut context, x);
not_op::__expand(&mut context, x);
assert_eq!(
format!("{:?}", context.into_scope().operations),

View File

@ -22,7 +22,7 @@ mod tests {
let y = context.create_local(Item::new(ElemType::as_elem()));
let z = context.create_local(Item::new(ElemType::as_elem()));
parenthesis_expand::<ElemType>(&mut context, x, y, z);
parenthesis::__expand::<ElemType>(&mut context, x, y, z);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());

View File

@ -53,7 +53,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
redeclare_same_scope_expand::<ElemType>(&mut context, x);
redeclare_same_scope::__expand::<ElemType>(&mut context, x);
let scope = context.into_scope();
assert_eq!(
@ -68,7 +68,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
redeclare_same_scope_other_type_expand::<ElemType, F32>(&mut context, x);
redeclare_same_scope_other_type::__expand::<ElemType, F32>(&mut context, x);
let scope = context.into_scope();
assert_eq!(
@ -83,7 +83,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
redeclare_different_scope_expand::<ElemType>(&mut context, x);
redeclare_different_scope::__expand::<ElemType>(&mut context, x);
let scope = context.into_scope();
assert_eq!(
@ -98,7 +98,7 @@ mod tests {
let x = context.create_local(Item::new(UInt::as_elem()));
redeclare_two_for_loops_expand(&mut context, x);
redeclare_two_for_loops::__expand(&mut context, x);
let scope = context.into_scope();
assert_eq!(

View File

@ -32,7 +32,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
reuse_expand::<ElemType>(&mut context, x);
reuse::__expand::<ElemType>(&mut context, x);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign());
@ -44,7 +44,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
reuse_incr_expand::<ElemType>(&mut context, x);
reuse_incr::__expand::<ElemType>(&mut context, x);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr());

View File

@ -1,7 +1,7 @@
use burn_cube::prelude::*;
#[cube]
fn shared_memory_read_write<T: Numeric>(sm_size: Comptime<u32>) {
pub fn shared_memory_read_write<T: Numeric>(sm_size: Comptime<u32>) {
let mut shared = SharedMemory::<T>::new(sm_size);
shared[0] = T::from_int(3);
let _ = shared[0];
@ -20,7 +20,7 @@ mod tests {
fn cube_support_shared_memory() {
let mut context = CubeContext::root();
shared_memory_read_write_expand::<ElemType>(&mut context, 512);
shared_memory_read_write::__expand::<ElemType>(&mut context, 512);
assert_eq!(
format!("{:?}", context.into_scope().operations),
inline_macro_ref()

View File

@ -1,25 +1,25 @@
use burn_cube::prelude::*;
#[derive(CubeType)]
struct State<T: Numeric> {
pub struct State<T: Numeric> {
first: T,
second: T,
}
#[cube]
fn state_receiver_with_reuse<T: Numeric>(state: State<T>) -> T {
pub fn state_receiver_with_reuse<T: Numeric>(state: State<T>) -> T {
let x = state.first + state.second;
state.second + x + state.first
}
#[cube]
fn attribute_modifier_reuse_field<T: Numeric>(mut state: State<T>) -> T {
pub fn attribute_modifier_reuse_field<T: Numeric>(mut state: State<T>) -> T {
state.first = T::from_int(4);
state.first
}
#[cube]
fn attribute_modifier_reuse_struct<T: Numeric>(mut state: State<T>) -> State<T> {
pub fn attribute_modifier_reuse_struct<T: Numeric>(mut state: State<T>) -> State<T> {
state.first = T::from_int(4);
state
}
@ -48,7 +48,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
creator_expand::<ElemType>(&mut context, x, y);
creator::__expand::<ElemType>(&mut context, x, y);
let scope = context.into_scope();
assert_eq!(
@ -68,7 +68,7 @@ mod tests {
first: x,
second: y,
};
state_receiver_with_reuse_expand::<ElemType>(&mut context, expanded_state);
state_receiver_with_reuse::__expand::<ElemType>(&mut context, expanded_state);
let scope = context.into_scope();
assert_eq!(
@ -88,7 +88,7 @@ mod tests {
first: x,
second: y,
};
attribute_modifier_reuse_field_expand::<ElemType>(&mut context, expanded_state);
attribute_modifier_reuse_field::__expand::<ElemType>(&mut context, expanded_state);
let scope = context.into_scope();
assert_eq!(
@ -108,7 +108,7 @@ mod tests {
first: x,
second: y,
};
attribute_modifier_reuse_struct_expand::<ElemType>(&mut context, expanded_state);
attribute_modifier_reuse_struct::__expand::<ElemType>(&mut context, expanded_state);
let scope = context.into_scope();
assert_eq!(

View File

@ -1,7 +1,7 @@
use burn_cube::prelude::*;
#[cube]
fn kernel<T: Numeric>(input: &Tensor<T>) {
pub fn kernel<T: Numeric>(input: &Tensor<T>) {
let _shape = input.shape(1);
let _stride = input.stride(1);
let _length = input.len();
@ -21,7 +21,7 @@ mod tests {
let mut context = CubeContext::root();
let input = context.input(0, Item::new(ElemType::as_elem()));
kernel_expand::<ElemType>(&mut context, input.into());
kernel::__expand::<ElemType>(&mut context, input.into());
assert_eq!(
format!("{:?}", context.into_scope().operations),
inline_macro_ref()

View File

@ -1,7 +1,7 @@
use burn_cube::prelude::*;
#[cube]
fn topology_kernel<T: Numeric>(input: Tensor<T>) {
pub fn topology_kernel<T: Numeric>(input: Tensor<T>) {
let x = ABSOLUTE_POS + UInt::new(4);
let _ = input[x];
}
@ -20,7 +20,7 @@ mod tests {
let mut context = CubeContext::root();
let input = context.input(0, Item::new(ElemType::as_elem()));
topology_kernel_expand::<ElemType>(&mut context, input.into());
topology_kernel::__expand::<ElemType>(&mut context, input.into());
assert_eq!(
format!("{:?}", context.into_scope().operations),
inline_macro_ref()

View File

@ -4,7 +4,7 @@ use burn_cube::prelude::*;
/// for all their methods. However, one does not need to provide its
/// implementation, see examples below.
#[cube]
trait Strategy<T: Numeric> {
pub trait Strategy<T: Numeric> {
fn operation(input_1: T, input_2: T) -> T;
}
@ -13,7 +13,7 @@ struct AddStrategy;
#[cube]
/// The actual implementation of AddStrategy's operation
/// Automatically generated an _expand variant
fn add_strategy_operation<T: Numeric>(input_1: T, input_2: T) -> T {
pub fn add_strategy_operation<T: Numeric>(input_1: T, input_2: T) -> T {
input_1 + input_2
}
@ -34,19 +34,19 @@ impl<T: Numeric> Strategy<T> for SubStrategy {
}
#[cube]
fn with_strategy_trait<S: Strategy<T>, T: Numeric>(x: T, y: T) -> T {
pub fn with_strategy_trait<S: Strategy<T>, T: Numeric>(x: T, y: T) -> T {
S::operation(x, y)
}
#[cube]
fn two_strategy_traits<S1: Strategy<F>, S2: Strategy<F>, F: Float>(x: F, y: F) -> F {
pub fn two_strategy_traits<S1: Strategy<F>, S2: Strategy<F>, F: Float>(x: F, y: F) -> F {
let z = S1::operation(x, y);
S2::operation(z, y)
}
trait MethodTypedStrategy {
pub trait MethodTypedStrategy {
fn operation<T: Numeric>(input_1: T, input_2: T) -> T;
fn operation_expand<T: Numeric>(
fn __expand_operation<T: Numeric>(
_context: &mut CubeContext,
input_1: <T as CubeType>::ExpandType,
input_2: <T as CubeType>::ExpandType,
@ -58,17 +58,17 @@ impl MethodTypedStrategy for AddStrategy {
add_strategy_operation(input_1, input_2)
}
fn operation_expand<T: Numeric>(
fn __expand_operation<T: Numeric>(
context: &mut CubeContext,
input_1: <T as CubeType>::ExpandType,
input_2: <T as CubeType>::ExpandType,
) -> <T as CubeType>::ExpandType {
add_strategy_operation_expand::<T>(context, input_1, input_2)
add_strategy_operation::__expand::<T>(context, input_1, input_2)
}
}
#[cube]
fn with_trait_generic_method<S: MethodTypedStrategy, T: Numeric>(x: T, y: T) -> T {
pub fn with_trait_generic_method<S: MethodTypedStrategy, T: Numeric>(x: T, y: T) -> T {
S::operation::<T>(x, y)
}
@ -87,7 +87,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
with_strategy_trait_expand::<AddStrategy, ElemType>(&mut context, x, y);
with_strategy_trait::__expand::<AddStrategy, ElemType>(&mut context, x, y);
let scope = context.into_scope();
assert_eq!(
@ -103,7 +103,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
with_strategy_trait_expand::<SubStrategy, ElemType>(&mut context, x, y);
with_strategy_trait::__expand::<SubStrategy, ElemType>(&mut context, x, y);
let scope = context.into_scope();
assert_eq!(
@ -119,7 +119,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
two_strategy_traits_expand::<SubStrategy, AddStrategy, ElemType>(&mut context, x, y);
two_strategy_traits::__expand::<SubStrategy, AddStrategy, ElemType>(&mut context, x, y);
let scope = context.into_scope();
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_two());
@ -132,7 +132,7 @@ mod tests {
let x = context.create_local(Item::new(ElemType::as_elem()));
let y = context.create_local(Item::new(ElemType::as_elem()));
with_trait_generic_method_expand::<AddStrategy, ElemType>(&mut context, x, y);
with_trait_generic_method::__expand::<AddStrategy, ElemType>(&mut context, x, y);
let scope = context.into_scope();
assert_eq!(

View File

@ -22,7 +22,7 @@ mod tests {
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2));
vectorization_binary_expand::<ElemType>(&mut context, lhs);
vectorization_binary::__expand::<ElemType>(&mut context, lhs);
}
#[test]
@ -32,7 +32,7 @@ mod tests {
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4));
vectorization_binary_expand::<ElemType>(&mut context, lhs);
vectorization_binary::__expand::<ElemType>(&mut context, lhs);
}
#[test]
@ -41,7 +41,7 @@ mod tests {
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2));
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
vectorization_cmp::__expand::<ElemType>(&mut context, lhs);
}
#[test]
@ -51,7 +51,7 @@ mod tests {
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4));
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
vectorization_cmp::__expand::<ElemType>(&mut context, lhs);
}
#[test]
@ -60,6 +60,6 @@ mod tests {
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 1));
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
vectorization_cmp::__expand::<ElemType>(&mut context, lhs);
}
}

View File

@ -3,23 +3,23 @@ use burn_cube::prelude::*;
use crate::kernel::{launch_unary, UnaryOp};
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
#[derive(CubeLaunch)]
struct Options<C: Numeric> {
min_value: C,
max_value: C,
}
pub(crate) fn clamp<R: JitRuntime, E: JitElement, const D: usize>(
input: JitTensor<R, E, D>,
min_value: E,
max_value: E,
) -> JitTensor<R, E, D> {
#[derive(CubeLaunch)]
struct Options<C: Numeric> {
min_value: C,
max_value: C,
}
struct ClampOp;
impl<C: Numeric> UnaryOp<C> for ClampOp {
type Options = Options<C>;
fn execute_expand(
fn __expand_execute(
context: &mut CubeContext,
input: C::ExpandType,
options: OptionsExpand<C>,
@ -29,7 +29,7 @@ pub(crate) fn clamp<R: JitRuntime, E: JitElement, const D: usize>(
C::clamp(input, options.min_value, options.max_value)
}
execute_expand(context, input, options)
execute::__expand(context, input, options)
}
}

View File

@ -1,4 +1,4 @@
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
use super::{index_offset_with_layout, Kernel};
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
use burn_cube::{
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, Runtime,
@ -146,7 +146,7 @@ pub(crate) fn launch_cmp<
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
kernel_cmp_launch::<E::Primitive, O, R>(
kernel_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -170,7 +170,7 @@ pub(crate) fn launch_cmp<
JitTensor::new(lhs.client, lhs.handle, lhs.shape, lhs.device, lhs.strides)
} else if same_tensor_type && rhs.can_mut_broadcast(&lhs) {
kernel_cmp_launch::<E::Primitive, O, R>(
kernel_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -199,7 +199,7 @@ pub(crate) fn launch_cmp<
let to_contiguous_rhs = !rhs.is_contiguous();
let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
kernel_cmp_launch::<E::Primitive, O, R>(
kernel_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -251,7 +251,7 @@ pub(crate) fn launch_scalar_cmp<
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
if same_tensor_type && tensor.can_mut() {
kernel_scalar_cmp_launch::<E::Primitive, O, R>(
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -282,7 +282,7 @@ pub(crate) fn launch_scalar_cmp<
tensor.strides,
);
kernel_scalar_cmp_launch::<E::Primitive, O, R>(
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),

View File

@ -87,7 +87,7 @@ pub fn into_contiguous<R: JitRuntime, E: JitElement, const D: usize>(
SUBCUBE_DIM_APPROX,
);
into_contiguous_kernel_launch::<E::Primitive, R>(
into_contiguous_kernel::launch::<E::Primitive, R>(
client,
cube_count,
CubeDim::default(),

View File

@ -163,7 +163,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
let num_elems_output = output.shape.num_elements();
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
conv2d_kernel_launch::<E::FloatPrimitive, R>(
conv2d_kernel::launch::<E::FloatPrimitive, R>(
input.client,
cube_dim,
CubeDim::default(),

View File

@ -188,7 +188,7 @@ pub(crate) fn conv3d<R: JitRuntime, E: FloatElement>(
}
};
conv3d_kernel_launch::<E::FloatPrimitive, R>(
conv3d_kernel::launch::<E::FloatPrimitive, R>(
input.client,
calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX),
CubeDim::default(),

View File

@ -116,7 +116,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
false => 1,
};
matmul_kernel_launch::<E::FloatPrimitive, R>(
matmul_kernel::launch::<E::FloatPrimitive, R>(
lhs.client,
cube_count,
CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1),

View File

@ -2,11 +2,11 @@ use burn_cube::prelude::*;
use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::block_loop::{block_loop, block_loop_expand};
use super::block_loop::block_loop;
#[cube(launch)]
#[allow(unused_mut)]
fn tiling2d_cube<F: Float>(
pub fn tiling2d_cube_kernel<F: Float>(
lhs: &Tensor<F>,
rhs: &Tensor<F>,
out: &mut Tensor<F>,

View File

@ -4,10 +4,10 @@ use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::{
base::{BatchOffsets, Coordinates, Dimensions, SharedMemories},
compute_loop::{compute_loop, compute_loop_expand},
load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand},
compute_loop::compute_loop,
load_shared_memory::load_to_shared_memories,
tile::{loader::TileLoader, writer::TileWriter},
write_output::{write_to_output, write_to_output_expand},
write_output::write_to_output,
};
#[cube]

View File

@ -2,10 +2,7 @@ use burn_cube::prelude::*;
use crate::kernel::matmul::config::CubeTiling2dConfig;
use super::{
base::Coordinates,
outer_product::{tile_outer_product, tile_outer_product_expand},
};
use super::{base::Coordinates, outer_product::tile_outer_product};
#[cube]
#[allow(unused_mut)]
@ -105,7 +102,7 @@ pub mod tests {
const SOME_DIM: usize = 12;
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
compute_loop_test_launch::<F32, R>(
compute_loop_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -134,7 +131,7 @@ pub mod tests {
let config = make_config(4, 8, 4);
compute_loop_test_launch::<F32, R>(
compute_loop_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,

View File

@ -7,6 +7,7 @@ use crate::{
into_contiguous,
matmul::{
config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig},
tiling2d_cube::base::tiling2d_cube_kernel,
Tiling2dConfig,
},
},
@ -14,8 +15,6 @@ use crate::{
FloatElement, JitRuntime,
};
use super::base::tiling2d_cube_launch;
/// Matrix multiplication using tiling 2d algorithm
pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
lhs: JitTensor<R, E, D>,
@ -69,7 +68,7 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
let cube_dim = tiling2d_cube_dim(&config);
let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed);
tiling2d_cube_launch::<E::FloatPrimitive, R>(
tiling2d_cube_kernel::launch::<E::FloatPrimitive, R>(
client,
cube_count,
cube_dim,

View File

@ -385,7 +385,7 @@ pub mod tests {
let config = make_config(16, 16, 8);
load_tensor_test_launch::<F32, R>(
load_tensor_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -417,7 +417,7 @@ pub mod tests {
let config = make_config(5, 1, 1);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -451,7 +451,7 @@ pub mod tests {
let config = make_config(8, 8, 8);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -481,7 +481,7 @@ pub mod tests {
let config = make_config(8, 8, 16);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -511,7 +511,7 @@ pub mod tests {
let config = make_config(8, 16, 16);
load_tensor_test_launch::<F32, R>(
load_tensor_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
@ -543,7 +543,7 @@ pub mod tests {
let config = make_config(8, 8, 8);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
@ -573,7 +573,7 @@ pub mod tests {
let config = make_config(16, 16, 8);
load_tensor_multiple_tiles_test_launch::<F32, R>(
load_tensor_multiple_tiles_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
@ -603,7 +603,7 @@ pub mod tests {
let config = make_config(16, 16, 8);
load_tensor_permuted_test_launch::<F32, R>(
load_tensor_permuted_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -636,7 +636,7 @@ pub mod tests {
let config = make_config(m, k, 8);
load_tensor_permuted_test_launch::<F32, R>(
load_tensor_permuted_test::launch::<F32, R>(
lhs.client.clone(),
cube_count,
cube_dim,
@ -667,7 +667,7 @@ pub mod tests {
let config = make_config(16, 16, 8);
load_tensor_permuted_test_launch::<F32, R>(
load_tensor_permuted_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,
@ -699,7 +699,7 @@ pub mod tests {
let config = make_config(8, k, n);
load_tensor_permuted_test_launch::<F32, R>(
load_tensor_permuted_test::launch::<F32, R>(
rhs.client.clone(),
cube_count,
cube_dim,

View File

@ -70,7 +70,7 @@ pub mod tests {
const SOME_DIM: usize = 12;
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
tile_outer_product_test_launch::<F32, R>(
tile_outer_product_test::launch::<F32, R>(
client.clone(),
cube_count,
cube_dim,
@ -99,7 +99,7 @@ pub mod tests {
const SOME_DIM: usize = 12;
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
tile_outer_product_test_launch::<F32, R>(
tile_outer_product_test::launch::<F32, R>(
client.clone(),
cube_count,
cube_dim,

View File

@ -14,10 +14,7 @@ use crate::kernel::matmul::{
},
};
use super::base::{
all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand,
BlockLoader, BlockWriter,
};
use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter};
pub(crate) struct HorizontalCheckBlockIO;

View File

@ -14,7 +14,7 @@ use crate::kernel::matmul::{
},
};
use super::base::{all_zeros_runtime, all_zeros_runtime_expand, BlockLoader, BlockWriter};
use super::base::{all_zeros_runtime, BlockLoader, BlockWriter};
pub(crate) struct VerticalCheckBlockIO;

View File

@ -14,10 +14,7 @@ use crate::kernel::matmul::{
},
};
use super::base::{
all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand,
BlockLoader, BlockWriter,
};
use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter};
pub(crate) struct WholeCheckBlockIO;

View File

@ -130,7 +130,7 @@ pub mod tests {
let config = make_config(6, 8, 8);
write_to_output_test_launch::<F32, R>(
write_to_output_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,
@ -156,7 +156,7 @@ pub mod tests {
let config = make_config(8, 8, 4);
write_to_output_test_launch::<F32, R>(
write_to_output_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,
@ -182,7 +182,7 @@ pub mod tests {
let config = make_config(8, 8, 8);
write_to_output_test_launch::<F32, R>(
write_to_output_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,
@ -215,7 +215,7 @@ pub mod tests {
let config = make_config(8, 8, 8);
write_to_output_test_launch::<F32, R>(
write_to_output_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,
@ -248,7 +248,7 @@ pub mod tests {
let config = make_config(5, 8, 1);
write_results_to_output_out_of_bounds_test_launch::<F32, R>(
write_results_to_output_out_of_bounds_test::launch::<F32, R>(
out.client.clone(),
cube_count,
cube_dim,

View File

@ -11,7 +11,7 @@ pub use binary::*;
pub use cast::*;
pub use contiguous::*;
pub use mask::*;
pub use unary::*;
pub(crate) use unary::*;
pub use burn_cube::{Kernel, SUBCUBE_DIM_APPROX};

View File

@ -4,7 +4,7 @@ use burn_cube::{
SUBCUBE_DIM_APPROX,
};
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
use super::{index_offset_with_layout, Kernel};
pub(crate) trait UnaryOp<C: CubePrimitive>: 'static + Send + Sync {
type Options: LaunchArg;
@ -13,7 +13,7 @@ pub(crate) trait UnaryOp<C: CubePrimitive>: 'static + Send + Sync {
fn execute(_input: C, _options: &Self::Options) -> C {
unexpanded!();
}
fn execute_expand(
fn __expand_execute(
context: &mut CubeContext,
input: C::ExpandType,
options: <Self::Options as CubeType>::ExpandType,
@ -78,7 +78,7 @@ where
let is_contiguous = tensor.is_contiguous();
if tensor.can_mut() && is_contiguous {
unary_kernel_launch::<E::Primitive, O, R>(
unary_kernel::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -104,7 +104,7 @@ where
buffer,
);
unary_kernel_launch::<E::Primitive, O, R>(
unary_kernel::launch::<E::Primitive, O, R>(
client,
cube_count,
CubeDim::default(),
@ -136,7 +136,7 @@ macro_rules! unary_op {
type Options = ();
#[allow(clippy::redundant_closure_call)]
fn execute_expand(
fn __expand_execute(
context: &mut CubeContext,
input: C::ExpandType,
_options: <Self::Options as CubeType>::ExpandType,
@ -152,7 +152,7 @@ macro_rules! unary_op {
type Options = C;
#[allow(clippy::redundant_closure_call)]
fn execute_expand(
fn __expand_execute(
context: &mut CubeContext,
input: C::ExpandType,
scalar: C::ExpandType,

View File

@ -334,7 +334,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::exp(input)
}
execute_expand::<C>(context, input)
execute::__expand::<C>(context, input)
})
}
@ -344,7 +344,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::log(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -354,7 +354,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::log1p(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -367,7 +367,7 @@ where
fn execute<C: Float>(input: C, scalar: C) -> C {
C::powf(input, scalar)
}
execute_expand::<C>(context, tensor, scalar)
execute::__expand::<C>(context, tensor, scalar)
})
}
@ -377,7 +377,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::sqrt(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -387,7 +387,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::abs(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -397,7 +397,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::cos(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -407,7 +407,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::sin(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -417,7 +417,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::tanh(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -427,7 +427,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::erf(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}
@ -463,7 +463,7 @@ where
fn execute<C: Float>(input: C) -> C {
C::recip(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}

View File

@ -299,7 +299,7 @@ where
fn execute<C: Numeric>(input: C) -> C {
C::abs(input)
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}

View File

@ -26,7 +26,7 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
let empty = empty_device(client, device, shape);
#[cube(launch)]
pub(crate) fn full_kernel<C: Numeric + Vectorized>(tensor: &mut Tensor<C>, value: C) {
pub fn full_kernel<C: Numeric + Vectorized>(tensor: &mut Tensor<C>, value: C) {
if ABSOLUTE_POS >= tensor.len() {
return;
}
@ -42,7 +42,7 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
SUBCUBE_DIM_APPROX,
);
full_kernel_launch::<E::Primitive, R>(
full_kernel::launch::<E::Primitive, R>(
empty.client.clone(),
cube_count,
CubeDim::default(),
@ -127,7 +127,7 @@ pub fn add_scalar<R: JitRuntime, E: JitElement, const D: usize>(
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs + rhs
}
execute_expand::<C>(context, lhs, rhs)
execute::__expand::<C>(context, lhs, rhs)
})
}
@ -156,7 +156,7 @@ pub fn sub_scalar<R: JitRuntime, E: JitElement, const D: usize>(
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs - rhs
}
execute_expand::<C>(context, lhs, rhs)
execute::__expand::<C>(context, lhs, rhs)
})
}
@ -185,7 +185,7 @@ pub fn mul_scalar<R: JitRuntime, E: JitElement, const D: usize>(
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs * rhs
}
execute_expand::<C>(context, lhs, rhs)
execute::__expand::<C>(context, lhs, rhs)
})
}
@ -214,7 +214,7 @@ pub fn div_scalar<R: JitRuntime, E: JitElement, const D: usize>(
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
lhs / rhs
}
execute_expand::<C>(context, lhs, rhs)
execute::__expand::<C>(context, lhs, rhs)
})
}

View File

@ -146,7 +146,7 @@ where
fn execute<C: Numeric>(input: C) -> C {
input
}
execute_expand::<C>(context, tensor)
execute::__expand::<C>(context, tensor)
})
}

View File

@ -20,7 +20,7 @@ pub fn launch<R: Runtime>(device: &R::Device) {
let input_handle = client.create(f32::as_bytes(input));
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
gelu_launch::<F32, R>(
gelu::launch::<F32, R>(
client.clone(),
CubeCount::Static(1, 1, 1),
CubeDim::default(),