#[cfg(test)] #[cfg(feature = "onnx")] mod tests { use std::fs::read_to_string; use std::path::Path; use burn::record::FullPrecisionSettings; use pretty_assertions::assert_eq; use rstest::*; fn code>(onnx_path: P) -> String { let graph = burn_import::onnx::parse_onnx(onnx_path.as_ref()); let graph = graph .into_burn::() .with_blank_space(true) .with_top_comment(Some("Generated by integration tests".into())); burn_import::format_tokens(graph.codegen()) } #[rstest] #[case("model1")] // #[case("model2")] <- Add more models here fn test_codegen(#[case] model_name: &str) { let input_file = format!("tests/data/{model_name}/{model_name}.onnx"); let source_file = format!("tests/data/{model_name}/{model_name}.rs"); let source_expected: String = read_to_string(source_file).expect("Expected source file is missing"); let code = code(input_file); assert_eq!(source_expected, code); } }