use backend_comparison::persistence::save; use burn::tensor::{backend::Backend, Distribution, Shape, Tensor, TensorData}; use burn_common::benchmark::{run_benchmark, Benchmark}; use derive_new::new; #[derive(new)] struct ToDataBenchmark { shape: Shape, device: B::Device, } impl Benchmark for ToDataBenchmark { type Args = Tensor; fn name(&self) -> String { "to_data".into() } fn shapes(&self) -> Vec> { vec![self.shape.dims.clone()] } fn execute(&self, args: Self::Args) { let _data = args.to_data(); } fn prepare(&self) -> Self::Args { Tensor::random(self.shape.clone(), Distribution::Default, &self.device) } fn sync(&self) { B::sync(&self.device) } } #[derive(new)] struct FromDataBenchmark { shape: Shape, device: B::Device, } impl Benchmark for FromDataBenchmark { type Args = (TensorData, B::Device); fn name(&self) -> String { "from_data".into() } fn shapes(&self) -> Vec> { vec![self.shape.dims.clone()] } fn execute(&self, (data, device): Self::Args) { let _data = Tensor::::from_data(data.clone(), &device); } fn prepare(&self) -> Self::Args { ( TensorData::random::( self.shape.clone(), Distribution::Default, &mut rand::thread_rng(), ), self.device.clone(), ) } fn sync(&self) { B::sync(&self.device) } } #[allow(dead_code)] fn bench( device: &B::Device, feature_name: &str, url: Option<&str>, token: Option<&str>, ) { const D: usize = 3; let shape: Shape = [32, 512, 1024].into(); let to_benchmark = ToDataBenchmark::::new(shape.clone(), device.clone()); let from_benchmark = FromDataBenchmark::::new(shape, device.clone()); save::( vec![run_benchmark(to_benchmark), run_benchmark(from_benchmark)], device, feature_name, url, token, ) .unwrap(); } fn main() { backend_comparison::bench_on_backend!(); }