Cube: cleaner use of topology values (#1835)

* constant keyword parsing

* works
This commit is contained in:
Louis Fortier-Dubois 2024-05-29 09:08:10 -04:00 committed by GitHub
parent a2ad424fc8
commit 61c9fdbbc8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 87 additions and 29 deletions

View File

@ -4,6 +4,8 @@ use syn::{PathArguments, Stmt};
use crate::VariableKey;
pub const KEYWORDS: [&str; 1] = ["ABSOLUTE_INDEX"];
#[derive(Debug)]
/// Information about a single variable's use in Cube code
/// Information about a single variable's use in Cube code
@ -200,8 +202,10 @@ impl CodeAnalysisBuilder {
.get_ident()
.expect("Analysis: only ident path are supported.");
// Use
self.var_uses.push(ident.into());
if !KEYWORDS.contains(&ident.to_string().as_str()) {
// Use
self.var_uses.push(ident.into());
}
}
syn::Expr::Binary(expr) => {
self.find_occurrences_in_expr(&expr.left, depth);

View File

@ -2,7 +2,10 @@ use proc_macro2::TokenStream;
use quote::ToTokens;
use syn::Lit;
use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr};
use crate::{
analysis::{CodeAnalysis, KEYWORDS},
codegen::base::codegen_expr,
};
/// Codegen for literals
pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream {
@ -140,15 +143,21 @@ pub(crate) fn codegen_path_rhs(
.get_ident()
.expect("Codegen: Only ident path are supported.");
let will_be_used_again = variable_analyses.should_clone(ident, loop_level);
if will_be_used_again {
if KEYWORDS.contains(&ident.to_string().as_str()) {
quote::quote! {
#ident.clone()
#ident :: expand(context)
}
} else {
quote::quote! {
#ident
let will_be_used_again = variable_analyses.should_clone(ident, loop_level);
if will_be_used_again {
quote::quote! {
#ident.clone()
}
} else {
quote::quote! {
#ident
}
}
}
}

View File

@ -49,7 +49,7 @@ impl<R: Runtime> ArgSettings<R> for u32 {
impl Numeric for UInt {}
impl UInt {
pub fn new(val: u32) -> Self {
pub const fn new(val: u32) -> Self {
Self {
val,
vectorization: 1,

View File

@ -1,20 +1,17 @@
use crate::{unexpanded, CubeContext, CubeType, ExpandElement, UInt};
use crate::UInt;
/// In this file we use a trick where the constant has the same name as the module containing
/// the expand function, so that a user implicitly imports the expand function when importing the constant.
/// The index of the working unit in the whole cube kernel, without regards to blocks.
pub struct AbsoluteIndex {}
pub const ABSOLUTE_INDEX: UInt = UInt::new(0u32);
impl AbsoluteIndex {
/// Obtain the absolute index
pub fn get() -> UInt {
unexpanded!();
}
#[allow(non_snake_case)]
pub mod ABSOLUTE_INDEX {
use crate::{CubeContext, ExpandElement};
/// Obtain the absolute index
pub fn get_expand(_context: &mut CubeContext) -> ExpandElement {
/// Expanded version of ABSOLUTE_INDEX
pub fn expand(_context: &mut CubeContext) -> ExpandElement {
ExpandElement::Plain(crate::dialect::Variable::Id)
}
}
impl CubeType for AbsoluteIndex {
type ExpandType = ExpandElement;
}

View File

@ -13,5 +13,6 @@ mod parenthesis;
mod reuse;
mod shared_memory;
mod tensor;
mod topology;
mod r#trait;
mod vectorization;

View File

@ -0,0 +1,47 @@
use burn_cube::{cube, Numeric, Tensor, UInt, ABSOLUTE_INDEX};
#[cube]
fn topology_kernel<T: Numeric>(input: Tensor<T>) {
let x = ABSOLUTE_INDEX + UInt::new(4);
let _ = input[x];
}
mod tests {
use super::*;
use burn_cube::{
cpa,
dialect::{Elem, Item, Variable},
CubeContext, CubeElem, F32,
};
type ElemType = F32;
#[test]
fn cube_support_topology() {
let mut context = CubeContext::root();
let input = context.input(0, Item::new(ElemType::as_elem()));
topology_kernel_expand::<ElemType>(&mut context, input);
assert_eq!(
format!("{:?}", context.into_scope().operations),
inline_macro_ref()
);
}
fn inline_macro_ref() -> String {
let mut context = CubeContext::root();
let item = Item::new(ElemType::as_elem());
let input = context.input(0, item);
let mut scope = context.into_scope();
let input: Variable = input.into();
let x = scope.create_local(Item::new(Elem::UInt));
let y = scope.create_local(item);
let id = Variable::Id;
cpa!(&mut scope, x = id + 4u32);
cpa!(&mut scope, y = input[x]);
format!("{:?}", scope.operations)
}
}

View File

@ -31,7 +31,7 @@ fn kernel<F: Float>(
kernel_size_0_unroll: Comptime<Option<UInt>>,
kernel_size_1_unroll: Comptime<Option<UInt>>,
) {
if AbsoluteIndex::get() >= output.len() {
if ABSOLUTE_INDEX >= output.len() {
return;
}
@ -42,10 +42,10 @@ fn kernel<F: Float>(
let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3));
let unroll_1 = Comptime::is_some(kernel_size_1_unroll);
let b = AbsoluteIndex::get() / output.stride(0) % output.shape(0);
let oc = AbsoluteIndex::get() / output.stride(1) % output.shape(1);
let oh = AbsoluteIndex::get() / output.stride(2) % output.shape(2);
let ow = AbsoluteIndex::get() / output.stride(3) % output.shape(3);
let b = ABSOLUTE_INDEX / output.stride(0) % output.shape(0);
let oc = ABSOLUTE_INDEX / output.stride(1) % output.shape(1);
let oh = ABSOLUTE_INDEX / output.stride(2) % output.shape(2);
let ow = ABSOLUTE_INDEX / output.stride(3) % output.shape(3);
let g = (weight.shape(0) + oc) % groups;
let ic_start = in_channels * g;
@ -107,7 +107,7 @@ fn kernel<F: Float>(
}
}
output[AbsoluteIndex::get()] = sum;
output[ABSOLUTE_INDEX] = sum;
}
pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(