diff --git a/backend-comparison/Cargo.toml b/backend-comparison/Cargo.toml index fd2a180b5..aab974ad8 100644 --- a/backend-comparison/Cargo.toml +++ b/backend-comparison/Cargo.toml @@ -103,6 +103,11 @@ name = "custom-gelu" path = "benches/custom_gelu.rs" harness = false +[[bench]] +name = "resnet50" +path = "benches/resnet.rs" +harness = false + [[bench]] name = "autodiff" harness = false diff --git a/backend-comparison/README.md b/backend-comparison/README.md index 7db17dfe2..b6700034e 100644 --- a/backend-comparison/README.md +++ b/backend-comparison/README.md @@ -43,6 +43,7 @@ Available Benchmarks: - custom-gelu - data - matmul +- resnet50 - unary ``` @@ -144,7 +145,7 @@ Then it must be registered in the `BenchmarkValues` enumeration: ```rs #[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)] -pub(crate) enum BackendValues { +pub(crate) enum BenchmarkValues { // ... #[strum(to_string = "mybench")] MyBench, diff --git a/backend-comparison/benches/resnet.rs b/backend-comparison/benches/resnet.rs new file mode 100644 index 000000000..f155a738a --- /dev/null +++ b/backend-comparison/benches/resnet.rs @@ -0,0 +1,73 @@ +use backend_comparison::persistence::save; +use burn::tensor::{backend::Backend, Distribution, Shape, Tensor}; +use burn_common::benchmark::{run_benchmark, Benchmark}; +use cubecl::client::SyncType; + +// Files retrieved during build to avoid reimplementing ResNet for benchmarks +mod block { + extern crate alloc; + include!(concat!(env!("OUT_DIR"), "/block.rs")); +} + +mod model { + include!(concat!(env!("OUT_DIR"), "/resnet.rs")); +} + +pub struct ResNetBenchmark { + shape: Shape<4>, + device: B::Device, +} + +impl Benchmark for ResNetBenchmark { + type Args = (model::ResNet, Tensor); + + fn name(&self) -> String { + "resnet50".into() + } + + fn shapes(&self) -> Vec> { + vec![self.shape.dims.into()] + } + + fn execute(&self, (model, input): Self::Args) { + let _out = model.forward(input); + } + + fn prepare(&self) -> Self::Args { + // 1k classes like ImageNet + let model = model::ResNet::resnet50(1000, &self.device); + let input = Tensor::random(self.shape.clone(), Distribution::Default, &self.device); + + (model, input) + } + + fn sync(&self) { + B::sync(&self.device, SyncType::Wait) + } +} + +#[allow(dead_code)] +fn bench( + device: &B::Device, + feature_name: &str, + url: Option<&str>, + token: Option<&str>, +) { + let benchmark = ResNetBenchmark:: { + shape: [1, 3, 224, 224].into(), + device: device.clone(), + }; + + save::( + vec![run_benchmark(benchmark)], + device, + feature_name, + url, + token, + ) + .unwrap(); +} + +fn main() { + backend_comparison::bench_on_backend!(); +} diff --git a/backend-comparison/build.rs b/backend-comparison/build.rs new file mode 100644 index 000000000..7ecb328e4 --- /dev/null +++ b/backend-comparison/build.rs @@ -0,0 +1,278 @@ +use std::env; +use std::fs; +use std::path::Path; +use std::process::Command; + +const MODELS_DIR: &str = "/tmp/models"; +const MODELS_REPO: &str = "https://github.com/tracel-ai/models.git"; + +// Patch resnet code (remove pretrained feature code) +const PATCH: &str = r#"diff --git a/resnet-burn/resnet/src/resnet.rs b/resnet-burn/resnet/src/resnet.rs +index e7f8787..3967049 100644 +--- a/resnet-burn/resnet/src/resnet.rs ++++ b/resnet-burn/resnet/src/resnet.rs +@@ -12,13 +12,6 @@ use burn::{ + + use super::block::{LayerBlock, LayerBlockConfig}; + +-#[cfg(feature = "pretrained")] +-use { +- super::weights::{self, WeightsMeta}, +- burn::record::{FullPrecisionSettings, Recorder, RecorderError}, +- burn_import::pytorch::{LoadArgs, PyTorchFileRecorder}, +-}; +- + // ResNet residual layer block configs + const RESNET18_BLOCKS: [usize; 4] = [2, 2, 2, 2]; + const RESNET34_BLOCKS: [usize; 4] = [3, 4, 6, 3]; +@@ -77,29 +70,6 @@ impl ResNet { + ResNetConfig::new(RESNET18_BLOCKS, num_classes, 1).init(device) + } + +- /// ResNet-18 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) +- /// with pre-trained weights. +- /// +- /// # Arguments +- /// +- /// * `weights`: Pre-trained weights to load. +- /// * `device` - Device to create the module on. +- /// +- /// # Returns +- /// +- /// A ResNet-18 module with pre-trained weights. +- #[cfg(feature = "pretrained")] +- pub fn resnet18_pretrained( +- weights: weights::ResNet18, +- device: &Device, +- ) -> Result { +- let weights = weights.weights(); +- let record = Self::load_weights_record(&weights, device)?; +- let model = ResNet::::resnet18(weights.num_classes, device).load_record(record); +- +- Ok(model) +- } +- + /// ResNet-34 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385). + /// + /// # Arguments +@@ -114,29 +84,6 @@ impl ResNet { + ResNetConfig::new(RESNET34_BLOCKS, num_classes, 1).init(device) + } + +- /// ResNet-34 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) +- /// with pre-trained weights. +- /// +- /// # Arguments +- /// +- /// * `weights`: Pre-trained weights to load. +- /// * `device` - Device to create the module on. +- /// +- /// # Returns +- /// +- /// A ResNet-34 module with pre-trained weights. +- #[cfg(feature = "pretrained")] +- pub fn resnet34_pretrained( +- weights: weights::ResNet34, +- device: &Device, +- ) -> Result { +- let weights = weights.weights(); +- let record = Self::load_weights_record(&weights, device)?; +- let model = ResNet::::resnet34(weights.num_classes, device).load_record(record); +- +- Ok(model) +- } +- + /// ResNet-50 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385). + /// + /// # Arguments +@@ -151,29 +98,6 @@ impl ResNet { + ResNetConfig::new(RESNET50_BLOCKS, num_classes, 4).init(device) + } + +- /// ResNet-50 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) +- /// with pre-trained weights. +- /// +- /// # Arguments +- /// +- /// * `weights`: Pre-trained weights to load. +- /// * `device` - Device to create the module on. +- /// +- /// # Returns +- /// +- /// A ResNet-50 module with pre-trained weights. +- #[cfg(feature = "pretrained")] +- pub fn resnet50_pretrained( +- weights: weights::ResNet50, +- device: &Device, +- ) -> Result { +- let weights = weights.weights(); +- let record = Self::load_weights_record(&weights, device)?; +- let model = ResNet::::resnet50(weights.num_classes, device).load_record(record); +- +- Ok(model) +- } +- + /// ResNet-101 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385). + /// + /// # Arguments +@@ -188,29 +112,6 @@ impl ResNet { + ResNetConfig::new(RESNET101_BLOCKS, num_classes, 4).init(device) + } + +- /// ResNet-101 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) +- /// with pre-trained weights. +- /// +- /// # Arguments +- /// +- /// * `weights`: Pre-trained weights to load. +- /// * `device` - Device to create the module on. +- /// +- /// # Returns +- /// +- /// A ResNet-101 module with pre-trained weights. +- #[cfg(feature = "pretrained")] +- pub fn resnet101_pretrained( +- weights: weights::ResNet101, +- device: &Device, +- ) -> Result { +- let weights = weights.weights(); +- let record = Self::load_weights_record(&weights, device)?; +- let model = ResNet::::resnet101(weights.num_classes, device).load_record(record); +- +- Ok(model) +- } +- + /// ResNet-152 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385). + /// + /// # Arguments +@@ -225,29 +126,6 @@ impl ResNet { + ResNetConfig::new(RESNET152_BLOCKS, num_classes, 4).init(device) + } + +- /// ResNet-152 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385) +- /// with pre-trained weights. +- /// +- /// # Arguments +- /// +- /// * `weights`: Pre-trained weights to load. +- /// * `device` - Device to create the module on. +- /// +- /// # Returns +- /// +- /// A ResNet-152 module with pre-trained weights. +- #[cfg(feature = "pretrained")] +- pub fn resnet152_pretrained( +- weights: weights::ResNet152, +- device: &Device, +- ) -> Result { +- let weights = weights.weights(); +- let record = Self::load_weights_record(&weights, device)?; +- let model = ResNet::::resnet152(weights.num_classes, device).load_record(record); +- +- Ok(model) +- } +- + /// Re-initialize the last layer with the specified number of output classes. + pub fn with_classes(mut self, num_classes: usize) -> Self { + let [d_input, _d_output] = self.fc.weight.dims(); +@@ -256,32 +134,6 @@ impl ResNet { + } + } + +-#[cfg(feature = "pretrained")] +-impl ResNet { +- /// Load specified pre-trained PyTorch weights as a record. +- fn load_weights_record( +- weights: &weights::Weights, +- device: &Device, +- ) -> Result, RecorderError> { +- // Download torch weights +- let torch_weights = weights.download().map_err(|err| { +- RecorderError::Unknown(format!("Could not download weights.\nError: {err}")) +- })?; +- +- // Load weights from torch state_dict +- let load_args = LoadArgs::new(torch_weights) +- // Map *.downsample.0.* -> *.downsample.conv.* +- .with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2") +- // Map *.downsample.1.* -> *.downsample.bn.* +- .with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2") +- // Map layer[i].[j].* -> layer[i].blocks.[j].* +- .with_key_remap("(layer[1-4])\\.([0-9]+)\\.(.+)", "$1.blocks.$2.$3"); +- let record = PyTorchFileRecorder::::new().load(load_args, device)?; +- +- Ok(record) +- } +-} +- + /// [ResNet](ResNet) configuration. + struct ResNetConfig { + conv1: Conv2dConfig, +"#; + +fn run(name: &str, mut configure: F) +where + F: FnMut(&mut Command) -> &mut Command, +{ + let mut command = Command::new(name); + let configured = configure(&mut command); + println!("Executing {:?}", configured); + if !configured.status().unwrap().success() { + panic!("failed to execute {:?}", configured); + } + println!("Command {:?} finished successfully", configured); +} + +fn main() { + // Checkout ResNet code from models repo + let models_dir = Path::new(MODELS_DIR); + if !models_dir.join(".git").exists() { + run("git", |command| { + command + .arg("clone") + .arg("--depth=1") + .arg("--no-checkout") + .arg(MODELS_REPO) + .arg(MODELS_DIR) + }); + + run("git", |command| { + command + .current_dir(models_dir) + .arg("sparse-checkout") + .arg("set") + .arg("resnet-burn") + }); + + run("git", |command| { + command.current_dir(models_dir).arg("checkout") + }); + + let patch_file = models_dir.join("benchmark.patch"); + + fs::write(&patch_file, PATCH).expect("should write to file successfully"); + + // Apply patch + run("git", |command| { + command + .current_dir(models_dir) + .arg("apply") + .arg(patch_file.to_str().unwrap()) + }); + } + + // Copy contents to output dir + let out_dir = env::var("OUT_DIR").unwrap(); + let source_path = models_dir.join("resnet-burn").join("resnet").join("src"); + let dest_path = Path::new(&out_dir); + + for file in fs::read_dir(source_path).unwrap() { + let source_file = file.unwrap().path(); + let dest_file = dest_path.join(source_file.file_name().unwrap()); + fs::copy(source_file, dest_file).expect("should copy file successfully"); + } + + // Delete cloned repository contents + fs::remove_dir_all(models_dir.join(".git")).unwrap(); + fs::remove_dir_all(models_dir).unwrap(); +} diff --git a/backend-comparison/src/burnbenchapp/base.rs b/backend-comparison/src/burnbenchapp/base.rs index 4a345b7be..f1aa16935 100644 --- a/backend-comparison/src/burnbenchapp/base.rs +++ b/backend-comparison/src/burnbenchapp/base.rs @@ -98,6 +98,8 @@ enum BenchmarkValues { Unary, #[strum(to_string = "max-pool2d")] MaxPool2d, + #[strum(to_string = "resnet50")] + Resnet50, #[strum(to_string = "load-record")] LoadRecord, #[strum(to_string = "autodiff")]