Refactor/burn benchmark (#829)

This commit is contained in:
Louis Fortier-Dubois 2023-09-28 09:38:21 -04:00 committed by GitHub
parent d06cc2f239
commit aa90fe8efb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 467 additions and 351 deletions

View File

@ -23,6 +23,7 @@ members = [
"burn-train",
"xtask",
"examples/*",
"backend-comparison",
]
exclude = ["examples/notebook"]

View File

@ -0,0 +1,48 @@
[package]
authors = ["louisfd <louisfd94@gmail.com>"]
categories = ["science"]
description = "This crate is used to time the execution of various computations, from operation kernels to complex model scenarios."
edition = "2021"
license = "MIT OR Apache-2.0"
name = "backend-comparison"
readme = "README.md"
repository = "https://github.com/burn-rs/burn/tree/main/backend-comparison"
version = "0.10.0"
[features]
default = ["std"]
std = []
ndarray = ["burn/ndarray"]
ndarray-blas-accelerate = ["burn/ndarray-blas-accelerate"]
ndarray-blas-netlib = ["burn/ndarray-blas-netlib"]
ndarray-blas-openblas = ["burn/ndarray-blas-openblas"]
tch-cpu = ["burn/tch"]
tch-gpu = ["burn/tch"]
wgpu = ["burn/wgpu"]
[dependencies]
burn = { path = "../burn" }
derive-new = { workspace = true }
rand = { workspace = true }
burn-tensor = { path = "../burn-tensor", version = "0.10.0", features = [
"benchmark",
] }
[dev-dependencies]
[[bench]]
name = "unary"
harness = false
[[bench]]
name = "binary"
harness = false
[[bench]]
name = "matmul"
harness = false
[[bench]]
name = "data"
harness = false

View File

@ -0,0 +1,8 @@
# Burn Benchmark
This crate is used with `cargo bench --features <backend>`
to compare backend computation times, from tensor operations to complex models.
Note: in order to compare different backend-specific tensor operation
implementations (for autotuning purposes, for instance), this should be done
within the corresponding backend crate.

View File

@ -0,0 +1,51 @@
use std::marker::PhantomData;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_tensor::benchmark::{run_benchmark, Benchmark};
pub struct BinaryBenchmark<B: Backend, const D: usize> {
shape: Shape<D>,
num_repeats: usize,
backend: PhantomData<B>,
}
impl<B: Backend, const D: usize> Benchmark<B> for BinaryBenchmark<B, D> {
type Args = (Tensor<B, D>, Tensor<B, D>);
fn name(&self) -> String {
"Binary Ops".into()
}
fn execute(&self, (lhs, rhs): Self::Args) {
for _ in 0..self.num_repeats {
// Choice of add is arbitrary
B::add(lhs.clone().into_primitive(), rhs.clone().into_primitive());
}
}
fn prepare(&self, device: &B::Device) -> Self::Args {
let lhs = Tensor::random(self.shape.clone(), Distribution::Default).to_device(device);
let rhs = Tensor::random(self.shape.clone(), Distribution::Default).to_device(device);
(lhs, rhs)
}
}
#[allow(dead_code)]
fn bench<B: Backend>(device: &B::Device) {
const D: usize = 3;
let shape: Shape<D> = [32, 512, 1024].into();
let num_repeats = 10;
let benchmark = BinaryBenchmark::<B, D> {
shape,
num_repeats,
backend: PhantomData,
};
run_benchmark(benchmark, device)
}
fn main() {
backend_comparison::bench_on_backend!();
}

View File

@ -0,0 +1,79 @@
use std::marker::PhantomData;
use burn::tensor::{backend::Backend, Data, Distribution, Shape, Tensor};
use burn_tensor::benchmark::{run_benchmark, Benchmark};
use derive_new::new;
#[derive(new)]
struct ToDataBenchmark<B: Backend, const D: usize> {
shape: Shape<D>,
num_repeats: usize,
backend: PhantomData<B>,
}
impl<B: Backend, const D: usize> Benchmark<B> for ToDataBenchmark<B, D> {
type Args = Tensor<B, D>;
fn name(&self) -> String {
format!("to-data-{:?}-{}", self.shape.dims, self.num_repeats)
}
fn execute(&self, args: Self::Args) {
for _ in 0..self.num_repeats {
let _data = args.to_data();
}
}
fn prepare(&self, device: &B::Device) -> Self::Args {
Tensor::random_device(self.shape.clone(), Distribution::Default, device)
}
}
#[derive(new)]
struct FromDataBenchmark<B: Backend, const D: usize> {
shape: Shape<D>,
num_repeats: usize,
backend: PhantomData<B>,
}
impl<B: Backend, const D: usize> Benchmark<B> for FromDataBenchmark<B, D> {
type Args = (Data<B::FloatElem, D>, B::Device);
fn name(&self) -> String {
format!("from-data-{:?}-{}", self.shape.dims, self.num_repeats)
}
fn execute(&self, (data, device): Self::Args) {
for _ in 0..self.num_repeats {
let _data = Tensor::<B, D>::from_data_device(data.clone(), &device);
}
}
fn prepare(&self, device: &B::Device) -> Self::Args {
(
Data::random(
self.shape.clone(),
Distribution::Default,
&mut rand::thread_rng(),
),
device.clone(),
)
}
}
#[allow(dead_code)]
fn bench<B: Backend>(device: &B::Device) {
const D: usize = 3;
let shape: Shape<D> = [32, 512, 1024].into();
let num_repeats = 10;
let to_benchmark = ToDataBenchmark::<B, D>::new(shape.clone(), num_repeats);
let from_benchmark = FromDataBenchmark::<B, D>::new(shape, num_repeats);
run_benchmark(to_benchmark, device);
run_benchmark(from_benchmark, device)
}
fn main() {
backend_comparison::bench_on_backend!();
}

View File

@ -0,0 +1,59 @@
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_tensor::benchmark::{run_benchmark, Benchmark};
use derive_new::new;
use std::marker::PhantomData;
#[derive(new)]
struct MatmulBenchmark<B, const D: usize> {
shape_lhs: Shape<D>,
shape_rhs: Shape<D>,
num_repeats: usize,
backend: PhantomData<B>,
}
impl<B: Backend, const D: usize> Benchmark<B> for MatmulBenchmark<B, D> {
type Args = (Tensor<B, D>, Tensor<B, D>);
fn name(&self) -> String {
format!(
"Matmul {:?} x {:?}",
self.shape_lhs.dims, self.shape_rhs.dims
)
}
fn num_samples(&self) -> usize {
10
}
fn execute(&self, (lhs, rhs): Self::Args) {
for _ in 0..self.num_repeats {
lhs.clone().matmul(rhs.clone());
}
}
fn prepare(&self, device: &B::Device) -> Self::Args {
let lhs = Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, device);
let rhs = Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, device);
(lhs, rhs)
}
}
#[allow(dead_code)]
fn bench<B: Backend>(device: &B::Device) {
const D: usize = 3;
let num_repeats = 3;
let batch_size = 3;
let m = 1024;
let k = 2048;
let n = 1024;
let shape_lhs = [batch_size, m, k].into();
let shape_rhs = [batch_size, k, n].into();
let benchmark = MatmulBenchmark::<B, D>::new(shape_lhs, shape_rhs, num_repeats);
run_benchmark(benchmark, device);
}
fn main() {
backend_comparison::bench_on_backend!();
}

View File

@ -0,0 +1,46 @@
use std::marker::PhantomData;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_tensor::benchmark::{run_benchmark, Benchmark};
use derive_new::new;
#[derive(new)]
struct UnaryBenchmark<B: Backend, const D: usize> {
shape: Shape<D>,
num_repeats: usize,
backend: PhantomData<B>,
}
impl<B: Backend, const D: usize> Benchmark<B> for UnaryBenchmark<B, D> {
type Args = Tensor<B, D>;
fn name(&self) -> String {
"Unary Ops".into()
}
fn execute(&self, args: Self::Args) {
for _ in 0..self.num_repeats {
// Choice of tanh is arbitrary
B::tanh(args.clone().into_primitive());
}
}
fn prepare(&self, device: &B::Device) -> Self::Args {
Tensor::random_device(self.shape.clone(), Distribution::Default, device)
}
}
#[allow(dead_code)]
fn bench<B: Backend>(device: &B::Device) {
const D: usize = 3;
let shape: Shape<D> = [32, 512, 1024].into();
let num_repeats = 10;
let benchmark = UnaryBenchmark::<B, D>::new(shape, num_repeats);
run_benchmark(benchmark, device)
}
fn main() {
backend_comparison::bench_on_backend!();
}

View File

@ -0,0 +1,44 @@
#[macro_export]
macro_rules! bench_on_backend {
() => {
#[cfg(feature = "wgpu")]
{
use burn::backend::wgpu::{AutoGraphicsApi, WgpuBackend, WgpuDevice};
bench::<WgpuBackend<AutoGraphicsApi, f32, i32>>(&WgpuDevice::default());
}
#[cfg(feature = "tch-gpu")]
{
use burn::backend::{tch::TchDevice, TchBackend};
#[cfg(not(target_os = "macos"))]
let device = TchDevice::Cuda(0);
#[cfg(target_os = "macos")]
let device = TchDevice::Mps;
bench::<TchBackend>(&device);
}
#[cfg(feature = "tch-cpu")]
{
use burn::backend::{tch::TchDevice, TchBackend};
let device = TchDevice::Cpu;
bench::<TchBackend>(&device);
}
#[cfg(any(
feature = "ndarray",
feature = "ndarray-blas-netlib",
feature = "ndarray-blas-openblas",
feature = "ndarray-blas-accelerate",
))]
{
use burn::backend::ndarray::NdArrayDevice;
use burn::backend::NdArrayBackend;
let device = NdArrayDevice::Cpu;
bench::<NdArrayBackend>(&device);
}
};
}

View File

@ -13,18 +13,16 @@ version = "0.10.0"
[features]
default = ["std"]
std = [
"rand/std",
]
std = ["rand/std"]
[dependencies]
# ** Please make sure all dependencies support no_std when std is disabled **
const-random = {workspace = true}
rand = {workspace = true}
spin = {workspace = true}# using in place of use std::sync::Mutex;
uuid = {workspace = true}
const-random = { workspace = true }
rand = { workspace = true }
spin = { workspace = true } # using in place of use std::sync::Mutex;
uuid = { workspace = true }
[dev-dependencies]
dashmap = {workspace = true}
dashmap = { workspace = true }

View File

@ -95,6 +95,8 @@ impl<E: TchElement> Backend for TchBackend<E> {
fn sync(device: &Self::Device) {
if let TchDevice::Cuda(index) = device {
tch::Cuda::synchronize(*index as i64);
} else if let TchDevice::Mps = device {
panic!("Can't sync MPS device")
}
}
}

View File

@ -14,26 +14,24 @@ version = "0.10.0"
default = ["std"]
experimental-named-tensor = []
export_tests = ["burn-tensor-testgen"]
std = [
"rand/std",
"half/std",
]
std = ["rand/std", "half/std"]
benchmark = []
[dependencies]
burn-tensor-testgen = {path = "../burn-tensor-testgen", version = "0.10.0", optional = true}
burn-tensor-testgen = { path = "../burn-tensor-testgen", version = "0.10.0", optional = true }
derive-new = {workspace = true}
half = {workspace = true}
libm = {workspace = true}# no_std is supported by default
num-traits = {workspace = true}
rand = {workspace = true}
rand_distr = {workspace = true}# use instead of statrs because it supports no_std
derive-new = { workspace = true }
half = { workspace = true }
libm = { workspace = true } # no_std is supported by default
num-traits = { workspace = true }
rand = { workspace = true }
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
# The same implementation of HashMap in std but with no_std support (only needs alloc crate)
hashbrown = {workspace = true}# no_std compatible
hashbrown = { workspace = true } # no_std compatible
# Serialization
serde = {workspace = true}
serde = { workspace = true }
[dev-dependencies]
rand = {workspace = true, features = ["std", "std_rng"]}# Default enables std
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std

View File

@ -1,9 +1,10 @@
use crate::{compute::compute_client, GraphicsApi, WgpuDevice};
use std::{
fmt::Display,
time::{Duration, Instant},
};
use crate::backend::Backend;
/// Results of a benchmark run.
#[derive(Debug)]
pub struct BenchmarkResult {
@ -62,7 +63,7 @@ impl Display for BenchmarkResult {
}
/// Benchmark trait.
pub trait Benchmark<G: GraphicsApi> {
pub trait Benchmark<B: Backend> {
/// Benchmark arguments.
type Args;
@ -73,7 +74,7 @@ pub trait Benchmark<G: GraphicsApi> {
///
/// This should not include warmup, the benchmark will be run at least one time without
/// measuring the execution time.
fn prepare(&self, device: &WgpuDevice) -> Self::Args;
fn prepare(&self, device: &B::Device) -> Self::Args;
/// Execute the benchmark and returns the time it took to complete.
fn execute(&self, args: Self::Args);
/// Number of samples required to have a statistical significance.
@ -83,24 +84,22 @@ pub trait Benchmark<G: GraphicsApi> {
/// Name of the benchmark.
fn name(&self) -> String;
/// Run the benchmark a number of times.
fn run(&self, device: &WgpuDevice) -> BenchmarkResult {
let client = compute_client::<G>(device);
fn run(&self, device: &B::Device) -> BenchmarkResult {
// Warmup
self.execute(self.prepare(device));
client.sync();
B::sync(device);
let mut durations = Vec::with_capacity(self.num_samples());
for _ in 0..self.num_samples() {
// Prepare
let args = self.prepare(device);
client.sync();
B::sync(device);
// Execute the benchmark
let start = Instant::now();
self.execute(args);
client.sync();
B::sync(device);
let end = Instant::now();
// Register the duration
@ -111,51 +110,23 @@ pub trait Benchmark<G: GraphicsApi> {
}
}
/// Run a benchmark on all platforms.
#[macro_export]
macro_rules! run_benchmark {
($bench:expr) => {{
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
let output = std::process::Command::new("git")
.args(&["rev-parse", "HEAD"])
.output()
.unwrap();
let git_hash = String::from_utf8(output.stdout).unwrap();
println!("Timestamp: {}", timestamp);
println!("Git Hash: {}", str::trim(&git_hash));
#[cfg(any(target_os = "linux", target_os = "windows"))]
{
println!(
"Vulkan - {}{}",
Benchmark::<burn_wgpu::Vulkan>::name(&$bench),
Benchmark::<burn_wgpu::Vulkan>::run(&$bench, &WgpuDevice::DiscreteGpu(0))
);
}
#[cfg(target_os = "windows")]
{
println!(
"Dx11 - {}{}",
Benchmark::<burn_wgpu::Dx11>::name(&$bench),
Benchmark::<burn_wgpu::Dx11>::run(&$bench, &WgpuDevice::DiscreteGpu(0))
);
println!(
"Dx12 - {}{}",
Benchmark::<burn_wgpu::Dx12>::name(&$bench),
Benchmark::<burn_wgpu::Dx12>::run(&$bench, &WgpuDevice::DiscreteGpu(0))
);
}
#[cfg(target_os = "macos")]
{
println!(
"Metal - {}{}",
Benchmark::<burn_wgpu::Metal>::name(&$bench),
Benchmark::<burn_wgpu::Metal>::run(&$bench, &WgpuDevice::IntegratedGpu(0))
);
}
}};
/// Runs the given benchmark on the device and prints result and information.
pub fn run_benchmark<B: Backend, BM: Benchmark<B>>(benchmark: BM, device: &B::Device) {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_millis();
let output = std::process::Command::new("git")
.args(["rev-parse", "HEAD"])
.output()
.unwrap();
let git_hash = String::from_utf8(output.stdout).unwrap();
println!("Timestamp: {}", timestamp);
println!("Git Hash: {}", str::trim(&git_hash));
println!("Backend: {}", B::name());
println!(
"Benchmarking - {}{}",
benchmark.name(),
benchmark.run(device)
);
}

View File

@ -0,0 +1,2 @@
mod base;
pub use base::*;

View File

@ -17,3 +17,11 @@ mod tests;
pub use half::{bf16, f16};
pub use tensor::*;
#[cfg(feature = "benchmark")]
/// This module provides benchmark utilities for easily and reliably run
/// benches on any function that is generic over a backend.
///
/// This can be useful to compare backends on inference or training speed
/// for your models.
pub mod benchmark;

View File

@ -33,7 +33,10 @@ serde = { workspace = true }
text_placeholder = { version = "0.5.0", features = ["struct_context"] }
hashbrown = { workspace = true }
burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features=["channel-mutex", "std"] }
burn-compute = { path = "../burn-compute", version = "0.10.0", default-features = false, features = [
"channel-mutex",
"std",
] }
[dev-dependencies]
@ -42,22 +45,11 @@ burn-autodiff = { path = "../burn-autodiff", version = "0.10.0", default-feature
] }
burn-tensor = { path = "../burn-tensor", version = "0.10.0", default-features = false, features = [
"export_tests",
"benchmark",
] }
burn-ndarray = { path = "../burn-ndarray", version = "0.10.0" }
serial_test = "2.0.0"
[[bench]]
name = "unary"
harness = false
[[bench]]
name = "binary"
harness = false
[[bench]]
name = "matmul"
harness = false
[[bench]]
name = "data"
harness = false

View File

@ -1,66 +0,0 @@
use burn_tensor::{Distribution, Shape, Tensor};
use burn_wgpu::{
benchmark::Benchmark,
binary_elemwise, binary_elemwise_inplace,
kernel::{binary_elemwise_default, binary_elemwise_inplace_default},
run_benchmark, GraphicsApi, WgpuBackend, WgpuDevice,
};
binary_elemwise!(TestKernel, "+");
binary_elemwise_inplace!(TestKernelInplace, "+");
struct BinaryBenchmark<const D: usize> {
inplace: bool,
shape: Shape<D>,
num_repeats: usize,
}
impl<const D: usize, G: GraphicsApi> Benchmark<G> for BinaryBenchmark<D> {
type Args = (
Tensor<WgpuBackend<G, f32, i32>, D>,
Tensor<WgpuBackend<G, f32, i32>, D>,
);
fn name(&self) -> String {
match self.inplace {
true => "Binary Inplace Ops",
false => "Binary Ops",
}
.into()
}
fn execute(&self, (lhs, rhs): Self::Args) {
for _ in 0..self.num_repeats {
if self.inplace {
binary_elemwise_inplace_default::<TestKernelInplace, f32, D>(
lhs.clone().into_primitive(),
rhs.clone().into_primitive(),
);
} else {
binary_elemwise_default::<TestKernel, f32, D>(
lhs.clone().into_primitive(),
rhs.clone().into_primitive(),
);
}
}
}
fn prepare(&self, device: &WgpuDevice) -> Self::Args {
let lhs = Tensor::random(self.shape.clone(), Distribution::Default).to_device(device);
let rhs = Tensor::random(self.shape.clone(), Distribution::Default).to_device(device);
(lhs, rhs)
}
}
fn main() {
run_benchmark!(BinaryBenchmark::<3> {
inplace: false,
shape: [32, 512, 1024].into(),
num_repeats: 10,
});
run_benchmark!(BinaryBenchmark::<3> {
inplace: true,
shape: [32, 512, 1024].into(),
num_repeats: 10,
});
}

View File

@ -1,77 +0,0 @@
use burn_tensor::{Data, Distribution, Shape, Tensor};
use burn_wgpu::{benchmark::Benchmark, run_benchmark, GraphicsApi, WgpuBackend, WgpuDevice};
struct ToDataBenchmark<const D: usize> {
shape: Shape<D>,
num_repeats: usize,
}
impl<const D: usize, G: GraphicsApi> Benchmark<G> for ToDataBenchmark<D> {
type Args = Tensor<WgpuBackend<G, f32, i32>, D>;
fn name(&self) -> String {
format!("to-data-{:?}-{}", self.shape.dims, self.num_repeats)
}
fn execute(&self, args: Self::Args) {
for _ in 0..self.num_repeats {
let _data = args.to_data();
}
}
fn prepare(&self, device: &WgpuDevice) -> Self::Args {
Tensor::random(self.shape.clone(), Distribution::Default).to_device(device)
}
}
struct FromDataBenchmark<const D: usize> {
shape: Shape<D>,
num_repeats: usize,
}
impl<const D: usize, G: GraphicsApi> Benchmark<G> for FromDataBenchmark<D> {
type Args = (Data<f32, D>, WgpuDevice);
fn name(&self) -> String {
format!("from-data-{:?}-{}", self.shape.dims, self.num_repeats)
}
fn execute(&self, (data, device): Self::Args) {
for _ in 0..self.num_repeats {
let _data =
Tensor::<WgpuBackend<G, f32, i32>, D>::from_data_device(data.clone(), &device);
}
}
fn prepare(&self, device: &WgpuDevice) -> Self::Args {
(
Data::random(
self.shape.clone(),
Distribution::Default,
&mut rand::thread_rng(),
),
device.clone(),
)
}
}
fn main() {
let num_repeats = 3;
run_benchmark!(ToDataBenchmark::<3> {
shape: [32, 256, 512].into(),
num_repeats,
});
run_benchmark!(ToDataBenchmark::<3> {
shape: [32, 512, 1024].into(),
num_repeats,
});
run_benchmark!(FromDataBenchmark::<3> {
shape: [32, 256, 512].into(),
num_repeats,
});
run_benchmark!(FromDataBenchmark::<3> {
shape: [32, 512, 1024].into(),
num_repeats,
});
}

View File

@ -1,18 +1,24 @@
use burn_tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_wgpu::WgpuDevice;
use std::marker::PhantomData;
use burn_tensor::{
benchmark::{run_benchmark, Benchmark},
Distribution, Shape, Tensor,
};
use derive_new::new;
use burn_wgpu::{
benchmark::Benchmark,
kernel::matmul::{
contiguous, contiguous_vectorized, matmul_mem_coalescing_default, matmul_naive_default,
tile, tile_vectorized,
},
run_benchmark, GraphicsApi, WgpuBackend, WgpuDevice,
AutoGraphicsApi, GraphicsApi, WgpuBackend,
};
use std::marker::PhantomData;
trait MatmulFunction<B: Backend, const D: usize> {
fn run(lhs: Tensor<B, D>, rhs: Tensor<B, D>) -> Tensor<B, D>;
}
type WTensor<G, const D: usize> = Tensor<WgpuBackend<G, f32, i32>, D>;
#[derive(new)]
struct MatmulBenchmark<F, const D: usize> {
shape_lhs: Shape<D>,
shape_rhs: Shape<D>,
@ -20,15 +26,16 @@ struct MatmulBenchmark<F, const D: usize> {
matmul: PhantomData<F>,
}
impl<F, const D: usize, G> Benchmark<G> for MatmulBenchmark<F, D>
trait MatmulFunction<G: GraphicsApi, const D: usize> {
fn run(lhs: WTensor<G, D>, rhs: WTensor<G, D>) -> WTensor<G, D>;
}
impl<F, const D: usize, G> Benchmark<WgpuBackend<G, f32, i32>> for MatmulBenchmark<F, D>
where
F: MatmulFunction<WgpuBackend<G, f32, i32>, D>,
F: MatmulFunction<G, D>,
G: GraphicsApi,
{
type Args = (
Tensor<WgpuBackend<G, f32, i32>, D>,
Tensor<WgpuBackend<G, f32, i32>, D>,
);
type Args = (WTensor<G, D>, WTensor<G, D>);
fn name(&self) -> String {
format!(
@ -50,79 +57,80 @@ where
}
fn prepare(&self, device: &WgpuDevice) -> Self::Args {
let lhs = Tensor::random_device(self.shape_lhs.clone(), Distribution::Default, device);
let rhs = Tensor::random_device(self.shape_rhs.clone(), Distribution::Default, device);
let lhs = WTensor::random_device(self.shape_lhs.clone(), Distribution::Default, device);
let rhs = WTensor::random_device(self.shape_rhs.clone(), Distribution::Default, device);
(lhs, rhs)
}
}
macro_rules! benchmark {
($name:ident, $func:expr) => {
struct $name;
impl<const D: usize, G: GraphicsApi> MatmulFunction<WgpuBackend<G, f32, i32>, D> for $name {
fn run(
lhs: Tensor<WgpuBackend<G, f32, i32>, D>,
rhs: Tensor<WgpuBackend<G, f32, i32>, D>,
) -> Tensor<WgpuBackend<G, f32, i32>, D> {
macro_rules! bench_matmul {
($benchmark:ident, $matmul_name:ident, $func:expr) => {
struct $matmul_name {}
impl<G: GraphicsApi, const D: usize> MatmulFunction<G, D> for $matmul_name {
fn run(lhs: WTensor<G, D>, rhs: WTensor<G, D>) -> WTensor<G, D> {
Tensor::from_primitive($func(lhs.into_primitive(), rhs.into_primitive()))
}
}
type $benchmark<const D: usize> = MatmulBenchmark<$matmul_name, D>;
};
}
benchmark!(NaiveMatmul, matmul_naive_default);
benchmark!(MemCoalescingMatmul, matmul_mem_coalescing_default);
benchmark!(
bench_matmul!(NaiveMatmulBenchmark, NaiveMatmul, matmul_naive_default);
bench_matmul!(
MemCoalescingMatmulBenchmark,
MemCoalescingMatmul,
matmul_mem_coalescing_default
);
bench_matmul!(
Tiling2DMatmulContiguousBenchmark,
Tiling2DMatmulContiguous,
contiguous::matmul_tiling_2d_default
);
benchmark!(Tiling2DMatmulTile, tile::matmul_tiling_2d_default);
benchmark!(
bench_matmul!(
Tiling2DMatmulTileBenchmark,
Tiling2DMatmulTile,
tile::matmul_tiling_2d_default
);
bench_matmul!(
Tiling2DMatmulTileVectorizedBenchmark,
Tiling2DMatmulTileVectorized,
tile_vectorized::matmul_tiling_2d_default
);
benchmark!(
bench_matmul!(
Tiling2DMatmulContiguousVectorizedBenchmark,
Tiling2DMatmulContiguousVectorized,
contiguous_vectorized::matmul_tiling_2d_default
);
fn main() {
#[allow(dead_code)]
/// Runs the benchmarks for wgpu matmul implementations
pub fn bench(device: &WgpuDevice) {
const D: usize = 3;
let num_repeats = 3;
let batch_size = 3;
let m = 1024;
let k = 2048;
let n = 1024;
let shape_lhs = Shape::new([batch_size, m, k]);
let shape_rhs = Shape::new([batch_size, k, n]);
run_benchmark!(MatmulBenchmark::<MemCoalescingMatmul, 3> {
shape_lhs: [batch_size, m, k].into(),
shape_rhs: [batch_size, k, n].into(),
num_repeats,
matmul: PhantomData
});
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulContiguous, 3> {
shape_lhs: [batch_size, m, k].into(),
shape_rhs: [batch_size, k, n].into(),
num_repeats,
matmul: PhantomData
});
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulContiguousVectorized, 3> {
shape_lhs: [batch_size, m, k].into(),
shape_rhs: [batch_size, k, n].into(),
num_repeats,
matmul: PhantomData
});
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulTile, 3> {
shape_lhs: [batch_size, m, k].into(),
shape_rhs: [batch_size, k, n].into(),
num_repeats,
matmul: PhantomData
});
run_benchmark!(MatmulBenchmark::<Tiling2DMatmulTileVectorized, 3> {
shape_lhs: [batch_size, m, k].into(),
shape_rhs: [batch_size, k, n].into(),
num_repeats,
matmul: PhantomData
});
macro_rules! run_matmul_benchmark {
($benchmark:ident) => {
run_benchmark::<WgpuBackend<AutoGraphicsApi, f32, i32>, $benchmark<D>>(
$benchmark::new(shape_lhs.clone(), shape_rhs.clone(), num_repeats),
device,
);
};
}
run_matmul_benchmark!(NaiveMatmulBenchmark);
run_matmul_benchmark!(MemCoalescingMatmulBenchmark);
run_matmul_benchmark!(Tiling2DMatmulContiguousBenchmark);
run_matmul_benchmark!(Tiling2DMatmulTileBenchmark);
run_matmul_benchmark!(Tiling2DMatmulTileVectorizedBenchmark);
run_matmul_benchmark!(Tiling2DMatmulContiguousVectorizedBenchmark);
}
fn main() {
bench(&WgpuDevice::BestAvailable)
}

View File

@ -1,54 +0,0 @@
use burn_tensor::{Distribution, Shape, Tensor};
use burn_wgpu::{
benchmark::Benchmark,
kernel::{unary_default, unary_inplace_default},
run_benchmark, unary, unary_inplace, GraphicsApi, WgpuBackend, WgpuDevice,
};
unary!(TestKernel, func "log");
unary_inplace!(TestKernelInplace, func "log");
struct UnaryBenchmark<const D: usize> {
inplace: bool,
shape: Shape<D>,
num_repeats: usize,
}
impl<const D: usize, G: GraphicsApi> Benchmark<G> for UnaryBenchmark<D> {
type Args = Tensor<WgpuBackend<G, f32, i32>, D>;
fn name(&self) -> String {
match self.inplace {
true => "Unary Inplace Ops",
false => "Unary Ops",
}
.into()
}
fn execute(&self, args: Self::Args) {
for _ in 0..self.num_repeats {
if self.inplace {
unary_inplace_default::<TestKernelInplace, f32, D>(args.clone().into_primitive());
} else {
unary_default::<TestKernel, f32, D>(args.clone().into_primitive());
}
}
}
fn prepare(&self, device: &WgpuDevice) -> Self::Args {
Tensor::random(self.shape.clone(), Distribution::Default).to_device(device)
}
}
fn main() {
run_benchmark!(UnaryBenchmark::<3> {
inplace: false,
shape: [32, 512, 1024].into(),
num_repeats: 10,
});
run_benchmark!(UnaryBenchmark::<3> {
inplace: true,
shape: [32, 512, 1024].into(),
num_repeats: 10,
});
}

View File

@ -8,8 +8,6 @@ extern crate alloc;
mod ops;
/// Benchmark module
pub mod benchmark;
/// Compute related module.
pub mod compute;
/// Kernel module