mirror of https://github.com/tracel-ai/burn.git
Improve ONNX import book section (#2059)
* Improve ONNX importing section * Update onnx-model.md
This commit is contained in:
parent
62a30e973c
commit
bb13729b20
|
@ -1,137 +1,218 @@
|
|||
# Import ONNX Model
|
||||
# Importing ONNX Models in Burn
|
||||
|
||||
## Why Importing Models is Necessary
|
||||
## Table of Contents
|
||||
|
||||
In the realm of deep learning, it's common to switch between different frameworks depending on your
|
||||
project's specific needs. Maybe you've painstakingly fine-tuned a model in TensorFlow or PyTorch and
|
||||
now you want to reap the benefits of Burn's unique features for deployment or further testing. This
|
||||
is precisely the scenario where importing models into Burn can be a game-changer.
|
||||
1. [Introduction](#introduction)
|
||||
2. [Why Import Models?](#why-import-models)
|
||||
3. [Understanding ONNX](#understanding-onnx)
|
||||
4. [Burn's ONNX Support](#burns-onnx-support)
|
||||
5. [Step-by-Step Guide](#step-by-step-guide)
|
||||
6. [Advanced Configuration](#advanced-configuration)
|
||||
7. [Loading and Using Models](#loading-and-using-models)
|
||||
8. [Troubleshooting](#troubleshooting)
|
||||
9. [Examples and Resources](#examples-and-resources)
|
||||
10. [Conclusion](#conclusion)
|
||||
|
||||
## Traditional Methods: The Drawbacks
|
||||
## Introduction
|
||||
|
||||
If you've been working with other deep learning frameworks like PyTorch, it's likely that you've
|
||||
exported model weights before. PyTorch, for instance, lets you save model weights using its
|
||||
`torch.save()` function. Yet, to port this model to another framework, you face the arduous task of
|
||||
manually recreating the architecture in the destination framework before loading in the weights. Not
|
||||
only is this method tedious, but it's also error-prone and hinders smooth interoperability between
|
||||
As the field of deep learning continues to evolve, the need for interoperability between different
|
||||
frameworks becomes increasingly important. Burn, a modern deep learning framework in Rust,
|
||||
recognizes this need and provides robust support for importing models from other popular frameworks.
|
||||
This section focuses on importing
|
||||
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn,
|
||||
enabling you to leverage pre-trained models and seamlessly integrate them into your Rust-based deep
|
||||
learning projects.
|
||||
|
||||
## Why Import Models?
|
||||
|
||||
Importing pre-trained models offers several advantages:
|
||||
|
||||
1. **Time-saving**: Avoid the need to train models from scratch, which can be time-consuming and
|
||||
resource-intensive.
|
||||
2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by
|
||||
researchers and industry leaders.
|
||||
3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from
|
||||
knowledge transfer.
|
||||
4. **Consistency across frameworks**: Ensure consistent performance when moving from one framework
|
||||
to another.
|
||||
|
||||
## Understanding ONNX
|
||||
|
||||
ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models.
|
||||
Key features include:
|
||||
|
||||
- **Framework agnostic**: ONNX provides a common format that works across various deep learning
|
||||
frameworks.
|
||||
- **Comprehensive representation**: It captures both the model architecture and trained weights.
|
||||
- **Wide support**: Many popular frameworks like PyTorch, TensorFlow, and scikit-learn support ONNX
|
||||
export.
|
||||
|
||||
It's worth noting that for models using cutting-edge, framework-specific features, manual porting
|
||||
might be the only option, as standards like ONNX might not yet support these new innovations.
|
||||
By using ONNX, you can easily move models between different frameworks and deployment environments.
|
||||
|
||||
## Enter ONNX
|
||||
## Burn's ONNX Support
|
||||
|
||||
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) is designed to solve
|
||||
such complications. It's an open-standard format that exports both the architecture and the weights
|
||||
of a deep learning model. This feature makes it exponentially easier to move models between
|
||||
different frameworks, thereby significantly aiding interoperability. ONNX is supported by a number
|
||||
of frameworks including but not limited to TensorFlow, PyTorch, Caffe2, and Microsoft Cognitive
|
||||
Toolkit.
|
||||
Burn takes a unique approach to ONNX import, offering several advantages:
|
||||
|
||||
### Advantages of ONNX
|
||||
1. **Native Rust code generation**: ONNX models are translated into Rust source code, allowing for
|
||||
deep integration with Burn's ecosystem.
|
||||
2. **Compile-time optimization**: The generated Rust code can be optimized by the Rust compiler,
|
||||
potentially improving performance.
|
||||
3. **No runtime dependency**: Unlike some solutions that require an ONNX runtime, Burn's approach
|
||||
eliminates this dependency.
|
||||
4. **Trainability**: Imported models can be further trained or fine-tuned using Burn.
|
||||
5. **Portability**: The generated Rust code can be compiled for various targets, including
|
||||
WebAssembly and embedded devices.
|
||||
6. **Any Burn Backend**: The imported models can be used with any of Burn's backends.
|
||||
|
||||
ONNX stands out for encapsulating two key elements:
|
||||
## Step-by-Step Guide
|
||||
|
||||
1. **Model Information**: It captures the architecture, detailing the layers, their connections, and
|
||||
configurations.
|
||||
2. **Weights**: ONNX also contains the trained model's weights.
|
||||
Let's walk through the process of importing an ONNX model into a Burn project:
|
||||
|
||||
This dual encapsulation not only simplifies the porting of models between frameworks but also allows
|
||||
seamless deployment across different environments without compatibility concerns.
|
||||
### Step 1: Update `build.rs`
|
||||
|
||||
## Burn's ONNX Support: Importing Made Easy
|
||||
First, add the `burn-import` crate to your `Cargo.toml`:
|
||||
|
||||
Understanding the important role that ONNX plays in the contemporary deep learning landscape, Burn
|
||||
simplifies the process of importing ONNX models via an intuitive API designed to mesh well with
|
||||
Burn's ecosystem.
|
||||
```toml
|
||||
[build-dependencies]
|
||||
burn-import = "0.14.0"
|
||||
```
|
||||
|
||||
Burn's solution is to translate ONNX files into Rust source code as well as Burn-compatible weights.
|
||||
This transformation is carried out through the burn-import crate's code generator during build time,
|
||||
providing advantages for both executing and further training ONNX models.
|
||||
Then, in your `build.rs` file:
|
||||
|
||||
### Advantages of Burn's ONNX Approach
|
||||
|
||||
1. **Native Integration**: The generated Rust code is fully integrated into Burn's architecture,
|
||||
enabling your model to run on various backends without the need for a separate ONNX runtime.
|
||||
|
||||
2. **Trainability**: The imported model is not just for inference; it can be further trained or
|
||||
fine-tuned using Burn's native training loop.
|
||||
|
||||
3. **Portability**: As the model is converted to Rust source code, it can be compiled into
|
||||
WebAssembly for browser execution. Likewise, this approach is beneficial for no-std embedded
|
||||
devices.
|
||||
|
||||
4. **Optimization**: Rust's compiler can further optimize the generated code for target
|
||||
architectures, thereby improving performance.
|
||||
|
||||
### Sample Code for Importing ONNX Model
|
||||
|
||||
Below is a step-by-step guide to importing an ONNX model into a Burn-based project:
|
||||
|
||||
#### Step 1: Update `build.rs`
|
||||
|
||||
Include the `burn-import` crate and use the following Rust code in your `build.rs`:
|
||||
|
||||
```rust, ignore
|
||||
```rust
|
||||
use burn_import::onnx::ModelGen;
|
||||
|
||||
fn main() {
|
||||
// Generate Rust code from the ONNX model file
|
||||
ModelGen::new()
|
||||
.input("src/model/mnist.onnx")
|
||||
.input("src/model/my_model.onnx")
|
||||
.out_dir("model/")
|
||||
.run_from_script();
|
||||
}
|
||||
```
|
||||
|
||||
#### Step 2: Modify `mod.rs`
|
||||
This script uses `ModelGen` to generate Rust code from your ONNX model during the build process.
|
||||
|
||||
Add this code to the `mod.rs` file located in `src/model`:
|
||||
### Step 2: Modify `mod.rs`
|
||||
|
||||
```rust, ignore
|
||||
pub mod mnist {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
|
||||
In your `src/model/mod.rs` file, include the generated code:
|
||||
|
||||
```rust
|
||||
pub mod my_model {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/my_model.rs"));
|
||||
}
|
||||
```
|
||||
|
||||
#### Step 3: Utilize Imported Model
|
||||
This makes the generated model code available in your project.
|
||||
|
||||
Here's how to use the imported model in your application:
|
||||
### Step 3: Use the Imported Model
|
||||
|
||||
```rust, ignore
|
||||
mod model;
|
||||
Now you can use the imported model in your Rust code:
|
||||
|
||||
```rust
|
||||
use burn::tensor;
|
||||
use burn_ndarray::{NdArray, NdArrayDevice};
|
||||
use model::mnist::Model;
|
||||
use model::my_model::Model;
|
||||
|
||||
fn main() {
|
||||
// Initialize a new model instance
|
||||
let device = NdArrayDevice::default();
|
||||
let model: Model<NdArray<f32>> = Model::new(&device);
|
||||
|
||||
// Create a sample input tensor (zeros for demonstration)
|
||||
let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 1, 28, 28], &device);
|
||||
// Create model instance and load weights from target dir default device.
|
||||
// (see more load options below in "Loading and Using Models" section)
|
||||
let model: Model<NdArray<f32>> = Model::default();
|
||||
|
||||
// Create input tensor (replace with your actual input)
|
||||
let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 3, 224, 224], &device);
|
||||
|
||||
// Perform inference
|
||||
let output = model.forward(input);
|
||||
|
||||
// Print the output
|
||||
println!("{:?}", output);
|
||||
println!("Model output: {:?}", output);
|
||||
}
|
||||
```
|
||||
|
||||
### Working Examples
|
||||
## Advanced Configuration
|
||||
|
||||
For practical examples, please refer to:
|
||||
The `ModelGen` struct offers several configuration options:
|
||||
|
||||
```rust
|
||||
ModelGen::new()
|
||||
.input("path/to/model.onnx")
|
||||
.out_dir("model/")
|
||||
.record_type(RecordType::NamedMpk)
|
||||
.half_precision(false)
|
||||
.embed_states(false)
|
||||
.run_from_script();
|
||||
```
|
||||
|
||||
- `record_type`: Specifies the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or
|
||||
PrettyJson).
|
||||
- `half_precision`: Use half-precision (f16) for weights to reduce model size.
|
||||
- `embed_states`: Embed model weights directly in the generated Rust code. Note: This requires
|
||||
record type `Bincode`.
|
||||
|
||||
## Loading and Using Models
|
||||
|
||||
Depending on your configuration, you can load models in different ways:
|
||||
|
||||
```rust
|
||||
// Create a new model instance with device. Initializes weights randomly and lazily.
|
||||
// You can load weights via `load_record` afterwards.
|
||||
let model = Model::<Backend>::new(&device);
|
||||
|
||||
// Load from a file (must specify weights file in the target output directory or copy it from there).
|
||||
// File type should match the record type specified in `ModelGen`.
|
||||
let model = Model::<Backend>::from_file("path/to/weights", &device);
|
||||
|
||||
// Load from embedded weights (if embed_states was true)
|
||||
let model = Model::<Backend>::from_embedded();
|
||||
|
||||
// Load from the out director location and load to default device (useful for testing)
|
||||
let model = Model::<Backend>::default();
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
Here are some common issues and their solutions:
|
||||
|
||||
1. **Unsupported ONNX operator**: If you encounter an error about an unsupported operator, check the
|
||||
[list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
|
||||
You may need to simplify your model or wait for support to be added.
|
||||
|
||||
2. **Build errors**: Ensure that your `burn-import` version matches your Burn version. Also, check
|
||||
that the ONNX file path in `build.rs` is correct.
|
||||
|
||||
3. **Runtime errors**: If you get errors when running your model, double-check that your input
|
||||
tensors match the expected shape and data type of your model.
|
||||
|
||||
4. **Performance issues**: If your imported model is slower than expected, try using the
|
||||
`half_precision` option to reduce memory usage, or experiment with different `record_type`
|
||||
options.
|
||||
|
||||
5. **Artifact Files**: You can view the generated Rust code and weights files in the `OUT_DIR`
|
||||
directory specified in `build.rs` (usually `target/debug/build/<project>/out`).
|
||||
|
||||
## Examples and Resources
|
||||
|
||||
For more detailed examples, check out:
|
||||
|
||||
1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference)
|
||||
2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn)
|
||||
|
||||
By combining ONNX's robustness with Burn's unique features, you'll have the flexibility and power to
|
||||
streamline your deep learning workflows like never before.
|
||||
These examples demonstrate real-world usage of ONNX import in Burn projects.
|
||||
|
||||
## Conclusion
|
||||
|
||||
Importing ONNX models into Burn opens up a world of possibilities, allowing you to leverage
|
||||
pre-trained models from other frameworks while taking advantage of Burn's performance and Rust's
|
||||
safety features. By following this guide, you should be able to seamlessly integrate ONNX models
|
||||
into your Burn projects, whether for inference, fine-tuning, or as a starting point for further
|
||||
development.
|
||||
|
||||
Remember that the `burn-import` crate is actively developed, with ongoing work to support more ONNX
|
||||
operators and improve performance. Stay tuned to the Burn repository for updates and new features!
|
||||
|
||||
---
|
||||
|
||||
> 🚨**Note**: `burn-import` crate is in active development and currently supports a
|
||||
> [limited set of ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
|
||||
> 🚨**Note**: The `burn-import` crate is in active development. For the most up-to-date information
|
||||
> on supported ONNX operators, please refer to the
|
||||
> [official documentation](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
|
||||
|
|
Loading…
Reference in New Issue