mirror of https://github.com/tracel-ai/burn.git
Refactor/autotune/key (#924)
This commit is contained in:
parent
8c80c9b94a
commit
1cc1844d32
|
@ -12,7 +12,7 @@ use spin::Mutex;
|
||||||
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
|
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
|
||||||
/// It should be obtained for a specific device via the Compute struct.
|
/// It should be obtained for a specific device via the Compute struct.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ComputeClient<Server, Channel> {
|
pub struct ComputeClient<Server: ComputeServer, Channel> {
|
||||||
channel: Channel,
|
channel: Channel,
|
||||||
tuner: Arc<Mutex<Tuner<Server, Channel>>>,
|
tuner: Arc<Mutex<Tuner<Server, Channel>>>,
|
||||||
_server: PhantomData<Server>,
|
_server: PhantomData<Server>,
|
||||||
|
@ -72,7 +72,10 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks
|
/// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks
|
||||||
pub fn execute_autotune(&self, autotune_operation_set: Box<dyn AutotuneOperationSet>) {
|
pub fn execute_autotune(
|
||||||
|
&self,
|
||||||
|
autotune_operation_set: Box<dyn AutotuneOperationSet<Server::AutotuneKey>>,
|
||||||
|
) {
|
||||||
self.tuner
|
self.tuner
|
||||||
.lock()
|
.lock()
|
||||||
.execute_autotune(autotune_operation_set, self);
|
.execute_autotune(autotune_operation_set, self);
|
||||||
|
|
|
@ -4,7 +4,7 @@ use hashbrown::HashMap;
|
||||||
|
|
||||||
/// The compute type has the responsibility to retrieve the correct compute client based on the
|
/// The compute type has the responsibility to retrieve the correct compute client based on the
|
||||||
/// given device.
|
/// given device.
|
||||||
pub struct Compute<Device, Server, Channel> {
|
pub struct Compute<Device, Server: ComputeServer, Channel> {
|
||||||
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
|
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
|
use core::fmt::Debug;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
memory_management::{MemoryHandle, MemoryManagement},
|
memory_management::{MemoryHandle, MemoryManagement},
|
||||||
storage::ComputeStorage,
|
storage::ComputeStorage,
|
||||||
|
tune::AutotuneKey,
|
||||||
};
|
};
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
use burn_common::reader::Reader;
|
use burn_common::reader::Reader;
|
||||||
|
@ -19,6 +22,8 @@ where
|
||||||
type Storage: ComputeStorage;
|
type Storage: ComputeStorage;
|
||||||
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
|
||||||
type MemoryManagement: MemoryManagement<Self::Storage>;
|
type MemoryManagement: MemoryManagement<Self::Storage>;
|
||||||
|
/// The key used to cache operations used on specific inputs in autotune
|
||||||
|
type AutotuneKey: AutotuneKey;
|
||||||
|
|
||||||
/// Given a handle, returns the owned resource as bytes.
|
/// Given a handle, returns the owned resource as bytes.
|
||||||
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>>;
|
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>>;
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
use alloc::boxed::Box;
|
use alloc::boxed::Box;
|
||||||
use alloc::format;
|
|
||||||
use alloc::string::String;
|
use alloc::string::String;
|
||||||
use alloc::vec::Vec;
|
use alloc::vec::Vec;
|
||||||
use core::fmt::Display;
|
use core::fmt::{Debug, Display};
|
||||||
|
use core::hash::Hash;
|
||||||
|
|
||||||
/// Groups operations of the same type for autotune
|
/// Groups operations of the same type for autotune
|
||||||
pub trait AutotuneOperationSet: Send {
|
pub trait AutotuneOperationSet<K>: Send {
|
||||||
/// The key used in the tune cache
|
/// The key used in the tune cache
|
||||||
fn key(&self) -> AutotuneKey;
|
fn key(&self) -> K;
|
||||||
|
|
||||||
/// All candidate operations for autotuning this operation type
|
/// All candidate operations for autotuning this operation type
|
||||||
/// Operations can run on toy tensors of relevant size
|
/// Operations can run on toy tensors of relevant size
|
||||||
|
@ -32,16 +32,6 @@ pub trait AutotuneOperation {
|
||||||
fn clone(&self) -> Box<dyn AutotuneOperation>;
|
fn clone(&self) -> Box<dyn AutotuneOperation>;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(new, Clone, Debug, PartialEq, Eq, Hash)]
|
/// Trait alias
|
||||||
/// The key used in the tune cache, referring to the operation type,
|
pub trait AutotuneKey: Clone + Debug + PartialEq + Eq + Hash + Display {}
|
||||||
/// generally hardcoded for an autotune operation, and to the input shape
|
impl AutotuneKey for String {}
|
||||||
pub struct AutotuneKey {
|
|
||||||
operation: String,
|
|
||||||
input_description: String,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl Display for AutotuneKey {
|
|
||||||
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
|
||||||
f.write_str(format!("{}-{}", self.operation, self.input_description).as_str())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -7,8 +7,6 @@ use crate::server::ComputeServer;
|
||||||
use super::AutotuneOperation;
|
use super::AutotuneOperation;
|
||||||
use alloc::boxed::Box;
|
use alloc::boxed::Box;
|
||||||
use alloc::string::{String, ToString};
|
use alloc::string::{String, ToString};
|
||||||
use alloc::vec;
|
|
||||||
use alloc::vec::Vec;
|
|
||||||
|
|
||||||
/// A benchmark that runs on server handles
|
/// A benchmark that runs on server handles
|
||||||
#[derive(new)]
|
#[derive(new)]
|
||||||
|
@ -18,20 +16,17 @@ pub struct TuneBenchmark<S: ComputeServer, C> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
|
impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
|
||||||
// list of operations
|
type Args = Box<dyn AutotuneOperation>;
|
||||||
type Args = Vec<Box<dyn AutotuneOperation>>;
|
|
||||||
|
|
||||||
fn prepare(&self) -> Self::Args {
|
fn prepare(&self) -> Self::Args {
|
||||||
vec![self.operation.clone()]
|
self.operation.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn num_samples(&self) -> usize {
|
fn num_samples(&self) -> usize {
|
||||||
10
|
10
|
||||||
}
|
}
|
||||||
|
|
||||||
fn execute(&self, args: Self::Args) {
|
fn execute(&self, operation: Self::Args) {
|
||||||
let operation = args[0].clone(); // TODO rm 0
|
|
||||||
|
|
||||||
AutotuneOperation::execute(operation);
|
AutotuneOperation::execute(operation);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,40 +1,35 @@
|
||||||
use core::marker::PhantomData;
|
|
||||||
|
|
||||||
use hashbrown::HashMap;
|
|
||||||
|
|
||||||
use super::AutotuneKey;
|
use super::AutotuneKey;
|
||||||
use super::AutotuneOperation;
|
use super::AutotuneOperation;
|
||||||
use super::AutotuneOperationSet;
|
use super::AutotuneOperationSet;
|
||||||
use alloc::boxed::Box;
|
use alloc::boxed::Box;
|
||||||
|
use hashbrown::HashMap;
|
||||||
|
|
||||||
/// Use to find and reuse the best kernel for some input
|
/// Use to find and reuse the best kernel for some input
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
pub(crate) struct TuneCache<S> {
|
pub(crate) struct TuneCache<K> {
|
||||||
cache: HashMap<AutotuneKey, usize>,
|
cache: HashMap<K, usize>,
|
||||||
_server: PhantomData<S>,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Result of the cache try
|
/// Result of the cache try
|
||||||
pub enum TuneCacheResult {
|
pub enum TuneCacheResult<K> {
|
||||||
/// An operation is found and given
|
/// An operation is found and given
|
||||||
Hit(Box<dyn AutotuneOperation>),
|
Hit(Box<dyn AutotuneOperation>),
|
||||||
/// No operation is found and the set is given back for ownership
|
/// No operation is found and the set is given back for ownership
|
||||||
Miss(Box<dyn AutotuneOperationSet>),
|
Miss(Box<dyn AutotuneOperationSet<K>>),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<S> TuneCache<S> {
|
impl<K: AutotuneKey> TuneCache<K> {
|
||||||
pub(crate) fn new() -> Self {
|
pub(crate) fn new() -> Self {
|
||||||
TuneCache {
|
TuneCache {
|
||||||
cache: HashMap::new(),
|
cache: HashMap::new(),
|
||||||
_server: PhantomData,
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::borrowed_box)]
|
#[allow(clippy::borrowed_box)]
|
||||||
pub(crate) fn try_cache(
|
pub(crate) fn try_cache(
|
||||||
&self,
|
&self,
|
||||||
autotune_operation_set: Box<dyn AutotuneOperationSet>,
|
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
|
||||||
) -> TuneCacheResult {
|
) -> TuneCacheResult<K> {
|
||||||
let index = self.cache.get(&autotune_operation_set.key());
|
let index = self.cache.get(&autotune_operation_set.key());
|
||||||
if let Some(&i) = index {
|
if let Some(&i) = index {
|
||||||
return TuneCacheResult::Hit(autotune_operation_set.fastest(i));
|
return TuneCacheResult::Hit(autotune_operation_set.fastest(i));
|
||||||
|
@ -42,7 +37,7 @@ impl<S> TuneCache<S> {
|
||||||
TuneCacheResult::Miss(autotune_operation_set)
|
TuneCacheResult::Miss(autotune_operation_set)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn cache_insert(&mut self, key: AutotuneKey, fastest_index: usize) {
|
pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
|
||||||
self.cache.insert(key, fastest_index);
|
self.cache.insert(key, fastest_index);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -13,9 +13,8 @@ use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCa
|
||||||
|
|
||||||
#[derive(Debug, Default)]
|
#[derive(Debug, Default)]
|
||||||
/// Executes autotune benchmarking and caching
|
/// Executes autotune benchmarking and caching
|
||||||
pub struct Tuner<S, C> {
|
pub struct Tuner<S: ComputeServer, C> {
|
||||||
tune_cache: TuneCache<S>,
|
tune_cache: TuneCache<S::AutotuneKey>,
|
||||||
_server: PhantomData<S>,
|
|
||||||
_channel: PhantomData<C>,
|
_channel: PhantomData<C>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -24,14 +23,13 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
|
||||||
pub fn new() -> Self {
|
pub fn new() -> Self {
|
||||||
Self {
|
Self {
|
||||||
tune_cache: TuneCache::new(),
|
tune_cache: TuneCache::new(),
|
||||||
_server: PhantomData,
|
|
||||||
_channel: PhantomData,
|
_channel: PhantomData,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) fn execute_autotune(
|
pub(crate) fn execute_autotune(
|
||||||
&mut self,
|
&mut self,
|
||||||
autotune_operation_set: Box<dyn AutotuneOperationSet>,
|
autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
|
||||||
client: &ComputeClient<S, C>,
|
client: &ComputeClient<S, C>,
|
||||||
) {
|
) {
|
||||||
let operation = match self.tune_cache.try_cache(autotune_operation_set) {
|
let operation = match self.tune_cache.try_cache(autotune_operation_set) {
|
||||||
|
@ -44,7 +42,7 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
|
||||||
|
|
||||||
fn autotuning(
|
fn autotuning(
|
||||||
&mut self,
|
&mut self,
|
||||||
autotune_operation_set: Box<dyn AutotuneOperationSet>,
|
autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
|
||||||
client: &ComputeClient<S, C>,
|
client: &ComputeClient<S, C>,
|
||||||
) -> Box<dyn AutotuneOperation> {
|
) -> Box<dyn AutotuneOperation> {
|
||||||
let key = autotune_operation_set.key();
|
let key = autotune_operation_set.key();
|
||||||
|
|
|
@ -24,6 +24,7 @@ where
|
||||||
type Kernel = Arc<dyn DummyKernel>;
|
type Kernel = Arc<dyn DummyKernel>;
|
||||||
type Storage = BytesStorage;
|
type Storage = BytesStorage;
|
||||||
type MemoryManagement = MM;
|
type MemoryManagement = MM;
|
||||||
|
type AutotuneKey = String;
|
||||||
|
|
||||||
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>> {
|
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>> {
|
||||||
let bytes = self.memory_management.get(&handle.memory);
|
let bytes = self.memory_management.get(&handle.memory);
|
||||||
|
|
|
@ -2,7 +2,7 @@ use std::sync::Arc;
|
||||||
|
|
||||||
use burn_compute::{
|
use burn_compute::{
|
||||||
server::Handle,
|
server::Handle,
|
||||||
tune::{AutotuneKey, AutotuneOperation, AutotuneOperationSet},
|
tune::{AutotuneOperation, AutotuneOperationSet},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::dummy::{
|
use crate::dummy::{
|
||||||
|
@ -15,7 +15,7 @@ use super::DummyElementwiseAdditionSlowWrong;
|
||||||
|
|
||||||
pub struct AdditionAutotuneOperationSet {
|
pub struct AdditionAutotuneOperationSet {
|
||||||
client: DummyClient,
|
client: DummyClient,
|
||||||
key: AutotuneKey,
|
key: String,
|
||||||
shapes: Vec<Vec<usize>>,
|
shapes: Vec<Vec<usize>>,
|
||||||
handles: Vec<Handle<DummyServer>>,
|
handles: Vec<Handle<DummyServer>>,
|
||||||
}
|
}
|
||||||
|
@ -28,15 +28,15 @@ impl AdditionAutotuneOperationSet {
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
key: AutotuneKey::new("add".to_string(), log_shape_input_key(&shapes)),
|
key: format!("{}-{}", "add", log_shape_input_key(&shapes)),
|
||||||
shapes,
|
shapes,
|
||||||
handles,
|
handles,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AutotuneOperationSet for AdditionAutotuneOperationSet {
|
impl AutotuneOperationSet<String> for AdditionAutotuneOperationSet {
|
||||||
fn key(&self) -> AutotuneKey {
|
fn key(&self) -> String {
|
||||||
self.key.clone()
|
self.key.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ impl AutotuneOperationSet for AdditionAutotuneOperationSet {
|
||||||
|
|
||||||
pub struct MultiplicationAutotuneOperationSet {
|
pub struct MultiplicationAutotuneOperationSet {
|
||||||
client: DummyClient,
|
client: DummyClient,
|
||||||
key: AutotuneKey,
|
key: String,
|
||||||
shapes: Vec<Vec<usize>>,
|
shapes: Vec<Vec<usize>>,
|
||||||
handles: Vec<Handle<DummyServer>>,
|
handles: Vec<Handle<DummyServer>>,
|
||||||
}
|
}
|
||||||
|
@ -77,14 +77,14 @@ impl MultiplicationAutotuneOperationSet {
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
key: AutotuneKey::new("mul".to_string(), log_shape_input_key(&shapes)),
|
key: format!("{}-{}", "mul", log_shape_input_key(&shapes)),
|
||||||
shapes,
|
shapes,
|
||||||
handles,
|
handles,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl AutotuneOperationSet for MultiplicationAutotuneOperationSet {
|
impl AutotuneOperationSet<String> for MultiplicationAutotuneOperationSet {
|
||||||
fn key(&self) -> AutotuneKey {
|
fn key(&self) -> String {
|
||||||
self.key.clone()
|
self.key.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,7 +112,7 @@ impl AutotuneOperationSet for MultiplicationAutotuneOperationSet {
|
||||||
|
|
||||||
pub struct CacheTestAutotuneOperationSet {
|
pub struct CacheTestAutotuneOperationSet {
|
||||||
client: DummyClient,
|
client: DummyClient,
|
||||||
key: AutotuneKey,
|
key: String,
|
||||||
shapes: Vec<Vec<usize>>,
|
shapes: Vec<Vec<usize>>,
|
||||||
handles: Vec<Handle<DummyServer>>,
|
handles: Vec<Handle<DummyServer>>,
|
||||||
}
|
}
|
||||||
|
@ -125,14 +125,14 @@ impl CacheTestAutotuneOperationSet {
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
key: AutotuneKey::new("cache_test".to_string(), log_shape_input_key(&shapes)),
|
key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)),
|
||||||
shapes,
|
shapes,
|
||||||
handles,
|
handles,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl AutotuneOperationSet for CacheTestAutotuneOperationSet {
|
impl AutotuneOperationSet<String> for CacheTestAutotuneOperationSet {
|
||||||
fn key(&self) -> AutotuneKey {
|
fn key(&self) -> String {
|
||||||
self.key.clone()
|
self.key.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,10 @@ mod base;
|
||||||
mod kernel;
|
mod kernel;
|
||||||
mod server;
|
mod server;
|
||||||
mod storage;
|
mod storage;
|
||||||
|
mod tune_key;
|
||||||
|
|
||||||
pub use base::*;
|
pub use base::*;
|
||||||
pub use kernel::*;
|
pub use kernel::*;
|
||||||
pub use server::*;
|
pub use server::*;
|
||||||
pub use storage::*;
|
pub use storage::*;
|
||||||
|
pub use tune_key::*;
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use super::{WgpuStorage, WorkGroup};
|
use super::{WgpuAutotuneKey, WgpuStorage, WorkGroup};
|
||||||
use crate::kernel::SourceTemplate;
|
use crate::kernel::SourceTemplate;
|
||||||
use alloc::{borrow::Cow, sync::Arc};
|
use alloc::{borrow::Cow, sync::Arc};
|
||||||
use burn_compute::{
|
use burn_compute::{
|
||||||
|
@ -254,6 +254,7 @@ where
|
||||||
type Kernel = Box<dyn Kernel>;
|
type Kernel = Box<dyn Kernel>;
|
||||||
type Storage = WgpuStorage;
|
type Storage = WgpuStorage;
|
||||||
type MemoryManagement = MM;
|
type MemoryManagement = MM;
|
||||||
|
type AutotuneKey = WgpuAutotuneKey;
|
||||||
|
|
||||||
fn read(&mut self, handle: &server::Handle<Self>) -> Reader<Vec<u8>> {
|
fn read(&mut self, handle: &server::Handle<Self>) -> Reader<Vec<u8>> {
|
||||||
#[cfg(target_family = "wasm")]
|
#[cfg(target_family = "wasm")]
|
||||||
|
|
|
@ -0,0 +1,22 @@
|
||||||
|
use std::fmt::Display;
|
||||||
|
|
||||||
|
use burn_compute::tune::AutotuneKey;
|
||||||
|
|
||||||
|
use crate::kernel::matmul::MatmulAutotuneKey;
|
||||||
|
|
||||||
|
#[derive(Hash, Eq, PartialEq, Debug, Clone)]
|
||||||
|
/// Key for all autotune-enabled operations
|
||||||
|
pub enum WgpuAutotuneKey {
|
||||||
|
/// Key for matmul operation
|
||||||
|
Matmul(MatmulAutotuneKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for WgpuAutotuneKey {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
match self {
|
||||||
|
WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AutotuneKey for WgpuAutotuneKey {}
|
|
@ -1,28 +1,27 @@
|
||||||
use burn_compute::tune::{AutotuneKey, AutotuneOperation, AutotuneOperationSet};
|
use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet};
|
||||||
use burn_tensor::Element;
|
use burn_tensor::Element;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
compute::WgpuAutotuneKey,
|
||||||
element::WgpuElement,
|
element::WgpuElement,
|
||||||
kernel::matmul::{tune::utils::autotune_tensors, utils::init_matmul_output},
|
kernel::matmul::{tune::utils::autotune_tensors, utils::init_matmul_output},
|
||||||
tensor::WgpuTensor,
|
tensor::WgpuTensor,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use super::key::MatmulAutotuneKey;
|
||||||
|
|
||||||
/// Set of matmul implementations available for autotune
|
/// Set of matmul implementations available for autotune
|
||||||
/// Autotune key is given by concatenating the closest upper power of 2 of m, k and n
|
/// Autotune key is given by concatenating the closest upper power of 2 of m, k and n
|
||||||
pub struct MatmulAutotuneOperationSet<E: WgpuElement, const D: usize> {
|
pub struct MatmulAutotuneOperationSet<E: WgpuElement, const D: usize> {
|
||||||
key: AutotuneKey,
|
key: WgpuAutotuneKey,
|
||||||
lhs: WgpuTensor<E, D>,
|
lhs: WgpuTensor<E, D>,
|
||||||
rhs: WgpuTensor<E, D>,
|
rhs: WgpuTensor<E, D>,
|
||||||
out: WgpuTensor<E, D>,
|
out: WgpuTensor<E, D>,
|
||||||
}
|
}
|
||||||
impl<E: WgpuElement, const D: usize> MatmulAutotuneOperationSet<E, D> {
|
impl<E: WgpuElement, const D: usize> MatmulAutotuneOperationSet<E, D> {
|
||||||
fn new(lhs: WgpuTensor<E, D>, rhs: WgpuTensor<E, D>, out: WgpuTensor<E, D>) -> Self {
|
fn new(lhs: WgpuTensor<E, D>, rhs: WgpuTensor<E, D>, out: WgpuTensor<E, D>) -> Self {
|
||||||
let m = lhs.shape.dims[D - 2];
|
|
||||||
let k = lhs.shape.dims[D - 1];
|
|
||||||
let n = rhs.shape.dims[D - 1];
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
key: AutotuneKey::new("matmul".to_string(), log_mkn_input_key(m, k, n)),
|
key: WgpuAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)),
|
||||||
lhs,
|
lhs,
|
||||||
rhs,
|
rhs,
|
||||||
out,
|
out,
|
||||||
|
@ -30,34 +29,10 @@ impl<E: WgpuElement, const D: usize> MatmulAutotuneOperationSet<E, D> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn log_mkn_input_key(m: usize, k: usize, n: usize) -> String {
|
impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet<WgpuAutotuneKey>
|
||||||
let mut desc = String::new();
|
|
||||||
let mut diff = false;
|
|
||||||
|
|
||||||
for size in [m, k, n] {
|
|
||||||
if !desc.is_empty() {
|
|
||||||
desc.push('-');
|
|
||||||
}
|
|
||||||
let exp = f32::ceil(f32::log2(size as f32)) as u32;
|
|
||||||
let updated = 2_u32.pow(exp);
|
|
||||||
|
|
||||||
if updated != size as u32 {
|
|
||||||
diff = true;
|
|
||||||
}
|
|
||||||
desc.push_str(updated.to_string().as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if diff {
|
|
||||||
desc.push_str("-uneven");
|
|
||||||
}
|
|
||||||
|
|
||||||
desc
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet
|
|
||||||
for MatmulAutotuneOperationSet<E, D>
|
for MatmulAutotuneOperationSet<E, D>
|
||||||
{
|
{
|
||||||
fn key(&self) -> AutotuneKey {
|
fn key(&self) -> WgpuAutotuneKey {
|
||||||
self.key.clone()
|
self.key.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -187,23 +162,3 @@ matmul_tune_ops!(
|
||||||
Vec4TilingMatmulUnpaddedDefault,
|
Vec4TilingMatmulUnpaddedDefault,
|
||||||
crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded
|
crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded
|
||||||
);
|
);
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn matmul_autotune_mkn_key() {
|
|
||||||
let key = log_mkn_input_key(512, 512, 512);
|
|
||||||
assert_eq!(key, "512-512-512");
|
|
||||||
|
|
||||||
let key = log_mkn_input_key(512, 256, 512);
|
|
||||||
assert_eq!(key, "512-256-512");
|
|
||||||
|
|
||||||
let key = log_mkn_input_key(512, 256, 127);
|
|
||||||
assert_eq!(key, "512-256-128-uneven");
|
|
||||||
|
|
||||||
let key = log_mkn_input_key(2, 149, 2344);
|
|
||||||
assert_eq!(key, "2-256-4096-uneven");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,119 @@
|
||||||
|
use burn_tensor::Shape;
|
||||||
|
use core::fmt::Debug;
|
||||||
|
use std::{
|
||||||
|
cmp::{max, min},
|
||||||
|
fmt::Display,
|
||||||
|
hash::Hash,
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(Hash, Eq, PartialEq, Debug, Clone)]
|
||||||
|
/// Autotune key representative of matmul versions
|
||||||
|
pub struct MatmulAutotuneKey {
|
||||||
|
round: bool, // True when all matmul dims are multiples of 64
|
||||||
|
broadcast: bool, // True when there are differences in batch size
|
||||||
|
anchored_m: usize,
|
||||||
|
anchored_k: usize,
|
||||||
|
anchored_n: usize,
|
||||||
|
anchored_batch: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Display for MatmulAutotuneKey {
|
||||||
|
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
|
||||||
|
f.write_str(
|
||||||
|
format!(
|
||||||
|
"Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}",
|
||||||
|
self.round,
|
||||||
|
self.broadcast,
|
||||||
|
self.anchored_m,
|
||||||
|
self.anchored_k,
|
||||||
|
self.anchored_n,
|
||||||
|
self.anchored_batch
|
||||||
|
)
|
||||||
|
.as_str(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MatmulAutotuneKey {
|
||||||
|
/// Create a matmul autotune key from the input shapes
|
||||||
|
pub fn new<const D: usize>(lhs_shape: &Shape<D>, rhs_shape: &Shape<D>) -> Self {
|
||||||
|
let m = lhs_shape.dims[D - 2];
|
||||||
|
let k = lhs_shape.dims[D - 1];
|
||||||
|
let n = rhs_shape.dims[D - 1];
|
||||||
|
|
||||||
|
let mut broadcast = false;
|
||||||
|
let mut batch_product_lhs = 1;
|
||||||
|
let mut batch_product_rhs = 1;
|
||||||
|
|
||||||
|
for b in 0..D - 2 {
|
||||||
|
batch_product_lhs *= lhs_shape.dims[b];
|
||||||
|
batch_product_rhs *= rhs_shape.dims[b];
|
||||||
|
if lhs_shape.dims[b] != rhs_shape.dims[b] {
|
||||||
|
broadcast = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
let batch_product = max(batch_product_lhs, batch_product_rhs);
|
||||||
|
|
||||||
|
let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0;
|
||||||
|
|
||||||
|
Self {
|
||||||
|
round,
|
||||||
|
broadcast,
|
||||||
|
anchored_m: anchor(m, None),
|
||||||
|
anchored_k: anchor(k, None),
|
||||||
|
anchored_n: anchor(n, None),
|
||||||
|
anchored_batch: anchor(batch_product, Some(256)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn anchor(x: usize, max: Option<usize>) -> usize {
|
||||||
|
let exp = f32::ceil(f32::log2(x as f32)) as u32;
|
||||||
|
let power_of_2 = 2_u32.pow(exp) as usize;
|
||||||
|
if let Some(max) = max {
|
||||||
|
min(power_of_2, max)
|
||||||
|
} else {
|
||||||
|
power_of_2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_autotune_key_all_same_and_round() {
|
||||||
|
let lhs_shape: Shape<3> = [4, 512, 512].into();
|
||||||
|
let rhs_shape: Shape<3> = [4, 512, 512].into();
|
||||||
|
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
|
||||||
|
|
||||||
|
assert!(key.round);
|
||||||
|
assert!(!key.broadcast);
|
||||||
|
assert!(key.anchored_m == 512);
|
||||||
|
assert!(key.anchored_k == 512);
|
||||||
|
assert!(key.anchored_n == 512);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_autotune_key_all_different() {
|
||||||
|
let lhs_shape: Shape<4> = [2, 3, 511, 512].into();
|
||||||
|
let rhs_shape: Shape<4> = [3, 2, 512, 513].into();
|
||||||
|
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
|
||||||
|
|
||||||
|
assert!(!key.round);
|
||||||
|
assert!(key.broadcast);
|
||||||
|
assert!(key.anchored_m == 512);
|
||||||
|
assert!(key.anchored_k == 512);
|
||||||
|
assert!(key.anchored_n == 1024);
|
||||||
|
assert!(key.anchored_batch == 8);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn matmul_autotune_key_large_batch() {
|
||||||
|
let lhs_shape: Shape<4> = [128, 512, 511, 512].into();
|
||||||
|
let rhs_shape: Shape<4> = [200, 400, 512, 513].into();
|
||||||
|
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
|
||||||
|
|
||||||
|
assert!(key.anchored_batch == 256);
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,6 @@
|
||||||
mod base;
|
mod base;
|
||||||
|
mod key;
|
||||||
mod utils;
|
mod utils;
|
||||||
|
|
||||||
pub use base::*;
|
pub use base::*;
|
||||||
|
pub use key::*;
|
||||||
|
|
Loading…
Reference in New Issue