mirror of https://github.com/tracel-ai/burn.git
Readme updates (#325)
* Update text-generation readme for Mac users * Update root readme to reference import crate * Update import's readme * Update torch backend
This commit is contained in:
parent
314db93b7f
commit
39297a6479
|
@ -42,6 +42,7 @@ __Sections__
|
|||
* [NdArray](https://github.com/burn-rs/burn/tree/main/burn-ndarray) backend featuring [`no_std`](#no_std-support) compatibility for any platform 👌
|
||||
* [Autodiff](https://github.com/burn-rs/burn/tree/main/burn-autodiff) backend enabling differentiability for all backends 🌟
|
||||
* [Dataset](https://github.com/burn-rs/burn/tree/main/burn-dataset) crate with a variety of utilities and sources 📚
|
||||
* [Import](https://github.com/burn-rs/burn/tree/main/burn-import) crate for seamless integration of pretrained models 📦
|
||||
|
||||
## Get Started
|
||||
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# Burn Import
|
||||
|
||||
`burn-import` is a crate designed to facilitate importing models trained in other machine learning
|
||||
frameworks into the Burn framework. This tool generates a Rust source file that aligns the source
|
||||
model with Burn's model and converts tensor data into a format compatible with Burn.
|
||||
`burn-import` is a crate designed to simplify the process of importing models trained in other
|
||||
machine learning frameworks into the Burn framework. This tool generates a Rust source file that
|
||||
aligns the imported model with Burn's model and converts tensor data into a format compatible with
|
||||
Burn.
|
||||
|
||||
Currently under development, `burn-import` supports importing ONNX models with a limited set of
|
||||
operators.
|
||||
Currently, `burn-import` supports importing ONNX models with a limited set of operators, as it is
|
||||
still under development.
|
||||
|
||||
## Supported ONNX Operators
|
||||
|
||||
|
@ -19,92 +20,92 @@ operators.
|
|||
|
||||
### Importing ONNX models
|
||||
|
||||
In `build.rs`, add the following:
|
||||
To import ONNX models, follow these steps:
|
||||
|
||||
```rust
|
||||
use burn_import::onnx::ModelCodeGen;
|
||||
1. Add the following code to your `build.rs` file:
|
||||
|
||||
fn main() {
|
||||
ModelCodeGen::new()
|
||||
.input("src/model/mnist.onnx") // Path to the ONNX model
|
||||
.out_dir("model/") // Directory to output the generated Rust source file (under target/)
|
||||
.run_from_script();
|
||||
}
|
||||
```
|
||||
```rust
|
||||
use burn_import::onnx::ModelCodeGen;
|
||||
|
||||
Then, add the following to mod.rs under `src/model`:
|
||||
fn main() {
|
||||
ModelCodeGen::new()
|
||||
.input("src/model/mnist.onnx") // Path to the ONNX model
|
||||
.out_dir("model/") // Directory for the generated Rust source file (under target/)
|
||||
.run_from_script();
|
||||
}
|
||||
```
|
||||
|
||||
```rust
|
||||
pub mod mnist {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
|
||||
}
|
||||
```
|
||||
2. Add the following code to the `mod.rs` file under `src/model`:
|
||||
|
||||
Finally, in your code, you can use the imported model as follows:
|
||||
```rust
|
||||
pub mod mnist {
|
||||
include!(concat!(env!("OUT_DIR"), "/model/mnist.rs"));
|
||||
}
|
||||
```
|
||||
|
||||
```rust
|
||||
use burn::tensor;
|
||||
use burn_ndarray::NdArrayBackend;
|
||||
use onnx_inference::model::mnist::{Model, INPUT1_SHAPE};
|
||||
3. Use the imported model in your code as shown below:
|
||||
|
||||
fn main() {
|
||||
```rust
|
||||
use burn::tensor;
|
||||
use burn_ndarray::NdArrayBackend;
|
||||
use onnx_inference::model::mnist::{Model, INPUT1_SHAPE};
|
||||
|
||||
// Create a new model
|
||||
let model: Model<NdArrayBackend<f32>> = Model::new();
|
||||
fn main() {
|
||||
|
||||
// Create a new input tensor (all zeros for demonstration purposes)
|
||||
let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros(INPUT1_SHAPE);
|
||||
// Create a new model
|
||||
let model: Model<NdArrayBackend<f32>> = Model::new();
|
||||
|
||||
// Run the model
|
||||
let output = model.forward(input);
|
||||
// Create a new input tensor (all zeros for demonstration purposes)
|
||||
let input = tensor::Tensor::<NdArrayBackend<f32>, 4>::zeros(INPUT1_SHAPE);
|
||||
|
||||
// Print the output
|
||||
println!("{:?}", output);
|
||||
}
|
||||
```
|
||||
// Run the model
|
||||
let output = model.forward(input);
|
||||
|
||||
You can view the working example in the `examples/onnx-inference` directory.
|
||||
// Print the output
|
||||
println!("{:?}", output);
|
||||
}
|
||||
```
|
||||
|
||||
A working example can be found in the [`examples/onnx-inference`](https://github.com/burn-rs/burn/tree/main/examples/onnx-inference) directory.
|
||||
|
||||
### Adding new operators
|
||||
|
||||
This section explains how to add support for new operators to `burn-import`.
|
||||
To add support for new operators to `burn-import`, follow these steps:
|
||||
|
||||
1. Optimize the ONNX model using [onnxoptimizer](https://github.com/onnx/optimizer). It will remove
|
||||
uncessary operator/constants and make the model easier to understand.
|
||||
2. Use [Netron](https://github.com/lutzroeder/netron) app to visualize the ONNX model.
|
||||
3. Generate artifact files to help to see what the ONNX model (`my-model.onnx) looks like and its
|
||||
components.
|
||||
```bash
|
||||
1. Optimize the ONNX model using [onnxoptimizer](https://github.com/onnx/optimizer). This will
|
||||
remove unnecessary operators and constants, making the model easier to understand.
|
||||
2. Use the [Netron](https://github.com/lutzroeder/netron) app to visualize the ONNX model.
|
||||
3. Generate artifact files for the ONNX model (`my-model.onnx`) and its components:
|
||||
```
|
||||
cargo r -- ./my-model.onnx ./
|
||||
```
|
||||
4. You will run into an error saying that the operator is not supported. Implement missing
|
||||
operators. Hopefully, at least `my-model.graph.txt` is generated before the error occurs. This
|
||||
file contains information about the ONNX model.
|
||||
5. The newly generated `my-model.graph.txt` file will contain IR information about the model. This
|
||||
file is useful for understanding the structure of the model and the operators it uses. The
|
||||
`my-model.rs` file will contain an actual Burn model in rust code. `my-model.json` will contain
|
||||
the data of the model.
|
||||
6. The following is the explaination of onnx modules (under `srs/onnx`):
|
||||
- `from_onnx.rs`: This module contains logic for converting ONNX data objects into IR
|
||||
(Intermediate Representation) objects. This module must contain anything that deals with ONNX
|
||||
directly.
|
||||
- `ir.rs`: This module contains the IR objects that are used to represent the ONNX model. These
|
||||
objects are used to generate the Burn model.
|
||||
- `to_burn.rs` - This module contains logic for converting IR objects into Burn model source code
|
||||
and data. Nothing in this module should deal with ONNX directly.
|
||||
- `coalesce.rs`: This module contains the logic to coalesce multiple ONNX operators into a single
|
||||
Burn operator. This is useful for operators that are not supported by Burn, but can be
|
||||
represented by a combination of supported operators.
|
||||
- `op_configuration.rs` - This module contains helper functions for configuring burn operators
|
||||
from operator nodes.
|
||||
- `shape_inference.rs` - This module contains helper functions for inferring shapes of tensors
|
||||
for inputs and outputs of operators.
|
||||
7. Add unit tests for the new operator in `burn-import/tests/onnx_tests.rs` file. Add the ONNX file
|
||||
and expected output to `tests/data` directory. Please be sure the ONNX file is small. If the ONNX
|
||||
file is too large, the repository size will grow too large and will be difficult to maintain and
|
||||
clone. See the existing unit tests for examples.
|
||||
4. Implement the missing operators when you encounter an error stating that the operator is not
|
||||
supported. Ideally, the `my-model.graph.txt` file is generated before the error occurs, providing
|
||||
information about the ONNX model.
|
||||
5. The newly generated `my-model.graph.txt` file contains IR information about the model, while the
|
||||
`my-model.rs` file contains an actual Burn model in Rust code. The `my-model.json` file contains
|
||||
the model data.
|
||||
6. The `srs/onnx` directory contains the following ONNX modules (continued):
|
||||
|
||||
- `coalesce.rs`: Coalesces multiple ONNX operators into a single Burn operator. This is useful
|
||||
for operators that are not supported by Burn but can be represented by a combination of
|
||||
supported operators.
|
||||
- `op_configuration.rs`: Contains helper functions for configuring Burn operators from operator
|
||||
nodes.
|
||||
- `shape_inference.rs`: Contains helper functions for inferring shapes of tensors for inputs and
|
||||
outputs of operators.
|
||||
|
||||
7. Add unit tests for the new operator in the `burn-import/tests/onnx_tests.rs` file. Add the ONNX
|
||||
file and expected output to the `tests/data` directory. Ensure the ONNX file is small, as large
|
||||
files can increase repository size and make it difficult to maintain and clone. Refer to existing
|
||||
unit tests for examples.
|
||||
|
||||
## Resources
|
||||
|
||||
1. [PyTorch ONNX](https://pytorch.org/docs/stable/onnx.html)
|
||||
2. [ONNX Intro](https://onnx.ai/onnx/intro/)
|
||||
1. [PyTorch to ONNX](https://pytorch.org/docs/stable/onnx.html)
|
||||
2. [ONNX to Pytorch](https://github.com/ENOT-AutoDL/onnx2torch)
|
||||
3. [ONNX Intro](https://onnx.ai/onnx/intro/)
|
||||
4. [ONNX Operators](https://onnx.ai/onnx/operators/index.html)
|
||||
5. [ONNX Protos](https://onnx.ai/onnx/api/classes.html)
|
||||
6. [ONNX Optimizer](https://github.com/onnx/optimizer)
|
||||
7. [Netron](https://github.com/lutzroeder/netron)
|
||||
|
|
|
@ -1,6 +1,45 @@
|
|||
# Burn Tch
|
||||
# Burn Torch Backend
|
||||
|
||||
> [Burn](https://github.com/burn-rs/burn) tch backend
|
||||
[Burn](https://github.com/burn-rs/burn) Torch backend
|
||||
|
||||
[![Current Crates.io Version](https://img.shields.io/crates/v/burn-tch.svg)](https://crates.io/crates/burn-tch)
|
||||
[![license](https://shields.io/badge/license-MIT%2FApache--2.0-blue)](https://github.com/burn-rs/burn-tch/blob/master/README.md)
|
||||
|
||||
This crate provides a Torch backend for [Burn](https://github.com/burn-rs/burn) utilizing the
|
||||
[tch-rs](https://github.com/LaurentMazare/tch-rs) crate, which offers a Rust interface to the
|
||||
[PyTorch](https://pytorch.org/) C++ API.
|
||||
|
||||
The backend supports CPU (multithreaded), [CUDA](https://pytorch.org/docs/stable/notes/cuda.html)
|
||||
(multiple GPUs), and [MPS](https://pytorch.org/docs/stable/notes/mps.html) devices (MacOS).
|
||||
|
||||
## Usage Example
|
||||
|
||||
```rust
|
||||
#[cfg(feature = "tch-gpu")]
|
||||
mod tch_gpu {
|
||||
use burn_autodiff::ADBackendDecorator;
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
use mnist::training;
|
||||
|
||||
pub fn run() {
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
let device = TchDevice::Cuda(0);
|
||||
#[cfg(target_os = "macos")]
|
||||
let device = TchDevice::Mps;
|
||||
|
||||
training::run::<ADBackendDecorator<TchBackend<f32>>>(device);
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "tch-cpu")]
|
||||
mod tch_cpu {
|
||||
use burn_autodiff::ADBackendDecorator;
|
||||
use burn_tch::{TchBackend, TchDevice};
|
||||
use mnist::training;
|
||||
|
||||
pub fn run() {
|
||||
let device = TchDevice::Cpu;
|
||||
training::run::<ADBackendDecorator<TchBackend<f32>>>(device);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
|
|
@ -2,6 +2,8 @@
|
|||
|
||||
The example can be run like so:
|
||||
|
||||
## CUDA users
|
||||
|
||||
```bash
|
||||
git clone https://github.com/burn-rs/burn.git
|
||||
cd burn
|
||||
|
@ -10,3 +12,13 @@ cd burn
|
|||
export TORCH_CUDA_VERSION=cu113
|
||||
cargo run --example text-generation --release
|
||||
```
|
||||
|
||||
## Mac users
|
||||
|
||||
```bash
|
||||
git clone https://github.com/burn-rs/burn.git
|
||||
cd burn
|
||||
|
||||
# Use the --release flag to really speed up training.
|
||||
cargo run --example text-generation --release
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue