burn/crates/burn-import/onnx-tests
Guillaume Lagrange cdd1fa1672
Refactor tensor data (#1916)
* Move distribution to module

* Add new TensorData with serialization support

* Implement display and from for TensorData

* Add missing Cargo.lock

* Add missing bytemuck feature

* Add zeros, ones, full and random TensorData methods

* Refactor Data -> TensorData usage

* Fix tests

Since TensorData is not generic over the element type anymore no type inference can be done by the compiler. We must explicitly cast the expected results to the expected backend type.

* Remove commented line

* Fix import

* Add record-backward-compat

* Remove dim const generic from TensorData

* Support NestedValue de/serialization with TensorData

* Fix burn-jit tests

* Remove eprinln

* Refactor onnx import to use TensorData

* Fix tch from_data

* Fix nested value serialization for u8

* Fix missing import

* Fix reduce min onnx test

* Fix deprecated attribute

* Remove shape getter

* Remove strict assert in tests

* Add tensor data as_bytes

* Add tensor check for rank mismatch

* Fix typo (dimensions plural)

* Fix error message

* Update book examples with from_data and fix Display impl for TensorData

* Add deprecation note
2024-06-26 20:22:19 -04:00
..
src [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00
tests Refactor tensor data (#1916) 2024-06-26 20:22:19 -04:00
.gitignore add dependency management for python (#1887) 2024-06-17 09:00:38 -04:00
.python-version add dependency management for python (#1887) 2024-06-17 09:00:38 -04:00
Cargo.toml [refactor] Move burn crates to their own crates directory (#1336) 2024-02-20 13:57:55 -05:00
README.md add dependency management for python (#1887) 2024-06-17 09:00:38 -04:00
build.rs feat: added reduce min onnx import (#1894) 2024-06-18 09:04:24 -04:00
pyproject.toml add dependency management for python (#1887) 2024-06-17 09:00:38 -04:00
requirements-dev.lock add dependency management for python (#1887) 2024-06-17 09:00:38 -04:00
requirements.lock add dependency management for python (#1887) 2024-06-17 09:00:38 -04:00

README.md

ONNX Tests

This crate contains ONNX models that are utilized in testing the conversion of ONNX to Burn source code through the burn-import crate. The tests are designed as end-to-end tests, ensuring that ONNX models are accurately converted into Burn source code. Of utmost importance is verifying that the converted Burn source code compiles without errors and produces the same output as the original ONNX model.

Here is the directory structure of this crate:

  • tests/<model>: This directory contains the ONNX model and the Python script to generate it.
  • tests/<model>/<model>.onnx: The ONNX model is generated by the script.
  • tests/<model>/<model>.py: This is the Python script responsible for generating the ONNX model using PyTorch.
  • tests/onnx_tests.rs: This is the main test file, where all the tests are contained.
  • build.rs: This build script generates the ONNX models and is executed by cargo test before running the actual tests.

Setting up your python environment

With rye

You can use rye to set up a Python environment with the necessary dependencies. To do so, cd into the onnx-tests directory and run rye sync. Assuming you are in the top-level burn directory, you can run the following command:

cd crates/burn-import/onnx-tests
rye sync # or rye sync -f

This will create a .venv in the onnx-tests directory.

You need to install onnx==1.15.0 and torch==2.1.1 in your python environment to add a new test

Adding new tests

Here are the steps to add a new test:

  1. Add your Python script to the tests/<model> directory. Refer to existing scripts for examples.
  2. Run your Python script to generate the ONNX model and inspect the output of the model with the test data. Use the inputs and outputs in your test.
  3. Make sure the ONNX output contains the desired operators by verifying with the Netron app. Sometimes PyTorch will optimize the model and remove operators that are not necessary for the model to run. If this happens, you can disable optimization by setting torch.onnx.export(..., do_constant_folding=False).
  4. Add an entry to the build.rs file to account for the generation of the new ONNX model.
  5. Add an entry to include_models! in tests/onnx_tests.rs to include the new ONNX model in the tests.
  6. Include a test in tests/onnx_tests.rs to test the new ONNX model.
  7. Run cargo test to ensure your test passes.