Refactor/autotune/key (#924)

This commit is contained in:
Louis Fortier-Dubois 2023-11-03 08:46:25 -04:00 committed by GitHub
parent 8c80c9b94a
commit 1cc1844d32
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 203 additions and 115 deletions

View File

@ -12,7 +12,7 @@ use spin::Mutex;
/// The ComputeClient is the entry point to require tasks from the ComputeServer.
/// It should be obtained for a specific device via the Compute struct.
#[derive(Debug)]
pub struct ComputeClient<Server, Channel> {
pub struct ComputeClient<Server: ComputeServer, Channel> {
channel: Channel,
tuner: Arc<Mutex<Tuner<Server, Channel>>>,
_server: PhantomData<Server>,
@ -72,7 +72,10 @@ where
}
/// Executes the fastest kernel in the autotune operation, using (cached) runtime benchmarks
pub fn execute_autotune(&self, autotune_operation_set: Box<dyn AutotuneOperationSet>) {
pub fn execute_autotune(
&self,
autotune_operation_set: Box<dyn AutotuneOperationSet<Server::AutotuneKey>>,
) {
self.tuner
.lock()
.execute_autotune(autotune_operation_set, self);

View File

@ -4,7 +4,7 @@ use hashbrown::HashMap;
/// The compute type has the responsibility to retrieve the correct compute client based on the
/// given device.
pub struct Compute<Device, Server, Channel> {
pub struct Compute<Device, Server: ComputeServer, Channel> {
clients: spin::Mutex<Option<HashMap<Device, ComputeClient<Server, Channel>>>>,
}

View File

@ -1,6 +1,9 @@
use core::fmt::Debug;
use crate::{
memory_management::{MemoryHandle, MemoryManagement},
storage::ComputeStorage,
tune::AutotuneKey,
};
use alloc::vec::Vec;
use burn_common::reader::Reader;
@ -19,6 +22,8 @@ where
type Storage: ComputeStorage;
/// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
type MemoryManagement: MemoryManagement<Self::Storage>;
/// The key used to cache operations used on specific inputs in autotune
type AutotuneKey: AutotuneKey;
/// Given a handle, returns the owned resource as bytes.
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>>;

View File

@ -1,13 +1,13 @@
use alloc::boxed::Box;
use alloc::format;
use alloc::string::String;
use alloc::vec::Vec;
use core::fmt::Display;
use core::fmt::{Debug, Display};
use core::hash::Hash;
/// Groups operations of the same type for autotune
pub trait AutotuneOperationSet: Send {
pub trait AutotuneOperationSet<K>: Send {
/// The key used in the tune cache
fn key(&self) -> AutotuneKey;
fn key(&self) -> K;
/// All candidate operations for autotuning this operation type
/// Operations can run on toy tensors of relevant size
@ -32,16 +32,6 @@ pub trait AutotuneOperation {
fn clone(&self) -> Box<dyn AutotuneOperation>;
}
#[derive(new, Clone, Debug, PartialEq, Eq, Hash)]
/// The key used in the tune cache, referring to the operation type,
/// generally hardcoded for an autotune operation, and to the input shape
pub struct AutotuneKey {
operation: String,
input_description: String,
}
impl Display for AutotuneKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("{}-{}", self.operation, self.input_description).as_str())
}
}
/// Trait alias
pub trait AutotuneKey: Clone + Debug + PartialEq + Eq + Hash + Display {}
impl AutotuneKey for String {}

View File

@ -7,8 +7,6 @@ use crate::server::ComputeServer;
use super::AutotuneOperation;
use alloc::boxed::Box;
use alloc::string::{String, ToString};
use alloc::vec;
use alloc::vec::Vec;
/// A benchmark that runs on server handles
#[derive(new)]
@ -18,20 +16,17 @@ pub struct TuneBenchmark<S: ComputeServer, C> {
}
impl<S: ComputeServer, C: ComputeChannel<S>> Benchmark for TuneBenchmark<S, C> {
// list of operations
type Args = Vec<Box<dyn AutotuneOperation>>;
type Args = Box<dyn AutotuneOperation>;
fn prepare(&self) -> Self::Args {
vec![self.operation.clone()]
self.operation.clone()
}
fn num_samples(&self) -> usize {
10
}
fn execute(&self, args: Self::Args) {
let operation = args[0].clone(); // TODO rm 0
fn execute(&self, operation: Self::Args) {
AutotuneOperation::execute(operation);
}

View File

@ -1,40 +1,35 @@
use core::marker::PhantomData;
use hashbrown::HashMap;
use super::AutotuneKey;
use super::AutotuneOperation;
use super::AutotuneOperationSet;
use alloc::boxed::Box;
use hashbrown::HashMap;
/// Use to find and reuse the best kernel for some input
#[derive(Debug, Default)]
pub(crate) struct TuneCache<S> {
cache: HashMap<AutotuneKey, usize>,
_server: PhantomData<S>,
pub(crate) struct TuneCache<K> {
cache: HashMap<K, usize>,
}
/// Result of the cache try
pub enum TuneCacheResult {
pub enum TuneCacheResult<K> {
/// An operation is found and given
Hit(Box<dyn AutotuneOperation>),
/// No operation is found and the set is given back for ownership
Miss(Box<dyn AutotuneOperationSet>),
Miss(Box<dyn AutotuneOperationSet<K>>),
}
impl<S> TuneCache<S> {
impl<K: AutotuneKey> TuneCache<K> {
pub(crate) fn new() -> Self {
TuneCache {
cache: HashMap::new(),
_server: PhantomData,
}
}
#[allow(clippy::borrowed_box)]
pub(crate) fn try_cache(
&self,
autotune_operation_set: Box<dyn AutotuneOperationSet>,
) -> TuneCacheResult {
autotune_operation_set: Box<dyn AutotuneOperationSet<K>>,
) -> TuneCacheResult<K> {
let index = self.cache.get(&autotune_operation_set.key());
if let Some(&i) = index {
return TuneCacheResult::Hit(autotune_operation_set.fastest(i));
@ -42,7 +37,7 @@ impl<S> TuneCache<S> {
TuneCacheResult::Miss(autotune_operation_set)
}
pub(crate) fn cache_insert(&mut self, key: AutotuneKey, fastest_index: usize) {
pub(crate) fn cache_insert(&mut self, key: K, fastest_index: usize) {
self.cache.insert(key, fastest_index);
}
}

View File

@ -13,9 +13,8 @@ use crate::tune::{AutotuneOperation, AutotuneOperationSet, TuneBenchmark, TuneCa
#[derive(Debug, Default)]
/// Executes autotune benchmarking and caching
pub struct Tuner<S, C> {
tune_cache: TuneCache<S>,
_server: PhantomData<S>,
pub struct Tuner<S: ComputeServer, C> {
tune_cache: TuneCache<S::AutotuneKey>,
_channel: PhantomData<C>,
}
@ -24,14 +23,13 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
pub fn new() -> Self {
Self {
tune_cache: TuneCache::new(),
_server: PhantomData,
_channel: PhantomData,
}
}
pub(crate) fn execute_autotune(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet>,
autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
client: &ComputeClient<S, C>,
) {
let operation = match self.tune_cache.try_cache(autotune_operation_set) {
@ -44,7 +42,7 @@ impl<S: ComputeServer, C: ComputeChannel<S>> Tuner<S, C> {
fn autotuning(
&mut self,
autotune_operation_set: Box<dyn AutotuneOperationSet>,
autotune_operation_set: Box<dyn AutotuneOperationSet<S::AutotuneKey>>,
client: &ComputeClient<S, C>,
) -> Box<dyn AutotuneOperation> {
let key = autotune_operation_set.key();

View File

@ -24,6 +24,7 @@ where
type Kernel = Arc<dyn DummyKernel>;
type Storage = BytesStorage;
type MemoryManagement = MM;
type AutotuneKey = String;
fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>> {
let bytes = self.memory_management.get(&handle.memory);

View File

@ -2,7 +2,7 @@ use std::sync::Arc;
use burn_compute::{
server::Handle,
tune::{AutotuneKey, AutotuneOperation, AutotuneOperationSet},
tune::{AutotuneOperation, AutotuneOperationSet},
};
use crate::dummy::{
@ -15,7 +15,7 @@ use super::DummyElementwiseAdditionSlowWrong;
pub struct AdditionAutotuneOperationSet {
client: DummyClient,
key: AutotuneKey,
key: String,
shapes: Vec<Vec<usize>>,
handles: Vec<Handle<DummyServer>>,
}
@ -28,15 +28,15 @@ impl AdditionAutotuneOperationSet {
) -> Self {
Self {
client,
key: AutotuneKey::new("add".to_string(), log_shape_input_key(&shapes)),
key: format!("{}-{}", "add", log_shape_input_key(&shapes)),
shapes,
handles,
}
}
}
impl AutotuneOperationSet for AdditionAutotuneOperationSet {
fn key(&self) -> AutotuneKey {
impl AutotuneOperationSet<String> for AdditionAutotuneOperationSet {
fn key(&self) -> String {
self.key.clone()
}
@ -64,7 +64,7 @@ impl AutotuneOperationSet for AdditionAutotuneOperationSet {
pub struct MultiplicationAutotuneOperationSet {
client: DummyClient,
key: AutotuneKey,
key: String,
shapes: Vec<Vec<usize>>,
handles: Vec<Handle<DummyServer>>,
}
@ -77,14 +77,14 @@ impl MultiplicationAutotuneOperationSet {
) -> Self {
Self {
client,
key: AutotuneKey::new("mul".to_string(), log_shape_input_key(&shapes)),
key: format!("{}-{}", "mul", log_shape_input_key(&shapes)),
shapes,
handles,
}
}
}
impl AutotuneOperationSet for MultiplicationAutotuneOperationSet {
fn key(&self) -> AutotuneKey {
impl AutotuneOperationSet<String> for MultiplicationAutotuneOperationSet {
fn key(&self) -> String {
self.key.clone()
}
@ -112,7 +112,7 @@ impl AutotuneOperationSet for MultiplicationAutotuneOperationSet {
pub struct CacheTestAutotuneOperationSet {
client: DummyClient,
key: AutotuneKey,
key: String,
shapes: Vec<Vec<usize>>,
handles: Vec<Handle<DummyServer>>,
}
@ -125,14 +125,14 @@ impl CacheTestAutotuneOperationSet {
) -> Self {
Self {
client,
key: AutotuneKey::new("cache_test".to_string(), log_shape_input_key(&shapes)),
key: format!("{}-{}", "cache_test", log_shape_input_key(&shapes)),
shapes,
handles,
}
}
}
impl AutotuneOperationSet for CacheTestAutotuneOperationSet {
fn key(&self) -> AutotuneKey {
impl AutotuneOperationSet<String> for CacheTestAutotuneOperationSet {
fn key(&self) -> String {
self.key.clone()
}

View File

@ -2,8 +2,10 @@ mod base;
mod kernel;
mod server;
mod storage;
mod tune_key;
pub use base::*;
pub use kernel::*;
pub use server::*;
pub use storage::*;
pub use tune_key::*;

View File

@ -1,4 +1,4 @@
use super::{WgpuStorage, WorkGroup};
use super::{WgpuAutotuneKey, WgpuStorage, WorkGroup};
use crate::kernel::SourceTemplate;
use alloc::{borrow::Cow, sync::Arc};
use burn_compute::{
@ -254,6 +254,7 @@ where
type Kernel = Box<dyn Kernel>;
type Storage = WgpuStorage;
type MemoryManagement = MM;
type AutotuneKey = WgpuAutotuneKey;
fn read(&mut self, handle: &server::Handle<Self>) -> Reader<Vec<u8>> {
#[cfg(target_family = "wasm")]

View File

@ -0,0 +1,22 @@
use std::fmt::Display;
use burn_compute::tune::AutotuneKey;
use crate::kernel::matmul::MatmulAutotuneKey;
#[derive(Hash, Eq, PartialEq, Debug, Clone)]
/// Key for all autotune-enabled operations
pub enum WgpuAutotuneKey {
/// Key for matmul operation
Matmul(MatmulAutotuneKey),
}
impl Display for WgpuAutotuneKey {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
WgpuAutotuneKey::Matmul(matmul_key) => std::fmt::Display::fmt(&matmul_key, f),
}
}
}
impl AutotuneKey for WgpuAutotuneKey {}

View File

@ -1,28 +1,27 @@
use burn_compute::tune::{AutotuneKey, AutotuneOperation, AutotuneOperationSet};
use burn_compute::tune::{AutotuneOperation, AutotuneOperationSet};
use burn_tensor::Element;
use crate::{
compute::WgpuAutotuneKey,
element::WgpuElement,
kernel::matmul::{tune::utils::autotune_tensors, utils::init_matmul_output},
tensor::WgpuTensor,
};
use super::key::MatmulAutotuneKey;
/// Set of matmul implementations available for autotune
/// Autotune key is given by concatenating the closest upper power of 2 of m, k and n
pub struct MatmulAutotuneOperationSet<E: WgpuElement, const D: usize> {
key: AutotuneKey,
key: WgpuAutotuneKey,
lhs: WgpuTensor<E, D>,
rhs: WgpuTensor<E, D>,
out: WgpuTensor<E, D>,
}
impl<E: WgpuElement, const D: usize> MatmulAutotuneOperationSet<E, D> {
fn new(lhs: WgpuTensor<E, D>, rhs: WgpuTensor<E, D>, out: WgpuTensor<E, D>) -> Self {
let m = lhs.shape.dims[D - 2];
let k = lhs.shape.dims[D - 1];
let n = rhs.shape.dims[D - 1];
Self {
key: AutotuneKey::new("matmul".to_string(), log_mkn_input_key(m, k, n)),
key: WgpuAutotuneKey::Matmul(MatmulAutotuneKey::new(&lhs.shape, &rhs.shape)),
lhs,
rhs,
out,
@ -30,34 +29,10 @@ impl<E: WgpuElement, const D: usize> MatmulAutotuneOperationSet<E, D> {
}
}
fn log_mkn_input_key(m: usize, k: usize, n: usize) -> String {
let mut desc = String::new();
let mut diff = false;
for size in [m, k, n] {
if !desc.is_empty() {
desc.push('-');
}
let exp = f32::ceil(f32::log2(size as f32)) as u32;
let updated = 2_u32.pow(exp);
if updated != size as u32 {
diff = true;
}
desc.push_str(updated.to_string().as_str());
}
if diff {
desc.push_str("-uneven");
}
desc
}
impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet
impl<E: WgpuElement + Element, const D: usize> AutotuneOperationSet<WgpuAutotuneKey>
for MatmulAutotuneOperationSet<E, D>
{
fn key(&self) -> AutotuneKey {
fn key(&self) -> WgpuAutotuneKey {
self.key.clone()
}
@ -187,23 +162,3 @@ matmul_tune_ops!(
Vec4TilingMatmulUnpaddedDefault,
crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded
);
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn matmul_autotune_mkn_key() {
let key = log_mkn_input_key(512, 512, 512);
assert_eq!(key, "512-512-512");
let key = log_mkn_input_key(512, 256, 512);
assert_eq!(key, "512-256-512");
let key = log_mkn_input_key(512, 256, 127);
assert_eq!(key, "512-256-128-uneven");
let key = log_mkn_input_key(2, 149, 2344);
assert_eq!(key, "2-256-4096-uneven");
}
}

View File

@ -0,0 +1,119 @@
use burn_tensor::Shape;
use core::fmt::Debug;
use std::{
cmp::{max, min},
fmt::Display,
hash::Hash,
};
#[derive(Hash, Eq, PartialEq, Debug, Clone)]
/// Autotune key representative of matmul versions
pub struct MatmulAutotuneKey {
round: bool, // True when all matmul dims are multiples of 64
broadcast: bool, // True when there are differences in batch size
anchored_m: usize,
anchored_k: usize,
anchored_n: usize,
anchored_batch: usize,
}
impl Display for MatmulAutotuneKey {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(
format!(
"Matmul - Round:{:?} Broadcast:{:?} m:{:?} k:{:?} n:{:?} batch:{:?}",
self.round,
self.broadcast,
self.anchored_m,
self.anchored_k,
self.anchored_n,
self.anchored_batch
)
.as_str(),
)
}
}
impl MatmulAutotuneKey {
/// Create a matmul autotune key from the input shapes
pub fn new<const D: usize>(lhs_shape: &Shape<D>, rhs_shape: &Shape<D>) -> Self {
let m = lhs_shape.dims[D - 2];
let k = lhs_shape.dims[D - 1];
let n = rhs_shape.dims[D - 1];
let mut broadcast = false;
let mut batch_product_lhs = 1;
let mut batch_product_rhs = 1;
for b in 0..D - 2 {
batch_product_lhs *= lhs_shape.dims[b];
batch_product_rhs *= rhs_shape.dims[b];
if lhs_shape.dims[b] != rhs_shape.dims[b] {
broadcast = true;
}
}
let batch_product = max(batch_product_lhs, batch_product_rhs);
let round = m % 64 == 0 && k % 64 == 0 && n % 64 == 0;
Self {
round,
broadcast,
anchored_m: anchor(m, None),
anchored_k: anchor(k, None),
anchored_n: anchor(n, None),
anchored_batch: anchor(batch_product, Some(256)),
}
}
}
fn anchor(x: usize, max: Option<usize>) -> usize {
let exp = f32::ceil(f32::log2(x as f32)) as u32;
let power_of_2 = 2_u32.pow(exp) as usize;
if let Some(max) = max {
min(power_of_2, max)
} else {
power_of_2
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn matmul_autotune_key_all_same_and_round() {
let lhs_shape: Shape<3> = [4, 512, 512].into();
let rhs_shape: Shape<3> = [4, 512, 512].into();
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
assert!(key.round);
assert!(!key.broadcast);
assert!(key.anchored_m == 512);
assert!(key.anchored_k == 512);
assert!(key.anchored_n == 512);
}
#[test]
fn matmul_autotune_key_all_different() {
let lhs_shape: Shape<4> = [2, 3, 511, 512].into();
let rhs_shape: Shape<4> = [3, 2, 512, 513].into();
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
assert!(!key.round);
assert!(key.broadcast);
assert!(key.anchored_m == 512);
assert!(key.anchored_k == 512);
assert!(key.anchored_n == 1024);
assert!(key.anchored_batch == 8);
}
#[test]
fn matmul_autotune_key_large_batch() {
let lhs_shape: Shape<4> = [128, 512, 511, 512].into();
let rhs_shape: Shape<4> = [200, 400, 512, 513].into();
let key = MatmulAutotuneKey::new(&lhs_shape, &rhs_shape);
assert!(key.anchored_batch == 256);
}
}

View File

@ -1,4 +1,6 @@
mod base;
mod key;
mod utils;
pub use base::*;
pub use key::*;