mirror of https://github.com/tracel-ai/burn.git
Feat/cube/compile error (#1909)
This commit is contained in:
parent
d50bac165e
commit
efc13d9a38
|
@ -536,6 +536,7 @@ dependencies = [
|
||||||
"log",
|
"log",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"serde",
|
"serde",
|
||||||
|
"trybuild",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
@ -5629,6 +5630,20 @@ version = "0.2.5"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "trybuild"
|
||||||
|
version = "1.0.96"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "33a5f13f11071020bb12de7a16b925d2d58636175c20c11dc5f96cb64bb6c9b3"
|
||||||
|
dependencies = [
|
||||||
|
"glob",
|
||||||
|
"serde",
|
||||||
|
"serde_derive",
|
||||||
|
"serde_json",
|
||||||
|
"termcolor",
|
||||||
|
"toml",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typenum"
|
name = "typenum"
|
||||||
version = "1.17.0"
|
version = "1.17.0"
|
||||||
|
|
|
@ -128,7 +128,7 @@ impl VariableAnalyzer {
|
||||||
if let syn::Expr::Block(expr_block) = &**expr {
|
if let syn::Expr::Block(expr_block) = &**expr {
|
||||||
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
|
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
|
||||||
} else {
|
} else {
|
||||||
todo!("Analysis: Only block else expr is supported")
|
// Unsupported: handled in codegen.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -190,17 +190,12 @@ impl VariableAnalyzer {
|
||||||
syn::Expr::Break(_) => {}
|
syn::Expr::Break(_) => {}
|
||||||
syn::Expr::Return(expr) => {
|
syn::Expr::Return(expr) => {
|
||||||
if expr.expr.is_some() {
|
if expr.expr.is_some() {
|
||||||
todo!("Analysis: only void return supported")
|
// Unsupported: handled in codegen.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||||
syn::Expr::Array(expr) => {
|
syn::Expr::Array(_expr) => {
|
||||||
for element in expr.elems.iter() {
|
// No analysis since only literals are supported
|
||||||
match element {
|
|
||||||
syn::Expr::Lit(_) => {}
|
|
||||||
_ => todo!("Analysis: only array of literals is supported"),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||||
syn::Expr::Closure(expr) => {
|
syn::Expr::Closure(expr) => {
|
||||||
|
@ -251,7 +246,12 @@ impl VariableAnalyzer {
|
||||||
self.find_occurrences_in_expr(&field.expr, depth)
|
self.find_occurrences_in_expr(&field.expr, depth)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => todo!("Analysis: unsupported expr {expr:?}"),
|
syn::Expr::Range(_range) => {
|
||||||
|
// Error is handled during codegen.
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
// Error is handled during codegen.
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -125,7 +125,14 @@ pub(crate) fn codegen_expr_with_comptime(
|
||||||
syn::Expr::Unary(op) => codegen_unary(op, loop_level, variable_tracker),
|
syn::Expr::Unary(op) => codegen_unary(op, loop_level, variable_tracker),
|
||||||
syn::Expr::Field(field) => codegen_field(field, loop_level, variable_tracker),
|
syn::Expr::Field(field) => codegen_field(field, loop_level, variable_tracker),
|
||||||
syn::Expr::Struct(struct_) => codegen_struct(struct_, loop_level, variable_tracker),
|
syn::Expr::Struct(struct_) => codegen_struct(struct_, loop_level, variable_tracker),
|
||||||
_ => panic!("Codegen: Unsupported {:?}", expr),
|
syn::Expr::Range(range) => syn::Error::new_spanned(
|
||||||
|
range,
|
||||||
|
"Range is not supported, use [range](cubecl::prelude::range) instead.",
|
||||||
|
)
|
||||||
|
.to_compile_error(),
|
||||||
|
_ => {
|
||||||
|
syn::Error::new_spanned(expr, "Expression is not supported").to_compile_error()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
(tokens, false)
|
(tokens, false)
|
||||||
|
|
|
@ -24,14 +24,24 @@ pub(crate) fn codegen_for_loop(
|
||||||
variable_tracker.codegen_declare(id.to_string(), loop_level as u8 + 1);
|
variable_tracker.codegen_declare(id.to_string(), loop_level as u8 + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let invalid_for_loop = || {
|
||||||
|
syn::Error::new_spanned(
|
||||||
|
&for_loop.expr,
|
||||||
|
"Invalid for loop: use [range](cubecl::prelude::range] instead.",
|
||||||
|
)
|
||||||
|
.into_compile_error()
|
||||||
|
};
|
||||||
|
|
||||||
match for_loop.expr.as_ref() {
|
match for_loop.expr.as_ref() {
|
||||||
syn::Expr::Call(call) => {
|
syn::Expr::Call(call) => {
|
||||||
let func_name = match call.func.as_ref() {
|
let func_name = match call.func.as_ref() {
|
||||||
syn::Expr::Path(path) => path
|
syn::Expr::Path(path) => match path.path.get_ident() {
|
||||||
.path
|
Some(ident) => ident,
|
||||||
.get_ident()
|
None => return invalid_for_loop(),
|
||||||
.expect("Codegen: func in for loop should have ident"),
|
},
|
||||||
_ => todo!("Codegen: Only path call supported"),
|
_ => {
|
||||||
|
return invalid_for_loop();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if &func_name.to_string() == "range" {
|
if &func_name.to_string() == "range" {
|
||||||
|
@ -64,10 +74,10 @@ pub(crate) fn codegen_for_loop(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
todo!("Codegen: Only range is supported")
|
invalid_for_loop()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => todo!("Codegen: Only call is supported {for_loop:?}"),
|
_ => invalid_for_loop(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -96,8 +106,10 @@ pub(crate) fn codegen_break() -> TokenStream {
|
||||||
/// Codegen for return statement
|
/// Codegen for return statement
|
||||||
pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream {
|
pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream {
|
||||||
if expr_return.expr.is_some() {
|
if expr_return.expr.is_some() {
|
||||||
panic!("Codegen: Only void return is supported.")
|
return syn::Error::new_spanned(expr_return, "Codegen: Only void return is supported.")
|
||||||
|
.into_compile_error();
|
||||||
}
|
}
|
||||||
|
|
||||||
quote::quote! {
|
quote::quote! {
|
||||||
burn_cube::frontend::branch::return_expand(context);
|
burn_cube::frontend::branch::return_expand(context);
|
||||||
}
|
}
|
||||||
|
@ -131,7 +143,11 @@ pub(crate) fn codegen_if(
|
||||||
burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block);
|
burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
todo!("Codegen: Only block else expr is supported")
|
syn::Error::new_spanned(
|
||||||
|
expr,
|
||||||
|
"Unsupported: only `else` block is allowed after an `if` statement.",
|
||||||
|
)
|
||||||
|
.into_compile_error()
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
quote::quote! {
|
quote::quote! {
|
||||||
|
|
|
@ -41,11 +41,12 @@ pub(crate) fn codegen_closure(
|
||||||
if let syn::Pat::Ident(ident) = &*pat_type.pat {
|
if let syn::Pat::Ident(ident) = &*pat_type.pat {
|
||||||
&ident.ident
|
&ident.ident
|
||||||
} else {
|
} else {
|
||||||
panic!("Codegen: Unsupported {:?}", input);
|
return syn::Error::new_spanned(pat_type, "Unsupported input")
|
||||||
|
.into_compile_error();
|
||||||
},
|
},
|
||||||
Some(pat_type.ty.clone()),
|
Some(pat_type.ty.clone()),
|
||||||
),
|
),
|
||||||
_ => panic!("Codegen: Unsupported {:?}", input),
|
_ => return syn::Error::new_spanned(input, "Unsupported input").into_compile_error(),
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(ty) = ty {
|
if let Some(ty) = ty {
|
||||||
|
@ -92,7 +93,12 @@ pub(crate) fn codegen_call(
|
||||||
}
|
}
|
||||||
path
|
path
|
||||||
}
|
}
|
||||||
_ => todo!("Codegen: func call {:?} not supported", call.func),
|
_ => {
|
||||||
|
return (
|
||||||
|
syn::Error::new_spanned(&call.func, "Unsupported").into_compile_error(),
|
||||||
|
false,
|
||||||
|
)
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Path
|
// Path
|
||||||
|
|
|
@ -83,7 +83,7 @@ impl Codegen {
|
||||||
codegen.state_inputs.push((ident.clone(), *ty));
|
codegen.state_inputs.push((ident.clone(), *ty));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
_ => todo!("Only Typed inputs are supported"),
|
_ => panic!("Only Typed inputs are supported"),
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,7 +29,10 @@ pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
|
||||||
for element in array.elems.iter() {
|
for element in array.elems.iter() {
|
||||||
let token = match element {
|
let token = match element {
|
||||||
syn::Expr::Lit(lit) => codegen_lit(lit),
|
syn::Expr::Lit(lit) => codegen_lit(lit),
|
||||||
_ => todo!("Codegen: Only arrays of literals are supported"),
|
_ => {
|
||||||
|
return syn::Error::new_spanned(array, "Only arrays of literals are supported")
|
||||||
|
.into_compile_error()
|
||||||
|
}
|
||||||
};
|
};
|
||||||
tokens.extend(quote::quote! { #token, });
|
tokens.extend(quote::quote! { #token, });
|
||||||
}
|
}
|
||||||
|
|
|
@ -49,7 +49,12 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
|
||||||
|
|
||||||
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
|
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
|
||||||
|
|
||||||
let cube = codegen_cube(&func, &mut variable_tracker);
|
let mut cube = codegen_cube(&func, &mut variable_tracker);
|
||||||
|
|
||||||
|
for err in variable_tracker.errors.drain(..) {
|
||||||
|
cube.extend(err.into_compile_error());
|
||||||
|
}
|
||||||
|
|
||||||
let code: TokenStream = if launch {
|
let code: TokenStream = if launch {
|
||||||
let launch = codegen_launch(&func.sig);
|
let launch = codegen_launch(&func.sig);
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,7 @@ pub(crate) struct VariableTracker {
|
||||||
analysis_repeats: HashMap<VariableKey, u8>,
|
analysis_repeats: HashMap<VariableKey, u8>,
|
||||||
codegen_repeats: HashMap<VariableKey, u8>,
|
codegen_repeats: HashMap<VariableKey, u8>,
|
||||||
variable_uses: HashMap<VariableIdent, VariableUse>,
|
variable_uses: HashMap<VariableIdent, VariableUse>,
|
||||||
|
pub errors: Vec<syn::Error>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
|
|
|
@ -32,3 +32,6 @@ derive-new = { workspace = true }
|
||||||
num-traits = { workspace = true }
|
num-traits = { workspace = true }
|
||||||
|
|
||||||
log = { workspace = true }
|
log = { workspace = true }
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
trybuild = "1"
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
use burn_cube::prelude::*;
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
fn range(x: UInt, y: UInt) {
|
||||||
|
let _array = [x, y];
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,5 @@
|
||||||
|
error: Only arrays of literals are supported
|
||||||
|
--> tests/error/array_variable.rs:5:18
|
||||||
|
|
|
||||||
|
5 | let _array = [x, y];
|
||||||
|
| ^^^^^^
|
|
@ -0,0 +1,8 @@
|
||||||
|
use burn_cube::prelude::*;
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
fn range() {
|
||||||
|
for _ in 0..10 {}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,5 @@
|
||||||
|
error: Invalid for loop: use [range](cubecl::prelude::range] instead.
|
||||||
|
--> tests/error/for_loop_range.rs:5:14
|
||||||
|
|
|
||||||
|
5 | for _ in 0..10 {}
|
||||||
|
| ^^^^^
|
|
@ -0,0 +1,10 @@
|
||||||
|
use burn_cube::prelude::*;
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
fn range(x: UInt, y: UInt) {
|
||||||
|
if x == y {
|
||||||
|
} else if x != y {
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,7 @@
|
||||||
|
error: Unsupported: only `else` block is allowed after an `if` statement.
|
||||||
|
--> tests/error/if_else_if.rs:6:12
|
||||||
|
|
|
||||||
|
6 | } else if x != y {
|
||||||
|
| ____________^
|
||||||
|
7 | | }
|
||||||
|
| |_____^
|
|
@ -0,0 +1,8 @@
|
||||||
|
use burn_cube::prelude::*;
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
fn range() {
|
||||||
|
0..10;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,5 @@
|
||||||
|
error: Range is not supported, use [range](cubecl::prelude::range) instead.
|
||||||
|
--> tests/error/range.rs:5:5
|
||||||
|
|
|
||||||
|
5 | 0..10;
|
||||||
|
| ^^^^^
|
|
@ -0,0 +1,12 @@
|
||||||
|
use burn_cube::prelude::*;
|
||||||
|
|
||||||
|
#[cube]
|
||||||
|
fn range(x: UInt, y: UInt) -> UInt {
|
||||||
|
if x == y {
|
||||||
|
return x;
|
||||||
|
}
|
||||||
|
|
||||||
|
y
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {}
|
|
@ -0,0 +1,5 @@
|
||||||
|
error: Codegen: Only void return is supported.
|
||||||
|
--> tests/error/return_value.rs:6:9
|
||||||
|
|
|
||||||
|
6 | return x;
|
||||||
|
| ^^^^^^^^
|
|
@ -1 +1,7 @@
|
||||||
mod frontend;
|
mod frontend;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn compile_fail_tests() {
|
||||||
|
let t = trybuild::TestCases::new();
|
||||||
|
t.compile_fail("tests/error/*.rs");
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue