burn/examples/pytorch-import
syl20bnr 8e78106680 Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
..
model Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
pytorch Add support for loading PyTorch `.pt` (weights/states) files directly to model's record (#1085) 2024-01-25 10:20:09 -05:00
src modified mnist image link in the Hugging face (#2134) 2024-08-08 11:15:08 -05:00
Cargo.toml Bump burn version to 0.15.0 2024-08-27 15:13:40 -04:00
README.md modified mnist image link in the Hugging face (#2134) 2024-08-08 11:15:08 -05:00
build.rs Upgrade to candle 0.4.1 (#1382) 2024-02-29 11:29:11 -06:00

README.md

Import PyTorch Weights

This crate provides a simple example for importing PyTorch generated weights to Burn.

The .pt file is converted into a Burn consumable file (message pack format) using burn-import. The conversation is done in the build.rs file.

The model is separated into a sub-crate because build.rs needs for conversion and build.rs cannot import modules for the same crate.

Usage

cargo run -- 15

Output:

Finished dev [unoptimized + debuginfo] target(s) in 0.13s
    Running `burn/target/debug/onnx-inference 15`

Image index: 15
Success!
Predicted: 5
Actual: 5
See the image online, click the link below:
https://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row=15