burn/crates/burn-import/onnx-tests/README.md

50 lines
2.4 KiB
Markdown

# ONNX Tests
This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source
code through the `burn-import` crate. The tests are designed as end-to-end tests, ensuring that ONNX
models are accurately converted into Burn source code. Of utmost importance is verifying that the
converted Burn source code compiles without errors and produces the same output as the original ONNX
model.
Here is the directory structure of this crate:
- `tests/<model>`: This directory contains the ONNX model and the Python script to generate it.
- `tests/<model>/<model>.onnx`: The ONNX model is generated by the script.
- `tests/<model>/<model>.py`: This is the Python script responsible for generating the ONNX model
using PyTorch.
- `tests/onnx_tests.rs`: This is the main test file, where all the tests are contained.
- `build.rs`: This build script generates the ONNX models and is executed by `cargo test` before
running the actual tests.
## Setting up your python environment
## With rye
You can use [`rye`](https://rye.astral.sh/) to set up a Python environment with the necessary dependencies. To do so, cd into the `onnx-tests` directory and run `rye sync`. Assuming you are in the top-level `burn` directory, you can run the following command:
```sh
cd crates/burn-import/onnx-tests
rye sync # or rye sync -f
```
This will create a .venv in the `onnx-tests` directory.
You need to install `onnx==1.15.0` and `torch==2.1.1` in your python environment to add a new test
## Adding new tests
Here are the steps to add a new test:
1. Add your Python script to the `tests/<model>` directory. Refer to existing scripts for examples.
2. Run your Python script to generate the ONNX model and inspect the output of the model with the
test data. Use the inputs and outputs in your test.
3. Make sure the ONNX output contains the desired operators by verifying with the
[Netron](https://github.com/lutzroeder/netron) app. Sometimes PyTorch will optimize the model and
remove operators that are not necessary for the model to run. If this happens, you can disable
optimization by setting `torch.onnx.export(..., do_constant_folding=False)`.
4. Add an entry to the `build.rs` file to account for the generation of the new ONNX model.
5. Add an entry to `include_models!` in `tests/onnx_tests.rs` to include the new ONNX model in the
tests.
6. Include a test in `tests/onnx_tests.rs` to test the new ONNX model.
7. Run `cargo test` to ensure your test passes.