burn/backend-comparison/benches/binary.rs

48 lines
1.2 KiB
Rust
Raw Normal View History

2023-09-28 21:38:21 +08:00
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
2023-10-30 04:44:59 +08:00
use burn_common::benchmark::{run_benchmark, Benchmark};
2023-09-28 21:38:21 +08:00
pub struct BinaryBenchmark<B: Backend, const D: usize> {
shape: Shape<D>,
num_repeats: usize,
2023-10-30 04:44:59 +08:00
device: B::Device,
2023-09-28 21:38:21 +08:00
}
2023-10-30 04:44:59 +08:00
impl<B: Backend, const D: usize> Benchmark for BinaryBenchmark<B, D> {
2023-09-28 21:38:21 +08:00
type Args = (Tensor<B, D>, Tensor<B, D>);
fn name(&self) -> String {
"Binary Ops".into()
}
fn execute(&self, (lhs, rhs): Self::Args) {
for _ in 0..self.num_repeats {
// Choice of add is arbitrary
B::add(lhs.clone().into_primitive(), rhs.clone().into_primitive());
}
}
2023-10-30 04:44:59 +08:00
fn prepare(&self) -> Self::Args {
let lhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
let rhs = Tensor::random_device(self.shape.clone(), Distribution::Default, &self.device);
2023-09-28 21:38:21 +08:00
(lhs, rhs)
}
2023-10-30 04:44:59 +08:00
fn sync(&self) {
B::sync(&self.device)
}
2023-09-28 21:38:21 +08:00
}
#[allow(dead_code)]
fn bench<B: Backend>(device: &B::Device) {
2023-10-30 04:44:59 +08:00
run_benchmark(BinaryBenchmark::<B, 3> {
shape: [32, 512, 1024].into(),
num_repeats: 10,
device: device.clone(),
})
2023-09-28 21:38:21 +08:00
}
fn main() {
backend_comparison::bench_on_backend!();
}