Refactor xtask to use tracel-xtask and refactor CI workflow (#2063)

* Migrate to xtask-common crate

* Fix example crate name for simple-regression

* Refactor CI workflows

* Flatten linux workflows

* Install grcov and typos from binaries

Although xtask-common support auto-installation of these tools via cargo
it is a lot faster to install them via the distributed binaries

* [CI] Update Rust caches on failure

* [CI] Add shell bash to jobs steps

* [CI] Try cache all crates

* Fix no-std tests not executing

* [CI] Add CARGO_INCREMENTAL 0

* Exclude tch and cuda from tests and merge crates and examples steps

* Fix some typos found with typos cli

* Add Windows and MacOS jobs

* Only test no-std with default rust target

* Fix syntax in composite action setup-windows

* Enable incremental build

* Upate cargo alias for xtask

* Bump to github action checkout v4

* Revert to tch 0.15 and disable WGPU on windows

* Fix color in output

* Add Test command

* Test long output errorring

* Build and test workspace before additional builds and tests

* Disable wgpu tests on windows

* Remove tests- prefix in CI workflow jobs name

* Add Checks command

* Rename ci workflow jobs

* Execute windows and macos CI tests on rust stable only

* Rename integration test files with a test_ prefix

* Fix format

* Don't auto-correct "arange" with typos

* Fix typos in code

* Merge unit and integration tests steps

* Fix macos tests

* Fix coverage step

* Name publish-crate workflow

* Fix bad cache name for macos

* Reorganize commands and get rid of the ci command

* Fix dispatch to customized commands for Burn

* Update to last version of tracel-xtask

* Remove unnecessary shell bash in ci workflow

* Update cargo.lock

* Fix format

* Bump tracel-xtask

* Simplify dispatch of base commands using updated macro

* Update to last version of tracel-xtask

* Adapt legacy run_checks script with new xtask commands

* Run xtask in debug for faster compilation time

* Ditch build step in ci and enable coverage for stable linux only

* Freeze tracel-xtask to specific commit rev

* Update cargo.lock

* Update Step 6 of CONTRIBUTING guidelines about run-checks script

* Remove unneeded CI and CD paragraphgs in CONRIBUTING.md

* Change cache version

* Fix typos

* Use centralized actions and workflows

* Update to last version of tracel-xtask

* Update CONTRIBUTING file to mention integration tests

* Add custom build for thumbv6m-none-eabi

* Ignore onnx files for typos check

* Fix action and workflow paths in github workflows

* Fix custom builds on MacOS

* Bump tracel-xtask crate to last version

* Update Cargo.lock

* Update publish workflow to use reusable workflow in tracel repo

* Add --ci flag for build and test commands
This commit is contained in:
Sylvain Benner 2024-08-28 15:57:13 -04:00 committed by GitHub
parent 40d321cc0d
commit a88c69af4a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
44 changed files with 1048 additions and 2106 deletions

View File

@ -1,2 +1,2 @@
[alias] [alias]
xtask = "run --manifest-path ./xtask/Cargo.toml --" xtask = "run --target-dir target/xtask --color always --package xtask --bin xtask --"

View File

@ -1,14 +0,0 @@
name: "Install llvmpipe and lavapipe"
description: "Installs software only Vulkan driver"
runs:
using: "composite"
steps:
- name: Install llvmpipe and lavapipe
shell: bash
run: |
sudo apt-get update -y -qq
for i in {1..5}; do
sudo add-apt-repository ppa:kisak/kisak-mesa -y && break || sleep 5;
done
sudo apt-get update
sudo apt install -y libegl1-mesa libgl1-mesa-dri libxcb-xfixes0-dev mesa-vulkan-drivers

247
.github/workflows/ci.yml vendored Normal file
View File

@ -0,0 +1,247 @@
name: CI
on:
push:
branches:
- main
paths:
- 'Cargo.lock'
- '**.rs'
- '**.sh'
- '**.ps1'
- '**.yml'
- '**.toml'
- '!**.md'
- '!LICENSE-APACHE'
- '!LICENSE-MIT'
pull_request:
types: [opened, synchronize]
paths:
- 'Cargo.lock'
- '**.rs'
- '**.sh'
- '**.ps1'
- '**.yml'
- '**.toml'
- '!**.md'
- '!LICENSE-APACHE'
- '!LICENSE-MIT'
env:
# Note: It is not possible to define env vars in composite actions.
# To work around this issue we use inputs and define all the env vars here.
# Cargo
CARGO_TERM_COLOR: "always"
# Dependency versioning
# from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
# Sourced from https://vulkan.lunarg.com/sdk/home#linux
VULKAN_SDK_VERSION: "1.3.268"
# Sourced from https://archive.mesa3d.org/. Bumping this requires
# updating the mesa build in https://github.com/gfx-rs/ci-build and creating a new release.
MESA_VERSION: "23.3.1"
# Corresponds to https://github.com/gfx-rs/ci-build/releases
MESA_CI_BINARY_BUILD: "build18"
# Sourced from https://www.nuget.org/packages/Microsoft.Direct3D.WARP
WARP_VERSION: "1.0.8"
# Sourced from https://github.com/microsoft/DirectXShaderCompiler/releases
# Must also be changed in shaders.yaml
DXC_RELEASE: "v1.7.2308"
DXC_FILENAME: "dxc_2023_08_14.zip"
# Mozilla Grcov
GRCOV_LINK: "https://github.com/mozilla/grcov/releases/download"
GRCOV_VERSION: "0.8.19"
# Typos version
TYPOS_LINK: "https://github.com/crate-ci/typos/releases/download"
TYPOS_VERSION: "1.23.4"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
code-quality:
runs-on: ubuntu-22.04
strategy:
matrix:
rust: [stable]
include:
- rust: stable
cache-version: stable
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-linux
# --------------------------------------------------------------------------------
- name: Audit
run: cargo xtask check audit
# --------------------------------------------------------------------------------
- name: Format
shell: bash
env:
# work around for colors
# see: https://github.com/rust-lang/rustfmt/issues/3385
TERM: xterm-256color
run: cargo xtask check format
# --------------------------------------------------------------------------------
- name: Lint
run: cargo xtask check lint
# --------------------------------------------------------------------------------
- name: Typos
uses: tracel-ai/github-actions/check-typos@v1
documentation:
runs-on: ubuntu-22.04
strategy:
matrix:
rust: [stable]
include:
- rust: stable
cache-version: stable
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-linux
# --------------------------------------------------------------------------------
- name: Documentation Build
run: cargo xtask doc build
# --------------------------------------------------------------------------------
- name: Documentation Tests
run: cargo xtask doc tests
linux-std-tests:
runs-on: ubuntu-22.04
strategy:
matrix:
rust: [stable, 1.79.0]
include:
- rust: stable
cache-version: stable
coverage: --enable-coverage
- rust: 1.79.0
cache-version: 1-79-0
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-linux
# --------------------------------------------------------------------------------
- name: Setup Linux runner
uses: tracel-ai/github-actions/setup-linux@v1
with:
vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }}
mesa-version: ${{ env.MESA_VERSION }}
mesa-ci-build-version: ${{ env.MESA_CI_BINARY_BUILD }}
# --------------------------------------------------------------------------------
- name: Install grcov
if: matrix.rust == 'stable'
shell: bash
run: |
curl -L "$GRCOV_LINK/v$GRCOV_VERSION/grcov-x86_64-unknown-linux-musl.tar.bz2" |
tar xj -C $HOME/.cargo/bin
cargo xtask coverage install
# --------------------------------------------------------------------------------
- name: Tests
run: cargo xtask ${{ matrix.coverage }} test --ci
# --------------------------------------------------------------------------------
- name: Generate lcov.info
if: matrix.rust == 'stable'
# /* is to exclude std library code coverage from analysis
run: cargo xtask coverage generate --ignore "/*,xtask/*,examples/*"
# --------------------------------------------------------------------------------
- name: Codecov upload lcov.info
if: matrix.rust == 'stable'
uses: codecov/codecov-action@v4
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
linux-no-std-tests:
runs-on: ubuntu-22.04
strategy:
matrix:
rust: [stable, 1.79.0]
include:
- rust: stable
cache-version: stable
- rust: 1.79.0
cache-version: 1-79-0
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-linux-no-std
# --------------------------------------------------------------------------------
- name: Setup Linux runner
uses: tracel-ai/github-actions/setup-linux@v1
with:
vulkan-sdk-version: ${{ env.VULKAN_SDK_VERSION }}
mesa-version: ${{ env.MESA_VERSION }}
mesa-ci-build-version: ${{ env.MESA_CI_BINARY_BUILD }}
# --------------------------------------------------------------------------------
- name: Crates Build
run: cargo xtask --execution-environment no-std build --ci
# --------------------------------------------------------------------------------
- name: Crates Tests
run: cargo xtask --execution-environment no-std test --ci
windows-std-tests:
runs-on: windows-2022
env:
DISABLE_WGPU: '1'
# Keep the stragegy to be able to easily add new rust versions if required
strategy:
matrix:
rust: [stable]
include:
- rust: stable
cache-version: stable
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-windows
# --------------------------------------------------------------------------------
- name: Setup Windows runner
if: env.DISABLE_WGPU != '1'
uses: tracel-ai/github-actions/setup-windows@v1
with:
dxc-release: ${{ env.DXC_RELEASE }}
dxc-filename: ${{ env.DXC_FILENAME }}
mesa-version: ${{ env.MESA_VERSION }}
warp-version: ${{ env.WARP_VERSION }}
# --------------------------------------------------------------------------------
- name: Tests
run: cargo xtask test --ci
macos-std-tests:
runs-on: blaze/macos-14
# Keep the stragegy to be able to easily add new rust versions if required
strategy:
matrix:
rust: [stable]
include:
- rust: stable
cache-version: stable
steps:
- name: Setup Rust
uses: tracel-ai/github-actions/setup-rust@v1
with:
rust-toolchain: ${{ matrix.rust }}
cache-key: ${{ matrix.cache-version }}-macos
# --------------------------------------------------------------------------------
- name: Tests
run: cargo xtask test --ci

View File

@ -1,24 +0,0 @@
on:
workflow_call:
inputs:
crate:
required: true
type: string
secrets:
CRATES_IO_API_TOKEN:
required: true
jobs:
publish-crate:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v3
- name: install rust
uses: dtolnay/rust-toolchain@stable
- name: publish to crates.io
run: cargo xtask publish ${{ inputs.crate }}
env:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

View File

@ -7,51 +7,57 @@ on:
jobs: jobs:
publish-burn-derive: publish-burn-derive:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with: with:
crate: burn-derive crate: burn-derive
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-dataset: publish-burn-dataset:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with: with:
crate: burn-dataset crate: burn-dataset
needs: needs:
- publish-burn-common - publish-burn-common
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-common: publish-burn-common:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with: with:
crate: burn-common crate: burn-common
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-tensor-testgen: publish-burn-tensor-testgen:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with: with:
crate: burn-tensor-testgen crate: burn-tensor-testgen
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-tensor: publish-burn-tensor:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-tensor-testgen - publish-burn-tensor-testgen
- publish-burn-common - publish-burn-common
with: with:
crate: burn-tensor crate: burn-tensor
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-fusion: publish-burn-fusion:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-tensor - publish-burn-tensor
- publish-burn-common - publish-burn-common
with: with:
crate: burn-fusion crate: burn-fusion
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-jit: publish-burn-jit:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-common - publish-burn-common
- publish-burn-fusion - publish-burn-fusion
@ -59,39 +65,43 @@ jobs:
- publish-burn-ndarray - publish-burn-ndarray
with: with:
crate: burn-jit crate: burn-jit
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-autodiff: publish-burn-autodiff:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-tensor - publish-burn-tensor
- publish-burn-tensor-testgen - publish-burn-tensor-testgen
- publish-burn-common - publish-burn-common
with: with:
crate: burn-autodiff crate: burn-autodiff
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-tch: publish-burn-tch:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-tensor - publish-burn-tensor
- publish-burn-autodiff - publish-burn-autodiff
with: with:
crate: burn-tch crate: burn-tch
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-ndarray: publish-burn-ndarray:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-tensor - publish-burn-tensor
- publish-burn-autodiff - publish-burn-autodiff
- publish-burn-common - publish-burn-common
with: with:
crate: burn-ndarray crate: burn-ndarray
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-wgpu: publish-burn-wgpu:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-tensor - publish-burn-tensor
- publish-burn-autodiff - publish-burn-autodiff
@ -100,10 +110,11 @@ jobs:
- publish-burn-jit - publish-burn-jit
with: with:
crate: burn-wgpu crate: burn-wgpu
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-cuda: publish-burn-cuda:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-tensor - publish-burn-tensor
- publish-burn-autodiff - publish-burn-autodiff
@ -112,20 +123,22 @@ jobs:
- publish-burn-jit - publish-burn-jit
with: with:
crate: burn-cuda crate: burn-cuda
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-candle: publish-burn-candle:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-tensor - publish-burn-tensor
- publish-burn-autodiff - publish-burn-autodiff
- publish-burn-tch - publish-burn-tch
with: with:
crate: burn-candle crate: burn-candle
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-core: publish-burn-core:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-dataset - publish-burn-dataset
- publish-burn-common - publish-burn-common
@ -138,35 +151,40 @@ jobs:
- publish-burn-candle - publish-burn-candle
with: with:
crate: burn-core crate: burn-core
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-train: publish-burn-train:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-core - publish-burn-core
with: with:
crate: burn-train crate: burn-train
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn: publish-burn:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn-core - publish-burn-core
- publish-burn-train - publish-burn-train
with: with:
crate: burn crate: burn
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-burn-import: publish-burn-import:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
needs: needs:
- publish-burn - publish-burn
with: with:
crate: burn-import crate: burn-import
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}
publish-onnx-ir: publish-onnx-ir:
uses: tracel-ai/burn/.github/workflows/publish-template.yml@main uses: tracel-ai/github-actions/.github/workflows/publish-crate.yml@v1
with: with:
crate: onnx-ir crate: onnx-ir
secrets: inherit secrets:
CRATES_IO_API_TOKEN: ${{ secrets.CRATES_IO_API_TOKEN }}

View File

@ -1,283 +0,0 @@
name: test
on:
push:
branches:
- main
paths:
- 'Cargo.lock'
- '**.rs'
- '**.sh'
- '**.ps1'
- '**.yml'
- '**.toml'
- '!**.md'
- '!LICENSE-APACHE'
- '!LICENSE-MIT'
pull_request:
types: [opened, synchronize]
paths:
- 'Cargo.lock'
- '**.rs'
- '**.sh'
- '**.ps1'
- '**.yml'
- '**.toml'
- '!**.md'
- '!LICENSE-APACHE'
- '!LICENSE-MIT'
env:
#
# Dependency versioning
# from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
#
# Sourced from https://vulkan.lunarg.com/sdk/home#linux
VULKAN_SDK_VERSION: "1.3.268"
# Sourced from https://www.nuget.org/packages/Microsoft.Direct3D.WARP
WARP_VERSION: "1.0.8"
# Sourced from https://github.com/microsoft/DirectXShaderCompiler/releases
#
# Must also be changed in shaders.yaml
DXC_RELEASE: "v1.7.2308"
DXC_FILENAME: "dxc_2023_08_14.zip"
# Sourced from https://archive.mesa3d.org/. Bumping this requires
# updating the mesa build in https://github.com/gfx-rs/ci-build and creating a new release.
MESA_VERSION: "23.3.1"
# Corresponds to https://github.com/gfx-rs/ci-build/releases
CI_BINARY_BUILD: "build18"
# Typos version
TYPOS_VERSION: "1.16.20"
# Grcov version
GRCOV_VERSION: "0.8.18"
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
cancel-in-progress: true
jobs:
tests:
runs-on: ${{ matrix.os }}
strategy:
matrix:
os: [blaze/macos-14, ubuntu-22.04, windows-2022]
# We support both the latest Rust toolchain and the preceding version.
rust: [stable, 1.79.0]
test: ['std', 'no-std', 'examples']
include:
- cache: stable
rust: stable
- cache: 1-79-0
rust: 1.79.0
- os: ubuntu-22.04
coverage-flags: COVERAGE=1
rust: stable
test: std
- os: blaze/macos-14
rust: stable
test: std
- os: windows-2022
wgpu-flags: "DISABLE_WGPU=1"
# not used yet, as wgpu tests are disabled on windows for now
# see issue: https://github.com/tracel-ai/burn/issues/1062
# auto-graphics-backend-flags: "AUTO_GRAPHICS_BACKEND=dx12";'
exclude:
# only need to check this once
- rust: 1.79.0
test: 'examples'
# Do not run no-std tests on macos
- os: blaze/macos-14
test: 'no-std'
# Do not run no-std tests on Windows
- os: windows-2022
test: 'no-std'
steps:
- name: checkout
uses: actions/checkout@v4
- name: install rust
uses: dtolnay/rust-toolchain@master
with:
components: rustfmt, clippy
toolchain: ${{ matrix.rust }}
- name: caching
uses: Swatinem/rust-cache@v2
with:
key: ${{ runner.os }}-${{ matrix.cache }}-${{ matrix.test}}-${{ hashFiles('**/Cargo.toml') }}
prefix-key: "v5-rust"
- name: free disk space
if: runner.os == 'Linux'
run: |
df -h
sudo swapoff -a
sudo rm -f /swapfile
sudo apt clean
df -h
cargo clean --package burn-tch
- name: install llvmpipe and lavapipe
if: runner.os == 'Linux'
uses: ./.github/actions/setup-llvmpipe-lavapipe
- name: Run cargo clippy for stable version
if: runner.os == 'Linux' && matrix.rust == 'stable' && matrix.test == 'std'
uses: giraffate/clippy-action@v1
with:
github_token: ${{ secrets.GITHUB_TOKEN }}
# Run clippy for each workspace, targets, and featrues, considering
# warnings as errors
clippy_flags: --all-targets -- -Dwarnings
# Do not filter results
filter_mode: nofilter
# Report clippy annotations as snippets
reporter: github-pr-check
- name: Install grcov
if: runner.os == 'Linux' && matrix.rust == 'stable' && matrix.test == 'std'
env:
GRCOV_LINK: https://github.com/mozilla/grcov/releases/download
run: |
curl -L "$GRCOV_LINK/v$GRCOV_VERSION/grcov-x86_64-unknown-linux-musl.tar.bz2" |
tar xj -C $HOME/.cargo/bin
# -----------------------------------------------------------------------------------
# BEGIN -- Windows steps disabled as long as DISABLE_WGPU=1 (wgpu tests are disabled)
# -----------------------------------------------------------------------------------
# - name: (windows) install dxc
# # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
# if: runner.os == 'Windows'
# shell: bash
# run: |
# set -e
# curl.exe -L --retry 5 https://github.com/microsoft/DirectXShaderCompiler/releases/download/$DXC_RELEASE/$DXC_FILENAME -o dxc.zip
# 7z.exe e dxc.zip -odxc bin/x64/{dxc.exe,dxcompiler.dll,dxil.dll}
# # We need to use cygpath to convert PWD to a windows path as we're using bash.
# cygpath --windows "$PWD/dxc" >> "$GITHUB_PATH"
# - name: (windows) install warp
# # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
# if: runner.os == 'Windows'
# shell: bash
# run: |
# set -e
# # Make sure dxc is in path.
# dxc --version
# curl.exe -L --retry 5 https://www.nuget.org/api/v2/package/Microsoft.Direct3D.WARP/$WARP_VERSION -o warp.zip
# 7z.exe e warp.zip -owarp build/native/amd64/d3d10warp.dll
# mkdir -p target/llvm-cov-target/debug/deps
# cp -v warp/d3d10warp.dll target/llvm-cov-target/debug/
# cp -v warp/d3d10warp.dll target/llvm-cov-target/debug/deps
# - name: (windows) install mesa
# # from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
# if: runner.os == 'Windows'
# shell: bash
# run: |
# set -e
# curl.exe -L --retry 5 https://github.com/pal1000/mesa-dist-win/releases/download/$MESA_VERSION/mesa3d-$MESA_VERSION-release-msvc.7z -o mesa.7z
# 7z.exe e mesa.7z -omesa x64/{opengl32.dll,libgallium_wgl.dll,libglapi.dll,vulkan_lvp.dll,lvp_icd.x86_64.json}
# cp -v mesa/* target/llvm-cov-target/debug/
# cp -v mesa/* target/llvm-cov-target/debug/deps
# # We need to use cygpath to convert PWD to a windows path as we're using bash.
# echo "VK_DRIVER_FILES=`cygpath --windows $PWD/mesa/lvp_icd.x86_64.json`" >> "$GITHUB_ENV"
# echo "GALLIUM_DRIVER=llvmpipe" >> "$GITHUB_ENV"
# -----------------------------------------------------------------------------------
# END -- Windows steps disabled as long as DISABLE_WGPU=1 (wgpu tests are disabled)
# -----------------------------------------------------------------------------------
- name: (linux) install vulkan sdk
# from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
if: runner.os == 'Linux'
shell: bash
run: |
set -e
sudo apt-get update -y -qq
# vulkan sdk
wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | sudo apt-key add -
sudo wget -qO /etc/apt/sources.list.d/lunarg-vulkan-$VULKAN_SDK_VERSION-jammy.list https://packages.lunarg.com/vulkan/$VULKAN_SDK_VERSION/lunarg-vulkan-$VULKAN_SDK_VERSION-jammy.list
sudo apt-get update
sudo apt install -y vulkan-sdk
- name: (linux) install mesa
# from wgpu repo: https://github.com/gfx-rs/wgpu/blob/trunk/.github/workflows/ci.yml
if: runner.os == 'Linux'
shell: bash
run: |
set -e
curl -L --retry 5 https://github.com/gfx-rs/ci-build/releases/download/$CI_BINARY_BUILD/mesa-$MESA_VERSION-linux-x86_64.tar.xz -o mesa.tar.xz
mkdir mesa
tar xpf mesa.tar.xz -C mesa
# The ICD provided by the mesa build is hardcoded to the build environment.
#
# We write out our own ICD file to point to the mesa vulkan
cat <<- EOF > icd.json
{
"ICD": {
"api_version": "1.1.255",
"library_path": "$PWD/mesa/lib/x86_64-linux-gnu/libvulkan_lvp.so"
},
"file_format_version": "1.0.0"
}
EOF
echo "VK_DRIVER_FILES=$PWD/icd.json" >> "$GITHUB_ENV"
echo "LD_LIBRARY_PATH=$PWD/mesa/lib/x86_64-linux-gnu/:$LD_LIBRARY_PATH" >> "$GITHUB_ENV"
echo "LIBGL_DRIVERS_PATH=$PWD/mesa/lib/x86_64-linux-gnu/dri" >> "$GITHUB_ENV"
- name: run checks & tests
shell: bash
run: ${{ matrix.coverage-flags }} ${{ matrix.wgpu-flags }} cargo xtask run-checks ${{ matrix.test }}
- name: Codecov upload
if: runner.os == 'Linux' && matrix.rust == 'stable' && matrix.test == 'std'
uses: codecov/codecov-action@v4
with:
files: lcov.info
token: ${{ secrets.CODECOV_TOKEN }}
check-typos:
runs-on: ubuntu-22.04
steps:
- name: checkout
uses: actions/checkout@v4
- name: caching
uses: Swatinem/rust-cache@v2
with:
key: ${{ runner.os }}-typos-${{ hashFiles('**/Cargo.toml') }}
prefix-key: "v5-rust"
- name: Install typos
env:
TYPOS_LINK: https://github.com/crate-ci/typos/releases/download
run: |
curl -L "$TYPOS_LINK/v$TYPOS_VERSION/typos-v$TYPOS_VERSION-x86_64-unknown-linux-musl.tar.gz" |
tar xz -C $HOME/.cargo/bin
- name: run spelling checks using typos
run: cargo xtask run-checks typos

View File

@ -43,13 +43,21 @@ your changes easier. You can create a new branch by using the command
Once you have set up your local repository and created a new branch, you can start making changes. Once you have set up your local repository and created a new branch, you can start making changes.
Be sure to follow the coding standards and guidelines used in the rest of the project. Be sure to follow the coding standards and guidelines used in the rest of the project.
### Step 6: Run the Pre-Pull Request Script ### Step 6: Validate code before opening a Pull Request
Before you open a pull request, please run [`./run-checks.sh all`](/run-checks.sh). This Before you open a pull request, please run [`./run-checks.sh all`](/run-checks.sh). This
will ensure that your changes are in line with our project's standards and guidelines. You can run will ensure that your changes are in line with our project's standards and guidelines. You can run
this script by opening a terminal, navigating to your local project directory, and typing this script by opening a terminal, navigating to your local project directory, and typing
`./run-checks`. `./run-checks`.
Note that under the hood `run-checks` runs the `cargo xtask validate` command which is powered by
the [tracel-xtask crate](https://github.com/tracel-ai/xtask). It is recommended to get familiar with
it as it provides a wide variety of commands to help you work with the code base.
If you have an error related to `torch` installation, see [Burn Torch Backend Installation](./crates/burn-tch/README.md#Installation)
Format and lint errors can often be fixed automatically using the command `cargo xtask fix all`.
### Step 7: Submit a Pull Request ### Step 7: Submit a Pull Request
After you've made your changes and run the pre-pull request script, you're ready to submit a pull After you've made your changes and run the pre-pull request script, you're ready to submit a pull
@ -87,50 +95,6 @@ You may also want to enable debugging by creating a `.vscode/settings.json` file
4. If you're creating a new library or binary, keep in mind to repeat the step 2 to always keep a fresh list of targets. 4. If you're creating a new library or binary, keep in mind to repeat the step 2 to always keep a fresh list of targets.
## Continuous Integration
### Run checks
On Unix systems, run `run-checks.sh` using this command
```
./run-checks.sh environment
```
On Windows systems, run `run-checks.ps1` using this command:
```
run-checks.ps1 environment
```
The `environment` argument can assume **ONLY** the following values:
- `std` to perform checks using `libstd`
- `no-std` to perform checks on an embedded environment using `libcore`
- `typos` to check for typos in the codebase
- `examples` to check the examples compile
If no `environment` value has been passed, run all checks except examples.
If you have an error related to `torch` installation, see [Burn Torch Backend Installation](./crates/burn-tch/README.md#Installation)
## Continuous Deployment
### Publish crates
Compile `scripts/publish.rs` using this command:
```
rustc scripts/publish.rs --crate-type bin --out-dir scripts
```
Run `scripts/publish` using this command
```
./scripts/publish crate_name
```
where `crate_name` is the name of the crate to publish
## Code Guidelines ## Code Guidelines
We believe in clean and efficient code. While we don't enforce strict coding guidelines, we trust We believe in clean and efficient code. While we don't enforce strict coding guidelines, we trust
@ -150,6 +114,11 @@ _Think of `expect()` messages as guidelines for future you and other developers.
This approach ensures that `expect()` messages are informative and aligned with the intended This approach ensures that `expect()` messages are informative and aligned with the intended
function outcomes, making debugging and maintenance more straightforward for everyone. function outcomes, making debugging and maintenance more straightforward for everyone.
### Writing integration tests
[Integration tests](https://doc.rust-lang.org/rust-by-example/testing/integration_testing.html) should be in a directory called `tests`
besides the `src` directory of a crate. Per convention, they must be implemented in files whose name start with the `test_` prefix.
## Others ## Others
To bump for the next version, install `cargo-edit` if its not on your system, and use this command: To bump for the next version, install `cargo-edit` if its not on your system, and use this command:

341
Cargo.lock generated
View File

@ -23,6 +23,12 @@ version = "1.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe"
[[package]]
name = "adler2"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "512761e0bb2578dd7380c6baaa0f4ce03e84f95e960231d1dec8bf4d7d6e2627"
[[package]] [[package]]
name = "aes" name = "aes"
version = "0.8.4" version = "0.8.4"
@ -193,9 +199,9 @@ checksum = "bf7d0a018de4f6aa429b9d33d69edf69072b1c5b1cb8d3e4a5f7ef898fc3eb76"
[[package]] [[package]]
name = "arrayvec" name = "arrayvec"
version = "0.7.4" version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]] [[package]]
name = "ash" name = "ash"
@ -297,7 +303,7 @@ dependencies = [
"os_info", "os_info",
"percent-encoding", "percent-encoding",
"rand", "rand",
"reqwest 0.12.5", "reqwest 0.12.7",
"rstest", "rstest",
"serde", "serde",
"serde_json", "serde_json",
@ -319,7 +325,7 @@ dependencies = [
"cc", "cc",
"cfg-if", "cfg-if",
"libc", "libc",
"miniz_oxide", "miniz_oxide 0.7.4",
"object", "object",
"rustc-demangle", "rustc-demangle",
] ]
@ -403,9 +409,9 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de"
[[package]] [[package]]
name = "bitstream-io" name = "bitstream-io"
version = "2.5.0" version = "2.5.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3dcde5f311c85b8ca30c2e4198d4326bc342c76541590106f5fa4a50946ea499" checksum = "b81e1519b0d82120d2fd469d5bfb2919a9361c48b02d82d04befc1cdd2002452"
[[package]] [[package]]
name = "blas-src" name = "blas-src"
@ -506,7 +512,7 @@ dependencies = [
"getrandom", "getrandom",
"indicatif", "indicatif",
"rayon", "rayon",
"reqwest 0.12.5", "reqwest 0.12.7",
"tokio", "tokio",
"web-time", "web-time",
] ]
@ -759,18 +765,18 @@ dependencies = [
[[package]] [[package]]
name = "bytemuck" name = "bytemuck"
version = "1.16.3" version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "102087e286b4677862ea56cf8fc58bb2cdfa8725c40ffb80fe3a008eb7f2fc83" checksum = "6fd4c6dcc3b0aea2f5c0b4b82c2b15fe39ddbc76041a310848f4706edf76bb31"
dependencies = [ dependencies = [
"bytemuck_derive", "bytemuck_derive",
] ]
[[package]] [[package]]
name = "bytemuck_derive" name = "bytemuck_derive"
version = "1.7.0" version = "1.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -896,12 +902,13 @@ dependencies = [
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.1.10" version = "1.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e9e8aabfac534be767c909e0690571677d49f41bd8465ae876fe043d52ba5292" checksum = "50d2eb3cd3d1bf4529e31c215ee6f93ec5a3d536d9f578f93d9d33ee19562932"
dependencies = [ dependencies = [
"jobserver", "jobserver",
"libc", "libc",
"shlex",
] ]
[[package]] [[package]]
@ -1068,9 +1075,9 @@ dependencies = [
[[package]] [[package]]
name = "cmake" name = "cmake"
version = "0.1.50" version = "0.1.51"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a"
dependencies = [ dependencies = [
"cc", "cc",
] ]
@ -1525,7 +1532,7 @@ version = "0.15.0"
dependencies = [ dependencies = [
"burn", "burn",
"csv", "csv",
"reqwest 0.12.5", "reqwest 0.12.7",
"serde", "serde",
] ]
@ -1965,7 +1972,7 @@ dependencies = [
"flume", "flume",
"half", "half",
"lebe", "lebe",
"miniz_oxide", "miniz_oxide 0.7.4",
"rayon-core", "rayon-core",
"smallvec", "smallvec",
"zune-inflate", "zune-inflate",
@ -2034,12 +2041,12 @@ dependencies = [
[[package]] [[package]]
name = "flate2" name = "flate2"
version = "1.0.31" version = "1.0.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f211bbe8e69bbd0cfdea405084f128ae8b4aaa6b0b522fc8f2b009084797920" checksum = "9c0596c1eac1f9e04ed902702e9878208b336edc9d6fddc8a48387349bab3666"
dependencies = [ dependencies = [
"crc32fast", "crc32fast",
"miniz_oxide", "miniz_oxide 0.8.0",
] ]
[[package]] [[package]]
@ -2490,8 +2497,8 @@ dependencies = [
"aho-corasick", "aho-corasick",
"bstr", "bstr",
"log", "log",
"regex-automata", "regex-automata 0.4.7",
"regex-syntax", "regex-syntax 0.8.4",
] ]
[[package]] [[package]]
@ -2608,9 +2615,9 @@ dependencies = [
[[package]] [[package]]
name = "h2" name = "h2"
version = "0.4.5" version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa82e28a107a8cc405f0839610bdc9b15f1e25ec7d696aa5cf173edbcb1486ab" checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205"
dependencies = [ dependencies = [
"atomic-waker", "atomic-waker",
"bytes", "bytes",
@ -2878,7 +2885,7 @@ dependencies = [
"bytes", "bytes",
"futures-channel", "futures-channel",
"futures-util", "futures-util",
"h2 0.4.5", "h2 0.4.6",
"http 1.1.0", "http 1.1.0",
"http-body 1.0.1", "http-body 1.0.1",
"httparse", "httparse",
@ -3004,7 +3011,7 @@ dependencies = [
"globset", "globset",
"log", "log",
"memchr", "memchr",
"regex-automata", "regex-automata 0.4.7",
"same-file", "same-file",
"walkdir", "walkdir",
"winapi-util", "winapi-util",
@ -3244,9 +3251,9 @@ checksum = "03087c2bad5e1034e8cace5926dec053fb3790248370865f5117a7d0213354c8"
[[package]] [[package]]
name = "libc" name = "libc"
version = "0.2.157" version = "0.2.158"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374af5f94e54fa97cf75e945cce8a6b201e88a1a07e688b47dfd2a59c66dbd86" checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439"
[[package]] [[package]]
name = "libfuzzer-sys" name = "libfuzzer-sys"
@ -3404,6 +3411,15 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "matchers"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558"
dependencies = [
"regex-automata 0.1.10",
]
[[package]] [[package]]
name = "matrixmultiply" name = "matrixmultiply"
version = "0.3.9" version = "0.3.9"
@ -3509,6 +3525,15 @@ dependencies = [
"simd-adler32", "simd-adler32",
] ]
[[package]]
name = "miniz_oxide"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2d80299ef12ff69b16a84bb182e3b9df68b5a91574d3d4fa6e41b65deec4df1"
dependencies = [
"adler2",
]
[[package]] [[package]]
name = "mio" name = "mio"
version = "0.8.11" version = "0.8.11"
@ -4386,7 +4411,7 @@ dependencies = [
"crc32fast", "crc32fast",
"fdeflate", "fdeflate",
"flate2", "flate2",
"miniz_oxide", "miniz_oxide 0.7.4",
] ]
[[package]] [[package]]
@ -4900,9 +4925,9 @@ dependencies = [
[[package]] [[package]]
name = "protobuf" name = "protobuf"
version = "3.5.0" version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df67496db1a89596beaced1579212e9b7c53c22dca1d9745de00ead76573d514" checksum = "0bcc343da15609eaecd65f8aa76df8dc4209d325131d8219358c0aaaebab0bf6"
dependencies = [ dependencies = [
"bytes", "bytes",
"once_cell", "once_cell",
@ -4912,9 +4937,9 @@ dependencies = [
[[package]] [[package]]
name = "protobuf-codegen" name = "protobuf-codegen"
version = "3.5.0" version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eab09155fad2d39333d3796f67845d43e29b266eea74f7bc93f153f707f126dc" checksum = "c4d0cde5642ea4df842b13eb9f59ea6fafa26dcb43e3e1ee49120e9757556189"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"once_cell", "once_cell",
@ -4927,9 +4952,9 @@ dependencies = [
[[package]] [[package]]
name = "protobuf-parse" name = "protobuf-parse"
version = "3.5.0" version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a16027030d4ec33e423385f73bb559821827e9ec18c50e7874e4d6de5a4e96f" checksum = "1b0e9b447d099ae2c4993c0cbb03c7a9d6c937b17f2d56cfc0b1550e6fcfdb76"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"indexmap 2.4.0", "indexmap 2.4.0",
@ -4943,9 +4968,9 @@ dependencies = [
[[package]] [[package]]
name = "protobuf-support" name = "protobuf-support"
version = "3.5.0" version = "3.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "70e2d30ab1878b2e72d1e2fc23ff5517799c9929e2cf81a8516f9f4dcf2b9cf3" checksum = "f0766e3675a627c327e4b3964582594b0e8741305d628a98a5de75a1d15f99b9"
dependencies = [ dependencies = [
"thiserror", "thiserror",
] ]
@ -4961,9 +4986,9 @@ dependencies = [
[[package]] [[package]]
name = "pulp" name = "pulp"
version = "0.18.21" version = "0.18.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ec8d02258294f59e4e223b41ad7e81c874aa6b15bc4ced9ba3965826da0eed5" checksum = "a0a01a0dc67cf4558d279f0c25b0962bd08fc6dec0137699eae304103e882fe6"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"libm", "libm",
@ -5261,9 +5286,9 @@ dependencies = [
[[package]] [[package]]
name = "redox_users" name = "redox_users"
version = "0.4.5" version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bd283d9651eeda4b2a83a43c1c91b266c40fd76ecd39a50a8c630ae69dc72891" checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43"
dependencies = [ dependencies = [
"getrandom", "getrandom",
"libredox", "libredox",
@ -5278,8 +5303,17 @@ checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
"regex-automata", "regex-automata 0.4.7",
"regex-syntax", "regex-syntax 0.8.4",
]
[[package]]
name = "regex-automata"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132"
dependencies = [
"regex-syntax 0.6.29",
] ]
[[package]] [[package]]
@ -5290,24 +5324,21 @@ checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df"
dependencies = [ dependencies = [
"aho-corasick", "aho-corasick",
"memchr", "memchr",
"regex-syntax", "regex-syntax 0.8.4",
] ]
[[package]]
name = "regex-syntax"
version = "0.6.29"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1"
[[package]] [[package]]
name = "regex-syntax" name = "regex-syntax"
version = "0.8.4" version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b"
[[package]]
name = "regression"
version = "0.15.0"
dependencies = [
"burn",
"log",
"serde",
]
[[package]] [[package]]
name = "relative-path" name = "relative-path"
version = "1.9.3" version = "1.9.3"
@ -5349,7 +5380,7 @@ dependencies = [
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"sync_wrapper 0.1.2", "sync_wrapper 0.1.2",
"system-configuration", "system-configuration 0.5.1",
"tokio", "tokio",
"tokio-native-tls", "tokio-native-tls",
"tower-service", "tower-service",
@ -5357,14 +5388,14 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"web-sys", "web-sys",
"winreg 0.50.0", "winreg",
] ]
[[package]] [[package]]
name = "reqwest" name = "reqwest"
version = "0.12.5" version = "0.12.7"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7d6d2a27d57148378eb5e111173f4276ad26340ecc5c49a4a2152167a2d6a37" checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63"
dependencies = [ dependencies = [
"base64 0.22.1", "base64 0.22.1",
"bytes", "bytes",
@ -5372,7 +5403,7 @@ dependencies = [
"futures-channel", "futures-channel",
"futures-core", "futures-core",
"futures-util", "futures-util",
"h2 0.4.5", "h2 0.4.6",
"http 1.1.0", "http 1.1.0",
"http-body 1.0.1", "http-body 1.0.1",
"http-body-util", "http-body-util",
@ -5393,7 +5424,7 @@ dependencies = [
"serde_json", "serde_json",
"serde_urlencoded", "serde_urlencoded",
"sync_wrapper 1.0.1", "sync_wrapper 1.0.1",
"system-configuration", "system-configuration 0.6.1",
"tokio", "tokio",
"tokio-native-tls", "tokio-native-tls",
"tower-service", "tower-service",
@ -5401,7 +5432,7 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
"wasm-bindgen-futures", "wasm-bindgen-futures",
"web-sys", "web-sys",
"winreg 0.52.0", "windows-registry",
] ]
[[package]] [[package]]
@ -5554,9 +5585,9 @@ dependencies = [
[[package]] [[package]]
name = "rustls-native-certs" name = "rustls-native-certs"
version = "0.7.1" version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a88d6d420651b496bdd98684116959239430022a115c1240e6c3993be0b15fba" checksum = "04182dffc9091a404e0fc069ea5cd60e5b866c3adf881eff99a32d048242dffa"
dependencies = [ dependencies = [
"openssl-probe", "openssl-probe",
"rustls-pemfile 2.1.3", "rustls-pemfile 2.1.3",
@ -5654,9 +5685,9 @@ dependencies = [
[[package]] [[package]]
name = "scc" name = "scc"
version = "2.1.14" version = "2.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "79da19444d9da7a9a82b80ecf059eceba6d3129d84a8610fd25ff2364f255466" checksum = "aeb7ac86243095b70a7920639507b71d51a63390d1ba26c4f60a552fbb914a37"
dependencies = [ dependencies = [
"sdd", "sdd",
] ]
@ -5865,6 +5896,12 @@ dependencies = [
"lazy_static", "lazy_static",
] ]
[[package]]
name = "shlex"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]] [[package]]
name = "signal-hook" name = "signal-hook"
version = "0.3.17" version = "0.3.17"
@ -5916,6 +5953,15 @@ version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
[[package]]
name = "simple-regression"
version = "0.15.0"
dependencies = [
"burn",
"log",
"serde",
]
[[package]] [[package]]
name = "siphasher" name = "siphasher"
version = "0.3.11" version = "0.3.11"
@ -6025,15 +6071,15 @@ checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
[[package]] [[package]]
name = "stacker" name = "stacker"
version = "0.1.15" version = "0.1.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c886bd4480155fd3ef527d45e9ac8dd7118a898a46530b7b94c3e21866259fce" checksum = "95a5daa25ea337c85ed954c0496e3bdd2c7308cc3b24cf7b50d04876654c579f"
dependencies = [ dependencies = [
"cc", "cc",
"cfg-if", "cfg-if",
"libc", "libc",
"psm", "psm",
"winapi", "windows-sys 0.36.1",
] ]
[[package]] [[package]]
@ -6136,6 +6182,9 @@ name = "sync_wrapper"
version = "1.0.1" version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394"
dependencies = [
"futures-core",
]
[[package]] [[package]]
name = "synstructure" name = "synstructure"
@ -6186,7 +6235,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7"
dependencies = [ dependencies = [
"bitflags 1.3.2", "bitflags 1.3.2",
"core-foundation", "core-foundation",
"system-configuration-sys", "system-configuration-sys 0.5.0",
]
[[package]]
name = "system-configuration"
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b"
dependencies = [
"bitflags 2.6.0",
"core-foundation",
"system-configuration-sys 0.6.0",
] ]
[[package]] [[package]]
@ -6199,6 +6259,16 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "system-configuration-sys"
version = "0.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]] [[package]]
name = "system-deps" name = "system-deps"
version = "6.2.2" version = "6.2.2"
@ -6446,7 +6516,7 @@ dependencies = [
"rayon", "rayon",
"rayon-cond", "rayon-cond",
"regex", "regex",
"regex-syntax", "regex-syntax 0.8.4",
"serde", "serde",
"serde_json", "serde_json",
"spm_precompiled", "spm_precompiled",
@ -6458,9 +6528,9 @@ dependencies = [
[[package]] [[package]]
name = "tokio" name = "tokio"
version = "1.39.2" version = "1.39.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" checksum = "9babc99b9923bfa4804bd74722ff02c0381021eafa4db9949217e3be8e84fff5"
dependencies = [ dependencies = [
"backtrace", "backtrace",
"bytes", "bytes",
@ -6604,6 +6674,34 @@ version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3"
[[package]]
name = "tracel-xtask"
version = "1.0.0"
source = "git+https://github.com/tracel-ai/xtask?rev=921408bc16e74d3ef8ae59356d928fb6706fb8f4#921408bc16e74d3ef8ae59356d928fb6706fb8f4"
dependencies = [
"anyhow",
"clap 4.5.16",
"derive_more",
"env_logger",
"log",
"rand",
"regex",
"serde_json",
"strum",
"tracel-xtask-macros",
"tracing-subscriber",
]
[[package]]
name = "tracel-xtask-macros"
version = "1.0.0"
source = "git+https://github.com/tracel-ai/xtask?rev=921408bc16e74d3ef8ae59356d928fb6706fb8f4#921408bc16e74d3ef8ae59356d928fb6706fb8f4"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.75",
]
[[package]] [[package]]
name = "tracing" name = "tracing"
version = "0.1.40" version = "0.1.40"
@ -6665,10 +6763,14 @@ version = "0.3.18"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b" checksum = "ad0f048c97dbd9faa9b7df56362b8ebcaa52adb06b498c050d2f4e32f90a7a8b"
dependencies = [ dependencies = [
"matchers",
"nu-ansi-term", "nu-ansi-term",
"once_cell",
"regex",
"sharded-slab", "sharded-slab",
"smallvec", "smallvec",
"thread_local", "thread_local",
"tracing",
"tracing-core", "tracing-core",
"tracing-log", "tracing-log",
] ]
@ -6749,9 +6851,9 @@ checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d"
[[package]] [[package]]
name = "unicode-xid" name = "unicode-xid"
version = "0.2.4" version = "0.2.5"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" checksum = "229730647fbc343e3a80e463c1db7f78f3855d3f3739bee0dda773c9a037c90a"
[[package]] [[package]]
name = "unicode_categories" name = "unicode_categories"
@ -7174,6 +7276,49 @@ dependencies = [
"windows-targets 0.52.6", "windows-targets 0.52.6",
] ]
[[package]]
name = "windows-registry"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-result"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
dependencies = [
"windows-targets 0.52.6",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
]
[[package]]
name = "windows-sys"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea04155a16a59f9eab786fe12a4a450e75cdb175f9e0d80da1e17db09f55b8d2"
dependencies = [
"windows_aarch64_msvc 0.36.1",
"windows_i686_gnu 0.36.1",
"windows_i686_msvc 0.36.1",
"windows_x86_64_gnu 0.36.1",
"windows_x86_64_msvc 0.36.1",
]
[[package]] [[package]]
name = "windows-sys" name = "windows-sys"
version = "0.48.0" version = "0.48.0"
@ -7244,6 +7389,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9bb8c3fd39ade2d67e9874ac4f3db21f0d710bee00fe7cab16949ec184eeaa47"
[[package]] [[package]]
name = "windows_aarch64_msvc" name = "windows_aarch64_msvc"
version = "0.48.5" version = "0.48.5"
@ -7256,6 +7407,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_i686_gnu"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "180e6ccf01daf4c426b846dfc66db1fc518f074baa793aa7d9b9aaeffad6a3b6"
[[package]] [[package]]
name = "windows_i686_gnu" name = "windows_i686_gnu"
version = "0.48.5" version = "0.48.5"
@ -7274,6 +7431,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e2e7917148b2812d1eeafaeb22a97e4813dfa60a3f8f78ebe204bcc88f12f024"
[[package]] [[package]]
name = "windows_i686_msvc" name = "windows_i686_msvc"
version = "0.48.5" version = "0.48.5"
@ -7286,6 +7449,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_x86_64_gnu"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dcd171b8776c41b97521e5da127a2d86ad280114807d0b2ab1e462bc764d9e1"
[[package]] [[package]]
name = "windows_x86_64_gnu" name = "windows_x86_64_gnu"
version = "0.48.5" version = "0.48.5"
@ -7310,6 +7479,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_msvc"
version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c811ca4a8c853ef420abd8592ba53ddbbac90410fab6903b3e79972a631f7680"
[[package]] [[package]]
name = "windows_x86_64_msvc" name = "windows_x86_64_msvc"
version = "0.48.5" version = "0.48.5"
@ -7350,16 +7525,6 @@ dependencies = [
"windows-sys 0.48.0", "windows-sys 0.48.0",
] ]
[[package]]
name = "winreg"
version = "0.52.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a277a57398d4bfa075df44f501a17cfdf8542d224f0d36095a2adc7aee4ef0a5"
dependencies = [
"cfg-if",
"windows-sys 0.48.0",
]
[[package]] [[package]]
name = "wrapcenum-derive" name = "wrapcenum-derive"
version = "0.4.1" version = "0.4.1"
@ -7414,16 +7579,12 @@ checksum = "539a77ee7c0de333dcc6da69b177380a0b81e0dacfa4f7344c465a36871ee601"
[[package]] [[package]]
name = "xtask" name = "xtask"
version = "0.6.0" version = "1.0.0"
dependencies = [ dependencies = [
"anyhow",
"clap 4.5.16",
"derive_more",
"env_logger",
"log", "log",
"rand",
"rstest", "rstest",
"serde_json", "strum",
"tracel-xtask",
] ]
[[package]] [[package]]

View File

@ -160,6 +160,9 @@ portable-atomic-util = { version = "0.2.2", features = ["alloc"] }
cubecl = { version="0.2.0", default-features = false } cubecl = { version="0.2.0", default-features = false }
cubecl-common = { version="0.2.0", default-features = false } cubecl-common = { version="0.2.0", default-features = false }
### For xtask crate ###
tracel-xtask = { git = "https://github.com/tracel-ai/xtask", rev = "921408bc16e74d3ef8ae59356d928fb6706fb8f4" }
[profile.dev] [profile.dev]
debug = 0 # Speed up compilation time and not necessary. debug = 0 # Speed up compilation time and not necessary.
opt-level = 2 opt-level = 2

View File

@ -3,6 +3,11 @@ extend-ignore-identifiers-re = ["ratatui", "Ratatui", "NdArray*", "ND"]
[files] [files]
extend-exclude = [ extend-exclude = [
"*.onnx",
"assets/ModuleSerialization.xml", "assets/ModuleSerialization.xml",
"examples/image-classification-web/src/model/label.txt", "examples/image-classification-web/src/model/label.txt",
] ]
[default.extend-words]
# Don't correct "arange" which is intentional
arange = "arange"

View File

@ -10,7 +10,7 @@ pub trait OutputProcessor: Send + Sync + 'static {
fn process_line(&self, line: &str); fn process_line(&self, line: &str);
/// To be called to indicate progress has been made /// To be called to indicate progress has been made
fn progress(&self); fn progress(&self);
/// To be called whent the processor has finished processing /// To be called went the processor has finished processing
fn finish(&self); fn finish(&self);
} }

View File

@ -23,10 +23,10 @@ class Model(nn.Module):
# Subtract a scalar constant from a scalar input # Subtract a scalar constant from a scalar input
d = k - self.b d = k - self.b
# Sutract a scalar from a tensor # Subtract a scalar from a tensor
x = x - d x = x - d
# Sutract a tensor from a scalar # Subtract a tensor from a scalar
x = d - x x = d - x
return x return x

View File

@ -24,10 +24,10 @@ class Model(nn.Module):
# Subtract a scalar constant from a scalar input # Subtract a scalar constant from a scalar input
d = k - self.b d = k - self.b
# Sutract a scalar from a tensor # Subtract a scalar from a tensor
x = x - d x = x - d
# Sutract a tensor from a scalar # Subtract a tensor from a scalar
x = d - x x = d - x
return x return x

View File

@ -2,7 +2,7 @@
authors = ["aasheeshsingh <aasheeshdtu@gmail.com>"] authors = ["aasheeshsingh <aasheeshdtu@gmail.com>"]
edition.workspace = true edition.workspace = true
license.workspace = true license.workspace = true
name = "regression" name = "simple-regression"
publish = false publish = false
version.workspace = true version.workspace = true

View File

@ -1,5 +1,10 @@
# This script runs all `burn` checks locally. It may take around 15 minutes on #!/usr/bin/env pwsh
# the first run.
# Exit immediately if a command exits with a non-zero status.
$ErrorActionPreference = "Stop"
# This script runs all `burn` checks locally. It may take around 15 minutes
# on the first run.
# #
# Run `run-checks` using this command: # Run `run-checks` using this command:
# #
@ -7,16 +12,12 @@
# #
# where `environment` can assume **ONLY** the following values: # where `environment` can assume **ONLY** the following values:
# #
# - `std` to perform checks using `libstd` # - `std` to perform validation using `libstd`
# - `no-std` to perform checks on an embedded environment using `libcore` # - `no-std` to perform validation on an embedded environment using `libcore`
# - `typos` to check for typos in the codebase # - `all` to perform both std and no-std validation
# - `examples` to check the examples compile #
# If no `environment` value has been passed, run all checks except examples. # If no `environment` value has been passed, default to `all`.
$exec_env = if ($args.Count -ge 1) { $args[0] } else { "all" }
# Exit if any command fails # Run the cargo xtask command with the specified environment
$ErrorActionPreference = "Stop" cargo xtask --execution-environment $exec_env validate
# Run binary passing the first input parameter, who is mandatory.
# If the input parameter is missing or wrong, it will be the `run-checks`
# binary which will be responsible of arising an error.
cargo xtask run-checks $args[0]

View File

@ -12,14 +12,11 @@ set -e
# #
# where `environment` can assume **ONLY** the following values: # where `environment` can assume **ONLY** the following values:
# #
# - `std` to perform checks using `libstd` # - `std` to perform validation using `libstd`
# - `no-std` to perform checks on an embedded environment using `libcore` # - `no-std` to perform validation on an embedded environment using `libcore`
# - `typos` to check for typos in the codebase # - `all` to perform both std and no-std validation
# - `examples` to check the examples compile
# #
# If no `environment` value has been passed, run all checks except examples. # If no `environment` value has been passed.
exec_env=${1:-all}
# Run binary passing the first input parameter, who is mandatory. cargo xtask --execution-environment "$exec_env" validate
# If the input parameter is missing or wrong, it will be the `run-checks`
# binary which will be responsible of arising an error.
cargo xtask run-checks $1

View File

@ -1,19 +1,15 @@
[package] [package]
name = "xtask" name = "xtask"
version = "0.6.0" version = "1.0.0"
edition = "2021" edition = "2021"
license = "MIT OR Apache-2.0" license = "MIT OR Apache-2.0"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies] [dependencies]
anyhow = "1.0.86" log = { workspace = true }
clap = { version = "4.5.16", features = ["derive"] } strum = { workspace = true }
derive_more = { version = "0.99.18", features = ["display"], default-features = false } tracel-xtask = { workspace = true }
env_logger = "0.11.5"
log = "0.4.22"
rand = { workspace = true, features = ["std"] }
serde_json = { version = "1" }
[dev-dependencies] [dev-dependencies]
rstest = { workspace = true } rstest = { workspace = true }

View File

@ -1,46 +1,36 @@
use std::{collections::HashMap, path::Path, time::Instant}; use std::path::Path;
use clap::{Args, Subcommand}; use tracel_xtask::prelude::*;
use derive_more::Display;
use crate::{ #[derive(clap::Args)]
endgroup, group, pub struct BooksArgs {
logging::init_logger,
utils::{
cargo::ensure_cargo_crate_is_installed, mdbook::run_mdbook_with_path, process::random_port,
time::format_duration, Params,
},
};
#[derive(Args)]
pub(crate) struct BooksArgs {
#[command(subcommand)] #[command(subcommand)]
book: BookKind, book: BookKind,
} }
#[derive(Subcommand)] #[derive(clap::Subcommand)]
pub(crate) enum BookKind { pub(crate) enum BookKind {
/// Burn Book, a.k.a. the guide, made for the Burn users. /// Burn Book, a.k.a. the guide, made for the Burn users.
Burn(BookKindArgs), Burn(BookKindArgs),
/// Contributor book, made for people willing to get all the technical understanding and advices to contribute actively to the project. /// Contributor book, made for people willing to get all the technical understanding and advice to contribute actively to the project.
Contributor(BookKindArgs), Contributor(BookKindArgs),
} }
#[derive(Args)] #[derive(clap::Args)]
pub(crate) struct BookKindArgs { pub(crate) struct BookKindArgs {
#[command(subcommand)] #[command(subcommand)]
command: BookCommand, command: BookSubCommand,
} }
#[derive(Subcommand, Display)] #[derive(clap::Subcommand, strum::Display)]
pub(crate) enum BookCommand { pub(crate) enum BookSubCommand {
/// Build the book /// Build the book
Build, Build,
/// Open the book on the specified port or random port and rebuild it automatically upon changes /// Open the book on the specified port or random port and rebuild it automatically upon changes
Open(OpenArgs), Open(OpenArgs),
} }
#[derive(Args, Display)] #[derive(clap::Args)]
pub(crate) struct OpenArgs { pub(crate) struct OpenArgs {
/// Specify the port to open the book on (defaults to a random port if not specified) /// Specify the port to open the book on (defaults to a random port if not specified)
#[clap(long, default_value_t = random_port())] #[clap(long, default_value_t = random_port())]
@ -55,15 +45,7 @@ pub(crate) struct Book {
impl BooksArgs { impl BooksArgs {
pub(crate) fn parse(&self) -> anyhow::Result<()> { pub(crate) fn parse(&self) -> anyhow::Result<()> {
init_logger().init(); Book::run(&self.book)
let start = Instant::now();
Book::run(&self.book)?;
let duration = start.elapsed();
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
} }
} }
@ -91,37 +73,37 @@ impl Book {
&args.command, &args.command,
), ),
}; };
book.execute(command); book.execute(command)
}
fn execute(&self, command: &BookSubCommand) -> anyhow::Result<()> {
ensure_cargo_crate_is_installed("mdbook", None, None, false)?;
group!("{}: {}", self.name, command);
match command {
BookSubCommand::Build => self.build(),
BookSubCommand::Open(args) => self.open(args),
}?;
endgroup!();
Ok(()) Ok(())
} }
fn execute(&self, command: &BookCommand) { fn build(&self) -> anyhow::Result<()> {
ensure_cargo_crate_is_installed("mdbook"); run_process(
group!("{}: {}", self.name, command); "mdbook",
match command { &vec!["build"],
BookCommand::Build => self.build(), None,
BookCommand::Open(args) => self.open(args),
};
endgroup!();
}
fn build(&self) {
run_mdbook_with_path(
"build",
Params::from([]),
HashMap::new(),
Some(self.path), Some(self.path),
"mdbook should build the book successfully", "mdbook should build the book successfully",
); )
} }
fn open(&self, args: &OpenArgs) { fn open(&self, args: &OpenArgs) -> anyhow::Result<()> {
run_mdbook_with_path( run_process(
"serve", "mdbook",
Params::from(["--open", "--port", &args.port.to_string()]), &vec!["serve", "--open", "--port", &args.port.to_string()],
HashMap::new(), None,
Some(self.path), Some(self.path),
"mdbook should build the book successfully", "mdbook should open the book successfully",
); )
} }
} }

View File

@ -0,0 +1,86 @@
use std::collections::HashMap;
use strum::IntoEnumIterator;
use tracel_xtask::prelude::*;
use crate::{ARM_NO_ATOMIC_PTR_TARGET, ARM_TARGET, NO_STD_CRATES, WASM32_TARGET};
#[macros::extend_command_args(BuildCmdArgs, Target, None)]
pub struct BurnBuildCmdArgs {
/// Build in CI mode which excludes unsupported crates.
#[arg(long)]
pub ci: bool,
}
pub(crate) fn handle_command(
mut args: BurnBuildCmdArgs,
exec_env: ExecutionEnvironment,
) -> anyhow::Result<()> {
match exec_env {
ExecutionEnvironment::NoStd => {
[
"Default",
WASM32_TARGET,
ARM_TARGET,
ARM_NO_ATOMIC_PTR_TARGET,
]
.iter()
.try_for_each(|build_target| {
let mut build_args = vec!["--no-default-features"];
let mut env_vars = HashMap::new();
if *build_target != "Default" {
build_args.extend(vec!["--target", *build_target]);
}
if *build_target == ARM_NO_ATOMIC_PTR_TARGET {
env_vars.insert(
"RUSTFLAGS",
"--cfg portable_atomic_unsafe_assume_single_core",
);
}
helpers::custom_crates_build(
NO_STD_CRATES.to_vec(),
build_args,
Some(env_vars),
None,
&format!("no-std with target {}", *build_target),
)
})?;
Ok(())
}
ExecutionEnvironment::Std => {
if args.ci {
// Exclude crates that are not supported on CI
args.exclude
.extend(vec!["burn-cuda".to_string(), "burn-tch".to_string()]);
if std::env::var("DISABLE_WGPU").is_ok() {
args.exclude.extend(vec!["burn-wgpu".to_string()]);
};
}
// Build workspace
base_commands::build::handle_command(args.try_into().unwrap())?;
// Specific additional commands to test specific features
// burn-dataset
helpers::custom_crates_build(
vec!["burn-dataset"],
vec!["--all-features"],
None,
None,
"std with all features",
)?;
Ok(())
}
ExecutionEnvironment::All => ExecutionEnvironment::iter()
.filter(|env| *env != ExecutionEnvironment::All)
.try_for_each(|env| {
handle_command(
BurnBuildCmdArgs {
target: args.target.clone(),
exclude: args.exclude.clone(),
only: args.only.clone(),
ci: args.ci,
},
env,
)
}),
}
}

23
xtask/src/commands/doc.rs Normal file
View File

@ -0,0 +1,23 @@
use tracel_xtask::prelude::*;
pub(crate) fn handle_command(mut args: DocCmdArgs) -> anyhow::Result<()> {
if args.get_command() == DocSubCommand::Build {
args.exclude.push("burn-cuda".to_string());
}
// Execute documentation command on workspace
base_commands::doc::handle_command(args.clone())?;
// Specific additional commands to build other docs
if args.get_command() == DocSubCommand::Build {
// burn-dataset
helpers::custom_crates_doc_build(
vec!["burn-dataset"],
vec!["--all-features"],
None,
None,
"All features",
)?;
}
Ok(())
}

View File

@ -0,0 +1,5 @@
pub(crate) mod books;
pub(crate) mod build;
pub(crate) mod doc;
pub(crate) mod test;
pub(crate) mod validate;

116
xtask/src/commands/test.rs Normal file
View File

@ -0,0 +1,116 @@
use strum::IntoEnumIterator;
use tracel_xtask::prelude::*;
use crate::NO_STD_CRATES;
#[macros::extend_command_args(TestCmdArgs, Target, TestSubCommand)]
pub struct BurnTestCmdArgs {
/// Test in CI mode which excludes unsupported crates.
#[arg(long)]
pub ci: bool,
}
pub(crate) fn handle_command(
mut args: BurnTestCmdArgs,
exec_env: ExecutionEnvironment,
) -> anyhow::Result<()> {
match exec_env {
ExecutionEnvironment::NoStd => {
["Default"].iter().try_for_each(|test_target| {
let mut test_args = vec!["--no-default-features"];
if *test_target != "Default" {
test_args.extend(vec!["--target", *test_target]);
}
helpers::custom_crates_tests(
NO_STD_CRATES.to_vec(),
test_args,
None,
None,
"no-std",
)
})?;
Ok(())
}
ExecutionEnvironment::Std => {
if args.ci {
// Exclude crates that are not supported on CI
args.exclude
.extend(vec!["burn-cuda".to_string(), "burn-tch".to_string()]);
}
if std::env::var("DISABLE_WGPU").is_ok() {
args.exclude.extend(vec!["burn-wgpu".to_string()]);
};
// test workspace
base_commands::test::handle_command(args.try_into().unwrap())?;
// Specific additional commands to test specific features
// burn-dataset
helpers::custom_crates_tests(
vec!["burn-dataset"],
vec!["--all-features"],
None,
None,
"std all features",
)?;
// burn-core
helpers::custom_crates_tests(
vec!["burn-core"],
vec!["--features", "test-tch,record-item-custom-serde"],
None,
None,
"std with features: test-tch,record-item-custom-serde",
)?;
if std::env::var("DISABLE_WGPU").is_err() {
helpers::custom_crates_tests(
vec!["burn-core"],
vec!["--features", "test-wgpu"],
None,
None,
"std wgpu",
)?;
}
// MacOS specific tests
#[cfg(target_os = "macos")]
{
// burn-candle
helpers::custom_crates_tests(
vec!["burn-candle"],
vec!["--features", "accelerate"],
None,
None,
"std accelerate",
)?;
// burn-ndarray
helpers::custom_crates_tests(
vec!["burn-ndarray"],
vec!["--features", "blas-accelerate"],
None,
None,
"std blas-accelerate",
)?;
}
Ok(())
}
ExecutionEnvironment::All => ExecutionEnvironment::iter()
.filter(|env| *env != ExecutionEnvironment::All)
.try_for_each(|env| {
handle_command(
BurnTestCmdArgs {
command: args.command.clone(),
target: args.target.clone(),
exclude: args.exclude.clone(),
only: args.only.clone(),
threads: args.threads,
jobs: args.jobs,
ci: args.ci,
},
env,
)
}),
}
}

View File

@ -0,0 +1,111 @@
use tracel_xtask::prelude::*;
use crate::commands::{build::BurnBuildCmdArgs, test::BurnTestCmdArgs};
pub fn handle_command(
args: &ValidateCmdArgs,
exec_env: &ExecutionEnvironment,
) -> anyhow::Result<()> {
let target = Target::Workspace;
let exclude = vec![];
let only = vec![];
if *exec_env == ExecutionEnvironment::Std || *exec_env == ExecutionEnvironment::All {
// ==============
// std validation
// ==============
info!("Run validation for std execution environment...");
// checks
[
CheckSubCommand::Audit,
CheckSubCommand::Format,
CheckSubCommand::Lint,
CheckSubCommand::Typos,
]
.iter()
.try_for_each(|c| {
base_commands::check::handle_command(CheckCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
command: Some(c.clone()),
ignore_audit: args.ignore_audit,
})
})?;
// build
super::build::handle_command(
BurnBuildCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
ci: true,
},
ExecutionEnvironment::Std,
)?;
// tests
super::test::handle_command(
BurnTestCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
threads: None,
jobs: None,
command: Some(TestSubCommand::All),
ci: true,
},
ExecutionEnvironment::Std,
)?;
// documentation
[DocSubCommand::Build, DocSubCommand::Tests]
.iter()
.try_for_each(|c| {
super::doc::handle_command(DocCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
command: Some(c.clone()),
})
})?;
}
if *exec_env == ExecutionEnvironment::NoStd || *exec_env == ExecutionEnvironment::All {
// =================
// no-std validation
// =================
info!("Run validation for no-std execution environment...");
#[cfg(target_os = "linux")]
{
// build
super::build::handle_command(
BurnBuildCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
ci: true,
},
ExecutionEnvironment::NoStd,
)?;
// tests
super::test::handle_command(
BurnTestCmdArgs {
target: target.clone(),
exclude: exclude.clone(),
only: only.clone(),
threads: None,
jobs: None,
command: Some(TestSubCommand::All),
ci: true,
},
ExecutionEnvironment::NoStd,
)?;
}
}
Ok(())
}

View File

@ -1,106 +0,0 @@
use std::{collections::HashMap, time::Instant};
use crate::{
endgroup, group,
logging::init_logger,
utils::{
cargo::{ensure_cargo_crate_is_installed, run_cargo},
rustup::is_current_toolchain_nightly,
time::format_duration,
Params,
},
};
#[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)]
pub(crate) enum DependencyCheck {
/// Run all dependency checks.
#[default]
All,
/// Perform an audit of all dependencies using the cargo-audit crate `<https://crates.io/crates/cargo-audit>`
Audit,
/// Run cargo-deny check `<https://crates.io/crates/cargo-deny>`
Deny,
/// Run cargo-udeps to find unused dependencies `<https://crates.io/crates/cargo-udeps>`
Unused,
}
impl DependencyCheck {
pub(crate) fn run(&self) -> anyhow::Result<()> {
// Setup logger
init_logger().init();
// Start time measurement
let start = Instant::now();
match self {
Self::Audit => cargo_audit(),
Self::Deny => cargo_deny(),
Self::Unused => cargo_udeps(),
Self::All => {
cargo_audit();
cargo_deny();
cargo_udeps();
}
}
// Stop time measurement
//
// Compute runtime duration
let duration = start.elapsed();
// Print duration
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
}
}
/// Run cargo-audit
fn cargo_audit() {
ensure_cargo_crate_is_installed("cargo-audit");
// Run cargo audit
group!("Cargo: run audit checks");
run_cargo(
"audit",
Params::from([]),
HashMap::new(),
"Cargo audit should be installed and it should correctly run",
);
endgroup!();
}
/// Run cargo-deny
fn cargo_deny() {
ensure_cargo_crate_is_installed("cargo-deny");
// Run cargo deny
group!("Cargo: run deny checks");
run_cargo(
"deny",
Params::from(["check"]),
HashMap::new(),
"Cargo deny should be installed and it should correctly run",
);
endgroup!();
}
/// Run cargo-udeps
fn cargo_udeps() {
if is_current_toolchain_nightly() {
ensure_cargo_crate_is_installed("cargo-udeps");
// Run cargo udeps
group!("Cargo: run unused dependencies checks");
run_cargo(
"udeps",
Params::from([]),
HashMap::new(),
"Cargo udeps should be installed and it should correctly run",
);
endgroup!();
} else {
error!(
"You must use 'cargo +nightly' to check for unused dependencies.
Install a nightly toolchain with 'rustup toolchain install nightly'."
)
}
}

View File

@ -1,67 +0,0 @@
use std::io::Write;
/// Initialise and create a `env_logger::Builder` which follows the
/// GitHub Actions logging syntax when running on CI.
pub(crate) fn init_logger() -> env_logger::Builder {
let mut builder = env_logger::Builder::from_default_env();
builder.target(env_logger::Target::Stdout);
// Find and setup the correct log level
builder.filter(None, get_log_level());
builder.write_style(env_logger::WriteStyle::Always);
// Custom Formatter for Github Actions
if std::env::var("CI").is_ok() {
builder.format(|buf, record| match record.level().as_str() {
"DEBUG" => writeln!(buf, "::debug:: {}", record.args()),
"WARN" => writeln!(buf, "::warning:: {}", record.args()),
"ERROR" => {
writeln!(buf, "::error:: {}", record.args())
}
_ => writeln!(buf, "{}", record.args()),
});
}
builder
}
/// Determine the LogLevel for the logger
fn get_log_level() -> log::LevelFilter {
// DEBUG
match std::env::var("DEBUG") {
Ok(_value) => return log::LevelFilter::Debug,
Err(_err) => (),
}
// ACTIONS_RUNNER_DEBUG
match std::env::var("ACTIONS_RUNNER_DEBUG") {
Ok(_value) => return log::LevelFilter::Debug,
Err(_err) => (),
};
log::LevelFilter::Info
}
/// Group Macro
#[macro_export]
macro_rules! group {
// group!()
($($arg:tt)*) => {
let title = format!($($arg)*);
if std::env::var("CI").is_ok() {
log!(log::Level::Info, "::group::{}", title)
} else {
log!(log::Level::Info, "{}", title)
}
};
}
/// End Group Macro
#[macro_export]
macro_rules! endgroup {
// endgroup!()
() => {
if std::env::var("CI").is_ok() {
log!(log::Level::Info, "::endgroup::")
}
};
}

View File

@ -1,61 +1,76 @@
use clap::{Parser, Subcommand}; mod commands;
mod books;
mod dependencies;
mod logging;
mod publish;
mod runchecks;
mod utils;
mod vulnerabilities;
#[macro_use] #[macro_use]
extern crate log; extern crate log;
#[derive(Parser)] use std::time::Instant;
#[command(author, version, about, long_about = None)] use tracel_xtask::prelude::*;
struct Args {
#[command(subcommand)]
command: Command,
}
#[derive(Subcommand)] // no-std
enum Command { const WASM32_TARGET: &str = "wasm32-unknown-unknown";
/// Run commands to manage Burn Books const ARM_TARGET: &str = "thumbv7m-none-eabi";
Books(books::BooksArgs), const ARM_NO_ATOMIC_PTR_TARGET: &str = "thumbv6m-none-eabi";
/// Run the specified dependencies check locally const NO_STD_CRATES: &[&str] = &[
Dependencies { "burn",
/// The dependency check to run "burn-core",
dependency_check: dependencies::DependencyCheck, "burn-common",
}, "burn-tensor",
/// Publish a crate to crates.io "burn-ndarray",
Publish { "burn-no-std-tests",
/// The name of the crate to publish on crates.io ];
name: String,
}, #[macros::base_commands(
/// Run the specified `burn` tests and checks locally. Bump,
RunChecks { Check,
/// The environment to run checks against Compile,
#[clap(value_enum, default_value_t = runchecks::CheckType::default())] Coverage,
env: runchecks::CheckType, Doc,
}, Dependencies,
/// Run the specified vulnerability check locally. These commands must be called with 'cargo +nightly'. Fix,
Vulnerabilities { Publish,
/// The vulnerability check to run. Validate,
/// For the reference visit the page `<https://doc.rust-lang.org/beta/unstable-book/compiler-flags/sanitizer.html>` Vulnerabilities
vulnerability_check: vulnerabilities::VulnerabilityCheck, )]
}, pub enum Command {
/// Run commands to manage Burn Books.
Books(commands::books::BooksArgs),
/// Build Burn in different modes.
Build(commands::build::BurnBuildCmdArgs),
/// Test Burn.
Test(commands::test::BurnTestCmdArgs),
} }
fn main() -> anyhow::Result<()> { fn main() -> anyhow::Result<()> {
let args = Args::parse(); let start = Instant::now();
let args = init_xtask::<Command>()?;
if args.execution_environment == ExecutionEnvironment::NoStd {
// Install additional targets for no-std execution environments
rustup_add_target(WASM32_TARGET)?;
rustup_add_target(ARM_TARGET)?;
rustup_add_target(ARM_NO_ATOMIC_PTR_TARGET)?;
}
match args.command { match args.command {
Command::Books(args) => args.parse(), Command::Books(cmd_args) => cmd_args.parse(),
Command::Dependencies { dependency_check } => dependency_check.run(), Command::Build(cmd_args) => {
Command::Publish { name } => publish::run(name), commands::build::handle_command(cmd_args, args.execution_environment)
Command::RunChecks { env } => env.run(), }
Command::Vulnerabilities { Command::Doc(cmd_args) => commands::doc::handle_command(cmd_args),
vulnerability_check, Command::Test(cmd_args) => {
} => vulnerability_check.run(), commands::test::handle_command(cmd_args, args.execution_environment)
} }
Command::Validate(cmd_args) => {
commands::validate::handle_command(&cmd_args, &args.execution_environment)
}
_ => dispatch_base_commands(args),
}?;
let duration = start.elapsed();
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
} }

View File

@ -1,114 +0,0 @@
//! This script publishes a crate on `crates.io`.
//!
//! To run the script:
//!
//! cargo xtask publish INPUT_CRATE
use std::{collections::HashMap, env, process::Command, str};
use crate::{
endgroup, group,
utils::{cargo::run_cargo, Params},
};
// Crates.io API token
const CRATES_IO_API_TOKEN: &str = "CRATES_IO_API_TOKEN";
// Obtain local crate version
fn local_version(crate_name: &str) -> String {
// Obtain local crate version contained in cargo pkgid data
let cargo_pkgid_output = Command::new("cargo")
.args(["pkgid", "-p", crate_name])
.output()
.expect("Failed to run cargo pkgid");
// Convert cargo pkgid output into a str
let cargo_pkgid_str = str::from_utf8(&cargo_pkgid_output.stdout)
.expect("Failed to convert pkgid output into a str");
// Extract only the local crate version from str
let (_, local_version) = cargo_pkgid_str
.split_once('#')
.expect("Failed to get local crate version");
local_version.trim_end().to_string()
}
// Obtain remote crate version
fn remote_version(crate_name: &str) -> Option<String> {
// Obtain remote crate version contained in cargo search data
let cargo_search_output = Command::new("cargo")
.args(["search", crate_name, "--limit", "1"])
.output()
.expect("Failed to run cargo search");
// Cargo search returns an empty string in case of a crate not present on
// crates.io
if cargo_search_output.stdout.is_empty() {
None
} else {
// Convert cargo search output into a str
let remote_version_str = str::from_utf8(&cargo_search_output.stdout)
.expect("Failed to convert cargo search output into a str");
// Extract only the remote crate version from str
remote_version_str
.split_once('=')
.and_then(|(_, second)| second.trim_start().split_once(' '))
.map(|(s, _)| s.trim_matches('"').to_string())
}
}
fn publish(crate_name: String) {
// Perform dry-run to ensure everything is good for publishing
let dry_run_params = Params::from(["-p", &crate_name, "--dry-run"]);
run_cargo(
"publish",
dry_run_params,
HashMap::new(),
"The cargo publish --dry-run should complete successfully, indicating readiness for actual publication",
);
let crates_io_token =
env::var(CRATES_IO_API_TOKEN).expect("Failed to retrieve the crates.io API token");
let envs = HashMap::from([("CRATES_IO_API_TOKEN", crates_io_token.clone())]);
let publish_params = Params::from(vec!["-p", &crate_name, "--token", &crates_io_token]);
// Actually publish the crate
run_cargo(
"publish",
publish_params,
envs,
"The crate should be successfully published",
);
}
pub(crate) fn run(crate_name: String) -> anyhow::Result<()> {
group!("Publishing {}...\n", crate_name);
// Retrieve local version for crate
let local_version = local_version(&crate_name);
info!("{crate_name} local version: {local_version}");
// Retrieve remote version for crate if it exists
match remote_version(&crate_name) {
Some(remote_version) => {
info!("{crate_name} remote version: {remote_version}\n");
// Early return if we don't need to publish the crate
if local_version == remote_version {
info!("Remote version {remote_version} is up to date, skipping deployment");
return Ok(());
}
}
None => info!("\nFirst time publishing {crate_name} on crates.io!\n"),
}
// Publish the crate
publish(crate_name);
endgroup!();
Ok(())
}

View File

@ -1,427 +0,0 @@
//! This script is run before a PR is created.
//!
//! It is used to check that the code compiles and passes all tests.
//!
//! It is also used to check that the code is formatted correctly and passes clippy.
use std::collections::HashMap;
use std::env;
use std::process::{Command, Stdio};
use std::str;
use std::time::Instant;
use crate::logging::init_logger;
use crate::utils::cargo::{run_cargo, run_cargo_with_path};
use crate::utils::process::{handle_child_process, run_command};
use crate::utils::rustup::{rustup_add_component, rustup_add_target};
use crate::utils::time::format_duration;
use crate::utils::workspace::{get_workspace_members, WorkspaceMemberType};
use crate::utils::Params;
use crate::{endgroup, group};
// Targets constants
const WASM32_TARGET: &str = "wasm32-unknown-unknown";
const ARM_TARGET: &str = "thumbv7m-none-eabi";
const ARM_NO_ATOMIC_PTR_TARGET: &str = "thumbv6m-none-eabi";
#[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)]
pub(crate) enum CheckType {
/// Run all checks except examples
#[default]
All,
/// Run `std` environment checks
Std,
/// Run `no-std` environment checks
NoStd,
/// Check for typos
Typos,
/// Test the examples
Examples,
}
impl CheckType {
pub(crate) fn run(&self) -> anyhow::Result<()> {
// Setup logger
init_logger().init();
// Start time measurement
let start = Instant::now();
// The environment can assume ONLY "std", "no_std", "typos", "examples"
//
// Depending on the input argument, the respective environment checks
// are run.
//
// If no `environment` value has been passed, run all checks except examples.
match self {
Self::Std => std_checks(),
Self::NoStd => no_std_checks(),
Self::Typos => check_typos(),
Self::Examples => check_examples(),
Self::All => {
/* Run all checks */
check_typos();
std_checks();
no_std_checks();
}
}
// Stop time measurement
//
// Compute runtime duration
let duration = start.elapsed();
// Print duration
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
}
}
/// Run cargo build command
fn cargo_build(params: Params, envs: Option<HashMap<&str, String>>) {
// Run cargo build
run_cargo(
"build",
params + "--color=always",
envs.unwrap_or_default(),
"Failed to run cargo build",
);
}
/// Run cargo install command
fn cargo_install(params: Params) {
// Run cargo install
run_cargo(
"install",
params + "--color=always",
HashMap::new(),
"Failed to run cargo install",
);
}
/// Run cargo test command
fn cargo_test(params: Params) {
// Run cargo test
run_cargo(
"test",
params + "--color=always" + "--" + "--color=always",
HashMap::new(),
"Failed to run cargo test",
);
}
/// Run cargo fmt command
fn cargo_fmt() {
group!("Cargo: fmt");
run_cargo(
"fmt",
["--check", "--all", "--", "--color=always"].into(),
HashMap::new(),
"Failed to run cargo fmt",
);
endgroup!();
}
/// Run cargo clippy command
fn cargo_clippy() {
if std::env::var("CI").is_ok() {
return;
}
// Run cargo clippy
run_cargo(
"clippy",
["--color=always", "--all-targets", "--", "-D", "warnings"].into(),
HashMap::new(),
"Failed to run cargo clippy",
);
}
/// Run cargo doc command
fn cargo_doc(params: Params) {
// Run cargo doc
run_cargo(
"doc",
params + "--color=always",
HashMap::new(),
"Failed to run cargo doc",
);
}
// Build and test a crate in a no_std environment
fn build_and_test_no_std<const N: usize>(crate_name: &str, extra_args: [&str; N]) {
group!("Checks: {} (no-std)", crate_name);
// Run cargo build --no-default-features
cargo_build(
Params::from(["-p", crate_name, "--no-default-features"]) + extra_args,
None,
);
// Run cargo test --no-default-features
cargo_test(Params::from(["-p", crate_name, "--no-default-features"]) + extra_args);
// Run cargo build --no-default-features --target wasm32-unknown-unknowns
cargo_build(
Params::from([
"-p",
crate_name,
"--no-default-features",
"--target",
WASM32_TARGET,
]) + extra_args,
None,
);
// Run cargo build --no-default-features --target thumbv7m-none-eabi
cargo_build(
Params::from([
"-p",
crate_name,
"--no-default-features",
"--target",
ARM_TARGET,
]) + extra_args,
None,
);
// Run cargo build --no-default-features --target thumbv6m-none-eabi
cargo_build(
Params::from([
"-p",
crate_name,
"--no-default-features",
"--target",
ARM_NO_ATOMIC_PTR_TARGET,
]) + extra_args,
Some(HashMap::from([(
"RUSTFLAGS",
"--cfg portable_atomic_unsafe_assume_single_core".to_string(),
)])),
);
endgroup!();
}
// Setup code coverage
fn setup_coverage() {
// Install llvm-tools-preview
rustup_add_component("llvm-tools-preview");
// Set coverage environment variables
env::set_var("RUSTFLAGS", "-Cinstrument-coverage");
env::set_var("LLVM_PROFILE_FILE", "burn-%p-%m.profraw");
}
// Run grcov to produce lcov.info
fn run_grcov() {
// grcov arguments
#[rustfmt::skip]
let args = [
".",
"--binary-path", "./target/debug/",
"-s", ".",
"-t", "lcov",
"--branch",
"--ignore-not-existing",
"--ignore", "/*", // It excludes std library code coverage from analysis
"--ignore", "xtask/*",
"--ignore", "examples/*",
"-o", "lcov.info",
];
run_command(
"grcov",
&args,
"Failed to run grcov",
"Failed to wait for grcov child process",
);
}
// Run no_std checks
fn no_std_checks() {
// Install wasm32 target
rustup_add_target(WASM32_TARGET);
// Install ARM target
rustup_add_target(ARM_TARGET);
// Install ARM no atomic ptr target
rustup_add_target(ARM_NO_ATOMIC_PTR_TARGET);
// Run checks for the following crates
build_and_test_no_std("burn", []);
build_and_test_no_std("burn-core", []);
build_and_test_no_std("burn-common", []);
build_and_test_no_std("burn-tensor", []);
build_and_test_no_std("burn-ndarray", []);
build_and_test_no_std("burn-no-std-tests", []);
}
// Test burn-core with tch and wgpu backend
fn burn_core_std() {
// Run cargo test --features test-tch, record-item-custom-serde
group!("Test: burn-core (tch) and record-item-custom-serde");
cargo_test(
[
"-p",
"burn-core",
"--features",
"test-tch,record-item-custom-serde,",
]
.into(),
);
endgroup!();
// Run cargo test --features test-wgpu
if std::env::var("DISABLE_WGPU").is_err() {
group!("Test: burn-core (wgpu)");
cargo_test(["-p", "burn-core", "--features", "test-wgpu"].into());
endgroup!();
}
}
// Test burn-dataset features
fn burn_dataset_features_std() {
group!("Checks: burn-dataset (all-features)");
// Run cargo build --all-features
cargo_build(["-p", "burn-dataset", "--all-features"].into(), None);
// Run cargo test --all-features
cargo_test(["-p", "burn-dataset", "--all-features"].into());
// Run cargo doc --all-features
cargo_doc(["-p", "burn-dataset", "--all-features", "--no-deps"].into());
endgroup!();
}
// macOS only checks
#[cfg(target_os = "macos")]
fn macos_checks() {
// Leverages the macOS Accelerate framework: https://developer.apple.com/documentation/accelerate
group!("Checks: burn-candle (accelerate)");
cargo_test(["-p", "burn-candle", "--features", "accelerate"].into());
endgroup!();
// Leverages the macOS Accelerate framework: https://developer.apple.com/documentation/accelerate
group!("Checks: burn-ndarray (accelerate)");
cargo_test(["-p", "burn-ndarray", "--features", "blas-accelerate"].into());
endgroup!();
}
fn std_checks() {
// Set RUSTDOCFLAGS environment variable to treat warnings as errors
// for the documentation build
env::set_var("RUSTDOCFLAGS", "-D warnings");
// Check if COVERAGE environment variable is set
let is_coverage = std::env::var("COVERAGE").is_ok();
let disable_wgpu = std::env::var("DISABLE_WGPU").is_ok();
// Check format
cargo_fmt();
// Check clippy lints
cargo_clippy();
// Produce documentation for each workspace member
group!("Docs: crates");
let mut params = Params::from(["--workspace", "--no-deps"]);
// Exclude burn-cuda on all platforms
params.params.push("--exclude".to_string());
params.params.push("burn-cuda".to_string());
cargo_doc(params);
endgroup!();
// Setup code coverage
if is_coverage {
setup_coverage();
}
// Build & test each member in workspace
let members = get_workspace_members(WorkspaceMemberType::Crate);
for member in members {
if disable_wgpu && member.name == "burn-wgpu" {
continue;
}
if member.name == "burn-cuda" {
// burn-cuda requires CUDA Toolkit which is not currently setup on our CI runners
continue;
}
if member.name == "burn-tch" {
continue;
}
group!("Checks: {}", member.name);
cargo_build(Params::from(["-p", &member.name]), None);
cargo_test(Params::from(["-p", &member.name]));
endgroup!();
}
// Test burn-candle with accelerate (macOS only)
#[cfg(target_os = "macos")]
macos_checks();
// Test burn-dataset features
burn_dataset_features_std();
// Test burn-core with tch and wgpu backend
burn_core_std();
// Run grcov and produce lcov.info
if is_coverage {
run_grcov();
}
}
fn check_typos() {
// This path defines where typos-cli is installed on different
// operating systems.
let typos_cli_path = std::env::var("CARGO_HOME")
.map(|v| std::path::Path::new(&v).join("bin/typos-cli"))
.unwrap();
// Do not run cargo install on CI to speed up the computation.
// Check whether the file has been installed on
if std::env::var("CI").is_err() && !typos_cli_path.exists() {
// Install typos-cli
cargo_install(["typos-cli", "--version", "1.16.5"].into());
}
info!("Running typos check \n\n");
// Run typos command as child process
let typos = Command::new("typos")
.args(["--exclude", "**/*.onnx"])
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()) // Send stderr directly to terminal
.spawn()
.expect("Failed to run typos");
// Handle typos child process
handle_child_process(typos, "Failed to wait for typos child process");
}
fn check_examples() {
let members = get_workspace_members(WorkspaceMemberType::Example);
for member in members {
if member.name == "notebook" {
continue;
}
group!("Checks: Example - {}", member.name);
run_cargo_with_path(
"check",
["--examples"].into(),
HashMap::new(),
Some(member.path),
"Failed to check example",
);
endgroup!();
}
}

View File

@ -1,66 +0,0 @@
use std::{
collections::HashMap,
path::Path,
process::{Command, Stdio},
};
use crate::{endgroup, group, utils::process::handle_child_process};
use super::Params;
/// Run a cargo command
pub(crate) fn run_cargo(command: &str, params: Params, envs: HashMap<&str, String>, error: &str) {
run_cargo_with_path::<String>(command, params, envs, None, error)
}
/// Run a cargo command with the passed directory as the current directory
pub(crate) fn run_cargo_with_path<P: AsRef<Path>>(
command: &str,
params: Params,
envs: HashMap<&str, String>,
path: Option<P>,
error: &str,
) {
info!("cargo {} {}\n", command, params.params.join(" "));
let mut cargo = Command::new("cargo");
cargo
.env("CARGO_INCREMENTAL", "0")
.envs(&envs)
.arg(command)
.args(&params.params)
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()); // Send stderr directly to terminal
if let Some(path) = path {
cargo.current_dir(path);
}
// Handle cargo child process
let cargo_process = cargo.spawn().expect(error);
handle_child_process(cargo_process, "Cargo process should run flawlessly");
}
/// Ensure that a cargo crate is installed
pub(crate) fn ensure_cargo_crate_is_installed(crate_name: &str) {
if !is_cargo_crate_installed(crate_name) {
group!("Cargo: install crate '{}'", crate_name);
run_cargo(
"install",
[crate_name].into(),
HashMap::new(),
&format!("crate '{}' should be installed", crate_name),
);
endgroup!();
}
}
/// Returns true if the passed cargo crate is installed locally
fn is_cargo_crate_installed(crate_name: &str) -> bool {
let output = Command::new("cargo")
.arg("install")
.arg("--list")
.output()
.expect("Should get the list of installed cargo commands");
let output_str = String::from_utf8_lossy(&output.stdout);
output_str.lines().any(|line| line.contains(crate_name))
}

View File

@ -1,35 +0,0 @@
use std::{
collections::HashMap,
path::Path,
process::{Command, Stdio},
};
use crate::utils::process::handle_child_process;
use super::Params;
/// Run a mdbook command with the passed directory as the current directory
pub(crate) fn run_mdbook_with_path<P: AsRef<Path>>(
command: &str,
params: Params,
envs: HashMap<&str, String>,
path: Option<P>,
error: &str,
) {
info!("mdbook {} {}\n", command, params.params.join(" "));
let mut mdbook = Command::new("mdbook");
mdbook
.envs(&envs)
.arg(command)
.args(&params.params)
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()); // Send stderr directly to terminal
if let Some(path) = path {
mdbook.current_dir(path);
}
// Handle mdbook child process
let mdbook_process = mdbook.spawn().expect(error);
handle_child_process(mdbook_process, "mdbook process should run flawlessly");
}

View File

@ -1,50 +0,0 @@
pub(crate) mod cargo;
pub(crate) mod mdbook;
pub(crate) mod process;
pub(crate) mod rustup;
pub(crate) mod time;
pub(crate) mod workspace;
pub(crate) struct Params {
pub params: Vec<String>,
}
impl<const N: usize> From<[&str; N]> for Params {
fn from(value: [&str; N]) -> Self {
Self {
params: value.iter().map(|v| v.to_string()).collect(),
}
}
}
impl From<&str> for Params {
fn from(value: &str) -> Self {
Self {
params: vec![value.to_string()],
}
}
}
impl From<Vec<&str>> for Params {
fn from(value: Vec<&str>) -> Self {
Self {
params: value.iter().map(|s| s.to_string()).collect(),
}
}
}
impl std::fmt::Display for Params {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.params.join(" ").as_str())
}
}
impl<Rhs: Into<Params>> std::ops::Add<Rhs> for Params {
type Output = Params;
fn add(mut self, rhs: Rhs) -> Self::Output {
let rhs: Params = rhs.into();
self.params.extend(rhs.params);
self
}
}

View File

@ -1,38 +0,0 @@
use rand::Rng;
use std::process::{Child, Command, Stdio};
/// Handle child process
pub(crate) fn handle_child_process(mut child: Child, error: &str) {
// Wait for the child process to finish
let status = child.wait().expect(error);
// If exit status is not a success, terminate the process with an error
if !status.success() {
// Use the exit code associated to a command to terminate the process,
// if any exit code had been found, use the default value 1
std::process::exit(status.code().unwrap_or(1));
}
}
/// Run a command
pub(crate) fn run_command(command: &str, args: &[&str], command_error: &str, child_error: &str) {
// Format command
info!("{command} {}\n\n", args.join(" "));
// Run command as child process
let command = Command::new(command)
.args(args)
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()) // Send stderr directly to terminal
.spawn()
.expect(command_error);
// Handle command child process
handle_child_process(command, child_error);
}
/// Return a random port between 3000 and 9999
pub(crate) fn random_port() -> u16 {
let mut rng = rand::thread_rng();
rng.gen_range(3000..=9999)
}

View File

@ -1,68 +0,0 @@
use std::process::{Command, Stdio};
use crate::{endgroup, group, utils::process::handle_child_process};
use super::Params;
/// Run rustup command
pub(crate) fn rustup(command: &str, params: Params, expected: &str) {
info!("rustup {} {}\n", command, params);
// Run rustup
let mut rustup = Command::new("rustup");
rustup
.arg(command)
.args(params.params)
.stdout(Stdio::inherit()) // Send stdout directly to terminal
.stderr(Stdio::inherit()); // Send stderr directly to terminal
let cargo_process = rustup.spawn().expect(expected);
handle_child_process(cargo_process, "Failed to wait for rustup child process");
}
/// Add a Rust target
pub(crate) fn rustup_add_target(target: &str) {
group!("Rustup: add target {}", target);
rustup(
"target",
Params::from(["add", target]),
"Target should be added",
);
endgroup!();
}
/// Add a Rust component
pub(crate) fn rustup_add_component(component: &str) {
group!("Rustup: add component {}", component);
rustup(
"component",
Params::from(["add", component]),
"Component should be added",
);
endgroup!();
}
// Returns the output of the rustup command to get the installed targets
pub(crate) fn rustup_get_installed_targets() -> String {
let output = Command::new("rustup")
.args(["target", "list", "--installed"])
.stdout(Stdio::piped())
.output()
.expect("Rustup command should execute successfully");
String::from_utf8(output.stdout).expect("Output should be valid UTF-8")
}
/// Returns true if the current toolchain is the nightly
pub(crate) fn is_current_toolchain_nightly() -> bool {
let output = Command::new("rustup")
.arg("show")
.output()
.expect("Should get the list of installed Rust toolchains");
let output_str = String::from_utf8_lossy(&output.stdout);
for line in output_str.lines() {
// look for the "rustc.*-nightly" line
if line.contains("rustc") && line.contains("-nightly") {
return true;
}
}
// assume we are using a stable toolchain if we did not find the nightly compiler
false
}

View File

@ -1,15 +0,0 @@
use std::time::Duration;
/// Print duration as HH:MM:SS format
pub(crate) fn format_duration(duration: &Duration) -> String {
let seconds = duration.as_secs();
let minutes = seconds / 60;
let hours = minutes / 60;
let remaining_minutes = minutes % 60;
let remaining_seconds = seconds % 60;
format!(
"{:02}:{:02}:{:02}",
hours, remaining_minutes, remaining_seconds
)
}

View File

@ -1,90 +0,0 @@
use std::{path::Path, process::Command};
use serde_json::Value;
const MEMBER_PATH_PREFIX: &str = if cfg!(target_os = "windows") {
"path+file:///"
} else {
"path+file://"
};
pub(crate) enum WorkspaceMemberType {
Crate,
Example,
}
#[derive(Debug)]
pub(crate) struct WorkspaceMember {
pub(crate) name: String,
pub(crate) path: String,
}
impl WorkspaceMember {
fn new(name: String, path: String) -> Self {
Self { name, path }
}
}
/// Get workspace crates
pub(crate) fn get_workspace_members(w_type: WorkspaceMemberType) -> Vec<WorkspaceMember> {
// Run `cargo metadata` command to get project metadata
let output = Command::new("cargo")
.arg("metadata")
.output()
.expect("Failed to execute command");
// Parse the JSON output
let metadata: Value = serde_json::from_slice(&output.stdout).expect("Failed to parse JSON");
// Extract workspaces from the metadata, excluding examples/ and xtask
let workspaces = metadata["workspace_members"]
.as_array()
.expect("Expected an array of workspace members")
.iter()
.filter_map(|member| {
let member_str = member.as_str()?;
let has_whitespace = member_str.chars().any(|c| c.is_whitespace());
let (name, path) = if has_whitespace {
parse_workspace_member0(member_str)?
} else {
parse_workspace_member1(member_str)?
};
match w_type {
WorkspaceMemberType::Crate if name != "xtask" && !path.contains("examples/") => {
Some(WorkspaceMember::new(name.to_string(), path.to_string()))
}
WorkspaceMemberType::Example if name != "xtask" && path.contains("examples/") => {
Some(WorkspaceMember::new(name.to_string(), path.to_string()))
}
_ => None,
}
})
.collect();
workspaces
}
/// Legacy cargo metadata format for member specs (rust < 1.77)
/// Example:
/// "backend-comparison 0.13.0 (path+file:///Users/username/burn/backend-comparison)"
fn parse_workspace_member0(specs: &str) -> Option<(String, String)> {
let parts: Vec<_> = specs.split_whitespace().collect();
let (name, path) = (parts.first()?.to_owned(), parts.last()?.to_owned());
// skip the first character because it is a '('
let path = path
.chars()
.skip(1)
.collect::<String>()
.replace(MEMBER_PATH_PREFIX, "")
.replace(')', "");
Some((name.to_string(), path.to_string()))
}
/// Cargo metadata format for member specs (rust >= 1.77)
/// Example:
/// "path+file:///Users/username/burn/backend-comparison#0.13.0"
fn parse_workspace_member1(specs: &str) -> Option<(String, String)> {
let no_prefix = specs.replace(MEMBER_PATH_PREFIX, "").replace(')', "");
let path = Path::new(no_prefix.split_once('#')?.0);
let name = path.file_name()?.to_str()?;
let path = path.to_str()?;
Some((name.to_string(), path.to_string()))
}

View File

@ -1,396 +0,0 @@
use std::collections::HashMap;
use std::time::Instant;
use crate::logging::init_logger;
use crate::utils::cargo::{ensure_cargo_crate_is_installed, run_cargo};
use crate::utils::rustup::{
is_current_toolchain_nightly, rustup_add_component, rustup_get_installed_targets,
};
use crate::utils::time::format_duration;
use crate::utils::Params;
use crate::{endgroup, group};
use std::fmt;
#[derive(clap::ValueEnum, Default, Copy, Clone, PartialEq, Eq)]
pub(crate) enum VulnerabilityCheck {
/// Run all most useful vulnerability checks.
#[default]
All,
/// Run Address sanitizer (memory error detector)
AddressSanitizer,
/// Run LLVM Control Flow Integrity (CFI) (provides forward-edge control flow protection)
ControlFlowIntegrity,
/// Run newer variant of Address sanitizer (memory error detector similar to AddressSanitizer, but based on partial hardware assistance)
HWAddressSanitizer,
/// Run Kernel LLVM Control Flow Integrity (KCFI) (provides forward-edge control flow protection for operating systems kernels)
KernelControlFlowIntegrity,
/// Run Leak sanitizer (run-time memory leak detector)
LeakSanitizer,
/// Run memory sanitizer (detector of uninitialized reads)
MemorySanitizer,
/// Run another address sanitizer (like AddressSanitizer and HardwareAddressSanitizer but with lower overhead suitable for use as hardening for production binaries)
MemTagSanitizer,
/// Run nightly-only checks through cargo-careful `<https://crates.io/crates/cargo-careful>`
NightlyChecks,
/// Run SafeStack check (provides backward-edge control flow protection by separating
/// stack into safe and unsafe regions)
SafeStack,
/// Run ShadowCall check (provides backward-edge control flow protection - aarch64 only)
ShadowCallStack,
/// Run Thread sanitizer (data race detector)
ThreadSanitizer,
}
impl VulnerabilityCheck {
pub(crate) fn run(&self) -> anyhow::Result<()> {
// Setup logger
init_logger().init();
// Start time measurement
let start = Instant::now();
match self {
Self::NightlyChecks => cargo_careful(),
Self::AddressSanitizer => Sanitizer::Address.run_tests(),
Self::ControlFlowIntegrity => Sanitizer::CFI.run_tests(),
Self::HWAddressSanitizer => Sanitizer::HWAddress.run_tests(),
Self::KernelControlFlowIntegrity => Sanitizer::KCFI.run_tests(),
Self::LeakSanitizer => Sanitizer::Leak.run_tests(),
Self::MemorySanitizer => Sanitizer::Memory.run_tests(),
Self::MemTagSanitizer => Sanitizer::MemTag.run_tests(),
Self::SafeStack => Sanitizer::SafeStack.run_tests(),
Self::ShadowCallStack => Sanitizer::ShadowCallStack.run_tests(),
Self::ThreadSanitizer => Sanitizer::Thread.run_tests(),
Self::All => {
cargo_careful();
Sanitizer::Address.run_tests();
Sanitizer::Leak.run_tests();
Sanitizer::Memory.run_tests();
Sanitizer::SafeStack.run_tests();
Sanitizer::Thread.run_tests();
}
}
// Stop time measurement
//
// Compute runtime duration
let duration = start.elapsed();
// Print duration
info!(
"\x1B[32;1mTime elapsed for the current execution: {}\x1B[0m",
format_duration(&duration)
);
Ok(())
}
}
/// Run cargo-careful
fn cargo_careful() {
if is_current_toolchain_nightly() {
ensure_cargo_crate_is_installed("cargo-careful");
rustup_add_component("rust-src");
// prepare careful sysroot
group!("Cargo: careful setup");
run_cargo(
"careful",
Params::from(["setup"]),
HashMap::new(),
"Cargo sysroot should be available",
);
endgroup!();
// Run cargo careful
group!("Cargo: run careful checks");
run_cargo(
"careful",
Params::from(["test"]),
HashMap::new(),
"Cargo careful should be installed and it should correctly run",
);
endgroup!();
} else {
error!(
"You must use 'cargo +nightly' to run nightly checks.
Install a nightly toolchain with 'rustup toolchain install nightly'."
)
}
}
// Represents the various sanitizer available in nightly compiler
// source: https://doc.rust-lang.org/beta/unstable-book/compiler-flags/sanitizer.html
#[allow(clippy::upper_case_acronyms)]
enum Sanitizer {
Address,
CFI,
HWAddress,
KCFI,
Leak,
Memory,
MemTag,
SafeStack,
ShadowCallStack,
Thread,
}
impl fmt::Display for Sanitizer {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Sanitizer::Address => write!(f, "AddressSanitizer"),
Sanitizer::CFI => write!(f, "ControlFlowIntegrity"),
Sanitizer::HWAddress => write!(f, "HWAddressSanitizer"),
Sanitizer::KCFI => write!(f, "KernelControlFlowIntegrity"),
Sanitizer::Leak => write!(f, "LeakSanitizer"),
Sanitizer::Memory => write!(f, "MemorySanitizer"),
Sanitizer::MemTag => write!(f, "MemTagSanitizer"),
Sanitizer::SafeStack => write!(f, "SafeStack"),
Sanitizer::ShadowCallStack => write!(f, "ShadowCallStack"),
Sanitizer::Thread => write!(f, "ThreadSanitizer"),
}
}
}
impl Sanitizer {
const DEFAULT_RUSTFLAGS: &'static str = "-Copt-level=3";
fn run_tests(&self) {
if is_current_toolchain_nightly() {
group!("Sanitizer: {}", self.to_string());
let retriever = RustupTargetRetriever;
if self.is_target_supported(&retriever) {
let envs = vec![
(
"RUSTFLAGS",
format!("{} {}", self.flags(), Sanitizer::DEFAULT_RUSTFLAGS),
),
("RUSTDOCFLAGS", self.flags().to_string()),
];
let features = self.cargo_features();
let mut args = vec!["--", "--color=always", "--no-capture"];
args.extend(features);
run_cargo(
"test",
args.into(),
envs.into_iter().collect(),
"Failed to run cargo test",
);
} else {
info!("No supported target found for this sanitizer.");
}
endgroup!();
} else {
error!(
"You must use 'cargo +nightly' to run this check.
Install a nightly toolchain with 'rustup toolchain install nightly'."
)
}
}
fn flags(&self) -> &'static str {
match self {
Sanitizer::Address => "-Zsanitizer=address",
Sanitizer::CFI => "-Zsanitizer=cfi -Clto",
Sanitizer::HWAddress => "-Zsanitizer=hwaddress -Ctarget-feature=+tagged-globals",
Sanitizer::KCFI => "-Zsanitizer=kcfi",
Sanitizer::Leak => "-Zsanitizer=leak",
Sanitizer::Memory => "-Zsanitizer=memory -Zsanitizer-memory-track-origins",
Sanitizer::MemTag => "--Zsanitizer=memtag -Ctarget-feature=\"+mte\"",
Sanitizer::SafeStack => "-Zsanitizer=safestack",
Sanitizer::ShadowCallStack => "-Zsanitizer=shadow-call-stack",
Sanitizer::Thread => "-Zsanitizer=thread",
}
}
fn cargo_features(&self) -> Vec<&str> {
match self {
Sanitizer::CFI => vec!["-Zbuild-std", "--target x86_64-unknown-linux-gnu"],
_ => vec![],
}
}
fn supported_targets(&self) -> Vec<Target> {
match self {
Sanitizer::Address => vec![
Target::Aarch64AppleDarwin,
Target::Aarch64UnknownFuchsia,
Target::Aarch64UnknownLinuxGnu,
Target::X8664AppleDarwin,
Target::X8664UnknownFuchsia,
Target::X8664UnknownFreebsd,
Target::X8664UnknownLinuxGnu,
],
Sanitizer::CFI => vec![Target::X8664UnknownLinuxGnu],
Sanitizer::HWAddress => {
vec![Target::Aarch64LinuxAndroid, Target::Aarch64UnknownLinuxGnu]
}
Sanitizer::KCFI => vec![
Target::Aarch64LinuxAndroid,
Target::Aarch64UnknownLinuxGnu,
Target::X8664LinuxAndroid,
Target::X8664UnknownLinuxGnu,
],
Sanitizer::Leak => vec![
Target::Aarch64AppleDarwin,
Target::Aarch64UnknownLinuxGnu,
Target::X8664AppleDarwin,
Target::X8664UnknownLinuxGnu,
],
Sanitizer::Memory => vec![
Target::Aarch64UnknownLinuxGnu,
Target::X8664UnknownFreebsd,
Target::X8664UnknownLinuxGnu,
],
Sanitizer::MemTag => vec![Target::Aarch64LinuxAndroid, Target::Aarch64UnknownLinuxGnu],
Sanitizer::SafeStack => vec![Target::X8664UnknownLinuxGnu],
Sanitizer::ShadowCallStack => vec![Target::Aarch64LinuxAndroid],
Sanitizer::Thread => vec![
Target::Aarch64AppleDarwin,
Target::Aarch64UnknownLinuxGnu,
Target::X8664AppleDarwin,
Target::X8664UnknownFreebsd,
Target::X8664UnknownLinuxGnu,
],
}
}
// Returns true if the sanitizer is supported by the currently installed targets
fn is_target_supported<T: TargetRetriever>(&self, retriever: &T) -> bool {
let installed_targets = retriever.get_installed_targets();
let supported = self.supported_targets();
installed_targets.iter().any(|installed| {
let installed_target = Target::from_str(installed.trim()).unwrap_or(Target::Unknown);
supported.iter().any(|target| target == &installed_target)
})
}
}
// Constants for target names
const AARCH64_APPLE_DARWIN: &str = "aarch64-apple-darwin";
const AARCH64_LINUX_ANDROID: &str = "aarch64-linux-android";
const AARCH64_UNKNOWN_FUCHSIA: &str = "aarch64-unknown-fuchsia";
const AARCH64_UNKNOWN_LINUX_GNU: &str = "aarch64-unknown-linux-gnu";
const X8664_APPLE_DARWIN: &str = "x86_64-apple-darwin";
const X8664_LINUX_ANDROID: &str = "x86_64-linux-android";
const X8664_UNKNOWN_FUCHSIA: &str = "x86_64-unknown-fuchsia";
const X8664_UNKNOWN_FREEBSD: &str = "x86_64-unknown-freebsd";
const X8664_UNKNOWN_LINUX_GNU: &str = "x86_64-unknown-linux-gnu";
trait TargetRetriever {
fn get_installed_targets(&self) -> Vec<String>;
}
struct RustupTargetRetriever;
impl TargetRetriever for RustupTargetRetriever {
fn get_installed_targets(&self) -> Vec<String> {
rustup_get_installed_targets()
.lines()
.map(|s| s.to_string())
.collect()
}
}
// Represents Rust targets
// Remark: we list only the targets that are supported by sanitizers
#[derive(Debug, PartialEq)]
enum Target {
Aarch64AppleDarwin,
Aarch64LinuxAndroid,
Aarch64UnknownFuchsia,
Aarch64UnknownLinuxGnu,
X8664AppleDarwin,
X8664LinuxAndroid,
X8664UnknownFuchsia,
X8664UnknownFreebsd,
X8664UnknownLinuxGnu,
Unknown,
}
impl Target {
fn from_str(s: &str) -> Option<Self> {
match s {
AARCH64_APPLE_DARWIN => Some(Self::Aarch64AppleDarwin),
AARCH64_LINUX_ANDROID => Some(Self::Aarch64LinuxAndroid),
AARCH64_UNKNOWN_FUCHSIA => Some(Self::Aarch64UnknownFuchsia),
AARCH64_UNKNOWN_LINUX_GNU => Some(Self::Aarch64UnknownLinuxGnu),
X8664_APPLE_DARWIN => Some(Self::X8664AppleDarwin),
X8664_LINUX_ANDROID => Some(Self::X8664LinuxAndroid),
X8664_UNKNOWN_FUCHSIA => Some(Self::X8664UnknownFuchsia),
X8664_UNKNOWN_FREEBSD => Some(Self::X8664UnknownFreebsd),
X8664_UNKNOWN_LINUX_GNU => Some(Self::X8664UnknownLinuxGnu),
_ => None,
}
}
}
impl fmt::Display for Target {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let target_str = match self {
Target::Aarch64AppleDarwin => AARCH64_APPLE_DARWIN,
Target::Aarch64LinuxAndroid => AARCH64_LINUX_ANDROID,
Target::Aarch64UnknownFuchsia => AARCH64_UNKNOWN_FUCHSIA,
Target::Aarch64UnknownLinuxGnu => AARCH64_UNKNOWN_LINUX_GNU,
Target::X8664AppleDarwin => X8664_APPLE_DARWIN,
Target::X8664LinuxAndroid => X8664_LINUX_ANDROID,
Target::X8664UnknownFuchsia => X8664_UNKNOWN_FUCHSIA,
Target::X8664UnknownFreebsd => X8664_UNKNOWN_FREEBSD,
Target::X8664UnknownLinuxGnu => X8664_UNKNOWN_LINUX_GNU,
Target::Unknown => "",
};
write!(f, "{}", target_str)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rstest::rstest;
struct MockTargetRetriever {
mock_data: Vec<String>,
}
impl MockTargetRetriever {
fn new(mock_data: Vec<String>) -> Self {
Self { mock_data }
}
}
impl TargetRetriever for MockTargetRetriever {
fn get_installed_targets(&self) -> Vec<String> {
self.mock_data.clone()
}
}
#[rstest]
#[case(vec!["".to_string()], false)] // empty string
#[case(vec!["x86_64-pc-windows-msvc".to_string()], false)] // not supported target
#[case(vec!["x86_64-pc-windows-msvc".to_string(), "".to_string()], false)] // not supported target and empty string
#[case(vec!["x86_64-unknown-linux-gnu".to_string()], true)] // one supported target
#[case(vec!["aarch64-apple-darwin".to_string(), "x86_64-unknown-linux-gnu".to_string()], true)] // one unsupported target and one supported
fn test_is_target_supported(#[case] installed_targets: Vec<String>, #[case] expected: bool) {
let mock_retriever = MockTargetRetriever::new(installed_targets);
let sanitizer = Sanitizer::Memory;
assert_eq!(sanitizer.is_target_supported(&mock_retriever), expected);
}
#[test]
fn test_consistency_of_fmt_and_from_str_strings() {
let variants = vec![
Target::Aarch64AppleDarwin,
Target::Aarch64LinuxAndroid,
Target::Aarch64UnknownFuchsia,
Target::Aarch64UnknownLinuxGnu,
Target::X8664AppleDarwin,
Target::X8664LinuxAndroid,
Target::X8664UnknownFuchsia,
Target::X8664UnknownFreebsd,
Target::X8664UnknownLinuxGnu,
];
for variant in variants {
let variant_str = format!("{}", variant);
let parsed_variant = Target::from_str(&variant_str);
assert_eq!(Some(variant), parsed_variant);
}
}
}