Add ResNet benchmark (#1534)

This commit is contained in:
Guillaume Lagrange 2024-09-16 09:57:15 -04:00 committed by GitHub
parent 5631afb3a0
commit 81ec64a929
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 360 additions and 1 deletions

View File

@ -103,6 +103,11 @@ name = "custom-gelu"
path = "benches/custom_gelu.rs"
harness = false
[[bench]]
name = "resnet50"
path = "benches/resnet.rs"
harness = false
[[bench]]
name = "autodiff"
harness = false

View File

@ -43,6 +43,7 @@ Available Benchmarks:
- custom-gelu
- data
- matmul
- resnet50
- unary
```
@ -144,7 +145,7 @@ Then it must be registered in the `BenchmarkValues` enumeration:
```rs
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
pub(crate) enum BackendValues {
pub(crate) enum BenchmarkValues {
// ...
#[strum(to_string = "mybench")]
MyBench,

View File

@ -0,0 +1,73 @@
use backend_comparison::persistence::save;
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
use burn_common::benchmark::{run_benchmark, Benchmark};
use cubecl::client::SyncType;
// Files retrieved during build to avoid reimplementing ResNet for benchmarks
mod block {
extern crate alloc;
include!(concat!(env!("OUT_DIR"), "/block.rs"));
}
mod model {
include!(concat!(env!("OUT_DIR"), "/resnet.rs"));
}
pub struct ResNetBenchmark<B: Backend> {
shape: Shape<4>,
device: B::Device,
}
impl<B: Backend> Benchmark for ResNetBenchmark<B> {
type Args = (model::ResNet<B>, Tensor<B, 4>);
fn name(&self) -> String {
"resnet50".into()
}
fn shapes(&self) -> Vec<Vec<usize>> {
vec![self.shape.dims.into()]
}
fn execute(&self, (model, input): Self::Args) {
let _out = model.forward(input);
}
fn prepare(&self) -> Self::Args {
// 1k classes like ImageNet
let model = model::ResNet::resnet50(1000, &self.device);
let input = Tensor::random(self.shape.clone(), Distribution::Default, &self.device);
(model, input)
}
fn sync(&self) {
B::sync(&self.device, SyncType::Wait)
}
}
#[allow(dead_code)]
fn bench<B: Backend>(
device: &B::Device,
feature_name: &str,
url: Option<&str>,
token: Option<&str>,
) {
let benchmark = ResNetBenchmark::<B> {
shape: [1, 3, 224, 224].into(),
device: device.clone(),
};
save::<B>(
vec![run_benchmark(benchmark)],
device,
feature_name,
url,
token,
)
.unwrap();
}
fn main() {
backend_comparison::bench_on_backend!();
}

278
backend-comparison/build.rs Normal file
View File

@ -0,0 +1,278 @@
use std::env;
use std::fs;
use std::path::Path;
use std::process::Command;
const MODELS_DIR: &str = "/tmp/models";
const MODELS_REPO: &str = "https://github.com/tracel-ai/models.git";
// Patch resnet code (remove pretrained feature code)
const PATCH: &str = r#"diff --git a/resnet-burn/resnet/src/resnet.rs b/resnet-burn/resnet/src/resnet.rs
index e7f8787..3967049 100644
--- a/resnet-burn/resnet/src/resnet.rs
+++ b/resnet-burn/resnet/src/resnet.rs
@@ -12,13 +12,6 @@ use burn::{
use super::block::{LayerBlock, LayerBlockConfig};
-#[cfg(feature = "pretrained")]
-use {
- super::weights::{self, WeightsMeta},
- burn::record::{FullPrecisionSettings, Recorder, RecorderError},
- burn_import::pytorch::{LoadArgs, PyTorchFileRecorder},
-};
-
// ResNet residual layer block configs
const RESNET18_BLOCKS: [usize; 4] = [2, 2, 2, 2];
const RESNET34_BLOCKS: [usize; 4] = [3, 4, 6, 3];
@@ -77,29 +70,6 @@ impl<B: Backend> ResNet<B> {
ResNetConfig::new(RESNET18_BLOCKS, num_classes, 1).init(device)
}
- /// ResNet-18 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
- /// with pre-trained weights.
- ///
- /// # Arguments
- ///
- /// * `weights`: Pre-trained weights to load.
- /// * `device` - Device to create the module on.
- ///
- /// # Returns
- ///
- /// A ResNet-18 module with pre-trained weights.
- #[cfg(feature = "pretrained")]
- pub fn resnet18_pretrained(
- weights: weights::ResNet18,
- device: &Device<B>,
- ) -> Result<Self, RecorderError> {
- let weights = weights.weights();
- let record = Self::load_weights_record(&weights, device)?;
- let model = ResNet::<B>::resnet18(weights.num_classes, device).load_record(record);
-
- Ok(model)
- }
-
/// ResNet-34 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385).
///
/// # Arguments
@@ -114,29 +84,6 @@ impl<B: Backend> ResNet<B> {
ResNetConfig::new(RESNET34_BLOCKS, num_classes, 1).init(device)
}
- /// ResNet-34 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
- /// with pre-trained weights.
- ///
- /// # Arguments
- ///
- /// * `weights`: Pre-trained weights to load.
- /// * `device` - Device to create the module on.
- ///
- /// # Returns
- ///
- /// A ResNet-34 module with pre-trained weights.
- #[cfg(feature = "pretrained")]
- pub fn resnet34_pretrained(
- weights: weights::ResNet34,
- device: &Device<B>,
- ) -> Result<Self, RecorderError> {
- let weights = weights.weights();
- let record = Self::load_weights_record(&weights, device)?;
- let model = ResNet::<B>::resnet34(weights.num_classes, device).load_record(record);
-
- Ok(model)
- }
-
/// ResNet-50 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385).
///
/// # Arguments
@@ -151,29 +98,6 @@ impl<B: Backend> ResNet<B> {
ResNetConfig::new(RESNET50_BLOCKS, num_classes, 4).init(device)
}
- /// ResNet-50 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
- /// with pre-trained weights.
- ///
- /// # Arguments
- ///
- /// * `weights`: Pre-trained weights to load.
- /// * `device` - Device to create the module on.
- ///
- /// # Returns
- ///
- /// A ResNet-50 module with pre-trained weights.
- #[cfg(feature = "pretrained")]
- pub fn resnet50_pretrained(
- weights: weights::ResNet50,
- device: &Device<B>,
- ) -> Result<Self, RecorderError> {
- let weights = weights.weights();
- let record = Self::load_weights_record(&weights, device)?;
- let model = ResNet::<B>::resnet50(weights.num_classes, device).load_record(record);
-
- Ok(model)
- }
-
/// ResNet-101 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385).
///
/// # Arguments
@@ -188,29 +112,6 @@ impl<B: Backend> ResNet<B> {
ResNetConfig::new(RESNET101_BLOCKS, num_classes, 4).init(device)
}
- /// ResNet-101 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
- /// with pre-trained weights.
- ///
- /// # Arguments
- ///
- /// * `weights`: Pre-trained weights to load.
- /// * `device` - Device to create the module on.
- ///
- /// # Returns
- ///
- /// A ResNet-101 module with pre-trained weights.
- #[cfg(feature = "pretrained")]
- pub fn resnet101_pretrained(
- weights: weights::ResNet101,
- device: &Device<B>,
- ) -> Result<Self, RecorderError> {
- let weights = weights.weights();
- let record = Self::load_weights_record(&weights, device)?;
- let model = ResNet::<B>::resnet101(weights.num_classes, device).load_record(record);
-
- Ok(model)
- }
-
/// ResNet-152 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385).
///
/// # Arguments
@@ -225,29 +126,6 @@ impl<B: Backend> ResNet<B> {
ResNetConfig::new(RESNET152_BLOCKS, num_classes, 4).init(device)
}
- /// ResNet-152 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
- /// with pre-trained weights.
- ///
- /// # Arguments
- ///
- /// * `weights`: Pre-trained weights to load.
- /// * `device` - Device to create the module on.
- ///
- /// # Returns
- ///
- /// A ResNet-152 module with pre-trained weights.
- #[cfg(feature = "pretrained")]
- pub fn resnet152_pretrained(
- weights: weights::ResNet152,
- device: &Device<B>,
- ) -> Result<Self, RecorderError> {
- let weights = weights.weights();
- let record = Self::load_weights_record(&weights, device)?;
- let model = ResNet::<B>::resnet152(weights.num_classes, device).load_record(record);
-
- Ok(model)
- }
-
/// Re-initialize the last layer with the specified number of output classes.
pub fn with_classes(mut self, num_classes: usize) -> Self {
let [d_input, _d_output] = self.fc.weight.dims();
@@ -256,32 +134,6 @@ impl<B: Backend> ResNet<B> {
}
}
-#[cfg(feature = "pretrained")]
-impl<B: Backend> ResNet<B> {
- /// Load specified pre-trained PyTorch weights as a record.
- fn load_weights_record(
- weights: &weights::Weights,
- device: &Device<B>,
- ) -> Result<ResNetRecord<B>, RecorderError> {
- // Download torch weights
- let torch_weights = weights.download().map_err(|err| {
- RecorderError::Unknown(format!("Could not download weights.\nError: {err}"))
- })?;
-
- // Load weights from torch state_dict
- let load_args = LoadArgs::new(torch_weights)
- // Map *.downsample.0.* -> *.downsample.conv.*
- .with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2")
- // Map *.downsample.1.* -> *.downsample.bn.*
- .with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2")
- // Map layer[i].[j].* -> layer[i].blocks.[j].*
- .with_key_remap("(layer[1-4])\\.([0-9]+)\\.(.+)", "$1.blocks.$2.$3");
- let record = PyTorchFileRecorder::<FullPrecisionSettings>::new().load(load_args, device)?;
-
- Ok(record)
- }
-}
-
/// [ResNet](ResNet) configuration.
struct ResNetConfig {
conv1: Conv2dConfig,
"#;
fn run<F>(name: &str, mut configure: F)
where
F: FnMut(&mut Command) -> &mut Command,
{
let mut command = Command::new(name);
let configured = configure(&mut command);
println!("Executing {:?}", configured);
if !configured.status().unwrap().success() {
panic!("failed to execute {:?}", configured);
}
println!("Command {:?} finished successfully", configured);
}
fn main() {
// Checkout ResNet code from models repo
let models_dir = Path::new(MODELS_DIR);
if !models_dir.join(".git").exists() {
run("git", |command| {
command
.arg("clone")
.arg("--depth=1")
.arg("--no-checkout")
.arg(MODELS_REPO)
.arg(MODELS_DIR)
});
run("git", |command| {
command
.current_dir(models_dir)
.arg("sparse-checkout")
.arg("set")
.arg("resnet-burn")
});
run("git", |command| {
command.current_dir(models_dir).arg("checkout")
});
let patch_file = models_dir.join("benchmark.patch");
fs::write(&patch_file, PATCH).expect("should write to file successfully");
// Apply patch
run("git", |command| {
command
.current_dir(models_dir)
.arg("apply")
.arg(patch_file.to_str().unwrap())
});
}
// Copy contents to output dir
let out_dir = env::var("OUT_DIR").unwrap();
let source_path = models_dir.join("resnet-burn").join("resnet").join("src");
let dest_path = Path::new(&out_dir);
for file in fs::read_dir(source_path).unwrap() {
let source_file = file.unwrap().path();
let dest_file = dest_path.join(source_file.file_name().unwrap());
fs::copy(source_file, dest_file).expect("should copy file successfully");
}
// Delete cloned repository contents
fs::remove_dir_all(models_dir.join(".git")).unwrap();
fs::remove_dir_all(models_dir).unwrap();
}

View File

@ -98,6 +98,8 @@ enum BenchmarkValues {
Unary,
#[strum(to_string = "max-pool2d")]
MaxPool2d,
#[strum(to_string = "resnet50")]
Resnet50,
#[strum(to_string = "load-record")]
LoadRecord,
#[strum(to_string = "autodiff")]