burn/burn-import
Sylvain Benner b17ac2dfba
[CI] Pin version of CI dependencies from wgpu repo (#1120)
* Pin version of CI dependencies from wgpu repo

Fixes seg fault in wgpu instance creation.

* Fix new_devauto dead code warning

* Add a notice for wgpu
2024-01-05 12:47:31 -05:00
..
onnx-tests Explicit device tensors (#1081) 2023-12-20 17:49:59 -05:00
src [CI] Pin version of CI dependencies from wgpu repo (#1120) 2024-01-05 12:47:31 -05:00
Cargo.toml chore(infra): Share some properties across workspace (#1039) 2023-12-12 09:39:07 -05:00
DEVELOPMENT.md Update burn-import README (#727) 2023-08-30 08:20:53 -04:00
LICENSE-APACHE License fixes (#648) 2023-08-16 12:45:35 -04:00
LICENSE-MIT License fixes (#648) 2023-08-16 12:45:35 -04:00
README.md Chore/release (#1031) 2023-12-01 14:33:28 -05:00
SUPPORTED-ONNX-OPS.md Add ConvTranspose2d ONNX OP (#1018) 2023-11-30 12:22:06 -06:00
build.rs Add foundation for importing ONNX files (#297) 2023-04-15 10:44:50 -04:00

README.md

Burn Import: A Crate for ONNX Model Import into the Burn Framework

burn-import facilitates the seamless import of machine learning models, particularly those in the ONNX format, into the Burn deep learning framework. It automatically generates Rust source code, aligns the model structure with Burn's native format, and converts tensor data for Burn compatibility.

Note: This crate is in active development and currently supports a limited set of ONNX operators.

Working Examples

For practical examples, please refer to:

  1. ONNX Inference Example
  2. SqueezeNet Image Classification

Usage

Importing ONNX Models

Follow these steps to import an ONNX model into your Burn project:

  1. Update build.rs: Include the following Rust code in your build.rs file:

    use burn_import::onnx::ModelGen;
    
    fn main() {
        // Generate Rust code from the ONNX model file
        ModelGen::new()
            .input("src/model/model_name.onnx")
            .out_dir("model/")
            .run_from_script();
    }
    
  2. Modify mod.rs: Add this code to the mod.rs file located in src/model:

    pub mod model_name {
        include!(concat!(env!("OUT_DIR"), "/model/model_name.rs"));
    }
    
  3. Utilize Imported Model: Use the following sample code to incorporate the imported model into your application:

    mod model;
    
    use burn::tensor;
    use burn_ndarray::NdArray;
    use model::model_name::Model;
    
    fn main() {
        // Initialize a new model instance
        let model: Model<NdArray<f32>> = Model::new();
    
        // Create a sample input tensor (zeros for demonstration)
        let input = tensor::Tensor::<NdArray<f32>, 4>::zeros([1, 1, 28, 28]);
    
        // Execute the model
        let output = model.forward(input);
    
        // Display the output
        println!("{:?}", output);
    }
    

Contribution

Interested in contributing to burn-import? Check out our development guide for more information.