diff --git a/burn-book/src/import/onnx-model.md b/burn-book/src/import/onnx-model.md index c8218c077..667e1b910 100644 --- a/burn-book/src/import/onnx-model.md +++ b/burn-book/src/import/onnx-model.md @@ -1,137 +1,218 @@ -# Import ONNX Model +# Importing ONNX Models in Burn -## Why Importing Models is Necessary +## Table of Contents -In the realm of deep learning, it's common to switch between different frameworks depending on your -project's specific needs. Maybe you've painstakingly fine-tuned a model in TensorFlow or PyTorch and -now you want to reap the benefits of Burn's unique features for deployment or further testing. This -is precisely the scenario where importing models into Burn can be a game-changer. +1. [Introduction](#introduction) +2. [Why Import Models?](#why-import-models) +3. [Understanding ONNX](#understanding-onnx) +4. [Burn's ONNX Support](#burns-onnx-support) +5. [Step-by-Step Guide](#step-by-step-guide) +6. [Advanced Configuration](#advanced-configuration) +7. [Loading and Using Models](#loading-and-using-models) +8. [Troubleshooting](#troubleshooting) +9. [Examples and Resources](#examples-and-resources) +10. [Conclusion](#conclusion) -## Traditional Methods: The Drawbacks +## Introduction -If you've been working with other deep learning frameworks like PyTorch, it's likely that you've -exported model weights before. PyTorch, for instance, lets you save model weights using its -`torch.save()` function. Yet, to port this model to another framework, you face the arduous task of -manually recreating the architecture in the destination framework before loading in the weights. Not -only is this method tedious, but it's also error-prone and hinders smooth interoperability between -frameworks. +As the field of deep learning continues to evolve, the need for interoperability between different +frameworks becomes increasingly important. Burn, a modern deep learning framework in Rust, +recognizes this need and provides robust support for importing models from other popular frameworks. +This section focuses on importing +[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn, +enabling you to leverage pre-trained models and seamlessly integrate them into your Rust-based deep +learning projects. -It's worth noting that for models using cutting-edge, framework-specific features, manual porting -might be the only option, as standards like ONNX might not yet support these new innovations. +## Why Import Models? -## Enter ONNX +Importing pre-trained models offers several advantages: -[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) is designed to solve -such complications. It's an open-standard format that exports both the architecture and the weights -of a deep learning model. This feature makes it exponentially easier to move models between -different frameworks, thereby significantly aiding interoperability. ONNX is supported by a number -of frameworks including but not limited to TensorFlow, PyTorch, Caffe2, and Microsoft Cognitive -Toolkit. +1. **Time-saving**: Avoid the need to train models from scratch, which can be time-consuming and + resource-intensive. +2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by + researchers and industry leaders. +3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from + knowledge transfer. +4. **Consistency across frameworks**: Ensure consistent performance when moving from one framework + to another. -### Advantages of ONNX +## Understanding ONNX -ONNX stands out for encapsulating two key elements: +ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models. +Key features include: -1. **Model Information**: It captures the architecture, detailing the layers, their connections, and - configurations. -2. **Weights**: ONNX also contains the trained model's weights. +- **Framework agnostic**: ONNX provides a common format that works across various deep learning + frameworks. +- **Comprehensive representation**: It captures both the model architecture and trained weights. +- **Wide support**: Many popular frameworks like PyTorch, TensorFlow, and scikit-learn support ONNX + export. -This dual encapsulation not only simplifies the porting of models between frameworks but also allows -seamless deployment across different environments without compatibility concerns. +By using ONNX, you can easily move models between different frameworks and deployment environments. -## Burn's ONNX Support: Importing Made Easy +## Burn's ONNX Support -Understanding the important role that ONNX plays in the contemporary deep learning landscape, Burn -simplifies the process of importing ONNX models via an intuitive API designed to mesh well with -Burn's ecosystem. +Burn takes a unique approach to ONNX import, offering several advantages: -Burn's solution is to translate ONNX files into Rust source code as well as Burn-compatible weights. -This transformation is carried out through the burn-import crate's code generator during build time, -providing advantages for both executing and further training ONNX models. +1. **Native Rust code generation**: ONNX models are translated into Rust source code, allowing for + deep integration with Burn's ecosystem. +2. **Compile-time optimization**: The generated Rust code can be optimized by the Rust compiler, + potentially improving performance. +3. **No runtime dependency**: Unlike some solutions that require an ONNX runtime, Burn's approach + eliminates this dependency. +4. **Trainability**: Imported models can be further trained or fine-tuned using Burn. +5. **Portability**: The generated Rust code can be compiled for various targets, including + WebAssembly and embedded devices. +6. **Any Burn Backend**: The imported models can be used with any of Burn's backends. -### Advantages of Burn's ONNX Approach +## Step-by-Step Guide -1. **Native Integration**: The generated Rust code is fully integrated into Burn's architecture, - enabling your model to run on various backends without the need for a separate ONNX runtime. +Let's walk through the process of importing an ONNX model into a Burn project: -2. **Trainability**: The imported model is not just for inference; it can be further trained or - fine-tuned using Burn's native training loop. +### Step 1: Update `build.rs` -3. **Portability**: As the model is converted to Rust source code, it can be compiled into - WebAssembly for browser execution. Likewise, this approach is beneficial for no-std embedded - devices. +First, add the `burn-import` crate to your `Cargo.toml`: -4. **Optimization**: Rust's compiler can further optimize the generated code for target - architectures, thereby improving performance. +```toml +[build-dependencies] +burn-import = "0.14.0" +``` -### Sample Code for Importing ONNX Model +Then, in your `build.rs` file: -Below is a step-by-step guide to importing an ONNX model into a Burn-based project: - -#### Step 1: Update `build.rs` - -Include the `burn-import` crate and use the following Rust code in your `build.rs`: - -```rust, ignore +```rust use burn_import::onnx::ModelGen; fn main() { - // Generate Rust code from the ONNX model file ModelGen::new() - .input("src/model/mnist.onnx") + .input("src/model/my_model.onnx") .out_dir("model/") .run_from_script(); } ``` -#### Step 2: Modify `mod.rs` +This script uses `ModelGen` to generate Rust code from your ONNX model during the build process. -Add this code to the `mod.rs` file located in `src/model`: +### Step 2: Modify `mod.rs` -```rust, ignore -pub mod mnist { - include!(concat!(env!("OUT_DIR"), "/model/mnist.rs")); +In your `src/model/mod.rs` file, include the generated code: + +```rust +pub mod my_model { + include!(concat!(env!("OUT_DIR"), "/model/my_model.rs")); } ``` -#### Step 3: Utilize Imported Model +This makes the generated model code available in your project. -Here's how to use the imported model in your application: +### Step 3: Use the Imported Model -```rust, ignore -mod model; +Now you can use the imported model in your Rust code: +```rust use burn::tensor; use burn_ndarray::{NdArray, NdArrayDevice}; -use model::mnist::Model; +use model::my_model::Model; fn main() { - // Initialize a new model instance let device = NdArrayDevice::default(); - let model: Model> = Model::new(&device); - // Create a sample input tensor (zeros for demonstration) - let input = tensor::Tensor::, 4>::zeros([1, 1, 28, 28], &device); + // Create model instance and load weights from target dir default device. + // (see more load options below in "Loading and Using Models" section) + let model: Model> = Model::default(); + + // Create input tensor (replace with your actual input) + let input = tensor::Tensor::, 4>::zeros([1, 3, 224, 224], &device); // Perform inference let output = model.forward(input); - // Print the output - println!("{:?}", output); + println!("Model output: {:?}", output); } ``` -### Working Examples +## Advanced Configuration -For practical examples, please refer to: +The `ModelGen` struct offers several configuration options: + +```rust +ModelGen::new() + .input("path/to/model.onnx") + .out_dir("model/") + .record_type(RecordType::NamedMpk) + .half_precision(false) + .embed_states(false) + .run_from_script(); +``` + +- `record_type`: Specifies the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or + PrettyJson). +- `half_precision`: Use half-precision (f16) for weights to reduce model size. +- `embed_states`: Embed model weights directly in the generated Rust code. Note: This requires + record type `Bincode`. + +## Loading and Using Models + +Depending on your configuration, you can load models in different ways: + +```rust +// Create a new model instance with device. Initializes weights randomly and lazily. +// You can load weights via `load_record` afterwards. +let model = Model::::new(&device); + +// Load from a file (must specify weights file in the target output directory or copy it from there). +// File type should match the record type specified in `ModelGen`. +let model = Model::::from_file("path/to/weights", &device); + +// Load from embedded weights (if embed_states was true) +let model = Model::::from_embedded(); + +// Load from the out director location and load to default device (useful for testing) +let model = Model::::default(); +``` + +## Troubleshooting + +Here are some common issues and their solutions: + +1. **Unsupported ONNX operator**: If you encounter an error about an unsupported operator, check the + [list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). + You may need to simplify your model or wait for support to be added. + +2. **Build errors**: Ensure that your `burn-import` version matches your Burn version. Also, check + that the ONNX file path in `build.rs` is correct. + +3. **Runtime errors**: If you get errors when running your model, double-check that your input + tensors match the expected shape and data type of your model. + +4. **Performance issues**: If your imported model is slower than expected, try using the + `half_precision` option to reduce memory usage, or experiment with different `record_type` + options. + +5. **Artifact Files**: You can view the generated Rust code and weights files in the `OUT_DIR` + directory specified in `build.rs` (usually `target/debug/build//out`). + +## Examples and Resources + +For more detailed examples, check out: 1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference) 2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn) -By combining ONNX's robustness with Burn's unique features, you'll have the flexibility and power to -streamline your deep learning workflows like never before. +These examples demonstrate real-world usage of ONNX import in Burn projects. + +## Conclusion + +Importing ONNX models into Burn opens up a world of possibilities, allowing you to leverage +pre-trained models from other frameworks while taking advantage of Burn's performance and Rust's +safety features. By following this guide, you should be able to seamlessly integrate ONNX models +into your Burn projects, whether for inference, fine-tuning, or as a starting point for further +development. + +Remember that the `burn-import` crate is actively developed, with ongoing work to support more ONNX +operators and improve performance. Stay tuned to the Burn repository for updates and new features! --- -> 🚨**Note**: `burn-import` crate is in active development and currently supports a -> [limited set of ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). +> 🚨**Note**: The `burn-import` crate is in active development. For the most up-to-date information +> on supported ONNX operators, please refer to the +> [official documentation](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md). diff --git a/crates/burn-train/Cargo.toml b/crates/burn-train/Cargo.toml index 8b35a8ed8..2f1c7b04c 100644 --- a/crates/burn-train/Cargo.toml +++ b/crates/burn-train/Cargo.toml @@ -17,7 +17,7 @@ metrics = ["nvml-wrapper", "sysinfo", "systemstat"] tui = ["ratatui", "crossterm"] [dependencies] -burn-core = { path = "../burn-core", version = "0.14.0", features = ["dataset"] } +burn-core = { path = "../burn-core", version = "0.14.0", features = ["dataset"], default-features = false } log = { workspace = true } tracing-subscriber = { workspace = true } diff --git a/crates/burn-wgpu/Cargo.toml b/crates/burn-wgpu/Cargo.toml index c6a2b9a54..3674c97d6 100644 --- a/crates/burn-wgpu/Cargo.toml +++ b/crates/burn-wgpu/Cargo.toml @@ -11,12 +11,12 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-wgpu" version.workspace = true [features] -default = ["fusion", "burn-jit/default", "cubecl/default"] +default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"] fusion = ["burn-fusion", "burn-jit/fusion"] autotune = ["burn-jit/autotune"] template = ["burn-jit/template", "cubecl/template"] doc = ["burn-jit/doc"] -std = ["burn-jit/std"] +std = ["burn-jit/std", "cubecl/std"] [dependencies] cubecl = { workspace = true, features = ["wgpu"] } diff --git a/examples/text-classification/Cargo.toml b/examples/text-classification/Cargo.toml index 1c5928e75..fe9eff670 100644 --- a/examples/text-classification/Cargo.toml +++ b/examples/text-classification/Cargo.toml @@ -20,7 +20,7 @@ cuda-jit = ["burn/cuda-jit"] [dependencies] # Burn -burn = {path = "../../crates/burn", features=["train", "ndarray", "fusion"]} +burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "tui", "metrics", "autotune", "fusion", "default"], default-features = false} # Tokenizer tokenizers = { version = "0.19.1", default-features = false, features = [