mirror of https://github.com/tracel-ai/burn.git
bugfix for macos test (#503)
This commit is contained in:
parent
5ece894e02
commit
57a5476c89
|
@ -44,9 +44,12 @@ pub type TestBackend = burn_ndarray::NdArrayBackend<f32>;
|
||||||
#[cfg(all(test, feature = "test-tch"))]
|
#[cfg(all(test, feature = "test-tch"))]
|
||||||
pub type TestBackend = burn_tch::TchBackend<f32>;
|
pub type TestBackend = burn_tch::TchBackend<f32>;
|
||||||
|
|
||||||
#[cfg(all(test, feature = "test-wgpu"))]
|
#[cfg(all(test, feature = "test-wgpu", not(target_os = "macos")))]
|
||||||
pub type TestBackend = burn_wgpu::WgpuBackend<burn_wgpu::Vulkan, f32, i32>;
|
pub type TestBackend = burn_wgpu::WgpuBackend<burn_wgpu::Vulkan, f32, i32>;
|
||||||
|
|
||||||
|
#[cfg(all(test, feature = "test-wgpu", target_os = "macos"))]
|
||||||
|
pub type TestBackend = burn_wgpu::WgpuBackend<burn_wgpu::Metal, f32, i32>;
|
||||||
|
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
pub type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
|
pub type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
|
||||||
|
|
|
@ -149,7 +149,6 @@ macro_rules! matmul_tile_2d {
|
||||||
use $crate::kernel::matmul::utils::tests::same_as_reference;
|
use $crate::kernel::matmul::utils::tests::same_as_reference;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[ignore]
|
|
||||||
pub fn test_matmul_tiling_2d_large_blocks() {
|
pub fn test_matmul_tiling_2d_large_blocks() {
|
||||||
test_with_params::<128, 128, 8, 4, 4, 32, 32>(8, 8, 8, 1, 1);
|
test_with_params::<128, 128, 8, 4, 4, 32, 32>(8, 8, 8, 1, 1);
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,8 +75,6 @@ fn main(
|
||||||
let thread_offset = local_idx * T_M_X_T_N;
|
let thread_offset = local_idx * T_M_X_T_N;
|
||||||
|
|
||||||
for (var k = 0u; k < K; k += B_K) {
|
for (var k = 0u; k < K; k += B_K) {
|
||||||
// tile: let lhs_sm_position = current_row * B_K + current_col;
|
|
||||||
// tile_vec: let lhs_sm_position = current_row + current_col * B_M;
|
|
||||||
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
|
for (var load_index = 0u; load_index < T_M_X_T_N; load_index ++) {
|
||||||
let lhs_sm_position = thread_offset + load_index;
|
let lhs_sm_position = thread_offset + load_index;
|
||||||
let block_row = lhs_sm_position % B_M;
|
let block_row = lhs_sm_position % B_M;
|
||||||
|
@ -85,6 +83,9 @@ fn main(
|
||||||
|
|
||||||
if block_col < B_K {
|
if block_col < B_K {
|
||||||
shared_lhs[lhs_sm_position] = lhs[lhs_position];
|
shared_lhs[lhs_sm_position] = lhs[lhs_position];
|
||||||
|
} else {
|
||||||
|
// Patch for mac os bugfix
|
||||||
|
output[offset_output + row * n_cols + col] = 0.0;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -85,8 +85,8 @@ std_func() {
|
||||||
echo "Test burn-core with tch backend"
|
echo "Test burn-core with tch backend"
|
||||||
cargo test --features test-tch
|
cargo test --features test-tch
|
||||||
|
|
||||||
# echo "Test burn-core with wgpu backend"
|
echo "Test burn-core with wgpu backend"
|
||||||
# cargo test --features test-wgpu
|
cargo test --features test-wgpu
|
||||||
|
|
||||||
cd .. || exit
|
cd .. || exit
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue