diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b9b3fd15f..db9270013 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,82 +6,49 @@ on: - main pull_request: types: [opened, synchronize] - branches: - - main jobs: - test-burn: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn - test-no-default-feature: true - no-std-build-targets: true + test-burn-std: + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v3 - test-burn-common: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-common - test-no-default-feature: true - no-std-build-targets: true + - name: install rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy - test-burn-dataset: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-dataset - args-test: --all-features + - name: caching + uses: Swatinem/rust-cache@v2 + with: + key: ${{ runner.os }}-rust-${{ hashFiles('**/Cargo.lock') }} - test-burn-tensor: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-tensor - test-no-default-feature: true - no-std-build-targets: true + - name: install llvmpipe and lavapipe + run: | + sudo apt-get update -y -qq + sudo add-apt-repository ppa:oibaf/graphics-drivers -y + sudo apt-get update + sudo apt install -y libegl1-mesa libgl1-mesa-dri libxcb-xfixes0-dev mesa-vulkan-drivers - test-burn-tch: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-tch - args-doc: --features doc + - name: run checks & tests + run: ./run-checks.sh std - test-burn-ndarray: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-ndarray - test-no-default-feature: true - no-std-build-targets: true + test-burn-no-std: + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v3 - test-burn-no-std-tests: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-no-std-tests - test-no-default-feature: true - no-std-build-targets: true + - name: install rust + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt, clippy - test-burn-autodiff: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-autodiff - - test-burn-core: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-core - test-no-default-feature: true - no-std-build-targets: true - - test-burn-core-backend-tch: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-core - args-doc: --features test-tch - - test-burn-train: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-train - - test-burn-import: - uses: burn-rs/burn/.github/workflows/test-template.yml@main - with: - crate: burn-import + - name: caching + uses: Swatinem/rust-cache@v2 + with: + key: ${{ runner.os }}-rust-${{ hashFiles('**/Cargo.lock') }} + - name: run checks & tests + run: ./run-checks.sh no_std diff --git a/burn-wgpu/src/ops/base.rs b/burn-wgpu/src/ops/base.rs index 156cfc032..a144983c5 100644 --- a/burn-wgpu/src/ops/base.rs +++ b/burn-wgpu/src/ops/base.rs @@ -312,7 +312,7 @@ impl BaseOps { .context .compile_static::>(); - let mut shape_tmp = values.shape.clone(); + let mut shape_tmp = values.shape; shape_tmp.dims[dim] = 1; // Just one thread for the dim. tensor.context.execute( diff --git a/run-before-pr.sh b/run-checks.sh similarity index 58% rename from run-before-pr.sh rename to run-checks.sh index ce37aa0cc..105706e12 100755 --- a/run-before-pr.sh +++ b/run-checks.sh @@ -4,6 +4,8 @@ # It is used to check that the code compiles and passes all tests. # It is also used to check that the code is formatted correctly and passes clippy. +# Usage: ./run-checks.sh {all|no_std|std} (default: all) + # Exit immediately if a command exits with a non-zero status. set -euo pipefail @@ -64,31 +66,69 @@ build_and_test_all_features() { # Set RUSTDOCFLAGS to treat warnings as errors for the documentation build export RUSTDOCFLAGS="-D warnings" +# Run the checks for std and all features with std +std_func() { + echo "Running std checks" + + cargo build --workspace + cargo test --workspace + cargo fmt --check --all + cargo clippy -- -D warnings + cargo doc --workspace + + # all features + echo "Running all-features checks" + build_and_test_all_features "burn-dataset" +} + +# Run the checks for no_std +no_std_func() { + echo "Running no_std checks" + + # Add wasm32 target for compiler. + rustup target add wasm32-unknown-unknown + rustup target add thumbv7m-none-eabi + + build_and_test_no_std "burn" + build_and_test_no_std "burn-core" + build_and_test_no_std "burn-common" + build_and_test_no_std "burn-tensor" + build_and_test_no_std "burn-ndarray" + build_and_test_no_std "burn-no-std-tests" +} + # Save the script start time start_time=$(date +%s) -# Add wasm32 target for compiler. -rustup target add wasm32-unknown-unknown -rustup target add thumbv7m-none-eabi +# If no arguments were supplied or if it's empty, set the default as 'all' +if [ -z "${1-}" ]; then + arg="all" +else + arg=$1 +fi -cargo build --workspace -cargo test --workspace -cargo fmt --check --all -cargo clippy -- -D warnings -cargo doc --workspace - -# no_std tests -build_and_test_no_std "burn" -build_and_test_no_std "burn-core" -build_and_test_no_std "burn-common" -build_and_test_no_std "burn-tensor" -build_and_test_no_std "burn-ndarray" -build_and_test_no_std "burn-no-std-tests" - -# all features tests -build_and_test_all_features "burn-dataset" +# Check the argument and call the appropriate functions +case $arg in +all) + no_std_func + std_func + ;; +no_std) + no_std_func + ;; +std) + std_func + ;; +*) + echo "Error: Invalid argument" + echo "Usage: $0 {all|no_std|std}" + exit 1 + ;; +esac # Calculate and print the script execution time end_time=$(date +%s) execution_time=$((end_time - start_time)) echo "Script executed in $execution_time seconds." + +exit 0