mirror of https://github.com/tracel-ai/burn.git
Feat/cube/compile error (#1909)
This commit is contained in:
parent
d50bac165e
commit
efc13d9a38
|
@ -536,6 +536,7 @@ dependencies = [
|
|||
"log",
|
||||
"num-traits",
|
||||
"serde",
|
||||
"trybuild",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -5629,6 +5630,20 @@ version = "0.2.5"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
|
||||
|
||||
[[package]]
|
||||
name = "trybuild"
|
||||
version = "1.0.96"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "33a5f13f11071020bb12de7a16b925d2d58636175c20c11dc5f96cb64bb6c9b3"
|
||||
dependencies = [
|
||||
"glob",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"termcolor",
|
||||
"toml",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
|
|
|
@ -128,7 +128,7 @@ impl VariableAnalyzer {
|
|||
if let syn::Expr::Block(expr_block) = &**expr {
|
||||
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
|
||||
} else {
|
||||
todo!("Analysis: Only block else expr is supported")
|
||||
// Unsupported: handled in codegen.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -190,17 +190,12 @@ impl VariableAnalyzer {
|
|||
syn::Expr::Break(_) => {}
|
||||
syn::Expr::Return(expr) => {
|
||||
if expr.expr.is_some() {
|
||||
todo!("Analysis: only void return supported")
|
||||
// Unsupported: handled in codegen.
|
||||
}
|
||||
}
|
||||
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||
syn::Expr::Array(expr) => {
|
||||
for element in expr.elems.iter() {
|
||||
match element {
|
||||
syn::Expr::Lit(_) => {}
|
||||
_ => todo!("Analysis: only array of literals is supported"),
|
||||
}
|
||||
}
|
||||
syn::Expr::Array(_expr) => {
|
||||
// No analysis since only literals are supported
|
||||
}
|
||||
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
|
||||
syn::Expr::Closure(expr) => {
|
||||
|
@ -251,7 +246,12 @@ impl VariableAnalyzer {
|
|||
self.find_occurrences_in_expr(&field.expr, depth)
|
||||
}
|
||||
}
|
||||
_ => todo!("Analysis: unsupported expr {expr:?}"),
|
||||
syn::Expr::Range(_range) => {
|
||||
// Error is handled during codegen.
|
||||
}
|
||||
_ => {
|
||||
// Error is handled during codegen.
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -125,7 +125,14 @@ pub(crate) fn codegen_expr_with_comptime(
|
|||
syn::Expr::Unary(op) => codegen_unary(op, loop_level, variable_tracker),
|
||||
syn::Expr::Field(field) => codegen_field(field, loop_level, variable_tracker),
|
||||
syn::Expr::Struct(struct_) => codegen_struct(struct_, loop_level, variable_tracker),
|
||||
_ => panic!("Codegen: Unsupported {:?}", expr),
|
||||
syn::Expr::Range(range) => syn::Error::new_spanned(
|
||||
range,
|
||||
"Range is not supported, use [range](cubecl::prelude::range) instead.",
|
||||
)
|
||||
.to_compile_error(),
|
||||
_ => {
|
||||
syn::Error::new_spanned(expr, "Expression is not supported").to_compile_error()
|
||||
}
|
||||
};
|
||||
|
||||
(tokens, false)
|
||||
|
|
|
@ -24,14 +24,24 @@ pub(crate) fn codegen_for_loop(
|
|||
variable_tracker.codegen_declare(id.to_string(), loop_level as u8 + 1);
|
||||
}
|
||||
|
||||
let invalid_for_loop = || {
|
||||
syn::Error::new_spanned(
|
||||
&for_loop.expr,
|
||||
"Invalid for loop: use [range](cubecl::prelude::range] instead.",
|
||||
)
|
||||
.into_compile_error()
|
||||
};
|
||||
|
||||
match for_loop.expr.as_ref() {
|
||||
syn::Expr::Call(call) => {
|
||||
let func_name = match call.func.as_ref() {
|
||||
syn::Expr::Path(path) => path
|
||||
.path
|
||||
.get_ident()
|
||||
.expect("Codegen: func in for loop should have ident"),
|
||||
_ => todo!("Codegen: Only path call supported"),
|
||||
syn::Expr::Path(path) => match path.path.get_ident() {
|
||||
Some(ident) => ident,
|
||||
None => return invalid_for_loop(),
|
||||
},
|
||||
_ => {
|
||||
return invalid_for_loop();
|
||||
}
|
||||
};
|
||||
|
||||
if &func_name.to_string() == "range" {
|
||||
|
@ -64,10 +74,10 @@ pub(crate) fn codegen_for_loop(
|
|||
}
|
||||
}
|
||||
} else {
|
||||
todo!("Codegen: Only range is supported")
|
||||
invalid_for_loop()
|
||||
}
|
||||
}
|
||||
_ => todo!("Codegen: Only call is supported {for_loop:?}"),
|
||||
_ => invalid_for_loop(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -96,8 +106,10 @@ pub(crate) fn codegen_break() -> TokenStream {
|
|||
/// Codegen for return statement
|
||||
pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream {
|
||||
if expr_return.expr.is_some() {
|
||||
panic!("Codegen: Only void return is supported.")
|
||||
return syn::Error::new_spanned(expr_return, "Codegen: Only void return is supported.")
|
||||
.into_compile_error();
|
||||
}
|
||||
|
||||
quote::quote! {
|
||||
burn_cube::frontend::branch::return_expand(context);
|
||||
}
|
||||
|
@ -131,7 +143,11 @@ pub(crate) fn codegen_if(
|
|||
burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block);
|
||||
}
|
||||
} else {
|
||||
todo!("Codegen: Only block else expr is supported")
|
||||
syn::Error::new_spanned(
|
||||
expr,
|
||||
"Unsupported: only `else` block is allowed after an `if` statement.",
|
||||
)
|
||||
.into_compile_error()
|
||||
}
|
||||
} else {
|
||||
quote::quote! {
|
||||
|
|
|
@ -41,11 +41,12 @@ pub(crate) fn codegen_closure(
|
|||
if let syn::Pat::Ident(ident) = &*pat_type.pat {
|
||||
&ident.ident
|
||||
} else {
|
||||
panic!("Codegen: Unsupported {:?}", input);
|
||||
return syn::Error::new_spanned(pat_type, "Unsupported input")
|
||||
.into_compile_error();
|
||||
},
|
||||
Some(pat_type.ty.clone()),
|
||||
),
|
||||
_ => panic!("Codegen: Unsupported {:?}", input),
|
||||
_ => return syn::Error::new_spanned(input, "Unsupported input").into_compile_error(),
|
||||
};
|
||||
|
||||
if let Some(ty) = ty {
|
||||
|
@ -92,7 +93,12 @@ pub(crate) fn codegen_call(
|
|||
}
|
||||
path
|
||||
}
|
||||
_ => todo!("Codegen: func call {:?} not supported", call.func),
|
||||
_ => {
|
||||
return (
|
||||
syn::Error::new_spanned(&call.func, "Unsupported").into_compile_error(),
|
||||
false,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Path
|
||||
|
|
|
@ -83,7 +83,7 @@ impl Codegen {
|
|||
codegen.state_inputs.push((ident.clone(), *ty));
|
||||
}
|
||||
}
|
||||
_ => todo!("Only Typed inputs are supported"),
|
||||
_ => panic!("Only Typed inputs are supported"),
|
||||
};
|
||||
}
|
||||
|
||||
|
|
|
@ -29,7 +29,10 @@ pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
|
|||
for element in array.elems.iter() {
|
||||
let token = match element {
|
||||
syn::Expr::Lit(lit) => codegen_lit(lit),
|
||||
_ => todo!("Codegen: Only arrays of literals are supported"),
|
||||
_ => {
|
||||
return syn::Error::new_spanned(array, "Only arrays of literals are supported")
|
||||
.into_compile_error()
|
||||
}
|
||||
};
|
||||
tokens.extend(quote::quote! { #token, });
|
||||
}
|
||||
|
|
|
@ -49,7 +49,12 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
|
|||
|
||||
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
|
||||
|
||||
let cube = codegen_cube(&func, &mut variable_tracker);
|
||||
let mut cube = codegen_cube(&func, &mut variable_tracker);
|
||||
|
||||
for err in variable_tracker.errors.drain(..) {
|
||||
cube.extend(err.into_compile_error());
|
||||
}
|
||||
|
||||
let code: TokenStream = if launch {
|
||||
let launch = codegen_launch(&func.sig);
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@ pub(crate) struct VariableTracker {
|
|||
analysis_repeats: HashMap<VariableKey, u8>,
|
||||
codegen_repeats: HashMap<VariableKey, u8>,
|
||||
variable_uses: HashMap<VariableIdent, VariableUse>,
|
||||
pub errors: Vec<syn::Error>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
|
|
|
@ -32,3 +32,6 @@ derive-new = { workspace = true }
|
|||
num-traits = { workspace = true }
|
||||
|
||||
log = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
trybuild = "1"
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn range(x: UInt, y: UInt) {
|
||||
let _array = [x, y];
|
||||
}
|
||||
|
||||
fn main() {}
|
|
@ -0,0 +1,5 @@
|
|||
error: Only arrays of literals are supported
|
||||
--> tests/error/array_variable.rs:5:18
|
||||
|
|
||||
5 | let _array = [x, y];
|
||||
| ^^^^^^
|
|
@ -0,0 +1,8 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn range() {
|
||||
for _ in 0..10 {}
|
||||
}
|
||||
|
||||
fn main() {}
|
|
@ -0,0 +1,5 @@
|
|||
error: Invalid for loop: use [range](cubecl::prelude::range] instead.
|
||||
--> tests/error/for_loop_range.rs:5:14
|
||||
|
|
||||
5 | for _ in 0..10 {}
|
||||
| ^^^^^
|
|
@ -0,0 +1,10 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn range(x: UInt, y: UInt) {
|
||||
if x == y {
|
||||
} else if x != y {
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {}
|
|
@ -0,0 +1,7 @@
|
|||
error: Unsupported: only `else` block is allowed after an `if` statement.
|
||||
--> tests/error/if_else_if.rs:6:12
|
||||
|
|
||||
6 | } else if x != y {
|
||||
| ____________^
|
||||
7 | | }
|
||||
| |_____^
|
|
@ -0,0 +1,8 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn range() {
|
||||
0..10;
|
||||
}
|
||||
|
||||
fn main() {}
|
|
@ -0,0 +1,5 @@
|
|||
error: Range is not supported, use [range](cubecl::prelude::range) instead.
|
||||
--> tests/error/range.rs:5:5
|
||||
|
|
||||
5 | 0..10;
|
||||
| ^^^^^
|
|
@ -0,0 +1,12 @@
|
|||
use burn_cube::prelude::*;
|
||||
|
||||
#[cube]
|
||||
fn range(x: UInt, y: UInt) -> UInt {
|
||||
if x == y {
|
||||
return x;
|
||||
}
|
||||
|
||||
y
|
||||
}
|
||||
|
||||
fn main() {}
|
|
@ -0,0 +1,5 @@
|
|||
error: Codegen: Only void return is supported.
|
||||
--> tests/error/return_value.rs:6:9
|
||||
|
|
||||
6 | return x;
|
||||
| ^^^^^^^^
|
|
@ -1 +1,7 @@
|
|||
mod frontend;
|
||||
|
||||
#[test]
|
||||
fn compile_fail_tests() {
|
||||
let t = trybuild::TestCases::new();
|
||||
t.compile_fail("tests/error/*.rs");
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue