mirror of https://github.com/tracel-ai/burn.git
Refactor/autotune/key (#924)
This commit is contained in:
parent
8c80c9b94a
commit
1cc1844d32
|
@ -12,7 +12,7 @@ use spin::Mutex;
|
|||
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
|
||||
/// It should be obtained for a specific device via the Compute struct.
|
||||
#[derive(Debug)]
|
||||
pub struct ComputeClient<Server, Channel> {
|
||||
pub struct ComputeClient<Server: ComputeServer, Channel> {
|
||||
channel: Channel,
|
||||
tuner: Arc<Mutex<Tuner<Server, Channel>>>,
|
||||
_server: PhantomData<Server>,
|
||||
|
@ -72,7 +72,10 @@ where
|
|||
}
|
||||
|
||||
/// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks
|
||||
pub fn execute_autotune(&self, autotune_operation_set: Box<dyn AutotuneOperationSet>) {
|
||||
pub fn execute_autotune(
|
||||
&self,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet<Server::AutotuneKey>>,
|
||||
) {
|
||||
self.tuner
|
||||
.lock()
|
||||
.execute_autotune(autotune_operation_set, self);
|
||||
|
|
|
@ -4,7 +4,7 @@ use hashbrown::HashMap;
|
|||
|
||||
/// The compute type has the responsibility to retrieve the correct compute client based on the
|
||||
/// given device.
|
||||
pub struct Compute<Device, Server, Channel> {
|
||||
pub struct Compute<Device, Server: ComputeServer, Channel> {
|
||||
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
use core::fmt::Debug;
|
||||
|
||||
use crate::{
|
||||
memory_management::{MemoryHandle, MemoryManagement},
|
||||
storage::ComputeStorage,
|
||||
tune::AutotuneKey,
|
||||
};
|
||||
use alloc::vec::Vec;
|
||||
use burn_common::reader::Reader;
|
||||
|
@ -19,6 +22,8 @@ where
|
|||
type Storage: ComputeStorage;
|
||||
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
||||
type MemoryManagement: MemoryManagement<Self::Storage>;
|
||||
/// The key used to cache operations used on specific inputs in autotune
|
||||
type AutotuneKey: AutotuneKey;
|
||||
|
||||
/// Given a handle, returns the owned resource as bytes.
|
||||
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>>;
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
use alloc::boxed::Box;
|
||||
use alloc::format;
|
||||
use alloc::string::String;
|
||||
use alloc::vec::Vec;
|
||||
use core::fmt::Display;
|
||||
use core::fmt::{Debug, Display};
|
||||
use core::hash::Hash;
|
||||
|
||||
/// Groups operations of the same type for autotune
|
||||
pub trait AutotuneOperationSet: Send {
|
||||
pub trait AutotuneOperationSet<K>: Send {
|
||||
/// The key used in the tune cache
|
||||
fn key(&self) -> AutotuneKey;
|
||||
fn key(&self) -> K;
|
||||
|
||||
/// All candidate operations for autotuning this operation type
|
||||
/// Operations can run on toy tensors of relevant size
|
||||
|
@ -32,16 +32,6 @@ pub trait AutotuneOperation {
|
|||
fn clone(&self) -> Box<dyn AutotuneOperation>;
|
||||
}
|
||||
|
||||
#[derive(new, Clone, Debug, PartialEq, Eq, Hash)]
|
||||
/// The key used in the tune cache, referring to the operation type,
|
||||
/// generally hardcoded for an autotune operation, and to the input shape
|
||||
pub struct AutotuneKey {
|
||||
operation: String,
|
||||
input_description: String,
|
||||
}
|
||||
|
||||
impl Display for AutotuneKey {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(format!("{}-{}", self.operation, self.input_description).as_str())
|
||||
}
|
||||
}
|
||||
/// Trait alias
|
||||
pub trait AutotuneKey: Clone + Debug + PartialEq + Eq + Hash + Display {}
|
||||
impl AutotuneKey for String {}
|
||||
|
|
|
@ -7,8 +7,6 @@ use crate::server::ComputeServer;
|
|||
use super::AutotuneOperation;
|
||||
use alloc::boxed::Box;
|
||||
use alloc::string::{String, ToString};
|
||||
use alloc::vec;
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// A benchmark that runs on server handles
|
||||
#[derive(new)]
|
||||
|
@ -18,20 +16,17 @@ pub struct TuneBenchmark<S: ComputeServer, C> {
|
|||
}
|
||||
|
||||
impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
|
||||
// list of operations
|
||||
type Args = Vec<Box<dyn AutotuneOperation>>;
|
||||
type Args = Box<dyn AutotuneOperation>;
|
||||
|
||||
fn prepare(&self) -> Self::Args {
|
||||
vec![self.operation.clone()]
|
||||
self.operation.clone()
|
||||
}
|
||||
|
||||
fn num_samples(&self) -> usize {
|
||||
10
|
||||
}
|
||||
|
||||
fn execute(&self, args: Self::Args) {
|
||||
let operation = args[0].clone(); // TODO rm 0
|
||||
|
||||
fn execute(&self, operation: Self::Args) {
|
||||
AutotuneOperation::execute(operation);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,40 +1,35 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use hashbrown::HashMap;
|
||||
|
||||
use super::AutotuneKey;
|
||||
use super::AutotuneOperation;
|
||||
use super::AutotuneOperationSet;
|
||||
use alloc::boxed::Box;
|
||||
use hashbrown::HashMap;
|
||||
|
||||
/// Use to find and reuse the best kernel for some input
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct TuneCache<S> {
|
||||
cache: HashMap<AutotuneKey, usize>,
|
||||
_server: PhantomData<S>,
|
||||
pub(crate) struct TuneCache<K> {
|
||||
cache: HashMap<K, usize>,
|
||||
}
|
||||
|
||||
/// Result of the cache try
|
||||
pub enum TuneCacheResult {
|
||||
pub enum TuneCacheResult<K> {
|
||||
/// An operation is found and given
|
||||
Hit(Box<dyn AutotuneOperation>),
|
||||
/// No operation is found and the set is given back for ownership
|
||||
Miss(Box<dyn AutotuneOperationSet>),
|
||||
Miss(Box<dyn AutotuneOperationSet<K>>),
|
||||
}
|
||||
|
||||
impl<S> TuneCache<S> {
|
||||
impl<K: AutotuneKey> TuneCache<K> {
|
||||
pub(crate) fn new() -> Self {
|
||||
TuneCache {
|
||||
cache: HashMap::new(),
|
||||
_server: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::borrowed_box)]
|
||||
pub(crate) fn try_cache(
|
||||
&self,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet>,
|
||||
) -> TuneCacheResult {
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
|
||||
) -> TuneCacheResult<K> {
|
||||
let index = self.cache.get(&autotune_operation_set.key());
|
||||
if let Some(&i) = index {
|
||||
return TuneCacheResult::Hit(autotune_operation_set.fastest(i));
|
||||
|
@ -42,7 +37,7 @@ impl<S> TuneCache<S> {
|
|||
TuneCacheResult::Miss(autotune_operation_set)
|
||||
}
|
||||
|
||||
pub(crate) fn cache_insert(&mut self, key: AutotuneKey, fastest_index: usize) {
|
||||
pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
|
||||
self.cache.insert(key, fastest_index);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -13,9 +13,8 @@ use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCa
|
|||
|
||||
#[derive(Debug, Default)]
|
||||
/// Executes autotune benchmarking and caching
|
||||
pub struct Tuner<S, C> {
|
||||
tune_cache: TuneCache<S>,
|
||||
_server: PhantomData<S>,
|
||||
pub struct Tuner<S: ComputeServer, C> {
|
||||
tune_cache: TuneCache<S::AutotuneKey>,
|
||||
_channel: PhantomData<C>,
|
||||
}
|
||||
|
||||
|
@ -24,14 +23,13 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
|
|||
pub fn new() -> Self {
|
||||
Self {
|
||||
tune_cache: TuneCache::new(),
|
||||
_server: PhantomData,
|
||||
_channel: PhantomData,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn execute_autotune(
|
||||
&mut self,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet>,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
|
||||
client: &ComputeClient<S, C>,
|
||||
) {
|
||||
let operation = match self.tune_cache.try_cache(autotune_operation_set) {
|
||||
|
@ -44,7 +42,7 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
|
|||
|
||||
fn autotuning(
|
||||
&mut self,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet>,
|
||||
autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
|
||||
client: &ComputeClient<S, C>,
|
||||
) -> Box<dyn AutotuneOperation> {
|
||||
let key = autotune_operation_set.key();
|
||||
|
|
|
@ -24,6 +24,7 @@ where
|
|||
type Kernel = Arc<dyn DummyKernel>;
|
||||
type Storage = BytesStorage;
|
||||
type MemoryManagement = MM;
|
||||
type AutotuneKey = String;
|
||||
|
||||
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>> {
|
||||
let bytes = self.memory_management.get(&handle.memory);
|
||||
|
|
|
@ -2,7 +2,7 @@ use std::sync::Arc;
|
|||
|
||||
use burn_compute::{
|
||||
server::Handle,
|
||||
tune::{AutotuneKey, AutotuneOperation, AutotuneOperationSet},
|
||||
tune::{AutotuneOperation, AutotuneOperationSet},
|
||||
};
|
||||
|
||||
use crate::dummy::{
|
||||
|
@ -15,7 +15,7 @@ use super::DummyElementwiseAdditionSlowWrong;
|
|||
|
||||
pub struct AdditionAutotuneOperationSet {
|
||||
client: DummyClient,
|
||||
key: AutotuneKey,
|
||||
key: String,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
handles: Vec<Handle<DummyServer>>,
|
||||
}
|
||||
|
@ -28,15 +28,15 @@ impl AdditionAutotuneOperationSet {
|
|||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
key: AutotuneKey::new("add".to_string(), log_shape_input_key(&shapes)),
|
||||
key: format!("{}-{}", "add", log_shape_input_key(&shapes)),
|
||||
shapes,
|
||||
handles,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutotuneOperationSet for AdditionAutotuneOperationSet {
|
||||
fn key(&self) -> AutotuneKey {
|
||||
impl AutotuneOperationSet<String> for AdditionAutotuneOperationSet {
|
||||
fn key(&self) -> String {
|
||||
self.key.clone()
|
||||
}
|
||||
|
||||
|
@ -64,7 +64,7 @@ impl AutotuneOperationSet for AdditionAutotuneOperationSet {
|
|||
|
||||
pub struct MultiplicationAutotuneOperationSet {
|
||||
client: DummyClient,
|
||||
key: AutotuneKey,
|
||||
key: String,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
handles: Vec<Handle<DummyServer>>,
|
||||
}
|
||||
|
@ -77,14 +77,14 @@ impl MultiplicationAutotuneOperationSet {
|
|||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
key: AutotuneKey::new("mul".to_string(), log_shape_input_key(&shapes)),
|
||||
key: format!("{}-{}", "mul", log_shape_input_key(&shapes)),
|
||||
shapes,
|
||||
handles,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AutotuneOperationSet for MultiplicationAutotuneOperationSet {
|
||||
fn key(&self) -> AutotuneKey {
|
||||
impl AutotuneOperationSet<String> for MultiplicationAutotuneOperationSet {
|
||||
fn key(&self) -> String {
|
||||
self.key.clone()
|
||||
}
|
||||
|
||||
|
@ -112,7 +112,7 @@ impl AutotuneOperationSet for MultiplicationAutotuneOperationSet {
|
|||
|
||||
pub struct CacheTestAutotuneOperationSet {
|
||||
client: DummyClient,
|
||||
key: AutotuneKey,
|
||||
key: String,
|
||||
shapes: Vec<Vec<usize>>,
|
||||
handles: Vec<Handle<DummyServer>>,
|
||||
}
|
||||
|
@ -125,14 +125,14 @@ impl CacheTestAutotuneOperationSet {
|
|||
) -> Self {
|
||||
Self {
|
||||
client,
|
||||
key: AutotuneKey::new("cache_test".to_string(), log_shape_input_key(&shapes)),
|
||||
key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)),
|
||||
shapes,
|
||||
handles,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl AutotuneOperationSet for CacheTestAutotuneOperationSet {
|
||||
fn key(&self) -> AutotuneKey {
|
||||
impl AutotuneOperationSet<String> for CacheTestAutotuneOperationSet {
|
||||
fn key(&self) -> String {
|
||||
self.key.clone()
|
||||
}
|
||||
|
||||
|
|
|
@ -2,8 +2,10 @@ mod base;
|
|||
mod kernel;
|
||||
mod server;
|
||||
mod storage;
|
||||
mod tune_key;
|
||||
|
||||
pub use base::*;
|
||||
pub use kernel::*;
|
||||
pub use server::*;
|
||||
pub use storage::*;
|
||||
pub use tune_key::*;
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use super::{WgpuStorage, WorkGroup};
|
||||
use super::{WgpuAutotuneKey, WgpuStorage, WorkGroup};
|
||||
use crate::kernel::SourceTemplate;
|
||||
use alloc::{borrow::Cow, sync::Arc};
|
||||
use burn_compute::{
|
||||
|
@ -254,6 +254,7 @@ where
|
|||
type Kernel = Box<dyn Kernel>;
|
||||
type Storage = WgpuStorage;
|
||||
type MemoryManagement = MM;
|
||||
type AutotuneKey = WgpuAutotuneKey;
|
||||
|
||||
fn read(&mut self, handle: &server::Handle<Self>) -> Reader<Vec<u8>> {
|
||||
#[cfg(target_family = "wasm")]
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
use burn_compute::tune::AutotuneKey;
|
||||
|
||||
use crate::kernel::matmul::MatmulAutotuneKey;
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone)]
|
||||
/// Key for all autotune-enabled operations
|
||||
pub enum WgpuAutotuneKey {
|
||||
/// Key for matmul operation
|
||||
Matmul(MatmulAutotuneKey),
|
||||
}
|
||||
|
||||
impl Display for WgpuAutotuneKey {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AutotuneKey for WgpuAutotuneKey {}
|
|
@ -1,28 +1,27 @@
|
|||
use burn_compute::tune::{AutotuneKey, AutotuneOperation, AutotuneOperationSet};
|
||||
use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet};
|
||||
use burn_tensor::Element;
|
||||
|
||||
use crate::{
|
||||
compute::WgpuAutotuneKey,
|
||||
element::WgpuElement,
|
||||
kernel::matmul::{tune::utils::autotune_tensors, utils::init_matmul_output},
|
||||
tensor::WgpuTensor,
|
||||
};
|
||||
|
||||
use super::key::MatmulAutotuneKey;
|
||||
|
||||
/// Set of matmul implementations available for autotune
|
||||
/// Autotune key is given by concatenating the closest upper power of 2 of m, k and n
|
||||
pub struct MatmulAutotuneOperationSet<E: WgpuElement, const D: usize> {
|
||||
key: AutotuneKey,
|
||||
key: WgpuAutotuneKey,
|
||||
lhs: WgpuTensor<E, D>,
|
||||
rhs: WgpuTensor<E, D>,
|
||||
out: WgpuTensor<E, D>,
|
||||
}
|
||||
impl<E: WgpuElement, const D: usize> MatmulAutotuneOperationSet<E, D> {
|
||||
fn new(lhs: WgpuTensor<E, D>, rhs: WgpuTensor<E, D>, out: WgpuTensor<E, D>) -> Self {
|
||||
let m = lhs.shape.dims[D - 2];
|
||||
let k = lhs.shape.dims[D - 1];
|
||||
let n = rhs.shape.dims[D - 1];
|
||||
|
||||
Self {
|
||||
key: AutotuneKey::new("matmul".to_string(), log_mkn_input_key(m, k, n)),
|
||||
key: WgpuAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)),
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
|
@ -30,34 +29,10 @@ impl<E: WgpuElement, const D: usize> MatmulAutotuneOperationSet<E, D> {
|
|||
}
|
||||
}
|
||||
|
||||
fn log_mkn_input_key(m: usize, k: usize, n: usize) -> String {
|
||||
let mut desc = String::new();
|
||||
let mut diff = false;
|
||||
|
||||
for size in [m, k, n] {
|
||||
if !desc.is_empty() {
|
||||
desc.push('-');
|
||||
}
|
||||
let exp = f32::ceil(f32::log2(size as f32)) as u32;
|
||||
let updated = 2_u32.pow(exp);
|
||||
|
||||
if updated != size as u32 {
|
||||
diff = true;
|
||||
}
|
||||
desc.push_str(updated.to_string().as_str());
|
||||
}
|
||||
|
||||
if diff {
|
||||
desc.push_str("-uneven");
|
||||
}
|
||||
|
||||
desc
|
||||
}
|
||||
|
||||
impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet
|
||||
impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet<WgpuAutotuneKey>
|
||||
for MatmulAutotuneOperationSet<E, D>
|
||||
{
|
||||
fn key(&self) -> AutotuneKey {
|
||||
fn key(&self) -> WgpuAutotuneKey {
|
||||
self.key.clone()
|
||||
}
|
||||
|
||||
|
@ -187,23 +162,3 @@ matmul_tune_ops!(
|
|||
Vec4TilingMatmulUnpaddedDefault,
|
||||
crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded
|
||||
);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn matmul_autotune_mkn_key() {
|
||||
let key = log_mkn_input_key(512, 512, 512);
|
||||
assert_eq!(key, "512-512-512");
|
||||
|
||||
let key = log_mkn_input_key(512, 256, 512);
|
||||
assert_eq!(key, "512-256-512");
|
||||
|
||||
let key = log_mkn_input_key(512, 256, 127);
|
||||
assert_eq!(key, "512-256-128-uneven");
|
||||
|
||||
let key = log_mkn_input_key(2, 149, 2344);
|
||||
assert_eq!(key, "2-256-4096-uneven");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
use burn_tensor::Shape;
|
||||
use core::fmt::Debug;
|
||||
use std::{
|
||||
cmp::{max, min},
|
||||
fmt::Display,
|
||||
hash::Hash,
|
||||
};
|
||||
|
||||
#[derive(Hash, Eq, PartialEq, Debug, Clone)]
|
||||
/// Autotune key representative of matmul versions
|
||||
pub struct MatmulAutotuneKey {
|
||||
round: bool, // True when all matmul dims are multiples of 64
|
||||
broadcast: bool, // True when there are differences in batch size
|
||||
anchored_m: usize,
|
||||
anchored_k: usize,
|
||||
anchored_n: usize,
|
||||
anchored_batch: usize,
|
||||
}
|
||||
|
||||
impl Display for MatmulAutotuneKey {
|
||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||
f.write_str(
|
||||
format!(
|
||||
"Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}",
|
||||
self.round,
|
||||
self.broadcast,
|
||||
self.anchored_m,
|
||||
self.anchored_k,
|
||||
self.anchored_n,
|
||||
self.anchored_batch
|
||||
)
|
||||
.as_str(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl MatmulAutotuneKey {
|
||||
/// Create a matmul autotune key from the input shapes
|
||||
pub fn new<const D: usize>(lhs_shape: &Shape<D>, rhs_shape: &Shape<D>) -> Self {
|
||||
let m = lhs_shape.dims[D - 2];
|
||||
let k = lhs_shape.dims[D - 1];
|
||||
let n = rhs_shape.dims[D - 1];
|
||||
|
||||
let mut broadcast = false;
|
||||
let mut batch_product_lhs = 1;
|
||||
let mut batch_product_rhs = 1;
|
||||
|
||||
for b in 0..D - 2 {
|
||||
batch_product_lhs *= lhs_shape.dims[b];
|
||||
batch_product_rhs *= rhs_shape.dims[b];
|
||||
if lhs_shape.dims[b] != rhs_shape.dims[b] {
|
||||
broadcast = true;
|
||||
}
|
||||
}
|
||||
let batch_product = max(batch_product_lhs, batch_product_rhs);
|
||||
|
||||
let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0;
|
||||
|
||||
Self {
|
||||
round,
|
||||
broadcast,
|
||||
anchored_m: anchor(m, None),
|
||||
anchored_k: anchor(k, None),
|
||||
anchored_n: anchor(n, None),
|
||||
anchored_batch: anchor(batch_product, Some(256)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn anchor(x: usize, max: Option<usize>) -> usize {
|
||||
let exp = f32::ceil(f32::log2(x as f32)) as u32;
|
||||
let power_of_2 = 2_u32.pow(exp) as usize;
|
||||
if let Some(max) = max {
|
||||
min(power_of_2, max)
|
||||
} else {
|
||||
power_of_2
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn matmul_autotune_key_all_same_and_round() {
|
||||
let lhs_shape: Shape<3> = [4, 512, 512].into();
|
||||
let rhs_shape: Shape<3> = [4, 512, 512].into();
|
||||
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
|
||||
|
||||
assert!(key.round);
|
||||
assert!(!key.broadcast);
|
||||
assert!(key.anchored_m == 512);
|
||||
assert!(key.anchored_k == 512);
|
||||
assert!(key.anchored_n == 512);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_autotune_key_all_different() {
|
||||
let lhs_shape: Shape<4> = [2, 3, 511, 512].into();
|
||||
let rhs_shape: Shape<4> = [3, 2, 512, 513].into();
|
||||
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
|
||||
|
||||
assert!(!key.round);
|
||||
assert!(key.broadcast);
|
||||
assert!(key.anchored_m == 512);
|
||||
assert!(key.anchored_k == 512);
|
||||
assert!(key.anchored_n == 1024);
|
||||
assert!(key.anchored_batch == 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn matmul_autotune_key_large_batch() {
|
||||
let lhs_shape: Shape<4> = [128, 512, 511, 512].into();
|
||||
let rhs_shape: Shape<4> = [200, 400, 512, 513].into();
|
||||
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
|
||||
|
||||
assert!(key.anchored_batch == 256);
|
||||
}
|
||||
}
|
|
@ -1,4 +1,6 @@
|
|||
mod base;
|
||||
mod key;
|
||||
mod utils;
|
||||
|
||||
pub use base::*;
|
||||
pub use key::*;
|
||||
|
|
Loading…
Reference in New Issue