mirror of https://github.com/tracel-ai/burn.git
50 lines
2.4 KiB
Markdown
50 lines
2.4 KiB
Markdown
# ONNX Tests
|
|
|
|
This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source
|
|
code through the `burn-import` crate. The tests are designed as end-to-end tests, ensuring that ONNX
|
|
models are accurately converted into Burn source code. Of utmost importance is verifying that the
|
|
converted Burn source code compiles without errors and produces the same output as the original ONNX
|
|
model.
|
|
|
|
Here is the directory structure of this crate:
|
|
|
|
- `tests/<model>`: This directory contains the ONNX model and the Python script to generate it.
|
|
- `tests/<model>/<model>.onnx`: The ONNX model is generated by the script.
|
|
- `tests/<model>/<model>.py`: This is the Python script responsible for generating the ONNX model
|
|
using PyTorch.
|
|
- `tests/onnx_tests.rs`: This is the main test file, where all the tests are contained.
|
|
- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before
|
|
running the actual tests.
|
|
|
|
## Setting up your python environment
|
|
|
|
## With rye
|
|
|
|
You can use [`rye`](https://rye.astral.sh/) to set up a Python environment with the necessary dependencies. To do so, cd into the `onnx-tests` directory and run `rye sync`. Assuming you are in the top-level `burn` directory, you can run the following command:
|
|
|
|
```sh
|
|
cd crates/burn-import/onnx-tests
|
|
rye sync # or rye sync -f
|
|
```
|
|
|
|
This will create a .venv in the `onnx-tests` directory.
|
|
|
|
You need to install `onnx==1.15.0` and `torch==2.1.1` in your python environment to add a new test
|
|
|
|
## Adding new tests
|
|
|
|
Here are the steps to add a new test:
|
|
|
|
1. Add your Python script to the `tests/<model>` directory. Refer to existing scripts for examples.
|
|
2. Run your Python script to generate the ONNX model and inspect the output of the model with the
|
|
test data. Use the inputs and outputs in your test.
|
|
3. Make sure the ONNX output contains the desired operators by verifying with the
|
|
[Netron](https://github.com/lutzroeder/netron) app. Sometimes PyTorch will optimize the model and
|
|
remove operators that are not necessary for the model to run. If this happens, you can disable
|
|
optimization by setting `torch.onnx.export(..., do_constant_folding=False)`.
|
|
4. Add an entry to the `build.rs` file to account for the generation of the new ONNX model.
|
|
5. Add an entry to `include_models!` in `tests/onnx_tests.rs` to include the new ONNX model in the
|
|
tests.
|
|
6. Include a test in `tests/onnx_tests.rs` to test the new ONNX model.
|
|
7. Run `cargo test` to ensure your test passes.
|