mirror of https://github.com/tracel-ai/burn.git
Fix image-classsification-web + autotune flag usage (#2011)
This commit is contained in:
parent
3afff434bd
commit
7661deb258
|
@ -3,10 +3,13 @@ use burn_cube::prelude::*;
|
|||
use burn_tensor::Shape;
|
||||
|
||||
use super::{
|
||||
config::Tiling2dConfig, init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d,
|
||||
config::Tiling2dConfig, init_matmul_output, matmul_simple, matmul_tiling_2d,
|
||||
matmul_tiling_2d_cube, matmul_tiling_2d_padded,
|
||||
};
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
use super::matmul_autotune;
|
||||
|
||||
/// The strategy to be used when launching a matmul kernel.
|
||||
pub enum MatmulStrategy {
|
||||
/// A simple kernel will be used with memory coalescing optimization.
|
||||
|
@ -27,11 +30,14 @@ pub enum MatmulStrategy {
|
|||
Tiling2dCube(Tiling2dConfig),
|
||||
}
|
||||
|
||||
#[allow(clippy::derivable_impls)] // Necessary otherwise the feature flags dont' work.
|
||||
#[cfg(feature = "autotune")]
|
||||
impl Default for MatmulStrategy {
|
||||
fn default() -> Self {
|
||||
MatmulStrategy::Autotune
|
||||
// if autotune is enabled, default to autotune
|
||||
#[cfg(feature = "autotune")]
|
||||
return MatmulStrategy::Autotune;
|
||||
|
||||
#[cfg(not(feature = "autotune"))]
|
||||
MatmulStrategy::Tiling2d(Tiling2dConfig::default())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#[cfg(feature = "autotune")]
|
||||
mod base;
|
||||
mod key;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub use base::*;
|
||||
pub use key::*;
|
||||
|
|
|
@ -7,6 +7,7 @@ use super::{
|
|||
shared::{base::ReduceDimShared, shader::reduce_dim_shared},
|
||||
};
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub(crate) trait ReduceDimAlgorithm<E: JitElement>:
|
||||
ReduceDimNaive<E> + ReduceDimShared<E>
|
||||
{
|
||||
|
|
|
@ -1,5 +1,7 @@
|
|||
#[cfg(feature = "autotune")]
|
||||
mod base;
|
||||
mod key;
|
||||
|
||||
#[cfg(feature = "autotune")]
|
||||
pub(crate) use base::*;
|
||||
pub use key::*;
|
||||
|
|
|
@ -17,7 +17,9 @@ half_precision = []
|
|||
burn = { path = "../../crates/burn", version = "0.14.0", default-features = false, features = [
|
||||
"ndarray",
|
||||
] }
|
||||
burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", default-features = false }
|
||||
burn-wgpu = { path = "../../crates/burn-wgpu", version = "0.14.0", default-features = false, features = [
|
||||
"autotune",
|
||||
] }
|
||||
burn-candle = { path = "../../crates/burn-candle", version = "0.14.0", default-features = false }
|
||||
|
||||
log = { workspace = true }
|
||||
|
|
Loading…
Reference in New Issue