burn/examples/pytorch-import
tiruka 64b57792e0
modified mnist image link in the Hugging face (#2134)
2024-08-08 11:15:08 -05:00
..
model Bump next version of Burn to 0.14.0 (#1618) 2024-04-12 17:14:45 -04:00
pytorch Add support for loading PyTorch `.pt` (weights/states) files directly to model's record (#1085) 2024-01-25 10:20:09 -05:00
src modified mnist image link in the Hugging face (#2134) 2024-08-08 11:15:08 -05:00
Cargo.toml Fix `DataSerialize` conversion for elements of the same type (#1832) 2024-05-28 18:12:44 -04:00
README.md modified mnist image link in the Hugging face (#2134) 2024-08-08 11:15:08 -05:00
build.rs Upgrade to candle 0.4.1 (#1382) 2024-02-29 11:29:11 -06:00

README.md

Import PyTorch Weights

This crate provides a simple example for importing PyTorch generated weights to Burn.

The .pt file is converted into a Burn consumable file (message pack format) using burn-import. The conversation is done in the build.rs file.

The model is separated into a sub-crate because build.rs needs for conversion and build.rs cannot import modules for the same crate.

Usage

cargo run -- 15

Output:

Finished dev [unoptimized + debuginfo] target(s) in 0.13s
    Running `burn/target/debug/onnx-inference 15`

Image index: 15
Success!
Predicted: 5
Actual: 5
See the image online, click the link below:
https://huggingface.co/datasets/ylecun/mnist/viewer/mnist/test?row=15