Fix image-classsification-web + autotune flag usage (#2011)

This commit is contained in:
Guillaume Lagrange 2024-07-15 09:31:54 -04:00 committed by GitHub
parent 3afff434bd
commit 7661deb258
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 18 additions and 5 deletions

View File

@ -3,10 +3,13 @@ use burn_cube::prelude::*;
use burn_tensor::Shape;
use super::{
config::Tiling2dConfig, init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d,
config::Tiling2dConfig, init_matmul_output, matmul_simple, matmul_tiling_2d,
matmul_tiling_2d_cube, matmul_tiling_2d_padded,
};
#[cfg(feature = "autotune")]
use super::matmul_autotune;
/// The strategy to be used when launching a matmul kernel.
pub enum MatmulStrategy {
/// A simple kernel will be used with memory coalescing optimization.
@ -27,11 +30,14 @@ pub enum MatmulStrategy {
Tiling2dCube(Tiling2dConfig),
}
#[allow(clippy::derivable_impls)] // Necessary otherwise the feature flags dont' work.
#[cfg(feature = "autotune")]
impl Default for MatmulStrategy {
fn default() -> Self {
MatmulStrategy::Autotune
// if autotune is enabled, default to autotune
#[cfg(feature = "autotune")]
return MatmulStrategy::Autotune;
#[cfg(not(feature = "autotune"))]
MatmulStrategy::Tiling2d(Tiling2dConfig::default())
}
}

View File

@ -1,5 +1,7 @@
#[cfg(feature = "autotune")]
mod base;
mod key;
#[cfg(feature = "autotune")]
pub use base::*;
pub use key::*;

View File

@ -7,6 +7,7 @@ use super::{
shared::{base::ReduceDimShared, shader::reduce_dim_shared},
};
#[allow(dead_code)]
pub(crate) trait ReduceDimAlgorithm<E: JitElement>:
ReduceDimNaive<E> + ReduceDimShared<E>
{

View File

@ -1,5 +1,7 @@
#[cfg(feature = "autotune")]
mod base;
mod key;
#[cfg(feature = "autotune")]
pub(crate) use base::*;
pub use key::*;

View File

@ -17,7 +17,9 @@ half_precision = []
burn = { path = "../../crates/burn", version = "0.14.0", default-features = false, features = [
"ndarray",
] }
burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", default-features = false }
burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", default-features = false, features = [
"autotune",
] }
burn-candle = { path = "../../crates/burn-candle", version = "0.14.0", default-features = false }
log = { workspace = true }