burn/examples/onnx-inference
Guillaume Lagrange cdd1fa1672
Refactor tensor data (#1916)
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
2024-06-26 20:22:19 -04:00
..
pytorch Add ability to load onnx state to the generated source code (#319) 2023-05-03 13:05:43 -04:00
src Refactor tensor data (#1916) 2024-06-26 20:22:19 -04:00
Cargo.toml [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00
README.md Remove _devauto fuctions (#518) (#1110) 2024-01-06 13:36:34 -05:00
build.rs Add support for different record types in ONNX (#816) 2023-09-21 09:06:57 -04:00

README.md

ONNX Inference

This crate provides a simple example for importing MNIST ONNX model to Burn. The onnx file is converted into a Rust source file using burn-import and the weights are stored in and loaded from a binary file.

Usage

cargo run -- 15

Output:

Finished dev [unoptimized + debuginfo] target(s) in 0.13s
    Running `burn/target/debug/onnx-inference 15`

Image index: 15
Success!
Predicted: 5
Actual: 5
See the image online, click the link below:
https://datasets-server.huggingface.co/assets/mnist/--/mnist/test/15/image/image.jpg

Feature Flags

  • embedded-model (default) - Embed the model weights into the binary. This is useful for small models (e.g. MNIST) but not recommended for very large models because it will increase the binary size significantly and will consume a lot of memory at runtime. If you do not use this feature, the model weights will be loaded from a binary file at runtime.

How to import

  1. Create model directory under src

  2. Copy the ONNX model to src/model/mnist.onnx

  3. Add the following to mod.rs:

    pub mod mnist {
        include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
    }
    
  4. Add the module to lib.rs:

    pub mod model;
    
    pub use model::mnist::*;
    
  5. Add the following to build.rs:

    use burn_import::onnx::ModelGen;
    
    fn main() {
        // Generate the model code from the ONNX file.
        ModelGen::new()
            .input("src/model/mnist.onnx")
            .out_dir("model/")
            .run_from_script();
    }
    
    
  6. Add your model to src/bin as a new file, in this specific case we have called it mnist.rs:

    use burn::tensor;
    use burn::backend::ndarray::NdArray;
    
    use onnx_inference::mnist::Model;
    
    fn main() {
        // Get a default device for the models's backend
        let device = Default::default();
    
        // Create a new model and load the state
        let model: Model<Backend> = Model::new(&device).load_state();
    
        // Create a new input tensor (all zeros for demonstration purposes)
        let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 1, 28, 28], &device);
    
        // Run the model
        let output = model.forward(input);
    
        // Print the output
        println!("{:?}", output);
    }
    
  7. Run cargo build to generate the model code, weights, and mnist binary.

How to export PyTorch model to ONNX

The following steps show how to export a PyTorch model to ONNX from checked in PyTorch code (see pytorch/mnist.py).

  1. Install dependencies:

    pip install torch torchvision onnx
    
  2. Run the following script to run the MNIST training and export the model to ONNX:

    python3 pytorch/mnist.py
    

This will generate pytorch/mnist.onnx.

Resources

  1. PyTorch ONNX
  2. ONNX Intro