mirror of https://github.com/tracel-ai/burn.git
add dependency management for python (#1887)
This commit is contained in:
parent
8bf1cd60dc
commit
eead748e90
|
@ -0,0 +1,10 @@
|
|||
# python generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# venv
|
||||
.venv
|
|
@ -0,0 +1 @@
|
|||
3.12.3
|
|
@ -18,6 +18,17 @@ Here is the directory structure of this crate:
|
|||
|
||||
## Setting up your python environment
|
||||
|
||||
## With rye
|
||||
|
||||
You can use [`rye`](https://rye.astral.sh/) to set up a Python environment with the necessary dependencies. To do so, cd into the `onnx-tests` directory and run `rye sync`. Assuming you are in the top-level `burn` directory, you can run the following command:
|
||||
|
||||
```sh
|
||||
cd crates/burn-import/onnx-tests
|
||||
rye sync # or rye sync -f
|
||||
```
|
||||
|
||||
This will create a .venv in the `onnx-tests` directory.
|
||||
|
||||
You need to install `onnx==1.15.0` and `torch==2.1.1` in your python environment to add a new test
|
||||
|
||||
## Adding new tests
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
[project]
|
||||
name = "onnx-tests"
|
||||
version = "0.1.0"
|
||||
description = "project for testing ONNX support"
|
||||
authors = []
|
||||
dependencies = [
|
||||
"torch>=2.3.1",
|
||||
"onnx>=1.16.1",
|
||||
"onnxruntime>=1.18.0",
|
||||
]
|
||||
readme = "README.md"
|
||||
requires-python = ">= 3.8"
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.rye]
|
||||
managed = true
|
||||
dev-dependencies = []
|
||||
|
||||
[tool.hatch.metadata]
|
||||
allow-direct-references = true
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src/onnx_tests"]
|
|
@ -0,0 +1,75 @@
|
|||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
|
||||
-e file:.
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
filelock==3.15.1
|
||||
# via torch
|
||||
flatbuffers==24.3.25
|
||||
# via onnxruntime
|
||||
fsspec==2024.6.0
|
||||
# via torch
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
jinja2==3.1.4
|
||||
# via torch
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
networkx==3.3
|
||||
# via torch
|
||||
numpy==1.26.4
|
||||
# via onnx
|
||||
# via onnxruntime
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.5.40
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
onnx==1.16.1
|
||||
# via onnx-tests
|
||||
onnxruntime==1.18.0
|
||||
# via onnx-tests
|
||||
packaging==24.1
|
||||
# via onnxruntime
|
||||
protobuf==5.27.1
|
||||
# via onnx
|
||||
# via onnxruntime
|
||||
sympy==1.12.1
|
||||
# via onnxruntime
|
||||
# via torch
|
||||
torch==2.3.1
|
||||
# via onnx-tests
|
||||
typing-extensions==4.12.2
|
||||
# via torch
|
|
@ -0,0 +1,75 @@
|
|||
# generated by rye
|
||||
# use `rye lock` or `rye sync` to update this lockfile
|
||||
#
|
||||
# last locked with the following flags:
|
||||
# pre: false
|
||||
# features: []
|
||||
# all-features: false
|
||||
# with-sources: false
|
||||
|
||||
-e file:.
|
||||
coloredlogs==15.0.1
|
||||
# via onnxruntime
|
||||
filelock==3.15.1
|
||||
# via torch
|
||||
flatbuffers==24.3.25
|
||||
# via onnxruntime
|
||||
fsspec==2024.6.0
|
||||
# via torch
|
||||
humanfriendly==10.0
|
||||
# via coloredlogs
|
||||
jinja2==3.1.4
|
||||
# via torch
|
||||
markupsafe==2.1.5
|
||||
# via jinja2
|
||||
mpmath==1.3.0
|
||||
# via sympy
|
||||
networkx==3.3
|
||||
# via torch
|
||||
numpy==1.26.4
|
||||
# via onnx
|
||||
# via onnxruntime
|
||||
nvidia-cublas-cu12==12.1.3.1
|
||||
# via nvidia-cudnn-cu12
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-cuda-cupti-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-nvrtc-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cuda-runtime-cu12==12.1.105
|
||||
# via torch
|
||||
nvidia-cudnn-cu12==8.9.2.26
|
||||
# via torch
|
||||
nvidia-cufft-cu12==11.0.2.54
|
||||
# via torch
|
||||
nvidia-curand-cu12==10.3.2.106
|
||||
# via torch
|
||||
nvidia-cusolver-cu12==11.4.5.107
|
||||
# via torch
|
||||
nvidia-cusparse-cu12==12.1.0.106
|
||||
# via nvidia-cusolver-cu12
|
||||
# via torch
|
||||
nvidia-nccl-cu12==2.20.5
|
||||
# via torch
|
||||
nvidia-nvjitlink-cu12==12.5.40
|
||||
# via nvidia-cusolver-cu12
|
||||
# via nvidia-cusparse-cu12
|
||||
nvidia-nvtx-cu12==12.1.105
|
||||
# via torch
|
||||
onnx==1.16.1
|
||||
# via onnx-tests
|
||||
onnxruntime==1.18.0
|
||||
# via onnx-tests
|
||||
packaging==24.1
|
||||
# via onnxruntime
|
||||
protobuf==5.27.1
|
||||
# via onnx
|
||||
# via onnxruntime
|
||||
sympy==1.12.1
|
||||
# via onnxruntime
|
||||
# via torch
|
||||
torch==2.3.1
|
||||
# via onnx-tests
|
||||
typing-extensions==4.12.2
|
||||
# via torch
|
Loading…
Reference in New Issue