mirror of https://github.com/tracel-ai/burn.git
Cube: cleaner use of topology values (#1835)
* constant keyword parsing * works
This commit is contained in:
parent
a2ad424fc8
commit
61c9fdbbc8
|
@ -4,6 +4,8 @@ use syn::{PathArguments, Stmt};
|
|||
|
||||
use crate::VariableKey;
|
||||
|
||||
pub const KEYWORDS: [&str; 1] = ["ABSOLUTE_INDEX"];
|
||||
|
||||
#[derive(Debug)]
|
||||
/// Information about a single variable's use in Cube code
|
||||
/// Information about a single variable's use in Cube code
|
||||
|
@ -200,8 +202,10 @@ impl CodeAnalysisBuilder {
|
|||
.get_ident()
|
||||
.expect("Analysis: only ident path are supported.");
|
||||
|
||||
// Use
|
||||
self.var_uses.push(ident.into());
|
||||
if !KEYWORDS.contains(&ident.to_string().as_str()) {
|
||||
// Use
|
||||
self.var_uses.push(ident.into());
|
||||
}
|
||||
}
|
||||
syn::Expr::Binary(expr) => {
|
||||
self.find_occurrences_in_expr(&expr.left, depth);
|
||||
|
|
|
@ -2,7 +2,10 @@ use proc_macro2::TokenStream;
|
|||
use quote::ToTokens;
|
||||
use syn::Lit;
|
||||
|
||||
use crate::{analysis::CodeAnalysis, codegen::base::codegen_expr};
|
||||
use crate::{
|
||||
analysis::{CodeAnalysis, KEYWORDS},
|
||||
codegen::base::codegen_expr,
|
||||
};
|
||||
|
||||
/// Codegen for literals
|
||||
pub(crate) fn codegen_lit(lit: &syn::ExprLit) -> TokenStream {
|
||||
|
@ -140,15 +143,21 @@ pub(crate) fn codegen_path_rhs(
|
|||
.get_ident()
|
||||
.expect("Codegen: Only ident path are supported.");
|
||||
|
||||
let will_be_used_again = variable_analyses.should_clone(ident, loop_level);
|
||||
|
||||
if will_be_used_again {
|
||||
if KEYWORDS.contains(&ident.to_string().as_str()) {
|
||||
quote::quote! {
|
||||
#ident.clone()
|
||||
#ident :: expand(context)
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
#ident
|
||||
let will_be_used_again = variable_analyses.should_clone(ident, loop_level);
|
||||
|
||||
if will_be_used_again {
|
||||
quote::quote! {
|
||||
#ident.clone()
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
#ident
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -49,7 +49,7 @@ impl<R: Runtime> ArgSettings<R> for u32 {
|
|||
impl Numeric for UInt {}
|
||||
|
||||
impl UInt {
|
||||
pub fn new(val: u32) -> Self {
|
||||
pub const fn new(val: u32) -> Self {
|
||||
Self {
|
||||
val,
|
||||
vectorization: 1,
|
||||
|
|
|
@ -1,20 +1,17 @@
|
|||
use crate::{unexpanded, CubeContext, CubeType, ExpandElement, UInt};
|
||||
use crate::UInt;
|
||||
|
||||
/// In this file we use a trick where the constant has the same name as the module containing
|
||||
/// the expand function, so that a user implicitly imports the expand function when importing the constant.
|
||||
|
||||
/// The index of the working unit in the whole cube kernel, without regards to blocks.
|
||||
pub struct AbsoluteIndex {}
|
||||
pub const ABSOLUTE_INDEX: UInt = UInt::new(0u32);
|
||||
|
||||
impl AbsoluteIndex {
|
||||
/// Obtain the absolute index
|
||||
pub fn get() -> UInt {
|
||||
unexpanded!();
|
||||
}
|
||||
#[allow(non_snake_case)]
|
||||
pub mod ABSOLUTE_INDEX {
|
||||
use crate::{CubeContext, ExpandElement};
|
||||
|
||||
/// Obtain the absolute index
|
||||
pub fn get_expand(_context: &mut CubeContext) -> ExpandElement {
|
||||
/// Expanded version of ABSOLUTE_INDEX
|
||||
pub fn expand(_context: &mut CubeContext) -> ExpandElement {
|
||||
ExpandElement::Plain(crate::dialect::Variable::Id)
|
||||
}
|
||||
}
|
||||
|
||||
impl CubeType for AbsoluteIndex {
|
||||
type ExpandType = ExpandElement;
|
||||
}
|
||||
|
|
|
@ -13,5 +13,6 @@ mod parenthesis;
|
|||
mod reuse;
|
||||
mod shared_memory;
|
||||
mod tensor;
|
||||
mod topology;
|
||||
mod r#trait;
|
||||
mod vectorization;
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
use burn_cube::{cube, Numeric, Tensor, UInt, ABSOLUTE_INDEX};
|
||||
|
||||
#[cube]
|
||||
fn topology_kernel<T: Numeric>(input: Tensor<T>) {
|
||||
let x = ABSOLUTE_INDEX + UInt::new(4);
|
||||
let _ = input[x];
|
||||
}
|
||||
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_cube::{
|
||||
cpa,
|
||||
dialect::{Elem, Item, Variable},
|
||||
CubeContext, CubeElem, F32,
|
||||
};
|
||||
|
||||
type ElemType = F32;
|
||||
|
||||
#[test]
|
||||
fn cube_support_topology() {
|
||||
let mut context = CubeContext::root();
|
||||
let input = context.input(0, Item::new(ElemType::as_elem()));
|
||||
|
||||
topology_kernel_expand::<ElemType>(&mut context, input);
|
||||
assert_eq!(
|
||||
format!("{:?}", context.into_scope().operations),
|
||||
inline_macro_ref()
|
||||
);
|
||||
}
|
||||
|
||||
fn inline_macro_ref() -> String {
|
||||
let mut context = CubeContext::root();
|
||||
let item = Item::new(ElemType::as_elem());
|
||||
let input = context.input(0, item);
|
||||
|
||||
let mut scope = context.into_scope();
|
||||
let input: Variable = input.into();
|
||||
let x = scope.create_local(Item::new(Elem::UInt));
|
||||
let y = scope.create_local(item);
|
||||
|
||||
let id = Variable::Id;
|
||||
cpa!(&mut scope, x = id + 4u32);
|
||||
cpa!(&mut scope, y = input[x]);
|
||||
|
||||
format!("{:?}", scope.operations)
|
||||
}
|
||||
}
|
|
@ -31,7 +31,7 @@ fn kernel<F: Float>(
|
|||
kernel_size_0_unroll: Comptime<Option<UInt>>,
|
||||
kernel_size_1_unroll: Comptime<Option<UInt>>,
|
||||
) {
|
||||
if AbsoluteIndex::get() >= output.len() {
|
||||
if ABSOLUTE_INDEX >= output.len() {
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -42,10 +42,10 @@ fn kernel<F: Float>(
|
|||
let kernel_size_1 = Comptime::unwrap_or_else(kernel_size_1_unroll, || weight.shape(3));
|
||||
let unroll_1 = Comptime::is_some(kernel_size_1_unroll);
|
||||
|
||||
let b = AbsoluteIndex::get() / output.stride(0) % output.shape(0);
|
||||
let oc = AbsoluteIndex::get() / output.stride(1) % output.shape(1);
|
||||
let oh = AbsoluteIndex::get() / output.stride(2) % output.shape(2);
|
||||
let ow = AbsoluteIndex::get() / output.stride(3) % output.shape(3);
|
||||
let b = ABSOLUTE_INDEX / output.stride(0) % output.shape(0);
|
||||
let oc = ABSOLUTE_INDEX / output.stride(1) % output.shape(1);
|
||||
let oh = ABSOLUTE_INDEX / output.stride(2) % output.shape(2);
|
||||
let ow = ABSOLUTE_INDEX / output.stride(3) % output.shape(3);
|
||||
|
||||
let g = (weight.shape(0) + oc) % groups;
|
||||
let ic_start = in_channels * g;
|
||||
|
@ -107,7 +107,7 @@ fn kernel<F: Float>(
|
|||
}
|
||||
}
|
||||
|
||||
output[AbsoluteIndex::get()] = sum;
|
||||
output[ABSOLUTE_INDEX] = sum;
|
||||
}
|
||||
|
||||
pub(crate) fn conv2d<R: JitRuntime, E: FloatElement>(
|
||||
|
|
Loading…
Reference in New Issue