bugfix for macos test (#503)

This commit is contained in:
Louis Fortier-Dubois 2023-07-18 16:15:00 -04:00 committed by GitHub
parent 5ece894e02
commit 57a5476c89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 9 additions and 6 deletions

View File

@ -44,9 +44,12 @@ pub type TestBackend = burn_ndarray::NdArrayBackend<f32>;
#[cfg(all(test, feature = "test-tch"))]
pub type TestBackend = burn_tch::TchBackend<f32>;
#[cfg(all(test, feature = "test-wgpu"))]
#[cfg(all(test, feature = "test-wgpu", not(target_os = "macos")))]
pub type TestBackend = burn_wgpu::WgpuBackend<burn_wgpu::Vulkan, f32, i32>;
#[cfg(all(test, feature = "test-wgpu", target_os = "macos"))]
pub type TestBackend = burn_wgpu::WgpuBackend<burn_wgpu::Metal, f32, i32>;
#[cfg(feature = "std")]
#[cfg(test)]
pub type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;

View File

@ -149,7 +149,6 @@ macro_rules! matmul_tile_2d {
use $crate::kernel::matmul::utils::tests::same_as_reference;
#[test]
#[ignore]
pub fn test_matmul_tiling_2d_large_blocks() {
test_with_params::<128, 128, 8, 4, 4, 32, 32>(8, 8, 8, 1, 1);
}

View File

@ -75,8 +75,6 @@ fn main(
let thread_offset = local_idx * T_M_X_T_N;
for (var k = 0u; k < K; k += B_K) {
// tile: let lhs_sm_position = current_row * B_K + current_col;
// tile_vec: let lhs_sm_position = current_row + current_col * B_M;
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
let lhs_sm_position = thread_offset + load_index;
let block_row = lhs_sm_position % B_M;
@ -85,6 +83,9 @@ fn main(
if block_col < B_K {
shared_lhs[lhs_sm_position] = lhs[lhs_position];
} else {
// Patch for mac os bugfix
output[offset_output + row * n_cols + col] = 0.0;
}
}

View File

@ -85,8 +85,8 @@ std_func() {
echo "Test burn-core with tch backend"
cargo test --features test-tch
# echo "Test burn-core with wgpu backend"
# cargo test --features test-wgpu
echo "Test burn-core with wgpu backend"
cargo test --features test-wgpu
cd .. || exit
}