Feat/cube/compile error (#1909)

This commit is contained in:
Nathaniel Simard 2024-06-19 17:21:32 -04:00 committed by GitHub
parent d50bac165e
commit efc13d9a38
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 161 additions and 26 deletions

15
Cargo.lock generated
View File

@ -536,6 +536,7 @@ dependencies = [
"log",
"num-traits",
"serde",
"trybuild",
]
[[package]]
@ -5629,6 +5630,20 @@ version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b"
[[package]]
name = "trybuild"
version = "1.0.96"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "33a5f13f11071020bb12de7a16b925d2d58636175c20c11dc5f96cb64bb6c9b3"
dependencies = [
"glob",
"serde",
"serde_derive",
"serde_json",
"termcolor",
"toml",
]
[[package]]
name = "typenum"
version = "1.17.0"

View File

@ -128,7 +128,7 @@ impl VariableAnalyzer {
if let syn::Expr::Block(expr_block) = &**expr {
self.find_occurrences_in_stmts(&expr_block.block.stmts, depth);
} else {
todo!("Analysis: Only block else expr is supported")
// Unsupported: handled in codegen.
}
}
}
@ -190,17 +190,12 @@ impl VariableAnalyzer {
syn::Expr::Break(_) => {}
syn::Expr::Return(expr) => {
if expr.expr.is_some() {
todo!("Analysis: only void return supported")
// Unsupported: handled in codegen.
}
}
syn::Expr::Paren(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Array(expr) => {
for element in expr.elems.iter() {
match element {
syn::Expr::Lit(_) => {}
_ => todo!("Analysis: only array of literals is supported"),
}
}
syn::Expr::Array(_expr) => {
// No analysis since only literals are supported
}
syn::Expr::Reference(expr) => self.find_occurrences_in_expr(&expr.expr, depth),
syn::Expr::Closure(expr) => {
@ -251,7 +246,12 @@ impl VariableAnalyzer {
self.find_occurrences_in_expr(&field.expr, depth)
}
}
_ => todo!("Analysis: unsupported expr {expr:?}"),
syn::Expr::Range(_range) => {
// Error is handled during codegen.
}
_ => {
// Error is handled during codegen.
}
}
}
}

View File

@ -125,7 +125,14 @@ pub(crate) fn codegen_expr_with_comptime(
syn::Expr::Unary(op) => codegen_unary(op, loop_level, variable_tracker),
syn::Expr::Field(field) => codegen_field(field, loop_level, variable_tracker),
syn::Expr::Struct(struct_) => codegen_struct(struct_, loop_level, variable_tracker),
_ => panic!("Codegen: Unsupported {:?}", expr),
syn::Expr::Range(range) => syn::Error::new_spanned(
range,
"Range is not supported, use [range](cubecl::prelude::range) instead.",
)
.to_compile_error(),
_ => {
syn::Error::new_spanned(expr, "Expression is not supported").to_compile_error()
}
};
(tokens, false)

View File

@ -24,14 +24,24 @@ pub(crate) fn codegen_for_loop(
variable_tracker.codegen_declare(id.to_string(), loop_level as u8 + 1);
}
let invalid_for_loop = || {
syn::Error::new_spanned(
&for_loop.expr,
"Invalid for loop: use [range](cubecl::prelude::range] instead.",
)
.into_compile_error()
};
match for_loop.expr.as_ref() {
syn::Expr::Call(call) => {
let func_name = match call.func.as_ref() {
syn::Expr::Path(path) => path
.path
.get_ident()
.expect("Codegen: func in for loop should have ident"),
_ => todo!("Codegen: Only path call supported"),
syn::Expr::Path(path) => match path.path.get_ident() {
Some(ident) => ident,
None => return invalid_for_loop(),
},
_ => {
return invalid_for_loop();
}
};
if &func_name.to_string() == "range" {
@ -64,10 +74,10 @@ pub(crate) fn codegen_for_loop(
}
}
} else {
todo!("Codegen: Only range is supported")
invalid_for_loop()
}
}
_ => todo!("Codegen: Only call is supported {for_loop:?}"),
_ => invalid_for_loop(),
}
}
@ -96,8 +106,10 @@ pub(crate) fn codegen_break() -> TokenStream {
/// Codegen for return statement
pub(crate) fn codegen_return(expr_return: &syn::ExprReturn) -> TokenStream {
if expr_return.expr.is_some() {
panic!("Codegen: Only void return is supported.")
return syn::Error::new_spanned(expr_return, "Codegen: Only void return is supported.")
.into_compile_error();
}
quote::quote! {
burn_cube::frontend::branch::return_expand(context);
}
@ -131,7 +143,11 @@ pub(crate) fn codegen_if(
burn_cube::frontend::branch::if_else_expand(context, #comptime_bool, _cond.into(), |context| #then_block, |context| #else_block);
}
} else {
todo!("Codegen: Only block else expr is supported")
syn::Error::new_spanned(
expr,
"Unsupported: only `else` block is allowed after an `if` statement.",
)
.into_compile_error()
}
} else {
quote::quote! {

View File

@ -41,11 +41,12 @@ pub(crate) fn codegen_closure(
if let syn::Pat::Ident(ident) = &*pat_type.pat {
&ident.ident
} else {
panic!("Codegen: Unsupported {:?}", input);
return syn::Error::new_spanned(pat_type, "Unsupported input")
.into_compile_error();
},
Some(pat_type.ty.clone()),
),
_ => panic!("Codegen: Unsupported {:?}", input),
_ => return syn::Error::new_spanned(input, "Unsupported input").into_compile_error(),
};
if let Some(ty) = ty {
@ -92,7 +93,12 @@ pub(crate) fn codegen_call(
}
path
}
_ => todo!("Codegen: func call {:?} not supported", call.func),
_ => {
return (
syn::Error::new_spanned(&call.func, "Unsupported").into_compile_error(),
false,
)
}
};
// Path

View File

@ -83,7 +83,7 @@ impl Codegen {
codegen.state_inputs.push((ident.clone(), *ty));
}
}
_ => todo!("Only Typed inputs are supported"),
_ => panic!("Only Typed inputs are supported"),
};
}

View File

@ -29,7 +29,10 @@ pub(crate) fn codegen_array_lit(array: &syn::ExprArray) -> TokenStream {
for element in array.elems.iter() {
let token = match element {
syn::Expr::Lit(lit) => codegen_lit(lit),
_ => todo!("Codegen: Only arrays of literals are supported"),
_ => {
return syn::Error::new_spanned(array, "Only arrays of literals are supported")
.into_compile_error()
}
};
tokens.extend(quote::quote! { #token, });
}

View File

@ -49,7 +49,12 @@ pub fn cube(attr: TokenStream, tokens: TokenStream) -> TokenStream {
let mut variable_tracker = VariableAnalyzer::create_tracker(&func);
let cube = codegen_cube(&func, &mut variable_tracker);
let mut cube = codegen_cube(&func, &mut variable_tracker);
for err in variable_tracker.errors.drain(..) {
cube.extend(err.into_compile_error());
}
let code: TokenStream = if launch {
let launch = codegen_launch(&func.sig);

View File

@ -23,6 +23,7 @@ pub(crate) struct VariableTracker {
analysis_repeats: HashMap<VariableKey, u8>,
codegen_repeats: HashMap<VariableKey, u8>,
variable_uses: HashMap<VariableIdent, VariableUse>,
pub errors: Vec<syn::Error>,
}
#[derive(Debug, Default)]

View File

@ -32,3 +32,6 @@ derive-new = { workspace = true }
num-traits = { workspace = true }
log = { workspace = true }
[dev-dependencies]
trybuild = "1"

View File

@ -0,0 +1,8 @@
use burn_cube::prelude::*;
#[cube]
fn range(x: UInt, y: UInt) {
let _array = [x, y];
}
fn main() {}

View File

@ -0,0 +1,5 @@
error: Only arrays of literals are supported
--> tests/error/array_variable.rs:5:18
|
5 | let _array = [x, y];
| ^^^^^^

View File

@ -0,0 +1,8 @@
use burn_cube::prelude::*;
#[cube]
fn range() {
for _ in 0..10 {}
}
fn main() {}

View File

@ -0,0 +1,5 @@
error: Invalid for loop: use [range](cubecl::prelude::range] instead.
--> tests/error/for_loop_range.rs:5:14
|
5 | for _ in 0..10 {}
| ^^^^^

View File

@ -0,0 +1,10 @@
use burn_cube::prelude::*;
#[cube]
fn range(x: UInt, y: UInt) {
if x == y {
} else if x != y {
}
}
fn main() {}

View File

@ -0,0 +1,7 @@
error: Unsupported: only `else` block is allowed after an `if` statement.
--> tests/error/if_else_if.rs:6:12
|
6 | } else if x != y {
| ____________^
7 | | }
| |_____^

View File

@ -0,0 +1,8 @@
use burn_cube::prelude::*;
#[cube]
fn range() {
0..10;
}
fn main() {}

View File

@ -0,0 +1,5 @@
error: Range is not supported, use [range](cubecl::prelude::range) instead.
--> tests/error/range.rs:5:5
|
5 | 0..10;
| ^^^^^

View File

@ -0,0 +1,12 @@
use burn_cube::prelude::*;
#[cube]
fn range(x: UInt, y: UInt) -> UInt {
if x == y {
return x;
}
y
}
fn main() {}

View File

@ -0,0 +1,5 @@
error: Codegen: Only void return is supported.
--> tests/error/return_value.rs:6:9
|
6 | return x;
| ^^^^^^^^

View File

@ -1 +1,7 @@
mod frontend;
#[test]
fn compile_fail_tests() {
let t = trybuild::TestCases::new();
t.compile_fail("tests/error/*.rs");
}