mirror of https://github.com/tracel-ai/burn.git
Refactor/cube/expand & fix double imports (#2009)
* Refactored function * WIP * Basic stuff done * Fix traits * Cleanup * Cleanup * Cleanup
This commit is contained in:
parent
35345de62a
commit
19f5ad7be5
|
@ -2,10 +2,17 @@ use quote::ToTokens;
|
|||
|
||||
use crate::tracker::VariableTracker;
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum ExpandMode {
|
||||
FuncImpl,
|
||||
MethodImpl,
|
||||
}
|
||||
|
||||
pub fn expand_sig(
|
||||
sig: &syn::Signature,
|
||||
visibility: &syn::Visibility,
|
||||
mut variable_tracker: Option<&mut VariableTracker>,
|
||||
mode: ExpandMode,
|
||||
) -> proc_macro2::TokenStream {
|
||||
let mut inputs = quote::quote!();
|
||||
|
||||
|
@ -42,7 +49,10 @@ pub fn expand_sig(
|
|||
}
|
||||
|
||||
let ident = &sig.ident;
|
||||
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
|
||||
let ident = match mode {
|
||||
ExpandMode::FuncImpl => syn::Ident::new("__expand".to_string().as_str(), ident.span()),
|
||||
_ => syn::Ident::new(format!("__expand_{ident}").as_str(), ident.span()),
|
||||
};
|
||||
|
||||
let generics = sig.generics.clone().into_token_stream();
|
||||
|
||||
|
|
|
@ -106,22 +106,42 @@ pub(crate) fn codegen_call(
|
|||
// Path
|
||||
let mut path_tokens = TokenStream::new();
|
||||
let mut is_comptime = false;
|
||||
let mut is_plain_func = true;
|
||||
let mut comptime_func: Option<String> = None;
|
||||
|
||||
for (i, (ident, generics)) in path.iter().enumerate() {
|
||||
if *ident == "Comptime" {
|
||||
let name = ident.to_string();
|
||||
|
||||
if name == "Comptime" {
|
||||
is_comptime = true;
|
||||
continue;
|
||||
}
|
||||
|
||||
if let Some(first_char) = name.chars().next() {
|
||||
if first_char.is_uppercase() {
|
||||
is_plain_func = false;
|
||||
}
|
||||
}
|
||||
|
||||
if i == path.len() - 1 {
|
||||
if is_comptime {
|
||||
comptime_func = Some(ident.to_string());
|
||||
break;
|
||||
}
|
||||
let func_name_expand = syn::Ident::new(
|
||||
format!("{ident}_expand").as_str(),
|
||||
proc_macro2::Span::call_site(),
|
||||
);
|
||||
|
||||
let func_name_expand = if is_plain_func {
|
||||
quote::quote! {
|
||||
#ident::__expand
|
||||
}
|
||||
} else {
|
||||
let ident = syn::Ident::new(
|
||||
format!("__expand_{ident}").as_str(),
|
||||
proc_macro2::Span::call_site(),
|
||||
);
|
||||
quote::quote! {
|
||||
#ident
|
||||
}
|
||||
};
|
||||
path_tokens.extend(quote_spanned! {func_name_expand.span() => #func_name_expand });
|
||||
if let Some(generics) = generics {
|
||||
path_tokens.extend(quote_spanned! {generics.span() => #generics });
|
||||
|
|
|
@ -211,7 +211,7 @@ impl Codegen {
|
|||
}
|
||||
}
|
||||
|
||||
fn gen_define_impl(&self, expand: &Ident) -> TokenStream {
|
||||
fn gen_define_impl(&self, expand: &TokenStream) -> TokenStream {
|
||||
let mut expand_args = quote::quote! { &mut builder.context, };
|
||||
|
||||
let mut variables = quote::quote! {};
|
||||
|
@ -340,7 +340,7 @@ impl Codegen {
|
|||
tokens
|
||||
}
|
||||
|
||||
fn gen_compile_impl(&self, expand: &Ident) -> TokenStream {
|
||||
fn gen_compile_impl(&self, expand: &TokenStream) -> TokenStream {
|
||||
let ident = Ident::new(&self.name, Span::call_site());
|
||||
let generics = add_runtime(self.generics.clone());
|
||||
let (impl_gen, ty_gen, where_gen) = generics.split_for_impl();
|
||||
|
@ -453,22 +453,27 @@ pub fn codegen_launch(sig: &syn::Signature) -> TokenStream {
|
|||
let codegen = Codegen::from_sig(sig);
|
||||
|
||||
let ident = &sig.ident;
|
||||
let ident_expand = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
|
||||
let ident = syn::Ident::new(format!("{ident}_launch").as_str(), ident.span());
|
||||
|
||||
let ident_expand = quote::quote! {
|
||||
__expand
|
||||
};
|
||||
|
||||
let generics = add_runtime(add_lifetime(sig.generics.clone()));
|
||||
let body = codegen.gen_launch_body();
|
||||
let kernel = codegen.gen_kernel_struct();
|
||||
let compile = codegen.gen_compile_impl(&ident_expand);
|
||||
let (inputs, output) = (codegen.fn_inputs, codegen.fn_output);
|
||||
let doc =
|
||||
format!("Launch the kernel [{ident}] with the provided argument on the given runtime.");
|
||||
|
||||
quote::quote! {
|
||||
#kernel
|
||||
#compile
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
/// Launch
|
||||
pub fn #ident #generics (
|
||||
#[doc = #doc]
|
||||
/// Launch the kernel.
|
||||
pub fn launch #generics (
|
||||
client: ComputeClient<R::Server, R::Channel>,
|
||||
cube_count: CubeCount<R::Server>,
|
||||
cube_dim: CubeDim,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
use crate::codegen_common::signature::expand_sig;
|
||||
use proc_macro2::TokenStream;
|
||||
|
||||
use crate::codegen_common::signature::{expand_sig, ExpandMode};
|
||||
|
||||
pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
|
||||
let mut expand_items = Vec::new();
|
||||
|
@ -6,7 +8,12 @@ pub fn expand_trait_def(mut tr: syn::ItemTrait) -> proc_macro2::TokenStream {
|
|||
for item in tr.items.iter() {
|
||||
match item {
|
||||
syn::TraitItem::Fn(func) => {
|
||||
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
|
||||
let expand = expand_sig(
|
||||
&func.sig,
|
||||
&syn::Visibility::Inherited,
|
||||
None,
|
||||
ExpandMode::MethodImpl,
|
||||
);
|
||||
expand_items.push(syn::parse_quote!(#expand;));
|
||||
}
|
||||
_ => continue,
|
||||
|
@ -26,7 +33,7 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
|
|||
match item {
|
||||
syn::ImplItem::Fn(func) => {
|
||||
let ident = &func.sig.ident;
|
||||
let ident = syn::Ident::new(format!("{ident}_expand").as_str(), ident.span());
|
||||
let ident = quote::quote! {#ident::__expand};
|
||||
let mut inputs = quote::quote!();
|
||||
|
||||
for input in &func.sig.inputs {
|
||||
|
@ -41,7 +48,12 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
|
|||
}
|
||||
}
|
||||
|
||||
let expand = expand_sig(&func.sig, &syn::Visibility::Inherited, None);
|
||||
let expand = expand_sig(
|
||||
&func.sig,
|
||||
&syn::Visibility::Inherited,
|
||||
None,
|
||||
ExpandMode::MethodImpl,
|
||||
);
|
||||
|
||||
let tokens = if !tr.generics.params.is_empty() {
|
||||
let mut func = func.clone();
|
||||
|
@ -67,7 +79,7 @@ pub fn expand_trait_impl(mut tr: syn::ItemImpl) -> proc_macro2::TokenStream {
|
|||
|
||||
fn register_expand(
|
||||
func: &syn::ImplItemFn,
|
||||
name: &syn::Ident,
|
||||
name: &TokenStream,
|
||||
expand: proc_macro2::TokenStream,
|
||||
inputs: proc_macro2::TokenStream,
|
||||
) -> proc_macro2::TokenStream {
|
||||
|
@ -91,7 +103,7 @@ fn register_expand(
|
|||
quote::quote! (
|
||||
#expand {
|
||||
#[cube]
|
||||
#func
|
||||
pub #func
|
||||
#func_expand
|
||||
}
|
||||
)
|
||||
|
|
|
@ -10,7 +10,7 @@ mod tracker;
|
|||
pub(crate) mod codegen_common;
|
||||
|
||||
use analyzer::VariableAnalyzer;
|
||||
use codegen_common::signature::expand_sig;
|
||||
use codegen_common::signature::{expand_sig, ExpandMode};
|
||||
use codegen_function::{codegen_launch, codegen_statement};
|
||||
use codegen_trait::{expand_trait_def, expand_trait_impl};
|
||||
use codegen_type::generate_cube_type;
|
||||
|
@ -69,20 +69,8 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
|
|||
fn cube_fn(func: syn::ItemFn, attrs: &SupportedAttributes) -> TokenStream {
|
||||
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
|
||||
|
||||
match codegen_cube(&func, &mut variable_tracker) {
|
||||
Ok(code) => {
|
||||
if attrs.launch {
|
||||
let launch = codegen_launch(&func.sig);
|
||||
|
||||
quote::quote! {
|
||||
#code
|
||||
#launch
|
||||
}
|
||||
.into()
|
||||
} else {
|
||||
code.into()
|
||||
}
|
||||
}
|
||||
match codegen_cube(&func, &mut variable_tracker, attrs.launch) {
|
||||
Ok(code) => code.into(),
|
||||
Err(err) => err.into(),
|
||||
}
|
||||
}
|
||||
|
@ -120,8 +108,15 @@ fn parse_attributes(args: &Punctuated<Meta, Comma>) -> SupportedAttributes {
|
|||
fn codegen_cube(
|
||||
func: &syn::ItemFn,
|
||||
variable_tracker: &mut VariableTracker,
|
||||
launch: bool,
|
||||
) -> Result<proc_macro2::TokenStream, proc_macro2::TokenStream> {
|
||||
let signature = expand_sig(&func.sig, &func.vis, Some(variable_tracker));
|
||||
let signature = expand_sig(
|
||||
&func.sig,
|
||||
&syn::Visibility::Public(Default::default()), // Always public, otherwise we can't import
|
||||
// it from an outside module.
|
||||
Some(variable_tracker),
|
||||
ExpandMode::FuncImpl,
|
||||
);
|
||||
let mut body = quote::quote! {};
|
||||
|
||||
for statement in func.block.stmts.iter() {
|
||||
|
@ -148,15 +143,36 @@ fn codegen_cube(
|
|||
return Err(code);
|
||||
}
|
||||
|
||||
let launch_doc = if launch { "and launch function " } else { "" };
|
||||
|
||||
let launch = if launch {
|
||||
codegen_launch(&func.sig)
|
||||
} else {
|
||||
quote::quote! {}
|
||||
};
|
||||
|
||||
let mod_name = &func.sig.ident;
|
||||
let vis = &func.vis;
|
||||
let doc = format!("Module containing the expand method {launch_doc}of {mod_name}.");
|
||||
|
||||
Ok(quote::quote! {
|
||||
#[allow(dead_code)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#func
|
||||
|
||||
#[allow(unused_mut)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#signature {
|
||||
#body
|
||||
|
||||
#[doc = #doc]
|
||||
#vis mod #mod_name {
|
||||
use super::*;
|
||||
|
||||
#launch
|
||||
|
||||
#[allow(unused_mut)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#signature {
|
||||
#body
|
||||
}
|
||||
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
@ -86,7 +86,7 @@ impl Init for MatrixExpand {
|
|||
|
||||
impl<C: CubePrimitive> Matrix<C> {
|
||||
/// Create a new matrix that is going to be used in the
|
||||
/// [matrix-multiply and accumulate](execute) function.
|
||||
/// [matrix-multiply and accumulate](execute()) function.
|
||||
///
|
||||
/// You have to declare the shape used for the execution.
|
||||
/// The shape of the current matrix is determined using the [MatrixIdent].
|
||||
|
@ -103,7 +103,7 @@ impl<C: CubePrimitive> Matrix<C> {
|
|||
Matrix { _c: PhantomData }
|
||||
}
|
||||
|
||||
pub fn new_expand(
|
||||
pub fn __expand_new(
|
||||
context: &mut CubeContext,
|
||||
ident: MatrixIdent,
|
||||
m: u8,
|
||||
|
@ -129,16 +129,21 @@ pub fn fill<C: CubeType>(mat: &Matrix<C>, value: C) {
|
|||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand method of [fill].
|
||||
pub fn fill_expand<C: CubeType>(
|
||||
context: &mut CubeContext,
|
||||
mat: MatrixExpand,
|
||||
value: ExpandElement,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Fill {
|
||||
mat: *mat.elem,
|
||||
value: *value,
|
||||
}));
|
||||
/// Module containing the expand function for [fill()].
|
||||
pub mod fill {
|
||||
use super::*;
|
||||
|
||||
/// Expand method of [fill()].
|
||||
pub fn __expand<C: CubeType>(
|
||||
context: &mut CubeContext,
|
||||
mat: MatrixExpand,
|
||||
value: ExpandElement,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Fill {
|
||||
mat: *mat.elem,
|
||||
value: *value,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/// Load the matrix with the provided array using the stride.
|
||||
|
@ -147,19 +152,24 @@ pub fn load<C: CubeType>(mat: &Matrix<C>, value: &Slice<'_, C>, stride: UInt) {
|
|||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand method of [load].
|
||||
#[allow(unused_variables)]
|
||||
pub fn load_expand<C: CubeType>(
|
||||
context: &mut CubeContext,
|
||||
mat: MatrixExpand,
|
||||
value: ExpandElementTyped<Slice<'static, C>>,
|
||||
stride: ExpandElement,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Load {
|
||||
mat: *mat.elem,
|
||||
value: *value.expand,
|
||||
stride: *stride,
|
||||
}));
|
||||
/// Module containing the expand function for [load()].
|
||||
pub mod load {
|
||||
use super::*;
|
||||
|
||||
/// Expand method of [load()].
|
||||
#[allow(unused_variables)]
|
||||
pub fn __expand<C: CubeType>(
|
||||
context: &mut CubeContext,
|
||||
mat: MatrixExpand,
|
||||
value: ExpandElementTyped<Slice<'static, C>>,
|
||||
stride: ExpandElement,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Load {
|
||||
mat: *mat.elem,
|
||||
value: *value.expand,
|
||||
stride: *stride,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/// Store the matrix in the given array following the given stride and layout.
|
||||
|
@ -173,21 +183,26 @@ pub fn store<C: CubePrimitive>(
|
|||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand method of [store].
|
||||
#[allow(unused_variables)]
|
||||
pub fn store_expand<C: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
output: ExpandElementTyped<SliceMut<'static, C>>,
|
||||
mat: MatrixExpand,
|
||||
stride: ExpandElement,
|
||||
layout: MatrixLayout,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Store {
|
||||
output: *output.expand,
|
||||
mat: *mat.elem,
|
||||
stride: *stride,
|
||||
layout,
|
||||
}));
|
||||
/// Module containing the expand function for [store()].
|
||||
pub mod store {
|
||||
use super::*;
|
||||
|
||||
/// Expand method of [store()].
|
||||
#[allow(unused_variables)]
|
||||
pub fn __expand<C: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
output: ExpandElementTyped<SliceMut<'static, C>>,
|
||||
mat: MatrixExpand,
|
||||
stride: ExpandElement,
|
||||
layout: MatrixLayout,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Store {
|
||||
output: *output.expand,
|
||||
mat: *mat.elem,
|
||||
stride: *stride,
|
||||
layout,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute the matrix-multiply and accumulate operation on the given [matrices](Matrix).
|
||||
|
@ -201,18 +216,23 @@ pub fn execute<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrim
|
|||
unexpanded!()
|
||||
}
|
||||
|
||||
/// Expand method of [execute].
|
||||
pub fn execute_expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
mat_a: MatrixExpand,
|
||||
mat_b: MatrixExpand,
|
||||
mat_c: MatrixExpand,
|
||||
mat_d: MatrixExpand,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Execute {
|
||||
mat_a: *mat_a.elem,
|
||||
mat_b: *mat_b.elem,
|
||||
mat_c: *mat_c.elem,
|
||||
mat_d: *mat_d.elem,
|
||||
}));
|
||||
/// Module containing the expand function for [execute()].
|
||||
pub mod execute {
|
||||
use super::*;
|
||||
|
||||
/// Expand method of [execute()].
|
||||
pub fn __expand<A: CubePrimitive, B: CubePrimitive, C: CubePrimitive, D: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
mat_a: MatrixExpand,
|
||||
mat_b: MatrixExpand,
|
||||
mat_c: MatrixExpand,
|
||||
mat_d: MatrixExpand,
|
||||
) {
|
||||
context.register(Operation::CoopMma(ir::CoopMma::Execute {
|
||||
mat_a: *mat_a.elem,
|
||||
mat_b: *mat_b.elem,
|
||||
mat_c: *mat_c.elem,
|
||||
mat_d: *mat_d.elem,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,11 @@ impl<T: CubePrimitive + Clone> Array<T> {
|
|||
Array { _val: PhantomData }
|
||||
}
|
||||
|
||||
pub fn new_expand<S: Index>(
|
||||
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
|
||||
Array { _val: PhantomData }
|
||||
}
|
||||
|
||||
pub fn __expand_new<S: Index>(
|
||||
context: &mut CubeContext,
|
||||
size: S,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
|
@ -44,11 +48,7 @@ impl<T: CubePrimitive + Clone> Array<T> {
|
|||
.into()
|
||||
}
|
||||
|
||||
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
|
||||
Array { _val: PhantomData }
|
||||
}
|
||||
|
||||
pub fn vectorized_expand<S: Index>(
|
||||
pub fn __expand_vectorized<S: Index>(
|
||||
context: &mut CubeContext,
|
||||
size: S,
|
||||
vectorization_factor: UInt,
|
||||
|
|
|
@ -3,6 +3,10 @@ use crate::ir::Elem;
|
|||
|
||||
use super::Vectorized;
|
||||
|
||||
// To be consistent with other primitive type.
|
||||
/// Boolean type.
|
||||
pub type Bool = bool;
|
||||
|
||||
impl CubeType for bool {
|
||||
type ExpandType = ExpandElement;
|
||||
}
|
||||
|
|
|
@ -6,7 +6,7 @@ use crate::{frontend::ExpandElement, unexpanded};
|
|||
pub trait Cast: CubePrimitive {
|
||||
fn cast_from<From: CubePrimitive>(value: From) -> Self;
|
||||
|
||||
fn cast_from_expand<From>(
|
||||
fn __expand_cast_from<From>(
|
||||
context: &mut CubeContext,
|
||||
value: From,
|
||||
) -> <Self as CubeType>::ExpandType
|
||||
|
|
|
@ -29,15 +29,15 @@ pub trait Float:
|
|||
+ core::ops::IndexMut<UInt, Output = Self>
|
||||
{
|
||||
fn new(val: f32) -> Self;
|
||||
fn new_expand(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
|
||||
fn vectorized(val: f32, vectorization: UInt) -> Self;
|
||||
fn vectorized_expand(
|
||||
fn vectorized_empty(vectorization: UInt) -> Self;
|
||||
fn __expand_new(context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType;
|
||||
fn __expand_vectorized(
|
||||
context: &mut CubeContext,
|
||||
val: f32,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType;
|
||||
fn vectorized_empty(vectorization: UInt) -> Self;
|
||||
fn vectorized_empty_expand(
|
||||
fn __expand_vectorized_empty(
|
||||
context: &mut CubeContext,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType;
|
||||
|
@ -74,14 +74,6 @@ macro_rules! impl_float {
|
|||
}
|
||||
}
|
||||
|
||||
fn new_expand(_context: &mut CubeContext, val: f32) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
fn vectorized(val: f32, vectorization: UInt) -> Self {
|
||||
if vectorization.val == 1 {
|
||||
Self::new(val)
|
||||
|
@ -93,13 +85,28 @@ macro_rules! impl_float {
|
|||
}
|
||||
}
|
||||
|
||||
fn vectorized_expand(
|
||||
fn vectorized_empty(vectorization: UInt) -> Self {
|
||||
Self::vectorized(0., vectorization)
|
||||
}
|
||||
|
||||
fn __expand_new(
|
||||
_context: &mut CubeContext,
|
||||
val: f32,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
fn __expand_vectorized(
|
||||
context: &mut CubeContext,
|
||||
val: f32,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
if vectorization.val == 1 {
|
||||
Self::new_expand(context, val)
|
||||
Self::__expand_new(context, val)
|
||||
} else {
|
||||
let mut new_var = context
|
||||
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));
|
||||
|
@ -111,16 +118,12 @@ macro_rules! impl_float {
|
|||
}
|
||||
}
|
||||
|
||||
fn vectorized_empty(vectorization: UInt) -> Self {
|
||||
Self::vectorized(0., vectorization)
|
||||
}
|
||||
|
||||
fn vectorized_empty_expand(
|
||||
fn __expand_vectorized_empty(
|
||||
context: &mut CubeContext,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
if vectorization.val == 1 {
|
||||
Self::new_expand(context, 0.)
|
||||
Self::__expand_new(context, 0.)
|
||||
} else {
|
||||
context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8))
|
||||
}
|
||||
|
|
|
@ -9,9 +9,9 @@ use super::{LaunchArgExpand, ScalarArgSettings, UInt, Vectorized};
|
|||
/// Signed integer. Used as input in int kernels
|
||||
pub trait Int: Numeric + std::ops::Rem<Output = Self> {
|
||||
fn new(val: i64) -> Self;
|
||||
fn new_expand(context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType;
|
||||
fn vectorized(val: i64, vectorization: UInt) -> Self;
|
||||
fn vectorized_expand(
|
||||
fn __expand_new(context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType;
|
||||
fn __expand_vectorized(
|
||||
context: &mut CubeContext,
|
||||
val: i64,
|
||||
vectorization: UInt,
|
||||
|
@ -48,14 +48,6 @@ macro_rules! impl_int {
|
|||
}
|
||||
}
|
||||
|
||||
fn new_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
fn vectorized(val: i64, vectorization: UInt) -> Self {
|
||||
if vectorization.val == 1 {
|
||||
Self::new(val)
|
||||
|
@ -67,13 +59,24 @@ macro_rules! impl_int {
|
|||
}
|
||||
}
|
||||
|
||||
fn vectorized_expand(
|
||||
fn __expand_new(
|
||||
_context: &mut CubeContext,
|
||||
val: i64,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
};
|
||||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
fn __expand_vectorized(
|
||||
context: &mut CubeContext,
|
||||
val: i64,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
if vectorization.val == 1 {
|
||||
Self::new_expand(context, val)
|
||||
Self::__expand_new(context, val)
|
||||
} else {
|
||||
let mut new_var = context
|
||||
.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));
|
||||
|
|
|
@ -14,6 +14,7 @@ mod vectorized;
|
|||
|
||||
pub use array::*;
|
||||
pub use base::*;
|
||||
pub use bool::*;
|
||||
pub use cast::*;
|
||||
pub use cube_elem::*;
|
||||
pub use float::*;
|
||||
|
|
|
@ -45,8 +45,11 @@ pub trait Numeric:
|
|||
|
||||
type Primitive: ScalarArgSettings;
|
||||
|
||||
/// Expand version of from_int
|
||||
fn from_int_expand(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
||||
fn from_vec<const D: usize>(_vec: [i64; D]) -> Self {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
fn __expand_from_int(_context: &mut CubeContext, val: i64) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
|
@ -54,11 +57,7 @@ pub trait Numeric:
|
|||
ExpandElement::Plain(new_var)
|
||||
}
|
||||
|
||||
fn from_vec<const D: usize>(_vec: [i64; D]) -> Self {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
fn from_vec_expand<const D: usize>(
|
||||
fn __expand_from_vec<const D: usize>(
|
||||
context: &mut CubeContext,
|
||||
vec: [i64; D],
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
|
|
|
@ -27,24 +27,11 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
|||
SharedMemory { _val: PhantomData }
|
||||
}
|
||||
|
||||
pub fn new_expand<S: Index>(
|
||||
context: &mut CubeContext,
|
||||
size: S,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Shared memory need constant initialization value"),
|
||||
};
|
||||
let var = context.create_shared(Item::new(T::as_elem()), size);
|
||||
ExpandElementTyped::new(var)
|
||||
}
|
||||
|
||||
pub fn vectorized<S: Index>(_size: S, _vectorization_factor: UInt) -> Self {
|
||||
SharedMemory { _val: PhantomData }
|
||||
}
|
||||
|
||||
pub fn vectorized_expand<S: Index>(
|
||||
pub fn __expand_vectorized<S: Index>(
|
||||
context: &mut CubeContext,
|
||||
size: S,
|
||||
vectorization_factor: UInt,
|
||||
|
@ -60,4 +47,17 @@ impl<T: CubePrimitive + Clone> SharedMemory<T> {
|
|||
);
|
||||
ExpandElementTyped::new(var)
|
||||
}
|
||||
|
||||
pub fn __expand_new<S: Index>(
|
||||
context: &mut CubeContext,
|
||||
size: S,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
let size = size.value();
|
||||
let size = match size {
|
||||
crate::ir::Variable::ConstantScalar { value, .. } => value as u32,
|
||||
_ => panic!("Shared memory need constant initialization value"),
|
||||
};
|
||||
let var = context.create_shared(Item::new(T::as_elem()), size);
|
||||
ExpandElementTyped::new(var)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,7 +48,7 @@ impl UInt {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn new_expand(_context: &mut CubeContext, val: u32) -> <Self as CubeType>::ExpandType {
|
||||
pub fn __expand_new(_context: &mut CubeContext, val: u32) -> <Self as CubeType>::ExpandType {
|
||||
let new_var = Variable::ConstantScalar {
|
||||
value: val as f64,
|
||||
elem: Self::as_elem(),
|
||||
|
@ -67,13 +67,13 @@ impl UInt {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn vectorized_expand(
|
||||
pub fn __expand_vectorized(
|
||||
context: &mut CubeContext,
|
||||
val: u32,
|
||||
vectorization: UInt,
|
||||
) -> <Self as CubeType>::ExpandType {
|
||||
if vectorization.val == 1 {
|
||||
Self::new_expand(context, val)
|
||||
Self::__expand_new(context, val)
|
||||
} else {
|
||||
let mut new_var =
|
||||
context.create_local(Item::vectorized(Self::as_elem(), vectorization.val as u8));
|
||||
|
|
|
@ -280,11 +280,20 @@ macro_rules! impl_binary_func {
|
|||
}
|
||||
}
|
||||
|
||||
impl_binary_func!(Powf, powf, powf_expand, Operator::Powf, F16, BF16, F32, F64);
|
||||
impl_binary_func!(
|
||||
Powf,
|
||||
powf,
|
||||
__expand_powf,
|
||||
Operator::Powf,
|
||||
F16,
|
||||
BF16,
|
||||
F32,
|
||||
F64
|
||||
);
|
||||
impl_binary_func!(
|
||||
Max,
|
||||
max,
|
||||
max_expand,
|
||||
__expand_max,
|
||||
Operator::Max,
|
||||
F16,
|
||||
BF16,
|
||||
|
@ -297,7 +306,7 @@ impl_binary_func!(
|
|||
impl_binary_func!(
|
||||
Min,
|
||||
min,
|
||||
min_expand,
|
||||
__expand_min,
|
||||
Operator::Min,
|
||||
F16,
|
||||
BF16,
|
||||
|
@ -310,7 +319,7 @@ impl_binary_func!(
|
|||
impl_binary_func!(
|
||||
Remainder,
|
||||
rem,
|
||||
rem_expand,
|
||||
__expand_rem,
|
||||
Operator::Remainder,
|
||||
F16,
|
||||
BF16,
|
||||
|
|
|
@ -12,7 +12,7 @@ pub trait Clamp: CubePrimitive + Sized {
|
|||
fn clamp(input: Self, min_value: Self, max_value: Self) -> Self {
|
||||
unexpanded!()
|
||||
}
|
||||
fn clamp_expand(
|
||||
fn __expand_clamp(
|
||||
context: &mut CubeContext,
|
||||
input: Self::ExpandType,
|
||||
min_value: Self::ExpandType,
|
||||
|
|
|
@ -33,7 +33,7 @@ macro_rules! impl_unary_func {
|
|||
impl_unary_func!(
|
||||
Abs,
|
||||
abs,
|
||||
abs_expand,
|
||||
__expand_abs,
|
||||
Operator::Abs,
|
||||
F16,
|
||||
BF16,
|
||||
|
@ -43,38 +43,65 @@ impl_unary_func!(
|
|||
I64,
|
||||
UInt
|
||||
);
|
||||
impl_unary_func!(Exp, exp, exp_expand, Operator::Exp, F16, BF16, F32, F64);
|
||||
impl_unary_func!(Log, log, log_expand, Operator::Log, F16, BF16, F32, F64);
|
||||
impl_unary_func!(Exp, exp, __expand_exp, Operator::Exp, F16, BF16, F32, F64);
|
||||
impl_unary_func!(Log, log, __expand_log, Operator::Log, F16, BF16, F32, F64);
|
||||
impl_unary_func!(
|
||||
Log1p,
|
||||
log1p,
|
||||
log1p_expand,
|
||||
__expand_log1p,
|
||||
Operator::Log1p,
|
||||
F16,
|
||||
BF16,
|
||||
F32,
|
||||
F64
|
||||
);
|
||||
impl_unary_func!(Cos, cos, cos_expand, Operator::Cos, F16, BF16, F32, F64);
|
||||
impl_unary_func!(Sin, sin, sin_expand, Operator::Sin, F16, BF16, F32, F64);
|
||||
impl_unary_func!(Tanh, tanh, tanh_expand, Operator::Tanh, F16, BF16, F32, F64);
|
||||
impl_unary_func!(Sqrt, sqrt, sqrt_expand, Operator::Sqrt, F16, BF16, F32, F64);
|
||||
impl_unary_func!(Cos, cos, __expand_cos, Operator::Cos, F16, BF16, F32, F64);
|
||||
impl_unary_func!(Sin, sin, __expand_sin, Operator::Sin, F16, BF16, F32, F64);
|
||||
impl_unary_func!(
|
||||
Tanh,
|
||||
tanh,
|
||||
__expand_tanh,
|
||||
Operator::Tanh,
|
||||
F16,
|
||||
BF16,
|
||||
F32,
|
||||
F64
|
||||
);
|
||||
impl_unary_func!(
|
||||
Sqrt,
|
||||
sqrt,
|
||||
__expand_sqrt,
|
||||
Operator::Sqrt,
|
||||
F16,
|
||||
BF16,
|
||||
F32,
|
||||
F64
|
||||
);
|
||||
impl_unary_func!(
|
||||
Floor,
|
||||
floor,
|
||||
floor_expand,
|
||||
__expand_floor,
|
||||
Operator::Floor,
|
||||
F16,
|
||||
BF16,
|
||||
F32,
|
||||
F64
|
||||
);
|
||||
impl_unary_func!(Ceil, ceil, ceil_expand, Operator::Ceil, F16, BF16, F32, F64);
|
||||
impl_unary_func!(Erf, erf, erf_expand, Operator::Erf, F16, BF16, F32, F64);
|
||||
impl_unary_func!(
|
||||
Ceil,
|
||||
ceil,
|
||||
__expand_ceil,
|
||||
Operator::Ceil,
|
||||
F16,
|
||||
BF16,
|
||||
F32,
|
||||
F64
|
||||
);
|
||||
impl_unary_func!(Erf, erf, __expand_erf, Operator::Erf, F16, BF16, F32, F64);
|
||||
impl_unary_func!(
|
||||
Recip,
|
||||
recip,
|
||||
recip_expand,
|
||||
__expand_recip,
|
||||
Operator::Recip,
|
||||
F16,
|
||||
BF16,
|
||||
|
|
|
@ -19,107 +19,143 @@ pub fn subcube_elect_expand<E: CubePrimitive>(context: &mut CubeContext) -> Expa
|
|||
output
|
||||
}
|
||||
|
||||
pub fn subcube_sum<E: CubePrimitive>(_elem: E) -> E {
|
||||
/// Perform a reduce sum operation across all units in a subcube.
|
||||
#[allow(unused_variables)]
|
||||
pub fn subcube_sum<E: CubePrimitive>(value: E) -> E {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn subcube_sum_expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
/// Module containing the expand function for [subcube_sum()].
|
||||
pub mod subcube_sum {
|
||||
use super::*;
|
||||
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
/// Expand method of [subcube_sum()].
|
||||
pub fn __expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
|
||||
context.register(Operation::Subcube(Subcube::Sum(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
|
||||
output
|
||||
context.register(Operation::Subcube(Subcube::Sum(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a reduce prod operation across all units in a subcube.
|
||||
pub fn subcube_prod<E: CubePrimitive>(_elem: E) -> E {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn subcube_prod_expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
/// Module containing the expand function for [subcube_prod()].
|
||||
pub mod subcube_prod {
|
||||
use super::*;
|
||||
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
/// Expand method of [subcube_prod()].
|
||||
pub fn __expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
|
||||
context.register(Operation::Subcube(Subcube::Prod(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
|
||||
output
|
||||
context.register(Operation::Subcube(Subcube::Prod(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a reduce max operation across all units in a subcube.
|
||||
pub fn subcube_max<E: CubePrimitive>(_elem: E) -> E {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn subcube_max_expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
/// Module containing the expand function for [subcube_max()].
|
||||
pub mod subcube_max {
|
||||
use super::*;
|
||||
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
/// Expand method of [subcube_max()].
|
||||
pub fn __expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
|
||||
context.register(Operation::Subcube(Subcube::Max(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
|
||||
output
|
||||
context.register(Operation::Subcube(Subcube::Max(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a reduce min operation across all units in a subcube.
|
||||
pub fn subcube_min<E: CubePrimitive>(_elem: E) -> E {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn subcube_min_expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
/// Module containing the expand function for [subcube_min()].
|
||||
pub mod subcube_min {
|
||||
use super::*;
|
||||
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
/// Expand method of [subcube_min()].
|
||||
pub fn __expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
|
||||
context.register(Operation::Subcube(Subcube::Min(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
|
||||
output
|
||||
context.register(Operation::Subcube(Subcube::Min(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
||||
/// Perform a reduce all operation across all units in a subcube.
|
||||
pub fn subcube_all<E: CubePrimitive>(_elem: E) -> E {
|
||||
unexpanded!()
|
||||
}
|
||||
|
||||
pub fn subcube_all_expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
/// Module containing the expand function for [subcube_all()].
|
||||
pub mod subcube_all {
|
||||
use super::*;
|
||||
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
/// Expand method of [subcube_all()].
|
||||
pub fn __expand<E: CubePrimitive>(
|
||||
context: &mut CubeContext,
|
||||
elem: ExpandElement,
|
||||
) -> ExpandElement {
|
||||
let output = context.create_local(elem.item());
|
||||
|
||||
context.register(Operation::Subcube(Subcube::All(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
let out = *output;
|
||||
let input = *elem;
|
||||
|
||||
output
|
||||
context.register(Operation::Subcube(Subcube::All(UnaryOperator {
|
||||
input,
|
||||
out,
|
||||
})));
|
||||
|
||||
output
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,10 @@ use crate::ir::Synchronization;
|
|||
|
||||
pub fn sync_units() {}
|
||||
|
||||
pub fn sync_units_expand(context: &mut CubeContext) {
|
||||
context.register(Synchronization::SyncUnits)
|
||||
pub mod sync_units {
|
||||
use super::*;
|
||||
|
||||
pub fn __expand(context: &mut CubeContext) {
|
||||
context.register(Synchronization::SyncUnits)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,7 +11,8 @@ pub use crate::runtime::Runtime;
|
|||
|
||||
/// Elements
|
||||
pub use crate::frontend::{
|
||||
Array, ArrayHandle, Float, LaunchArg, Slice, Tensor, TensorArg, UInt, F16, F32, F64, I32, I64,
|
||||
Array, ArrayHandle, Bool, Float, LaunchArg, Slice, SliceMut, Tensor, TensorArg, UInt, F16, F32,
|
||||
F64, I32, I64,
|
||||
};
|
||||
pub use crate::pod::CubeElement;
|
||||
|
||||
|
@ -23,10 +24,7 @@ pub use crate::frontend::{
|
|||
};
|
||||
|
||||
/// Export subcube operations.
|
||||
pub use crate::frontend::{
|
||||
subcube_all, subcube_all_expand, subcube_max, subcube_max_expand, subcube_min,
|
||||
subcube_min_expand, subcube_prod, subcube_prod_expand, subcube_sum, subcube_sum_expand,
|
||||
};
|
||||
pub use crate::frontend::{subcube_all, subcube_max, subcube_min, subcube_prod, subcube_sum};
|
||||
pub use burn_compute::client::ComputeClient;
|
||||
|
||||
pub use crate::frontend::*;
|
||||
|
|
|
@ -64,7 +64,7 @@ pub fn test_simple_1<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
|||
let rhs = client.create(f16::as_bytes(&rhs));
|
||||
let out = client.empty(core::mem::size_of::<f32>() * 256);
|
||||
|
||||
kernel_simple_1_launch::<R>(
|
||||
kernel_simple_1::launch::<R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::new(16, 16, 1),
|
||||
|
|
|
@ -18,7 +18,7 @@ pub fn kernel_without_generics(output: &mut Array<F32>) {
|
|||
pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));
|
||||
|
||||
kernel_with_generics_launch::<F32, R>(
|
||||
kernel_with_generics::launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::default(),
|
||||
|
@ -34,7 +34,7 @@ pub fn test_kernel_with_generics<R: Runtime>(client: ComputeClient<R::Server, R:
|
|||
pub fn test_kernel_without_generics<R: Runtime>(client: ComputeClient<R::Server, R::Channel>) {
|
||||
let handle = client.create(f32::as_bytes(&[0.0, 1.0]));
|
||||
|
||||
kernel_without_generics_launch::<R>(
|
||||
kernel_without_generics::launch::<R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::default(),
|
||||
|
|
|
@ -30,7 +30,7 @@ pub fn test_slice_select<R: Runtime>(client: ComputeClient<R::Server, R::Channel
|
|||
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
|
||||
let output = client.empty(core::mem::size_of::<f32>());
|
||||
|
||||
slice_select_launch::<F32, R>(
|
||||
slice_select::launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::new(1, 1, 1),
|
||||
|
@ -48,7 +48,7 @@ pub fn test_slice_len<R: Runtime>(client: ComputeClient<R::Server, R::Channel>)
|
|||
let input = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
|
||||
let output = client.empty(core::mem::size_of::<u32>());
|
||||
|
||||
slice_len_launch::<F32, R>(
|
||||
slice_len::launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::new(1, 1, 1),
|
||||
|
@ -66,7 +66,7 @@ pub fn test_slice_assign<R: Runtime>(client: ComputeClient<R::Server, R::Channel
|
|||
let input = client.create(f32::as_bytes(&[15.0]));
|
||||
let output = client.create(f32::as_bytes(&[0.0, 1.0, 2.0, 3.0, 4.0]));
|
||||
|
||||
slice_assign_launch::<F32, R>(
|
||||
slice_assign::launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::new(1, 1, 1),
|
||||
|
|
|
@ -49,7 +49,7 @@ pub fn test_subcube_sum<TestRuntime: Runtime>(
|
|||
&[17.0, 5.0, 7.0, 1.0],
|
||||
client.clone(),
|
||||
|cube_count, cube_dim, handle| {
|
||||
kernel_sum_launch::<F32, TestRuntime>(client.clone(), cube_count, cube_dim, handle)
|
||||
kernel_sum::launch::<F32, TestRuntime>(client.clone(), cube_count, cube_dim, handle)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -62,7 +62,7 @@ pub fn test_subcube_prod<TestRuntime: Runtime>(
|
|||
&[140.0, 5.0, 7.0, 1.0],
|
||||
client.clone(),
|
||||
|cube_dim, settings, handle| {
|
||||
kernel_prod_launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
|
||||
kernel_prod::launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -74,7 +74,7 @@ pub fn test_subcube_max<TestRuntime: Runtime>(
|
|||
&[7.0, 5.0, 7.0, 1.0],
|
||||
client.clone(),
|
||||
|cube_dim, settings, handle| {
|
||||
kernel_max_launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
|
||||
kernel_max::launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ pub fn test_subcube_min<TestRuntime: Runtime>(
|
|||
&[1.0, 5.0, 7.0, 1.0],
|
||||
client.clone(),
|
||||
|cube_dim, settings, handle| {
|
||||
kernel_min_launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
|
||||
kernel_min::launch::<F32, TestRuntime>(client.clone(), cube_dim, settings, handle)
|
||||
},
|
||||
);
|
||||
}
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn array_read_write<T: Numeric>(array_size: Comptime<u32>) {
|
||||
pub fn array_read_write<T: Numeric>(array_size: Comptime<u32>) {
|
||||
let mut array = Array::<T>::new(array_size);
|
||||
array[0] = T::from_int(3);
|
||||
let _ = array[0];
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn array_to_vectorized_variable<T: Numeric>() -> T {
|
||||
pub fn array_to_vectorized_variable<T: Numeric>() -> T {
|
||||
let mut array = Array::<T>::new(2);
|
||||
array[0] = T::from_int(0);
|
||||
array[1] = T::from_int(1);
|
||||
|
@ -16,19 +16,19 @@ fn array_to_vectorized_variable<T: Numeric>() -> T {
|
|||
}
|
||||
|
||||
#[cube]
|
||||
fn array_of_one_to_vectorized_variable<T: Numeric>() -> T {
|
||||
pub fn array_of_one_to_vectorized_variable<T: Numeric>() -> T {
|
||||
let mut array = Array::<T>::new(1);
|
||||
array[0] = T::from_int(3);
|
||||
array.to_vectorized(Comptime::new(UInt::new(1)))
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn array_add_assign_simple(array: &mut Array<UInt>) {
|
||||
pub fn array_add_assign_simple(array: &mut Array<UInt>) {
|
||||
array[UInt::new(1)] += UInt::new(1);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn array_add_assign_expr(array: &mut Array<UInt>) {
|
||||
pub fn array_add_assign_expr(array: &mut Array<UInt>) {
|
||||
array[UInt::new(1) + UInt::new(5)] += UInt::new(1);
|
||||
}
|
||||
|
||||
|
@ -45,7 +45,7 @@ mod tests {
|
|||
fn cube_support_array() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
array_read_write_expand::<ElemType>(&mut context, 512);
|
||||
array_read_write::__expand::<ElemType>(&mut context, 512);
|
||||
assert_eq!(
|
||||
context.into_scope().operations,
|
||||
inline_macro_ref_read_write()
|
||||
|
@ -57,7 +57,7 @@ mod tests {
|
|||
let mut context = CubeContext::root();
|
||||
let array = context.input(0, Item::new(Elem::UInt));
|
||||
|
||||
array_add_assign_simple_expand(&mut context, array.into());
|
||||
array_add_assign_simple::__expand(&mut context, array.into());
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(scope.operations, inline_macro_array_add_assign_simple());
|
||||
|
@ -67,7 +67,7 @@ mod tests {
|
|||
fn cube_array_to_vectorized() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
array_to_vectorized_variable_expand::<ElemType>(&mut context);
|
||||
array_to_vectorized_variable::__expand::<ElemType>(&mut context);
|
||||
assert_eq!(
|
||||
context.into_scope().operations,
|
||||
inline_macro_ref_to_vectorized()
|
||||
|
@ -78,7 +78,7 @@ mod tests {
|
|||
fn cube_array_of_one_to_vectorized() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
array_of_one_to_vectorized_variable_expand::<ElemType>(&mut context);
|
||||
array_of_one_to_vectorized_variable::__expand::<ElemType>(&mut context);
|
||||
assert_eq!(
|
||||
context.into_scope().operations,
|
||||
inline_macro_ref_one_to_vectorized()
|
||||
|
@ -110,7 +110,7 @@ mod tests {
|
|||
let mut context = CubeContext::root();
|
||||
let array = context.input(0, Item::new(Elem::UInt));
|
||||
|
||||
array_add_assign_expr_expand(&mut context, array.into());
|
||||
array_add_assign_expr::__expand(&mut context, array.into());
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(scope.operations, inline_macro_array_add_assign_expr());
|
||||
|
|
|
@ -1,27 +1,27 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn mut_assign() {
|
||||
pub fn mut_assign() {
|
||||
let mut x = UInt::new(0);
|
||||
x += UInt::new(1);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn mut_assign_input(y: UInt) -> UInt {
|
||||
pub fn mut_assign_input(y: UInt) -> UInt {
|
||||
let mut x = y;
|
||||
x += UInt::new(1);
|
||||
y + UInt::new(2)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn assign_mut_input(mut y: UInt) -> UInt {
|
||||
pub fn assign_mut_input(mut y: UInt) -> UInt {
|
||||
let x = y;
|
||||
y += UInt::new(1);
|
||||
x + UInt::new(2)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn assign_vectorized(y: UInt) -> UInt {
|
||||
pub fn assign_vectorized(y: UInt) -> UInt {
|
||||
let vectorization_factor = Comptime::vectorization(&y);
|
||||
let x = UInt::vectorized(1, Comptime::get(vectorization_factor));
|
||||
x + y
|
||||
|
@ -38,7 +38,7 @@ mod tests {
|
|||
fn cube_mut_assign_test() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
mut_assign_expand(&mut context);
|
||||
mut_assign::__expand(&mut context);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -53,7 +53,7 @@ mod tests {
|
|||
|
||||
let y = context.create_local(Item::new(UInt::as_elem()));
|
||||
|
||||
mut_assign_input_expand(&mut context, y);
|
||||
mut_assign_input::__expand(&mut context, y);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -68,7 +68,7 @@ mod tests {
|
|||
|
||||
let y = context.create_local(Item::new(UInt::as_elem()));
|
||||
|
||||
assign_mut_input_expand(&mut context, y);
|
||||
assign_mut_input::__expand(&mut context, y);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -83,7 +83,7 @@ mod tests {
|
|||
|
||||
let y = context.create_local(Item::vectorized(UInt::as_elem(), 4));
|
||||
|
||||
assign_vectorized_expand(&mut context, y);
|
||||
assign_vectorized::__expand(&mut context, y);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use burn_cube::{
|
||||
cube,
|
||||
frontend::{Cast, Numeric, UInt, F32, I32},
|
||||
frontend::{Bool, Cast, Numeric, UInt, F32, I32},
|
||||
};
|
||||
|
||||
// From float
|
||||
|
@ -26,7 +26,7 @@ pub fn float_to_uint(x: F32) {
|
|||
#[allow(clippy::overly_complex_bool_expr)]
|
||||
pub fn float_to_bool(x: F32) {
|
||||
let y = x + F32::from_int(2);
|
||||
let _ = bool::cast_from(y) || true;
|
||||
let _ = Bool::cast_from(y) || true;
|
||||
}
|
||||
|
||||
// From int
|
||||
|
@ -53,7 +53,7 @@ pub fn int_to_uint(x: I32) {
|
|||
#[allow(clippy::overly_complex_bool_expr)]
|
||||
pub fn int_to_bool(x: I32) {
|
||||
let y = x + I32::from_int(2);
|
||||
let _ = bool::cast_from(y) || true;
|
||||
let _ = Bool::cast_from(y) || true;
|
||||
}
|
||||
|
||||
// // From uint
|
||||
|
@ -80,27 +80,27 @@ pub fn uint_to_uint(x: UInt) {
|
|||
#[allow(clippy::overly_complex_bool_expr)]
|
||||
pub fn uint_to_bool(x: UInt) {
|
||||
let y = x + UInt::from_int(2);
|
||||
let _ = bool::cast_from(y) || true;
|
||||
let _ = Bool::cast_from(y) || true;
|
||||
}
|
||||
|
||||
// From bool
|
||||
#[cube]
|
||||
#[allow(clippy::overly_complex_bool_expr)]
|
||||
pub fn bool_to_float(x: bool) {
|
||||
pub fn bool_to_float(x: Bool) {
|
||||
let y = x && false;
|
||||
let _ = F32::cast_from(y) + F32::from_int(34);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
#[allow(clippy::overly_complex_bool_expr)]
|
||||
pub fn bool_to_int(x: bool) {
|
||||
pub fn bool_to_int(x: Bool) {
|
||||
let y = x && false;
|
||||
let _ = I32::cast_from(y) + I32::from_int(34);
|
||||
}
|
||||
|
||||
#[cube]
|
||||
#[allow(clippy::overly_complex_bool_expr)]
|
||||
pub fn bool_to_uint(x: bool) {
|
||||
pub fn bool_to_uint(x: Bool) {
|
||||
let y = x && false;
|
||||
let _ = UInt::cast_from(y) + UInt::from_int(34);
|
||||
}
|
||||
|
@ -108,9 +108,9 @@ pub fn bool_to_uint(x: bool) {
|
|||
#[cube]
|
||||
#[allow(clippy::overly_complex_bool_expr)]
|
||||
#[allow(clippy::useless_conversion)]
|
||||
pub fn bool_to_bool(x: bool) {
|
||||
pub fn bool_to_bool(x: Bool) {
|
||||
let y = x && false;
|
||||
let _ = bool::cast_from(y) || true;
|
||||
let _ = Bool::cast_from(y) || true;
|
||||
}
|
||||
|
||||
mod tests {
|
||||
|
@ -122,7 +122,7 @@ mod tests {
|
|||
};
|
||||
|
||||
macro_rules! cast_test {
|
||||
($name:ident, $module:ident, $from:expr, $to:expr) => {
|
||||
($name:ident, $module:expr, $from:expr, $to:expr) => {
|
||||
#[test]
|
||||
fn $name() {
|
||||
let mut context = CubeContext::root();
|
||||
|
@ -142,112 +142,112 @@ mod tests {
|
|||
|
||||
cast_test!(
|
||||
cube_float_to_float_test,
|
||||
float_to_float_expand,
|
||||
float_to_float::__expand,
|
||||
Item::new(F32::as_elem()),
|
||||
Item::new(F32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_float_to_int_test,
|
||||
float_to_int_expand,
|
||||
float_to_int::__expand,
|
||||
Item::new(F32::as_elem()),
|
||||
Item::new(I32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_float_to_uint_test,
|
||||
float_to_uint_expand,
|
||||
float_to_uint::__expand,
|
||||
Item::new(F32::as_elem()),
|
||||
Item::new(Elem::UInt)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_float_to_bool_test,
|
||||
float_to_bool_expand,
|
||||
float_to_bool::__expand,
|
||||
Item::new(F32::as_elem()),
|
||||
Item::new(Elem::Bool)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_int_to_float_test,
|
||||
int_to_float_expand,
|
||||
int_to_float::__expand,
|
||||
Item::new(I32::as_elem()),
|
||||
Item::new(F32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_int_to_int_test,
|
||||
int_to_int_expand,
|
||||
int_to_int::__expand,
|
||||
Item::new(I32::as_elem()),
|
||||
Item::new(I32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_int_to_uint_test,
|
||||
int_to_uint_expand,
|
||||
int_to_uint::__expand,
|
||||
Item::new(I32::as_elem()),
|
||||
Item::new(Elem::UInt)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_int_to_bool_test,
|
||||
int_to_bool_expand,
|
||||
int_to_bool::__expand,
|
||||
Item::new(I32::as_elem()),
|
||||
Item::new(Elem::Bool)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_uint_to_float_test,
|
||||
uint_to_float_expand,
|
||||
uint_to_float::__expand,
|
||||
Item::new(Elem::UInt),
|
||||
Item::new(F32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_uint_to_int_test,
|
||||
uint_to_int_expand,
|
||||
uint_to_int::__expand,
|
||||
Item::new(Elem::UInt),
|
||||
Item::new(I32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_uint_to_uint_test,
|
||||
uint_to_uint_expand,
|
||||
uint_to_uint::__expand,
|
||||
Item::new(Elem::UInt),
|
||||
Item::new(Elem::UInt)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_uint_to_bool_test,
|
||||
uint_to_bool_expand,
|
||||
uint_to_bool::__expand,
|
||||
Item::new(Elem::UInt),
|
||||
Item::new(Elem::Bool)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_bool_to_float_test,
|
||||
bool_to_float_expand,
|
||||
bool_to_float::__expand,
|
||||
Item::new(Elem::Bool),
|
||||
Item::new(F32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_bool_to_int_test,
|
||||
bool_to_int_expand,
|
||||
bool_to_int::__expand,
|
||||
Item::new(Elem::Bool),
|
||||
Item::new(I32::as_elem())
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_bool_to_uint_test,
|
||||
bool_to_uint_expand,
|
||||
bool_to_uint::__expand,
|
||||
Item::new(Elem::Bool),
|
||||
Item::new(Elem::UInt)
|
||||
);
|
||||
|
||||
cast_test!(
|
||||
cube_bool_to_bool_test,
|
||||
bool_to_bool_expand,
|
||||
bool_to_bool::__expand,
|
||||
Item::new(Elem::Bool),
|
||||
Item::new(Elem::Bool)
|
||||
);
|
||||
|
|
|
@ -46,7 +46,7 @@ mod tests {
|
|||
|
||||
let input = context.create_local(item);
|
||||
|
||||
cast_float_kind_expand::<F64, F32>(&mut context, input);
|
||||
cast_float_kind::__expand::<F64, F32>(&mut context, input);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float());
|
||||
|
@ -59,7 +59,7 @@ mod tests {
|
|||
|
||||
let input = context.create_local(item);
|
||||
|
||||
cast_int_kind_expand::<I32, I64>(&mut context, input);
|
||||
cast_int_kind::__expand::<I32, I64>(&mut context, input);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int());
|
||||
|
@ -72,7 +72,7 @@ mod tests {
|
|||
|
||||
let input = context.create_local(item);
|
||||
|
||||
cast_numeric_to_kind_expand::<I32, I64>(&mut context, input);
|
||||
cast_numeric_to_kind::__expand::<I32, I64>(&mut context, input);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int());
|
||||
|
@ -85,7 +85,7 @@ mod tests {
|
|||
|
||||
let input = context.create_local(item);
|
||||
|
||||
cast_int_to_numeric_expand::<I32, I64>(&mut context, input);
|
||||
cast_int_to_numeric::__expand::<I32, I64>(&mut context, input);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int());
|
||||
|
|
|
@ -122,7 +122,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
comptime_if_else_expand::<ElemType>(&mut context, lhs, true);
|
||||
comptime_if_else::__expand::<ElemType>(&mut context, lhs, true);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -137,7 +137,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
comptime_if_expr_expand::<ElemType>(&mut context, lhs, UInt::new(4), UInt::new(5));
|
||||
comptime_if_expr::__expand::<ElemType>(&mut context, lhs, UInt::new(4), UInt::new(5));
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -152,7 +152,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
comptime_if_else_expand::<ElemType>(&mut context, lhs, false);
|
||||
comptime_if_else::__expand::<ElemType>(&mut context, lhs, false);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -167,12 +167,12 @@ mod tests {
|
|||
for cond2 in [false, true] {
|
||||
let mut context1 = CubeContext::root();
|
||||
let lhs = context1.create_local(Item::new(ElemType::as_elem()));
|
||||
comptime_else_then_if_expand::<ElemType>(&mut context1, lhs, cond1, cond2);
|
||||
comptime_else_then_if::__expand::<ElemType>(&mut context1, lhs, cond1, cond2);
|
||||
let scope1 = context1.into_scope();
|
||||
|
||||
let mut context2 = CubeContext::root();
|
||||
let lhs = context2.create_local(Item::new(ElemType::as_elem()));
|
||||
comptime_elsif_expand::<ElemType>(&mut context2, lhs, cond1, cond2);
|
||||
comptime_elsif::__expand::<ElemType>(&mut context2, lhs, cond1, cond2);
|
||||
let scope2 = context2.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -188,7 +188,7 @@ mod tests {
|
|||
for cond in [false, true] {
|
||||
let mut context = CubeContext::root();
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
comptime_elsif_with_runtime1_expand::<ElemType>(&mut context, lhs, cond);
|
||||
comptime_elsif_with_runtime1::__expand::<ElemType>(&mut context, lhs, cond);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -203,7 +203,7 @@ mod tests {
|
|||
for cond in [false, true] {
|
||||
let mut context = CubeContext::root();
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
comptime_elsif_with_runtime2_expand::<ElemType>(&mut context, lhs, cond);
|
||||
comptime_elsif_with_runtime2::__expand::<ElemType>(&mut context, lhs, cond);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -227,8 +227,8 @@ mod tests {
|
|||
bound: 4,
|
||||
};
|
||||
|
||||
comptime_with_map_bool_expand::<ElemType>(&mut context1, comptime_state_true);
|
||||
comptime_with_map_bool_expand::<ElemType>(&mut context2, comptime_state_false);
|
||||
comptime_with_map_bool::__expand::<ElemType>(&mut context1, comptime_state_true);
|
||||
comptime_with_map_bool::__expand::<ElemType>(&mut context2, comptime_state_false);
|
||||
|
||||
let scope1 = context1.into_scope();
|
||||
let scope2 = context2.into_scope();
|
||||
|
@ -248,7 +248,7 @@ mod tests {
|
|||
bound: 4,
|
||||
};
|
||||
|
||||
comptime_with_map_uint_expand::<ElemType>(&mut context, comptime_state);
|
||||
comptime_with_map_uint::__expand::<ElemType>(&mut context, comptime_state);
|
||||
|
||||
let scope = context.into_scope();
|
||||
|
||||
|
|
|
@ -42,12 +42,12 @@ impl<C: Float> CombinedTraitFunctionGeneric<C> for Test {
|
|||
}
|
||||
|
||||
#[cube]
|
||||
fn simple<C: Float>(lhs: C, rhs: C) -> C {
|
||||
pub fn simple<C: Float>(lhs: C, rhs: C) -> C {
|
||||
lhs + rhs
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn with_cast<C: Float, O: Numeric>(lhs: C, rhs: C) -> O {
|
||||
pub fn with_cast<C: Float, O: Numeric>(lhs: C, rhs: C) -> O {
|
||||
O::cast_from(lhs + rhs)
|
||||
}
|
||||
|
||||
|
@ -62,7 +62,7 @@ mod tests {
|
|||
let lhs = context.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
<Test as FunctionGeneric>::test_expand::<F32>(&mut context, lhs, rhs);
|
||||
<Test as FunctionGeneric>::__expand_test::<F32>(&mut context, lhs, rhs);
|
||||
|
||||
assert_eq!(simple_scope(), context.into_scope());
|
||||
}
|
||||
|
@ -73,7 +73,7 @@ mod tests {
|
|||
let lhs = context.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
<Test as TraitGeneric<F32>>::test_expand(&mut context, lhs, rhs);
|
||||
<Test as TraitGeneric<F32>>::__expand_test(&mut context, lhs, rhs);
|
||||
|
||||
assert_eq!(simple_scope(), context.into_scope());
|
||||
}
|
||||
|
@ -84,7 +84,7 @@ mod tests {
|
|||
let lhs = context.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
<Test as CombinedTraitFunctionGeneric<F32>>::test_expand::<UInt>(&mut context, lhs, rhs);
|
||||
<Test as CombinedTraitFunctionGeneric<F32>>::__expand_test::<UInt>(&mut context, lhs, rhs);
|
||||
|
||||
assert_eq!(with_cast_scope(), context.into_scope());
|
||||
}
|
||||
|
@ -94,7 +94,7 @@ mod tests {
|
|||
let lhs = context_ref.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context_ref.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
simple_expand::<F32>(&mut context_ref, lhs, rhs);
|
||||
simple::__expand::<F32>(&mut context_ref, lhs, rhs);
|
||||
context_ref.into_scope()
|
||||
}
|
||||
|
||||
|
@ -103,7 +103,7 @@ mod tests {
|
|||
let lhs = context_ref.create_local(Item::new(F32::as_elem()));
|
||||
let rhs = context_ref.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
with_cast_expand::<F32, UInt>(&mut context_ref, lhs, rhs);
|
||||
with_cast::__expand::<F32, UInt>(&mut context_ref, lhs, rhs);
|
||||
context_ref.into_scope()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -30,7 +30,7 @@ mod tests {
|
|||
let rhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let end = 4u32.into();
|
||||
|
||||
for_loop_expand::<ElemType>(&mut context, lhs.into(), rhs, end, unroll);
|
||||
for_loop::__expand::<ElemType>(&mut context, lhs.into(), rhs, end, unroll);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll));
|
||||
|
@ -45,7 +45,7 @@ mod tests {
|
|||
let rhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let end = 4u32.into();
|
||||
|
||||
for_loop_expand::<ElemType>(&mut context, lhs.into(), rhs, end, unroll);
|
||||
for_loop::__expand::<ElemType>(&mut context, lhs.into(), rhs, end, unroll);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(unroll));
|
||||
|
|
|
@ -59,12 +59,12 @@ mod tests {
|
|||
fn cube_call_equivalent_to_no_call_no_arg_test() {
|
||||
let mut caller_context = CubeContext::root();
|
||||
let x = caller_context.create_local(Item::new(Elem::UInt));
|
||||
caller_no_arg_expand(&mut caller_context, x);
|
||||
caller_no_arg::__expand(&mut caller_context, x);
|
||||
let caller_scope = caller_context.into_scope();
|
||||
|
||||
let mut no_call_context = CubeContext::root();
|
||||
let x = no_call_context.create_local(Item::new(Elem::UInt));
|
||||
no_call_no_arg_expand(&mut no_call_context, x);
|
||||
no_call_no_arg::__expand(&mut no_call_context, x);
|
||||
let no_call_scope = no_call_context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -78,12 +78,12 @@ mod tests {
|
|||
let mut caller_context = CubeContext::root();
|
||||
|
||||
let x = caller_context.create_local(Item::new(Elem::UInt));
|
||||
caller_with_arg_expand(&mut caller_context, x);
|
||||
caller_with_arg::__expand(&mut caller_context, x);
|
||||
let caller_scope = caller_context.into_scope();
|
||||
|
||||
let mut no_call_context = CubeContext::root();
|
||||
let x = no_call_context.create_local(Item::new(Elem::UInt));
|
||||
no_call_with_arg_expand(&mut no_call_context, x);
|
||||
no_call_with_arg::__expand(&mut no_call_context, x);
|
||||
let no_call_scope = no_call_context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -97,12 +97,12 @@ mod tests {
|
|||
let mut caller_context = CubeContext::root();
|
||||
type ElemType = I64;
|
||||
let x = caller_context.create_local(Item::new(ElemType::as_elem()));
|
||||
caller_with_generics_expand::<ElemType>(&mut caller_context, x);
|
||||
caller_with_generics::__expand::<ElemType>(&mut caller_context, x);
|
||||
let caller_scope = caller_context.into_scope();
|
||||
|
||||
let mut no_call_context = CubeContext::root();
|
||||
let x = no_call_context.create_local(Item::new(ElemType::as_elem()));
|
||||
no_call_with_generics_expand::<ElemType>(&mut no_call_context, x);
|
||||
no_call_with_generics::__expand::<ElemType>(&mut no_call_context, x);
|
||||
let no_call_scope = no_call_context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
@ -20,7 +20,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(F32::as_elem()));
|
||||
|
||||
generic_kernel_expand::<F32>(&mut context, lhs);
|
||||
generic_kernel::__expand::<F32>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_float());
|
||||
|
@ -32,7 +32,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(I32::as_elem()));
|
||||
|
||||
generic_kernel_expand::<I32>(&mut context, lhs);
|
||||
generic_kernel::__expand::<I32>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_int());
|
||||
|
|
|
@ -52,7 +52,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
if_greater_expand::<ElemType>(&mut context, lhs);
|
||||
if_greater::__expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_if());
|
||||
|
@ -64,7 +64,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
if_then_else_expand::<ElemType>(&mut context, lhs);
|
||||
if_then_else::__expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -79,7 +79,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
elsif_expand::<ElemType>(&mut context, lhs);
|
||||
elsif::__expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_elsif());
|
||||
|
|
|
@ -25,7 +25,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
literal_expand::<ElemType>(&mut context, lhs);
|
||||
literal::__expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
|
||||
|
@ -37,7 +37,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
literal_float_no_decimals_expand::<ElemType>(&mut context, lhs);
|
||||
literal_float_no_decimals::__expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
|
||||
|
|
|
@ -42,7 +42,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
while_not_expand::<ElemType>(&mut context, lhs);
|
||||
while_not::__expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false));
|
||||
|
@ -54,7 +54,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
manual_loop_break_expand::<ElemType>(&mut context, lhs);
|
||||
manual_loop_break::__expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(false));
|
||||
|
@ -66,7 +66,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
loop_with_return_expand::<ElemType>(&mut context, lhs);
|
||||
loop_with_return::__expand::<ElemType>(&mut context, lhs);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref(true));
|
||||
|
|
|
@ -33,12 +33,12 @@ mod tests {
|
|||
fn cube_call_equivalent_to_no_call_no_arg_test() {
|
||||
let mut caller_context = CubeContext::root();
|
||||
let x = caller_context.create_local(Item::new(ElemType::as_elem()));
|
||||
here::caller_expand::<ElemType>(&mut caller_context, x);
|
||||
here::caller::__expand::<ElemType>(&mut caller_context, x);
|
||||
let caller_scope = caller_context.into_scope();
|
||||
|
||||
let mut no_call_context = CubeContext::root();
|
||||
let x = no_call_context.create_local(Item::new(ElemType::as_elem()));
|
||||
here::no_call_ref_expand::<ElemType>(&mut no_call_context, x);
|
||||
here::no_call_ref::__expand::<ElemType>(&mut no_call_context, x);
|
||||
let no_call_scope = no_call_context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
@ -1,192 +1,192 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn add_op<T: Numeric>(a: T, b: T) -> T {
|
||||
pub fn add_op<T: Numeric>(a: T, b: T) -> T {
|
||||
a + b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn sub_op<T: Numeric>(a: T, b: T) -> T {
|
||||
pub fn sub_op<T: Numeric>(a: T, b: T) -> T {
|
||||
a - b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn mul_op<T: Numeric>(a: T, b: T) -> T {
|
||||
pub fn mul_op<T: Numeric>(a: T, b: T) -> T {
|
||||
a * b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn div_op<T: Numeric>(a: T, b: T) -> T {
|
||||
pub fn div_op<T: Numeric>(a: T, b: T) -> T {
|
||||
a / b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn abs_op<T: Numeric>(a: T) -> T {
|
||||
pub fn abs_op<T: Numeric>(a: T) -> T {
|
||||
T::abs(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn exp_op<F: Float>(a: F) -> F {
|
||||
pub fn exp_op<F: Float>(a: F) -> F {
|
||||
F::exp(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn log_op<F: Float>(a: F) -> F {
|
||||
pub fn log_op<F: Float>(a: F) -> F {
|
||||
F::log(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn log1p_op<F: Float>(a: F) -> F {
|
||||
pub fn log1p_op<F: Float>(a: F) -> F {
|
||||
F::log1p(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn cos_op<F: Float>(a: F) -> F {
|
||||
pub fn cos_op<F: Float>(a: F) -> F {
|
||||
F::cos(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn sin_op<F: Float>(a: F) -> F {
|
||||
pub fn sin_op<F: Float>(a: F) -> F {
|
||||
F::sin(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn tanh_op<F: Float>(a: F) -> F {
|
||||
pub fn tanh_op<F: Float>(a: F) -> F {
|
||||
F::tanh(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn powf_op<F: Float>(a: F, b: F) -> F {
|
||||
pub fn powf_op<F: Float>(a: F, b: F) -> F {
|
||||
F::powf(a, b)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn sqrt_op<F: Float>(a: F) -> F {
|
||||
pub fn sqrt_op<F: Float>(a: F) -> F {
|
||||
F::sqrt(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn floor_op<F: Float>(a: F) -> F {
|
||||
pub fn floor_op<F: Float>(a: F) -> F {
|
||||
F::floor(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn ceil_op<F: Float>(a: F) -> F {
|
||||
pub fn ceil_op<F: Float>(a: F) -> F {
|
||||
F::ceil(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn erf_op<F: Float>(a: F) -> F {
|
||||
pub fn erf_op<F: Float>(a: F) -> F {
|
||||
F::erf(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn recip_op<F: Float>(a: F) -> F {
|
||||
pub fn recip_op<F: Float>(a: F) -> F {
|
||||
F::recip(a)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn equal_op<T: CubePrimitive>(a: T, b: T) -> bool {
|
||||
pub fn equal_op<T: CubePrimitive>(a: T, b: T) -> bool {
|
||||
a == b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn not_equal_op<T: CubePrimitive>(a: T, b: T) -> bool {
|
||||
pub fn not_equal_op<T: CubePrimitive>(a: T, b: T) -> bool {
|
||||
a != b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn lower_op<T: Numeric>(a: T, b: T) -> bool {
|
||||
pub fn lower_op<T: Numeric>(a: T, b: T) -> bool {
|
||||
a < b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn greater_op<T: Numeric>(a: T, b: T) -> bool {
|
||||
pub fn greater_op<T: Numeric>(a: T, b: T) -> bool {
|
||||
a > b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn lower_equal_op<T: Numeric>(a: T, b: T) -> bool {
|
||||
pub fn lower_equal_op<T: Numeric>(a: T, b: T) -> bool {
|
||||
a <= b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn greater_equal_op<T: Numeric>(a: T, b: T) -> bool {
|
||||
pub fn greater_equal_op<T: Numeric>(a: T, b: T) -> bool {
|
||||
a >= b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn modulo_op(a: UInt, b: UInt) -> UInt {
|
||||
pub fn modulo_op(a: UInt, b: UInt) -> UInt {
|
||||
a % b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn remainder_op<T: Numeric>(a: T, b: T) -> T {
|
||||
pub fn remainder_op<T: Numeric>(a: T, b: T) -> T {
|
||||
T::rem(a, b)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn max_op<T: Numeric>(a: T, b: T) -> T {
|
||||
pub fn max_op<T: Numeric>(a: T, b: T) -> T {
|
||||
T::max(a, b)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn min_op<T: Numeric>(a: T, b: T) -> T {
|
||||
pub fn min_op<T: Numeric>(a: T, b: T) -> T {
|
||||
T::min(a, b)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn and_op(a: bool, b: bool) -> bool {
|
||||
pub fn and_op(a: bool, b: bool) -> bool {
|
||||
a && b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn or_op(a: bool, b: bool) -> bool {
|
||||
pub fn or_op(a: bool, b: bool) -> bool {
|
||||
a || b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn not_op(a: bool) -> bool {
|
||||
pub fn not_op(a: bool) -> bool {
|
||||
!a
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn bitand_op(a: UInt, b: UInt) -> UInt {
|
||||
pub fn bitand_op(a: UInt, b: UInt) -> UInt {
|
||||
a & b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn bitxor_op(a: UInt, b: UInt) -> UInt {
|
||||
pub fn bitxor_op(a: UInt, b: UInt) -> UInt {
|
||||
a ^ b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn shl_op(a: UInt, b: UInt) -> UInt {
|
||||
pub fn shl_op(a: UInt, b: UInt) -> UInt {
|
||||
a << b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn shr_op(a: UInt, b: UInt) -> UInt {
|
||||
pub fn shr_op(a: UInt, b: UInt) -> UInt {
|
||||
a >> b
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn add_assign_op<T: Numeric>(mut a: T, b: T) {
|
||||
pub fn add_assign_op<T: Numeric>(mut a: T, b: T) {
|
||||
a += b;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn sub_assign_op<T: Numeric>(mut a: T, b: T) {
|
||||
pub fn sub_assign_op<T: Numeric>(mut a: T, b: T) {
|
||||
a -= b;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn mul_assign_op<T: Numeric>(mut a: T, b: T) {
|
||||
pub fn mul_assign_op<T: Numeric>(mut a: T, b: T) {
|
||||
a *= b;
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn div_assign_op<T: Numeric>(mut a: T, b: T) {
|
||||
pub fn div_assign_op<T: Numeric>(mut a: T, b: T) {
|
||||
a /= b;
|
||||
}
|
||||
|
||||
|
@ -195,14 +195,14 @@ mod tests {
|
|||
use burn_cube::ir::{Elem, FloatKind, Item};
|
||||
|
||||
macro_rules! binary_test {
|
||||
($test_name:ident, $op_expand:ident, $op_name:expr, $func:ident) => {
|
||||
($test_name:ident, $op_expand:expr, $op_name:expr, $func:ident) => {
|
||||
#[test]
|
||||
fn $test_name() {
|
||||
let mut context = CubeContext::root();
|
||||
let x = context.create_local(Item::new(Elem::Float(FloatKind::F32)));
|
||||
let y = context.create_local(Item::new(Elem::Float(FloatKind::F32)));
|
||||
|
||||
$op_expand::<F32>(&mut context, x, y);
|
||||
$op_expand(&mut context, x, y);
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
|
@ -213,13 +213,13 @@ mod tests {
|
|||
}
|
||||
|
||||
macro_rules! unary_test {
|
||||
($test_name:ident, $op_expand:ident, $op_name:expr) => {
|
||||
($test_name:ident, $op_expand:expr, $op_name:expr) => {
|
||||
#[test]
|
||||
fn $test_name() {
|
||||
let mut context = CubeContext::root();
|
||||
let x = context.create_local(Item::new(Elem::Float(FloatKind::F32)));
|
||||
|
||||
$op_expand::<F32>(&mut context, x);
|
||||
$op_expand(&mut context, x);
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
|
@ -230,7 +230,7 @@ mod tests {
|
|||
}
|
||||
|
||||
macro_rules! binary_boolean_test {
|
||||
($test_name:ident, $op_expand:ident, $op_name:expr) => {
|
||||
($test_name:ident, $op_expand:expr, $op_name:expr) => {
|
||||
#[test]
|
||||
fn $test_name() {
|
||||
let mut context = CubeContext::root();
|
||||
|
@ -248,7 +248,7 @@ mod tests {
|
|||
}
|
||||
|
||||
macro_rules! binary_uint_test {
|
||||
($test_name:ident, $op_expand:ident, $op_name:expr) => {
|
||||
($test_name:ident, $op_expand:expr, $op_name:expr) => {
|
||||
#[test]
|
||||
fn $test_name() {
|
||||
let mut context = CubeContext::root();
|
||||
|
@ -265,75 +265,90 @@ mod tests {
|
|||
};
|
||||
}
|
||||
|
||||
binary_test!(cube_can_add, add_op_expand, "Add", ref_ops_binary);
|
||||
binary_test!(cube_can_sub, sub_op_expand, "Sub", ref_ops_binary);
|
||||
binary_test!(cube_can_mul, mul_op_expand, "Mul", ref_ops_binary);
|
||||
binary_test!(cube_can_div, div_op_expand, "Div", ref_ops_binary);
|
||||
unary_test!(cube_can_abs, abs_op_expand, "Abs");
|
||||
unary_test!(cube_can_exp, exp_op_expand, "Exp");
|
||||
unary_test!(cube_can_log, log_op_expand, "Log");
|
||||
unary_test!(cube_can_log1p, log1p_op_expand, "Log1p");
|
||||
unary_test!(cube_can_cos, cos_op_expand, "Cos");
|
||||
unary_test!(cube_can_sin, sin_op_expand, "Sin");
|
||||
unary_test!(cube_can_tanh, tanh_op_expand, "Tanh");
|
||||
binary_test!(cube_can_powf, powf_op_expand, "Powf", ref_ops_binary);
|
||||
unary_test!(cube_can_sqrt, sqrt_op_expand, "Sqrt");
|
||||
unary_test!(cube_can_erf, erf_op_expand, "Erf");
|
||||
unary_test!(cube_can_recip, recip_op_expand, "Recip");
|
||||
unary_test!(cube_can_floor, floor_op_expand, "Floor");
|
||||
unary_test!(cube_can_ceil, ceil_op_expand, "Ceil");
|
||||
binary_test!(cube_can_eq, equal_op_expand, "Equal", ref_ops_cmp);
|
||||
binary_test!(cube_can_ne, not_equal_op_expand, "NotEqual", ref_ops_cmp);
|
||||
binary_test!(cube_can_lt, lower_op_expand, "Lower", ref_ops_cmp);
|
||||
binary_test!(cube_can_add, add_op::__expand::<F32>, "Add", ref_ops_binary);
|
||||
binary_test!(cube_can_sub, sub_op::__expand::<F32>, "Sub", ref_ops_binary);
|
||||
binary_test!(cube_can_mul, mul_op::__expand::<F32>, "Mul", ref_ops_binary);
|
||||
binary_test!(cube_can_div, div_op::__expand::<F32>, "Div", ref_ops_binary);
|
||||
unary_test!(cube_can_abs, abs_op::__expand::<F32>, "Abs");
|
||||
unary_test!(cube_can_exp, exp_op::__expand::<F32>, "Exp");
|
||||
unary_test!(cube_can_log, log_op::__expand::<F32>, "Log");
|
||||
unary_test!(cube_can_log1p, log1p_op::__expand::<F32>, "Log1p");
|
||||
unary_test!(cube_can_cos, cos_op::__expand::<F32>, "Cos");
|
||||
unary_test!(cube_can_sin, sin_op::__expand::<F32>, "Sin");
|
||||
unary_test!(cube_can_tanh, tanh_op::__expand::<F32>, "Tanh");
|
||||
binary_test!(
|
||||
cube_can_powf,
|
||||
powf_op::__expand::<F32>,
|
||||
"Powf",
|
||||
ref_ops_binary
|
||||
);
|
||||
unary_test!(cube_can_sqrt, sqrt_op::__expand::<F32>, "Sqrt");
|
||||
unary_test!(cube_can_erf, erf_op::__expand::<F32>, "Erf");
|
||||
unary_test!(cube_can_recip, recip_op::__expand::<F32>, "Recip");
|
||||
unary_test!(cube_can_floor, floor_op::__expand::<F32>, "Floor");
|
||||
unary_test!(cube_can_ceil, ceil_op::__expand::<F32>, "Ceil");
|
||||
binary_test!(cube_can_eq, equal_op::__expand::<F32>, "Equal", ref_ops_cmp);
|
||||
binary_test!(
|
||||
cube_can_ne,
|
||||
not_equal_op::__expand::<F32>,
|
||||
"NotEqual",
|
||||
ref_ops_cmp
|
||||
);
|
||||
binary_test!(cube_can_lt, lower_op::__expand::<F32>, "Lower", ref_ops_cmp);
|
||||
binary_test!(
|
||||
cube_can_le,
|
||||
lower_equal_op_expand,
|
||||
lower_equal_op::__expand::<F32>,
|
||||
"LowerEqual",
|
||||
ref_ops_cmp
|
||||
);
|
||||
binary_test!(
|
||||
cube_can_ge,
|
||||
greater_equal_op_expand,
|
||||
greater_equal_op::__expand::<F32>,
|
||||
"GreaterEqual",
|
||||
ref_ops_cmp
|
||||
);
|
||||
binary_test!(cube_can_gt, greater_op_expand, "Greater", ref_ops_cmp);
|
||||
binary_test!(cube_can_max, max_op_expand, "Max", ref_ops_binary);
|
||||
binary_test!(cube_can_min, min_op_expand, "Min", ref_ops_binary);
|
||||
binary_test!(
|
||||
cube_can_gt,
|
||||
greater_op::__expand::<F32>,
|
||||
"Greater",
|
||||
ref_ops_cmp
|
||||
);
|
||||
binary_test!(cube_can_max, max_op::__expand::<F32>, "Max", ref_ops_binary);
|
||||
binary_test!(cube_can_min, min_op::__expand::<F32>, "Min", ref_ops_binary);
|
||||
binary_test!(
|
||||
cube_can_add_assign,
|
||||
add_assign_op_expand,
|
||||
add_assign_op::__expand::<F32>,
|
||||
"Add",
|
||||
ref_ops_binary
|
||||
);
|
||||
binary_test!(
|
||||
cube_can_sub_assign,
|
||||
sub_assign_op_expand,
|
||||
sub_assign_op::__expand::<F32>,
|
||||
"Sub",
|
||||
ref_ops_binary
|
||||
);
|
||||
binary_test!(
|
||||
cube_can_mul_assign,
|
||||
mul_assign_op_expand,
|
||||
mul_assign_op::__expand::<F32>,
|
||||
"Mul",
|
||||
ref_ops_binary
|
||||
);
|
||||
binary_test!(
|
||||
cube_can_div_assign,
|
||||
div_assign_op_expand,
|
||||
div_assign_op::__expand::<F32>,
|
||||
"Div",
|
||||
ref_ops_binary
|
||||
);
|
||||
binary_boolean_test!(cube_can_and, and_op_expand, "And");
|
||||
binary_boolean_test!(cube_can_or, or_op_expand, "Or");
|
||||
binary_uint_test!(cube_can_bitand, bitand_op_expand, "BitwiseAnd");
|
||||
binary_uint_test!(cube_can_bitxor, bitxor_op_expand, "BitwiseXor");
|
||||
binary_uint_test!(cube_can_shl, shl_op_expand, "ShiftLeft");
|
||||
binary_uint_test!(cube_can_shr, shr_op_expand, "ShiftRight");
|
||||
binary_uint_test!(cube_can_mod, modulo_op_expand, "Modulo");
|
||||
binary_boolean_test!(cube_can_and, and_op::__expand, "And");
|
||||
binary_boolean_test!(cube_can_or, or_op::__expand, "Or");
|
||||
binary_uint_test!(cube_can_bitand, bitand_op::__expand, "BitwiseAnd");
|
||||
binary_uint_test!(cube_can_bitxor, bitxor_op::__expand, "BitwiseXor");
|
||||
binary_uint_test!(cube_can_shl, shl_op::__expand, "ShiftLeft");
|
||||
binary_uint_test!(cube_can_shr, shr_op::__expand, "ShiftRight");
|
||||
binary_uint_test!(cube_can_mod, modulo_op::__expand, "Modulo");
|
||||
binary_test!(
|
||||
cube_can_rem,
|
||||
remainder_op_expand,
|
||||
remainder_op::__expand::<F32>,
|
||||
"Remainder",
|
||||
ref_ops_binary
|
||||
);
|
||||
|
@ -343,7 +358,7 @@ mod tests {
|
|||
let mut context = CubeContext::root();
|
||||
let x = context.create_local(Item::new(Elem::Bool));
|
||||
|
||||
not_op_expand(&mut context, x);
|
||||
not_op::__expand(&mut context, x);
|
||||
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
|
|
|
@ -22,7 +22,7 @@ mod tests {
|
|||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let z = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
parenthesis_expand::<ElemType>(&mut context, x, y, z);
|
||||
parenthesis::__expand::<ElemType>(&mut context, x, y, z);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref());
|
||||
|
|
|
@ -53,7 +53,7 @@ mod tests {
|
|||
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
redeclare_same_scope_expand::<ElemType>(&mut context, x);
|
||||
redeclare_same_scope::__expand::<ElemType>(&mut context, x);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -68,7 +68,7 @@ mod tests {
|
|||
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
redeclare_same_scope_other_type_expand::<ElemType, F32>(&mut context, x);
|
||||
redeclare_same_scope_other_type::__expand::<ElemType, F32>(&mut context, x);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -83,7 +83,7 @@ mod tests {
|
|||
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
redeclare_different_scope_expand::<ElemType>(&mut context, x);
|
||||
redeclare_different_scope::__expand::<ElemType>(&mut context, x);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -98,7 +98,7 @@ mod tests {
|
|||
|
||||
let x = context.create_local(Item::new(UInt::as_elem()));
|
||||
|
||||
redeclare_two_for_loops_expand(&mut context, x);
|
||||
redeclare_two_for_loops::__expand(&mut context, x);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
@ -32,7 +32,7 @@ mod tests {
|
|||
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
reuse_expand::<ElemType>(&mut context, x);
|
||||
reuse::__expand::<ElemType>(&mut context, x);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_assign());
|
||||
|
@ -44,7 +44,7 @@ mod tests {
|
|||
|
||||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
reuse_incr_expand::<ElemType>(&mut context, x);
|
||||
reuse_incr::__expand::<ElemType>(&mut context, x);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_incr());
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn shared_memory_read_write<T: Numeric>(sm_size: Comptime<u32>) {
|
||||
pub fn shared_memory_read_write<T: Numeric>(sm_size: Comptime<u32>) {
|
||||
let mut shared = SharedMemory::<T>::new(sm_size);
|
||||
shared[0] = T::from_int(3);
|
||||
let _ = shared[0];
|
||||
|
@ -20,7 +20,7 @@ mod tests {
|
|||
fn cube_support_shared_memory() {
|
||||
let mut context = CubeContext::root();
|
||||
|
||||
shared_memory_read_write_expand::<ElemType>(&mut context, 512);
|
||||
shared_memory_read_write::__expand::<ElemType>(&mut context, 512);
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
inline_macro_ref()
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[derive(CubeType)]
|
||||
struct State<T: Numeric> {
|
||||
pub struct State<T: Numeric> {
|
||||
first: T,
|
||||
second: T,
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn state_receiver_with_reuse<T: Numeric>(state: State<T>) -> T {
|
||||
pub fn state_receiver_with_reuse<T: Numeric>(state: State<T>) -> T {
|
||||
let x = state.first + state.second;
|
||||
state.second + x + state.first
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn attribute_modifier_reuse_field<T: Numeric>(mut state: State<T>) -> T {
|
||||
pub fn attribute_modifier_reuse_field<T: Numeric>(mut state: State<T>) -> T {
|
||||
state.first = T::from_int(4);
|
||||
state.first
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn attribute_modifier_reuse_struct<T: Numeric>(mut state: State<T>) -> State<T> {
|
||||
pub fn attribute_modifier_reuse_struct<T: Numeric>(mut state: State<T>) -> State<T> {
|
||||
state.first = T::from_int(4);
|
||||
state
|
||||
}
|
||||
|
@ -48,7 +48,7 @@ mod tests {
|
|||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
creator_expand::<ElemType>(&mut context, x, y);
|
||||
creator::__expand::<ElemType>(&mut context, x, y);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -68,7 +68,7 @@ mod tests {
|
|||
first: x,
|
||||
second: y,
|
||||
};
|
||||
state_receiver_with_reuse_expand::<ElemType>(&mut context, expanded_state);
|
||||
state_receiver_with_reuse::__expand::<ElemType>(&mut context, expanded_state);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -88,7 +88,7 @@ mod tests {
|
|||
first: x,
|
||||
second: y,
|
||||
};
|
||||
attribute_modifier_reuse_field_expand::<ElemType>(&mut context, expanded_state);
|
||||
attribute_modifier_reuse_field::__expand::<ElemType>(&mut context, expanded_state);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -108,7 +108,7 @@ mod tests {
|
|||
first: x,
|
||||
second: y,
|
||||
};
|
||||
attribute_modifier_reuse_struct_expand::<ElemType>(&mut context, expanded_state);
|
||||
attribute_modifier_reuse_struct::__expand::<ElemType>(&mut context, expanded_state);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn kernel<T: Numeric>(input: &Tensor<T>) {
|
||||
pub fn kernel<T: Numeric>(input: &Tensor<T>) {
|
||||
let _shape = input.shape(1);
|
||||
let _stride = input.stride(1);
|
||||
let _length = input.len();
|
||||
|
@ -21,7 +21,7 @@ mod tests {
|
|||
let mut context = CubeContext::root();
|
||||
let input = context.input(0, Item::new(ElemType::as_elem()));
|
||||
|
||||
kernel_expand::<ElemType>(&mut context, input.into());
|
||||
kernel::__expand::<ElemType>(&mut context, input.into());
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
inline_macro_ref()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn topology_kernel<T: Numeric>(input: Tensor<T>) {
|
||||
pub fn topology_kernel<T: Numeric>(input: Tensor<T>) {
|
||||
let x = ABSOLUTE_POS + UInt::new(4);
|
||||
let _ = input[x];
|
||||
}
|
||||
|
@ -20,7 +20,7 @@ mod tests {
|
|||
let mut context = CubeContext::root();
|
||||
let input = context.input(0, Item::new(ElemType::as_elem()));
|
||||
|
||||
topology_kernel_expand::<ElemType>(&mut context, input.into());
|
||||
topology_kernel::__expand::<ElemType>(&mut context, input.into());
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
inline_macro_ref()
|
||||
|
|
|
@ -4,7 +4,7 @@ use burn_cube::prelude::*;
|
|||
/// for all their methods. However, one does not need to provide its
|
||||
/// implementation, see examples below.
|
||||
#[cube]
|
||||
trait Strategy<T: Numeric> {
|
||||
pub trait Strategy<T: Numeric> {
|
||||
fn operation(input_1: T, input_2: T) -> T;
|
||||
}
|
||||
|
||||
|
@ -13,7 +13,7 @@ struct AddStrategy;
|
|||
#[cube]
|
||||
/// The actual implementation of AddStrategy's operation
|
||||
/// Automatically generated an _expand variant
|
||||
fn add_strategy_operation<T: Numeric>(input_1: T, input_2: T) -> T {
|
||||
pub fn add_strategy_operation<T: Numeric>(input_1: T, input_2: T) -> T {
|
||||
input_1 + input_2
|
||||
}
|
||||
|
||||
|
@ -34,19 +34,19 @@ impl<T: Numeric> Strategy<T> for SubStrategy {
|
|||
}
|
||||
|
||||
#[cube]
|
||||
fn with_strategy_trait<S: Strategy<T>, T: Numeric>(x: T, y: T) -> T {
|
||||
pub fn with_strategy_trait<S: Strategy<T>, T: Numeric>(x: T, y: T) -> T {
|
||||
S::operation(x, y)
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn two_strategy_traits<S1: Strategy<F>, S2: Strategy<F>, F: Float>(x: F, y: F) -> F {
|
||||
pub fn two_strategy_traits<S1: Strategy<F>, S2: Strategy<F>, F: Float>(x: F, y: F) -> F {
|
||||
let z = S1::operation(x, y);
|
||||
S2::operation(z, y)
|
||||
}
|
||||
|
||||
trait MethodTypedStrategy {
|
||||
pub trait MethodTypedStrategy {
|
||||
fn operation<T: Numeric>(input_1: T, input_2: T) -> T;
|
||||
fn operation_expand<T: Numeric>(
|
||||
fn __expand_operation<T: Numeric>(
|
||||
_context: &mut CubeContext,
|
||||
input_1: <T as CubeType>::ExpandType,
|
||||
input_2: <T as CubeType>::ExpandType,
|
||||
|
@ -58,17 +58,17 @@ impl MethodTypedStrategy for AddStrategy {
|
|||
add_strategy_operation(input_1, input_2)
|
||||
}
|
||||
|
||||
fn operation_expand<T: Numeric>(
|
||||
fn __expand_operation<T: Numeric>(
|
||||
context: &mut CubeContext,
|
||||
input_1: <T as CubeType>::ExpandType,
|
||||
input_2: <T as CubeType>::ExpandType,
|
||||
) -> <T as CubeType>::ExpandType {
|
||||
add_strategy_operation_expand::<T>(context, input_1, input_2)
|
||||
add_strategy_operation::__expand::<T>(context, input_1, input_2)
|
||||
}
|
||||
}
|
||||
|
||||
#[cube]
|
||||
fn with_trait_generic_method<S: MethodTypedStrategy, T: Numeric>(x: T, y: T) -> T {
|
||||
pub fn with_trait_generic_method<S: MethodTypedStrategy, T: Numeric>(x: T, y: T) -> T {
|
||||
S::operation::<T>(x, y)
|
||||
}
|
||||
|
||||
|
@ -87,7 +87,7 @@ mod tests {
|
|||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
with_strategy_trait_expand::<AddStrategy, ElemType>(&mut context, x, y);
|
||||
with_strategy_trait::__expand::<AddStrategy, ElemType>(&mut context, x, y);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -103,7 +103,7 @@ mod tests {
|
|||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
with_strategy_trait_expand::<SubStrategy, ElemType>(&mut context, x, y);
|
||||
with_strategy_trait::__expand::<SubStrategy, ElemType>(&mut context, x, y);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
@ -119,7 +119,7 @@ mod tests {
|
|||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
two_strategy_traits_expand::<SubStrategy, AddStrategy, ElemType>(&mut context, x, y);
|
||||
two_strategy_traits::__expand::<SubStrategy, AddStrategy, ElemType>(&mut context, x, y);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(format!("{:?}", scope.operations), inline_macro_ref_two());
|
||||
|
@ -132,7 +132,7 @@ mod tests {
|
|||
let x = context.create_local(Item::new(ElemType::as_elem()));
|
||||
let y = context.create_local(Item::new(ElemType::as_elem()));
|
||||
|
||||
with_trait_generic_method_expand::<AddStrategy, ElemType>(&mut context, x, y);
|
||||
with_trait_generic_method::__expand::<AddStrategy, ElemType>(&mut context, x, y);
|
||||
let scope = context.into_scope();
|
||||
|
||||
assert_eq!(
|
||||
|
|
|
@ -22,7 +22,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2));
|
||||
|
||||
vectorization_binary_expand::<ElemType>(&mut context, lhs);
|
||||
vectorization_binary::__expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -32,7 +32,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4));
|
||||
|
||||
vectorization_binary_expand::<ElemType>(&mut context, lhs);
|
||||
vectorization_binary::__expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -41,7 +41,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 2));
|
||||
|
||||
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
|
||||
vectorization_cmp::__expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -51,7 +51,7 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 4));
|
||||
|
||||
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
|
||||
vectorization_cmp::__expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -60,6 +60,6 @@ mod tests {
|
|||
|
||||
let lhs = context.create_local(Item::vectorized(ElemType::as_elem(), 1));
|
||||
|
||||
vectorization_cmp_expand::<ElemType>(&mut context, lhs);
|
||||
vectorization_cmp::__expand::<ElemType>(&mut context, lhs);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,23 +3,23 @@ use burn_cube::prelude::*;
|
|||
use crate::kernel::{launch_unary, UnaryOp};
|
||||
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||
|
||||
#[derive(CubeLaunch)]
|
||||
struct Options<C: Numeric> {
|
||||
min_value: C,
|
||||
max_value: C,
|
||||
}
|
||||
|
||||
pub(crate) fn clamp<R: JitRuntime, E: JitElement, const D: usize>(
|
||||
input: JitTensor<R, E, D>,
|
||||
min_value: E,
|
||||
max_value: E,
|
||||
) -> JitTensor<R, E, D> {
|
||||
#[derive(CubeLaunch)]
|
||||
struct Options<C: Numeric> {
|
||||
min_value: C,
|
||||
max_value: C,
|
||||
}
|
||||
|
||||
struct ClampOp;
|
||||
|
||||
impl<C: Numeric> UnaryOp<C> for ClampOp {
|
||||
type Options = Options<C>;
|
||||
|
||||
fn execute_expand(
|
||||
fn __expand_execute(
|
||||
context: &mut CubeContext,
|
||||
input: C::ExpandType,
|
||||
options: OptionsExpand<C>,
|
||||
|
@ -29,7 +29,7 @@ pub(crate) fn clamp<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
C::clamp(input, options.min_value, options.max_value)
|
||||
}
|
||||
|
||||
execute_expand(context, input, options)
|
||||
execute::__expand(context, input, options)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
|
||||
use super::{index_offset_with_layout, Kernel};
|
||||
use crate::{element::JitElement, tensor::JitTensor, JitRuntime};
|
||||
use burn_cube::{
|
||||
calculate_cube_count_elemwise, prelude::*, tensor_vectorization_factor, Runtime,
|
||||
|
@ -146,7 +146,7 @@ pub(crate) fn launch_cmp<
|
|||
|
||||
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
|
||||
if same_tensor_type && lhs.can_mut_broadcast(&rhs) {
|
||||
kernel_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -170,7 +170,7 @@ pub(crate) fn launch_cmp<
|
|||
|
||||
JitTensor::new(lhs.client, lhs.handle, lhs.shape, lhs.device, lhs.strides)
|
||||
} else if same_tensor_type && rhs.can_mut_broadcast(&lhs) {
|
||||
kernel_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -199,7 +199,7 @@ pub(crate) fn launch_cmp<
|
|||
let to_contiguous_rhs = !rhs.is_contiguous();
|
||||
let output = JitTensor::new_contiguous(lhs.client.clone(), lhs.device, shape_out, buffer);
|
||||
|
||||
kernel_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -251,7 +251,7 @@ pub(crate) fn launch_scalar_cmp<
|
|||
|
||||
let same_tensor_type = core::any::TypeId::of::<E>() == core::any::TypeId::of::<UInt>();
|
||||
if same_tensor_type && tensor.can_mut() {
|
||||
kernel_scalar_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -282,7 +282,7 @@ pub(crate) fn launch_scalar_cmp<
|
|||
tensor.strides,
|
||||
);
|
||||
|
||||
kernel_scalar_cmp_launch::<E::Primitive, O, R>(
|
||||
kernel_scalar_cmp::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
|
|
@ -87,7 +87,7 @@ pub fn into_contiguous<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
SUBCUBE_DIM_APPROX,
|
||||
);
|
||||
|
||||
into_contiguous_kernel_launch::<E::Primitive, R>(
|
||||
into_contiguous_kernel::launch::<E::Primitive, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
|
|
@ -163,7 +163,7 @@ pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
|||
let num_elems_output = output.shape.num_elements();
|
||||
let cube_dim = calculate_cube_count_elemwise(num_elems_output, SUBCUBE_DIM_APPROX);
|
||||
|
||||
conv2d_kernel_launch::<E::FloatPrimitive, R>(
|
||||
conv2d_kernel::launch::<E::FloatPrimitive, R>(
|
||||
input.client,
|
||||
cube_dim,
|
||||
CubeDim::default(),
|
||||
|
|
|
@ -188,7 +188,7 @@ pub(crate) fn conv3d<R: JitRuntime, E: FloatElement>(
|
|||
}
|
||||
};
|
||||
|
||||
conv3d_kernel_launch::<E::FloatPrimitive, R>(
|
||||
conv3d_kernel::launch::<E::FloatPrimitive, R>(
|
||||
input.client,
|
||||
calculate_cube_count_elemwise(output.shape.num_elements(), SUBCUBE_DIM_APPROX),
|
||||
CubeDim::default(),
|
||||
|
|
|
@ -116,7 +116,7 @@ pub fn matmul_simple<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
false => 1,
|
||||
};
|
||||
|
||||
matmul_kernel_launch::<E::FloatPrimitive, R>(
|
||||
matmul_kernel::launch::<E::FloatPrimitive, R>(
|
||||
lhs.client,
|
||||
cube_count,
|
||||
CubeDim::new(cube_dim_x as u32, cube_dim_y as u32, 1),
|
||||
|
|
|
@ -2,11 +2,11 @@ use burn_cube::prelude::*;
|
|||
|
||||
use crate::kernel::matmul::config::CubeTiling2dConfig;
|
||||
|
||||
use super::block_loop::{block_loop, block_loop_expand};
|
||||
use super::block_loop::block_loop;
|
||||
|
||||
#[cube(launch)]
|
||||
#[allow(unused_mut)]
|
||||
fn tiling2d_cube<F: Float>(
|
||||
pub fn tiling2d_cube_kernel<F: Float>(
|
||||
lhs: &Tensor<F>,
|
||||
rhs: &Tensor<F>,
|
||||
out: &mut Tensor<F>,
|
||||
|
|
|
@ -4,10 +4,10 @@ use crate::kernel::matmul::config::CubeTiling2dConfig;
|
|||
|
||||
use super::{
|
||||
base::{BatchOffsets, Coordinates, Dimensions, SharedMemories},
|
||||
compute_loop::{compute_loop, compute_loop_expand},
|
||||
load_shared_memory::{load_to_shared_memories, load_to_shared_memories_expand},
|
||||
compute_loop::compute_loop,
|
||||
load_shared_memory::load_to_shared_memories,
|
||||
tile::{loader::TileLoader, writer::TileWriter},
|
||||
write_output::{write_to_output, write_to_output_expand},
|
||||
write_output::write_to_output,
|
||||
};
|
||||
|
||||
#[cube]
|
||||
|
|
|
@ -2,10 +2,7 @@ use burn_cube::prelude::*;
|
|||
|
||||
use crate::kernel::matmul::config::CubeTiling2dConfig;
|
||||
|
||||
use super::{
|
||||
base::Coordinates,
|
||||
outer_product::{tile_outer_product, tile_outer_product_expand},
|
||||
};
|
||||
use super::{base::Coordinates, outer_product::tile_outer_product};
|
||||
|
||||
#[cube]
|
||||
#[allow(unused_mut)]
|
||||
|
@ -105,7 +102,7 @@ pub mod tests {
|
|||
const SOME_DIM: usize = 12;
|
||||
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
|
||||
|
||||
compute_loop_test_launch::<F32, R>(
|
||||
compute_loop_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -134,7 +131,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(4, 8, 4);
|
||||
|
||||
compute_loop_test_launch::<F32, R>(
|
||||
compute_loop_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -7,6 +7,7 @@ use crate::{
|
|||
into_contiguous,
|
||||
matmul::{
|
||||
config::{tiling2d_cube_count, tiling2d_cube_dim, CubeTiling2dConfig},
|
||||
tiling2d_cube::base::tiling2d_cube_kernel,
|
||||
Tiling2dConfig,
|
||||
},
|
||||
},
|
||||
|
@ -14,8 +15,6 @@ use crate::{
|
|||
FloatElement, JitRuntime,
|
||||
};
|
||||
|
||||
use super::base::tiling2d_cube_launch;
|
||||
|
||||
/// Matrix multiplication using tiling 2d algorithm
|
||||
pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
|
@ -69,7 +68,7 @@ pub fn matmul_tiling_2d_cube<R: JitRuntime, E: FloatElement, const D: usize>(
|
|||
let cube_dim = tiling2d_cube_dim(&config);
|
||||
let cube_config = CubeTiling2dConfig::new(&config, m, k, n, lhs_transposed, rhs_transposed);
|
||||
|
||||
tiling2d_cube_launch::<E::FloatPrimitive, R>(
|
||||
tiling2d_cube_kernel::launch::<E::FloatPrimitive, R>(
|
||||
client,
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -385,7 +385,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_test_launch::<F32, R>(
|
||||
load_tensor_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -417,7 +417,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(5, 1, 1);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -451,7 +451,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -481,7 +481,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 16);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -511,7 +511,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 16, 16);
|
||||
|
||||
load_tensor_test_launch::<F32, R>(
|
||||
load_tensor_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -543,7 +543,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -573,7 +573,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_multiple_tiles_test_launch::<F32, R>(
|
||||
load_tensor_multiple_tiles_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -603,7 +603,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
load_tensor_permuted_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -636,7 +636,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(m, k, 8);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
load_tensor_permuted_test::launch::<F32, R>(
|
||||
lhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -667,7 +667,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(16, 16, 8);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
load_tensor_permuted_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -699,7 +699,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, k, n);
|
||||
|
||||
load_tensor_permuted_test_launch::<F32, R>(
|
||||
load_tensor_permuted_test::launch::<F32, R>(
|
||||
rhs.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -70,7 +70,7 @@ pub mod tests {
|
|||
const SOME_DIM: usize = 12;
|
||||
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
|
||||
|
||||
tile_outer_product_test_launch::<F32, R>(
|
||||
tile_outer_product_test::launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -99,7 +99,7 @@ pub mod tests {
|
|||
const SOME_DIM: usize = 12;
|
||||
let config = make_config(SOME_DIM, SOME_DIM, SOME_DIM);
|
||||
|
||||
tile_outer_product_test_launch::<F32, R>(
|
||||
tile_outer_product_test::launch::<F32, R>(
|
||||
client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -14,10 +14,7 @@ use crate::kernel::matmul::{
|
|||
},
|
||||
};
|
||||
|
||||
use super::base::{
|
||||
all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand,
|
||||
BlockLoader, BlockWriter,
|
||||
};
|
||||
use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter};
|
||||
|
||||
pub(crate) struct HorizontalCheckBlockIO;
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ use crate::kernel::matmul::{
|
|||
},
|
||||
};
|
||||
|
||||
use super::base::{all_zeros_runtime, all_zeros_runtime_expand, BlockLoader, BlockWriter};
|
||||
use super::base::{all_zeros_runtime, BlockLoader, BlockWriter};
|
||||
|
||||
pub(crate) struct VerticalCheckBlockIO;
|
||||
|
||||
|
|
|
@ -14,10 +14,7 @@ use crate::kernel::matmul::{
|
|||
},
|
||||
};
|
||||
|
||||
use super::base::{
|
||||
all_zeros_comptime, all_zeros_comptime_expand, all_zeros_runtime, all_zeros_runtime_expand,
|
||||
BlockLoader, BlockWriter,
|
||||
};
|
||||
use super::base::{all_zeros_comptime, all_zeros_runtime, BlockLoader, BlockWriter};
|
||||
|
||||
pub(crate) struct WholeCheckBlockIO;
|
||||
|
||||
|
|
|
@ -130,7 +130,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(6, 8, 8);
|
||||
|
||||
write_to_output_test_launch::<F32, R>(
|
||||
write_to_output_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -156,7 +156,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 4);
|
||||
|
||||
write_to_output_test_launch::<F32, R>(
|
||||
write_to_output_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -182,7 +182,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
write_to_output_test_launch::<F32, R>(
|
||||
write_to_output_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -215,7 +215,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(8, 8, 8);
|
||||
|
||||
write_to_output_test_launch::<F32, R>(
|
||||
write_to_output_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
@ -248,7 +248,7 @@ pub mod tests {
|
|||
|
||||
let config = make_config(5, 8, 1);
|
||||
|
||||
write_results_to_output_out_of_bounds_test_launch::<F32, R>(
|
||||
write_results_to_output_out_of_bounds_test::launch::<F32, R>(
|
||||
out.client.clone(),
|
||||
cube_count,
|
||||
cube_dim,
|
||||
|
|
|
@ -11,7 +11,7 @@ pub use binary::*;
|
|||
pub use cast::*;
|
||||
pub use contiguous::*;
|
||||
pub use mask::*;
|
||||
pub use unary::*;
|
||||
pub(crate) use unary::*;
|
||||
|
||||
pub use burn_cube::{Kernel, SUBCUBE_DIM_APPROX};
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ use burn_cube::{
|
|||
SUBCUBE_DIM_APPROX,
|
||||
};
|
||||
|
||||
use super::{index_offset_with_layout, index_offset_with_layout_expand, Kernel};
|
||||
use super::{index_offset_with_layout, Kernel};
|
||||
|
||||
pub(crate) trait UnaryOp<C: CubePrimitive>: 'static + Send + Sync {
|
||||
type Options: LaunchArg;
|
||||
|
@ -13,7 +13,7 @@ pub(crate) trait UnaryOp<C: CubePrimitive>: 'static + Send + Sync {
|
|||
fn execute(_input: C, _options: &Self::Options) -> C {
|
||||
unexpanded!();
|
||||
}
|
||||
fn execute_expand(
|
||||
fn __expand_execute(
|
||||
context: &mut CubeContext,
|
||||
input: C::ExpandType,
|
||||
options: <Self::Options as CubeType>::ExpandType,
|
||||
|
@ -78,7 +78,7 @@ where
|
|||
let is_contiguous = tensor.is_contiguous();
|
||||
|
||||
if tensor.can_mut() && is_contiguous {
|
||||
unary_kernel_launch::<E::Primitive, O, R>(
|
||||
unary_kernel::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -104,7 +104,7 @@ where
|
|||
buffer,
|
||||
);
|
||||
|
||||
unary_kernel_launch::<E::Primitive, O, R>(
|
||||
unary_kernel::launch::<E::Primitive, O, R>(
|
||||
client,
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -136,7 +136,7 @@ macro_rules! unary_op {
|
|||
type Options = ();
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
fn execute_expand(
|
||||
fn __expand_execute(
|
||||
context: &mut CubeContext,
|
||||
input: C::ExpandType,
|
||||
_options: <Self::Options as CubeType>::ExpandType,
|
||||
|
@ -152,7 +152,7 @@ macro_rules! unary_op {
|
|||
type Options = C;
|
||||
|
||||
#[allow(clippy::redundant_closure_call)]
|
||||
fn execute_expand(
|
||||
fn __expand_execute(
|
||||
context: &mut CubeContext,
|
||||
input: C::ExpandType,
|
||||
scalar: C::ExpandType,
|
||||
|
|
|
@ -334,7 +334,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::exp(input)
|
||||
}
|
||||
execute_expand::<C>(context, input)
|
||||
execute::__expand::<C>(context, input)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -344,7 +344,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::log(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -354,7 +354,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::log1p(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -367,7 +367,7 @@ where
|
|||
fn execute<C: Float>(input: C, scalar: C) -> C {
|
||||
C::powf(input, scalar)
|
||||
}
|
||||
execute_expand::<C>(context, tensor, scalar)
|
||||
execute::__expand::<C>(context, tensor, scalar)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -377,7 +377,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::sqrt(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -387,7 +387,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::abs(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -397,7 +397,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::cos(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -407,7 +407,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::sin(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -417,7 +417,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::tanh(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -427,7 +427,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::erf(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -463,7 +463,7 @@ where
|
|||
fn execute<C: Float>(input: C) -> C {
|
||||
C::recip(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -299,7 +299,7 @@ where
|
|||
fn execute<C: Numeric>(input: C) -> C {
|
||||
C::abs(input)
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
let empty = empty_device(client, device, shape);
|
||||
|
||||
#[cube(launch)]
|
||||
pub(crate) fn full_kernel<C: Numeric + Vectorized>(tensor: &mut Tensor<C>, value: C) {
|
||||
pub fn full_kernel<C: Numeric + Vectorized>(tensor: &mut Tensor<C>, value: C) {
|
||||
if ABSOLUTE_POS >= tensor.len() {
|
||||
return;
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ pub fn full_device<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
SUBCUBE_DIM_APPROX,
|
||||
);
|
||||
|
||||
full_kernel_launch::<E::Primitive, R>(
|
||||
full_kernel::launch::<E::Primitive, R>(
|
||||
empty.client.clone(),
|
||||
cube_count,
|
||||
CubeDim::default(),
|
||||
|
@ -127,7 +127,7 @@ pub fn add_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs + rhs
|
||||
}
|
||||
execute_expand::<C>(context, lhs, rhs)
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -156,7 +156,7 @@ pub fn sub_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs - rhs
|
||||
}
|
||||
execute_expand::<C>(context, lhs, rhs)
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -185,7 +185,7 @@ pub fn mul_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs * rhs
|
||||
}
|
||||
execute_expand::<C>(context, lhs, rhs)
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -214,7 +214,7 @@ pub fn div_scalar<R: JitRuntime, E: JitElement, const D: usize>(
|
|||
fn execute<C: Numeric>(lhs: C, rhs: C) -> C {
|
||||
lhs / rhs
|
||||
}
|
||||
execute_expand::<C>(context, lhs, rhs)
|
||||
execute::__expand::<C>(context, lhs, rhs)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -146,7 +146,7 @@ where
|
|||
fn execute<C: Numeric>(input: C) -> C {
|
||||
input
|
||||
}
|
||||
execute_expand::<C>(context, tensor)
|
||||
execute::__expand::<C>(context, tensor)
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ pub fn launch<R: Runtime>(device: &R::Device) {
|
|||
let input_handle = client.create(f32::as_bytes(input));
|
||||
let output_handle = client.empty(input.len() * core::mem::size_of::<f32>());
|
||||
|
||||
gelu_launch::<F32, R>(
|
||||
gelu::launch::<F32, R>(
|
||||
client.clone(),
|
||||
CubeCount::Static(1, 1, 1),
|
||||
CubeDim::default(),
|
||||
|
|
Loading…
Reference in New Issue