From 57a5476c8977490f648ccf7ba92cf60d97acbb02 Mon Sep 17 00:00:00 2001 From: Louis Fortier-Dubois Date: Tue, 18 Jul 2023 16:15:00 -0400 Subject: [PATCH] bugfix for macos test (#503) --- burn-core/src/lib.rs | 5 ++++- burn-wgpu/src/kernel/matmul/tiling2d/base.rs | 1 - .../matmul/blocktiling_2d/continuous_vectorized.wgsl | 5 +++-- run-checks.sh | 4 ++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/burn-core/src/lib.rs b/burn-core/src/lib.rs index ab9bbb902..6ecfc7d96 100644 --- a/burn-core/src/lib.rs +++ b/burn-core/src/lib.rs @@ -44,9 +44,12 @@ pub type TestBackend = burn_ndarray::NdArrayBackend; #[cfg(all(test, feature = "test-tch"))] pub type TestBackend = burn_tch::TchBackend; -#[cfg(all(test, feature = "test-wgpu"))] +#[cfg(all(test, feature = "test-wgpu", not(target_os = "macos")))] pub type TestBackend = burn_wgpu::WgpuBackend; +#[cfg(all(test, feature = "test-wgpu", target_os = "macos"))] +pub type TestBackend = burn_wgpu::WgpuBackend; + #[cfg(feature = "std")] #[cfg(test)] pub type TestADBackend = burn_autodiff::ADBackendDecorator; diff --git a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs index 022d76b53..b0c36f1d1 100644 --- a/burn-wgpu/src/kernel/matmul/tiling2d/base.rs +++ b/burn-wgpu/src/kernel/matmul/tiling2d/base.rs @@ -149,7 +149,6 @@ macro_rules! matmul_tile_2d { use $crate::kernel::matmul::utils::tests::same_as_reference; #[test] - #[ignore] pub fn test_matmul_tiling_2d_large_blocks() { test_with_params::<128, 128, 8, 4, 4, 32, 32>(8, 8, 8, 1, 1); } diff --git a/burn-wgpu/src/template/matmul/blocktiling_2d/continuous_vectorized.wgsl b/burn-wgpu/src/template/matmul/blocktiling_2d/continuous_vectorized.wgsl index 40faaf5e3..00689b8be 100644 --- a/burn-wgpu/src/template/matmul/blocktiling_2d/continuous_vectorized.wgsl +++ b/burn-wgpu/src/template/matmul/blocktiling_2d/continuous_vectorized.wgsl @@ -75,8 +75,6 @@ fn main( let thread_offset = local_idx * T_M_X_T_N; for (var k = 0u; k < K; k += B_K) { - // tile: let lhs_sm_position = current_row * B_K + current_col; - // tile_vec: let lhs_sm_position = current_row + current_col * B_M; for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) { let lhs_sm_position = thread_offset + load_index; let block_row = lhs_sm_position % B_M; @@ -85,6 +83,9 @@ fn main( if block_col < B_K { shared_lhs[lhs_sm_position] = lhs[lhs_position]; + } else { + // Patch for mac os bugfix + output[offset_output + row * n_cols + col] = 0.0; } } diff --git a/run-checks.sh b/run-checks.sh index 0126ebb8d..d9269f05b 100755 --- a/run-checks.sh +++ b/run-checks.sh @@ -85,8 +85,8 @@ std_func() { echo "Test burn-core with tch backend" cargo test --features test-tch - # echo "Test burn-core with wgpu backend" - # cargo test --features test-wgpu + echo "Test burn-core with wgpu backend" + cargo test --features test-wgpu cd .. || exit }