mirror of https://github.com/tracel-ai/burn.git
Merge branch 'fix/feature-flags' into fix/cuda/stability
This commit is contained in:
commit
fa5cc760b9
|
@ -1,137 +1,218 @@
|
||||||
# Import ONNX Model
|
# Importing ONNX Models in Burn
|
||||||
|
|
||||||
## Why Importing Models is Necessary
|
## Table of Contents
|
||||||
|
|
||||||
In the realm of deep learning, it's common to switch between different frameworks depending on your
|
1. [Introduction](#introduction)
|
||||||
project's specific needs. Maybe you've painstakingly fine-tuned a model in TensorFlow or PyTorch and
|
2. [Why Import Models?](#why-import-models)
|
||||||
now you want to reap the benefits of Burn's unique features for deployment or further testing. This
|
3. [Understanding ONNX](#understanding-onnx)
|
||||||
is precisely the scenario where importing models into Burn can be a game-changer.
|
4. [Burn's ONNX Support](#burns-onnx-support)
|
||||||
|
5. [Step-by-Step Guide](#step-by-step-guide)
|
||||||
|
6. [Advanced Configuration](#advanced-configuration)
|
||||||
|
7. [Loading and Using Models](#loading-and-using-models)
|
||||||
|
8. [Troubleshooting](#troubleshooting)
|
||||||
|
9. [Examples and Resources](#examples-and-resources)
|
||||||
|
10. [Conclusion](#conclusion)
|
||||||
|
|
||||||
## Traditional Methods: The Drawbacks
|
## Introduction
|
||||||
|
|
||||||
If you've been working with other deep learning frameworks like PyTorch, it's likely that you've
|
As the field of deep learning continues to evolve, the need for interoperability between different
|
||||||
exported model weights before. PyTorch, for instance, lets you save model weights using its
|
frameworks becomes increasingly important. Burn, a modern deep learning framework in Rust,
|
||||||
`torch.save()` function. Yet, to port this model to another framework, you face the arduous task of
|
recognizes this need and provides robust support for importing models from other popular frameworks.
|
||||||
manually recreating the architecture in the destination framework before loading in the weights. Not
|
This section focuses on importing
|
||||||
only is this method tedious, but it's also error-prone and hinders smooth interoperability between
|
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) models into Burn,
|
||||||
frameworks.
|
enabling you to leverage pre-trained models and seamlessly integrate them into your Rust-based deep
|
||||||
|
learning projects.
|
||||||
|
|
||||||
It's worth noting that for models using cutting-edge, framework-specific features, manual porting
|
## Why Import Models?
|
||||||
might be the only option, as standards like ONNX might not yet support these new innovations.
|
|
||||||
|
|
||||||
## Enter ONNX
|
Importing pre-trained models offers several advantages:
|
||||||
|
|
||||||
[ONNX (Open Neural Network Exchange)](https://onnx.ai/onnx/intro/index.html) is designed to solve
|
1. **Time-saving**: Avoid the need to train models from scratch, which can be time-consuming and
|
||||||
such complications. It's an open-standard format that exports both the architecture and the weights
|
resource-intensive.
|
||||||
of a deep learning model. This feature makes it exponentially easier to move models between
|
2. **Access to state-of-the-art architectures**: Utilize cutting-edge models developed by
|
||||||
different frameworks, thereby significantly aiding interoperability. ONNX is supported by a number
|
researchers and industry leaders.
|
||||||
of frameworks including but not limited to TensorFlow, PyTorch, Caffe2, and Microsoft Cognitive
|
3. **Transfer learning**: Fine-tune imported models for your specific tasks, benefiting from
|
||||||
Toolkit.
|
knowledge transfer.
|
||||||
|
4. **Consistency across frameworks**: Ensure consistent performance when moving from one framework
|
||||||
|
to another.
|
||||||
|
|
||||||
### Advantages of ONNX
|
## Understanding ONNX
|
||||||
|
|
||||||
ONNX stands out for encapsulating two key elements:
|
ONNX (Open Neural Network Exchange) is an open format designed to represent machine learning models.
|
||||||
|
Key features include:
|
||||||
|
|
||||||
1. **Model Information**: It captures the architecture, detailing the layers, their connections, and
|
- **Framework agnostic**: ONNX provides a common format that works across various deep learning
|
||||||
configurations.
|
frameworks.
|
||||||
2. **Weights**: ONNX also contains the trained model's weights.
|
- **Comprehensive representation**: It captures both the model architecture and trained weights.
|
||||||
|
- **Wide support**: Many popular frameworks like PyTorch, TensorFlow, and scikit-learn support ONNX
|
||||||
|
export.
|
||||||
|
|
||||||
This dual encapsulation not only simplifies the porting of models between frameworks but also allows
|
By using ONNX, you can easily move models between different frameworks and deployment environments.
|
||||||
seamless deployment across different environments without compatibility concerns.
|
|
||||||
|
|
||||||
## Burn's ONNX Support: Importing Made Easy
|
## Burn's ONNX Support
|
||||||
|
|
||||||
Understanding the important role that ONNX plays in the contemporary deep learning landscape, Burn
|
Burn takes a unique approach to ONNX import, offering several advantages:
|
||||||
simplifies the process of importing ONNX models via an intuitive API designed to mesh well with
|
|
||||||
Burn's ecosystem.
|
|
||||||
|
|
||||||
Burn's solution is to translate ONNX files into Rust source code as well as Burn-compatible weights.
|
1. **Native Rust code generation**: ONNX models are translated into Rust source code, allowing for
|
||||||
This transformation is carried out through the burn-import crate's code generator during build time,
|
deep integration with Burn's ecosystem.
|
||||||
providing advantages for both executing and further training ONNX models.
|
2. **Compile-time optimization**: The generated Rust code can be optimized by the Rust compiler,
|
||||||
|
potentially improving performance.
|
||||||
|
3. **No runtime dependency**: Unlike some solutions that require an ONNX runtime, Burn's approach
|
||||||
|
eliminates this dependency.
|
||||||
|
4. **Trainability**: Imported models can be further trained or fine-tuned using Burn.
|
||||||
|
5. **Portability**: The generated Rust code can be compiled for various targets, including
|
||||||
|
WebAssembly and embedded devices.
|
||||||
|
6. **Any Burn Backend**: The imported models can be used with any of Burn's backends.
|
||||||
|
|
||||||
### Advantages of Burn's ONNX Approach
|
## Step-by-Step Guide
|
||||||
|
|
||||||
1. **Native Integration**: The generated Rust code is fully integrated into Burn's architecture,
|
Let's walk through the process of importing an ONNX model into a Burn project:
|
||||||
enabling your model to run on various backends without the need for a separate ONNX runtime.
|
|
||||||
|
|
||||||
2. **Trainability**: The imported model is not just for inference; it can be further trained or
|
### Step 1: Update `build.rs`
|
||||||
fine-tuned using Burn's native training loop.
|
|
||||||
|
|
||||||
3. **Portability**: As the model is converted to Rust source code, it can be compiled into
|
First, add the `burn-import` crate to your `Cargo.toml`:
|
||||||
WebAssembly for browser execution. Likewise, this approach is beneficial for no-std embedded
|
|
||||||
devices.
|
|
||||||
|
|
||||||
4. **Optimization**: Rust's compiler can further optimize the generated code for target
|
```toml
|
||||||
architectures, thereby improving performance.
|
[build-dependencies]
|
||||||
|
burn-import = "0.14.0"
|
||||||
|
```
|
||||||
|
|
||||||
### Sample Code for Importing ONNX Model
|
Then, in your `build.rs` file:
|
||||||
|
|
||||||
Below is a step-by-step guide to importing an ONNX model into a Burn-based project:
|
```rust
|
||||||
|
|
||||||
#### Step 1: Update `build.rs`
|
|
||||||
|
|
||||||
Include the `burn-import` crate and use the following Rust code in your `build.rs`:
|
|
||||||
|
|
||||||
```rust, ignore
|
|
||||||
use burn_import::onnx::ModelGen;
|
use burn_import::onnx::ModelGen;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// Generate Rust code from the ONNX model file
|
|
||||||
ModelGen::new()
|
ModelGen::new()
|
||||||
.input("src/model/mnist.onnx")
|
.input("src/model/my_model.onnx")
|
||||||
.out_dir("model/")
|
.out_dir("model/")
|
||||||
.run_from_script();
|
.run_from_script();
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Step 2: Modify `mod.rs`
|
This script uses `ModelGen` to generate Rust code from your ONNX model during the build process.
|
||||||
|
|
||||||
Add this code to the `mod.rs` file located in `src/model`:
|
### Step 2: Modify `mod.rs`
|
||||||
|
|
||||||
```rust, ignore
|
In your `src/model/mod.rs` file, include the generated code:
|
||||||
pub mod mnist {
|
|
||||||
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
|
```rust
|
||||||
|
pub mod my_model {
|
||||||
|
include!(concat!(env!("OUT_DIR"), "/model/my_model.rs"));
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Step 3: Utilize Imported Model
|
This makes the generated model code available in your project.
|
||||||
|
|
||||||
Here's how to use the imported model in your application:
|
### Step 3: Use the Imported Model
|
||||||
|
|
||||||
```rust, ignore
|
Now you can use the imported model in your Rust code:
|
||||||
mod model;
|
|
||||||
|
|
||||||
|
```rust
|
||||||
use burn::tensor;
|
use burn::tensor;
|
||||||
use burn_ndarray::{NdArray, NdArrayDevice};
|
use burn_ndarray::{NdArray, NdArrayDevice};
|
||||||
use model::mnist::Model;
|
use model::my_model::Model;
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
// Initialize a new model instance
|
|
||||||
let device = NdArrayDevice::default();
|
let device = NdArrayDevice::default();
|
||||||
let model: Model<NdArray<f32>> = Model::new(&device);
|
|
||||||
|
|
||||||
// Create a sample input tensor (zeros for demonstration)
|
// Create model instance and load weights from target dir default device.
|
||||||
let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 1, 28, 28], &device);
|
// (see more load options below in "Loading and Using Models" section)
|
||||||
|
let model: Model<NdArray<f32>> = Model::default();
|
||||||
|
|
||||||
|
// Create input tensor (replace with your actual input)
|
||||||
|
let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 3, 224, 224], &device);
|
||||||
|
|
||||||
// Perform inference
|
// Perform inference
|
||||||
let output = model.forward(input);
|
let output = model.forward(input);
|
||||||
|
|
||||||
// Print the output
|
println!("Model output: {:?}", output);
|
||||||
println!("{:?}", output);
|
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
### Working Examples
|
## Advanced Configuration
|
||||||
|
|
||||||
For practical examples, please refer to:
|
The `ModelGen` struct offers several configuration options:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
ModelGen::new()
|
||||||
|
.input("path/to/model.onnx")
|
||||||
|
.out_dir("model/")
|
||||||
|
.record_type(RecordType::NamedMpk)
|
||||||
|
.half_precision(false)
|
||||||
|
.embed_states(false)
|
||||||
|
.run_from_script();
|
||||||
|
```
|
||||||
|
|
||||||
|
- `record_type`: Specifies the format for storing weights (Bincode, NamedMpk, NamedMpkGz, or
|
||||||
|
PrettyJson).
|
||||||
|
- `half_precision`: Use half-precision (f16) for weights to reduce model size.
|
||||||
|
- `embed_states`: Embed model weights directly in the generated Rust code. Note: This requires
|
||||||
|
record type `Bincode`.
|
||||||
|
|
||||||
|
## Loading and Using Models
|
||||||
|
|
||||||
|
Depending on your configuration, you can load models in different ways:
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Create a new model instance with device. Initializes weights randomly and lazily.
|
||||||
|
// You can load weights via `load_record` afterwards.
|
||||||
|
let model = Model::<Backend>::new(&device);
|
||||||
|
|
||||||
|
// Load from a file (must specify weights file in the target output directory or copy it from there).
|
||||||
|
// File type should match the record type specified in `ModelGen`.
|
||||||
|
let model = Model::<Backend>::from_file("path/to/weights", &device);
|
||||||
|
|
||||||
|
// Load from embedded weights (if embed_states was true)
|
||||||
|
let model = Model::<Backend>::from_embedded();
|
||||||
|
|
||||||
|
// Load from the out director location and load to default device (useful for testing)
|
||||||
|
let model = Model::<Backend>::default();
|
||||||
|
```
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
Here are some common issues and their solutions:
|
||||||
|
|
||||||
|
1. **Unsupported ONNX operator**: If you encounter an error about an unsupported operator, check the
|
||||||
|
[list of supported ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
|
||||||
|
You may need to simplify your model or wait for support to be added.
|
||||||
|
|
||||||
|
2. **Build errors**: Ensure that your `burn-import` version matches your Burn version. Also, check
|
||||||
|
that the ONNX file path in `build.rs` is correct.
|
||||||
|
|
||||||
|
3. **Runtime errors**: If you get errors when running your model, double-check that your input
|
||||||
|
tensors match the expected shape and data type of your model.
|
||||||
|
|
||||||
|
4. **Performance issues**: If your imported model is slower than expected, try using the
|
||||||
|
`half_precision` option to reduce memory usage, or experiment with different `record_type`
|
||||||
|
options.
|
||||||
|
|
||||||
|
5. **Artifact Files**: You can view the generated Rust code and weights files in the `OUT_DIR`
|
||||||
|
directory specified in `build.rs` (usually `target/debug/build/<project>/out`).
|
||||||
|
|
||||||
|
## Examples and Resources
|
||||||
|
|
||||||
|
For more detailed examples, check out:
|
||||||
|
|
||||||
1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference)
|
1. [MNIST Inference Example](https://github.com/tracel-ai/burn/tree/main/examples/onnx-inference)
|
||||||
2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn)
|
2. [SqueezeNet Image Classification](https://github.com/tracel-ai/models/tree/main/squeezenet-burn)
|
||||||
|
|
||||||
By combining ONNX's robustness with Burn's unique features, you'll have the flexibility and power to
|
These examples demonstrate real-world usage of ONNX import in Burn projects.
|
||||||
streamline your deep learning workflows like never before.
|
|
||||||
|
## Conclusion
|
||||||
|
|
||||||
|
Importing ONNX models into Burn opens up a world of possibilities, allowing you to leverage
|
||||||
|
pre-trained models from other frameworks while taking advantage of Burn's performance and Rust's
|
||||||
|
safety features. By following this guide, you should be able to seamlessly integrate ONNX models
|
||||||
|
into your Burn projects, whether for inference, fine-tuning, or as a starting point for further
|
||||||
|
development.
|
||||||
|
|
||||||
|
Remember that the `burn-import` crate is actively developed, with ongoing work to support more ONNX
|
||||||
|
operators and improve performance. Stay tuned to the Burn repository for updates and new features!
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
> 🚨**Note**: `burn-import` crate is in active development and currently supports a
|
> 🚨**Note**: The `burn-import` crate is in active development. For the most up-to-date information
|
||||||
> [limited set of ONNX operators](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
|
> on supported ONNX operators, please refer to the
|
||||||
|
> [official documentation](https://github.com/tracel-ai/burn/blob/main/crates/burn-import/SUPPORTED-ONNX-OPS.md).
|
||||||
|
|
|
@ -17,7 +17,7 @@ metrics = ["nvml-wrapper", "sysinfo", "systemstat"]
|
||||||
tui = ["ratatui", "crossterm"]
|
tui = ["ratatui", "crossterm"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
burn-core = { path = "../burn-core", version = "0.14.0", features = ["dataset"] }
|
burn-core = { path = "../burn-core", version = "0.14.0", features = ["dataset"], default-features = false }
|
||||||
|
|
||||||
log = { workspace = true }
|
log = { workspace = true }
|
||||||
tracing-subscriber = { workspace = true }
|
tracing-subscriber = { workspace = true }
|
||||||
|
|
|
@ -11,12 +11,12 @@ repository = "https://github.com/tracel-ai/burn/tree/main/crates/burn-wgpu"
|
||||||
version.workspace = true
|
version.workspace = true
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["fusion", "burn-jit/default", "cubecl/default"]
|
default = ["std", "autotune", "fusion", "burn-jit/default", "cubecl/default"]
|
||||||
fusion = ["burn-fusion", "burn-jit/fusion"]
|
fusion = ["burn-fusion", "burn-jit/fusion"]
|
||||||
autotune = ["burn-jit/autotune"]
|
autotune = ["burn-jit/autotune"]
|
||||||
template = ["burn-jit/template", "cubecl/template"]
|
template = ["burn-jit/template", "cubecl/template"]
|
||||||
doc = ["burn-jit/doc"]
|
doc = ["burn-jit/doc"]
|
||||||
std = ["burn-jit/std"]
|
std = ["burn-jit/std", "cubecl/std"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
cubecl = { workspace = true, features = ["wgpu"] }
|
cubecl = { workspace = true, features = ["wgpu"] }
|
||||||
|
|
|
@ -20,7 +20,7 @@ cuda-jit = ["burn/cuda-jit"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
# Burn
|
# Burn
|
||||||
burn = {path = "../../crates/burn", features=["train", "ndarray", "fusion"]}
|
burn = {path = "../../crates/burn", features=["train", "ndarray", "std", "tui", "metrics", "autotune", "fusion", "default"], default-features = false}
|
||||||
|
|
||||||
# Tokenizer
|
# Tokenizer
|
||||||
tokenizers = { version = "0.19.1", default-features = false, features = [
|
tokenizers = { version = "0.19.1", default-features = false, features = [
|
||||||
|
|
Loading…
Reference in New Issue