mirror of https://github.com/tracel-ai/burn.git
Add ResNet benchmark (#1534)
This commit is contained in:
parent
5631afb3a0
commit
81ec64a929
|
@ -103,6 +103,11 @@ name = "custom-gelu"
|
|||
path = "benches/custom_gelu.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "resnet50"
|
||||
path = "benches/resnet.rs"
|
||||
harness = false
|
||||
|
||||
[[bench]]
|
||||
name = "autodiff"
|
||||
harness = false
|
||||
|
|
|
@ -43,6 +43,7 @@ Available Benchmarks:
|
|||
- custom-gelu
|
||||
- data
|
||||
- matmul
|
||||
- resnet50
|
||||
- unary
|
||||
```
|
||||
|
||||
|
@ -144,7 +145,7 @@ Then it must be registered in the `BenchmarkValues` enumeration:
|
|||
```rs
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, ValueEnum, Display, EnumIter)]
|
||||
pub(crate) enum BackendValues {
|
||||
pub(crate) enum BenchmarkValues {
|
||||
// ...
|
||||
#[strum(to_string = "mybench")]
|
||||
MyBench,
|
||||
|
|
|
@ -0,0 +1,73 @@
|
|||
use backend_comparison::persistence::save;
|
||||
use burn::tensor::{backend::Backend, Distribution, Shape, Tensor};
|
||||
use burn_common::benchmark::{run_benchmark, Benchmark};
|
||||
use cubecl::client::SyncType;
|
||||
|
||||
// Files retrieved during build to avoid reimplementing ResNet for benchmarks
|
||||
mod block {
|
||||
extern crate alloc;
|
||||
include!(concat!(env!("OUT_DIR"), "/block.rs"));
|
||||
}
|
||||
|
||||
mod model {
|
||||
include!(concat!(env!("OUT_DIR"), "/resnet.rs"));
|
||||
}
|
||||
|
||||
pub struct ResNetBenchmark<B: Backend> {
|
||||
shape: Shape<4>,
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
impl<B: Backend> Benchmark for ResNetBenchmark<B> {
|
||||
type Args = (model::ResNet<B>, Tensor<B, 4>);
|
||||
|
||||
fn name(&self) -> String {
|
||||
"resnet50".into()
|
||||
}
|
||||
|
||||
fn shapes(&self) -> Vec<Vec<usize>> {
|
||||
vec![self.shape.dims.into()]
|
||||
}
|
||||
|
||||
fn execute(&self, (model, input): Self::Args) {
|
||||
let _out = model.forward(input);
|
||||
}
|
||||
|
||||
fn prepare(&self) -> Self::Args {
|
||||
// 1k classes like ImageNet
|
||||
let model = model::ResNet::resnet50(1000, &self.device);
|
||||
let input = Tensor::random(self.shape.clone(), Distribution::Default, &self.device);
|
||||
|
||||
(model, input)
|
||||
}
|
||||
|
||||
fn sync(&self) {
|
||||
B::sync(&self.device, SyncType::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn bench<B: Backend>(
|
||||
device: &B::Device,
|
||||
feature_name: &str,
|
||||
url: Option<&str>,
|
||||
token: Option<&str>,
|
||||
) {
|
||||
let benchmark = ResNetBenchmark::<B> {
|
||||
shape: [1, 3, 224, 224].into(),
|
||||
device: device.clone(),
|
||||
};
|
||||
|
||||
save::<B>(
|
||||
vec![run_benchmark(benchmark)],
|
||||
device,
|
||||
feature_name,
|
||||
url,
|
||||
token,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
fn main() {
|
||||
backend_comparison::bench_on_backend!();
|
||||
}
|
|
@ -0,0 +1,278 @@
|
|||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
const MODELS_DIR: &str = "/tmp/models";
|
||||
const MODELS_REPO: &str = "https://github.com/tracel-ai/models.git";
|
||||
|
||||
// Patch resnet code (remove pretrained feature code)
|
||||
const PATCH: &str = r#"diff --git a/resnet-burn/resnet/src/resnet.rs b/resnet-burn/resnet/src/resnet.rs
|
||||
index e7f8787..3967049 100644
|
||||
--- a/resnet-burn/resnet/src/resnet.rs
|
||||
+++ b/resnet-burn/resnet/src/resnet.rs
|
||||
@@ -12,13 +12,6 @@ use burn::{
|
||||
|
||||
use super::block::{LayerBlock, LayerBlockConfig};
|
||||
|
||||
-#[cfg(feature = "pretrained")]
|
||||
-use {
|
||||
- super::weights::{self, WeightsMeta},
|
||||
- burn::record::{FullPrecisionSettings, Recorder, RecorderError},
|
||||
- burn_import::pytorch::{LoadArgs, PyTorchFileRecorder},
|
||||
-};
|
||||
-
|
||||
// ResNet residual layer block configs
|
||||
const RESNET18_BLOCKS: [usize; 4] = [2, 2, 2, 2];
|
||||
const RESNET34_BLOCKS: [usize; 4] = [3, 4, 6, 3];
|
||||
@@ -77,29 +70,6 @@ impl<B: Backend> ResNet<B> {
|
||||
ResNetConfig::new(RESNET18_BLOCKS, num_classes, 1).init(device)
|
||||
}
|
||||
|
||||
- /// ResNet-18 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
|
||||
- /// with pre-trained weights.
|
||||
- ///
|
||||
- /// # Arguments
|
||||
- ///
|
||||
- /// * `weights`: Pre-trained weights to load.
|
||||
- /// * `device` - Device to create the module on.
|
||||
- ///
|
||||
- /// # Returns
|
||||
- ///
|
||||
- /// A ResNet-18 module with pre-trained weights.
|
||||
- #[cfg(feature = "pretrained")]
|
||||
- pub fn resnet18_pretrained(
|
||||
- weights: weights::ResNet18,
|
||||
- device: &Device<B>,
|
||||
- ) -> Result<Self, RecorderError> {
|
||||
- let weights = weights.weights();
|
||||
- let record = Self::load_weights_record(&weights, device)?;
|
||||
- let model = ResNet::<B>::resnet18(weights.num_classes, device).load_record(record);
|
||||
-
|
||||
- Ok(model)
|
||||
- }
|
||||
-
|
||||
/// ResNet-34 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385).
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -114,29 +84,6 @@ impl<B: Backend> ResNet<B> {
|
||||
ResNetConfig::new(RESNET34_BLOCKS, num_classes, 1).init(device)
|
||||
}
|
||||
|
||||
- /// ResNet-34 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
|
||||
- /// with pre-trained weights.
|
||||
- ///
|
||||
- /// # Arguments
|
||||
- ///
|
||||
- /// * `weights`: Pre-trained weights to load.
|
||||
- /// * `device` - Device to create the module on.
|
||||
- ///
|
||||
- /// # Returns
|
||||
- ///
|
||||
- /// A ResNet-34 module with pre-trained weights.
|
||||
- #[cfg(feature = "pretrained")]
|
||||
- pub fn resnet34_pretrained(
|
||||
- weights: weights::ResNet34,
|
||||
- device: &Device<B>,
|
||||
- ) -> Result<Self, RecorderError> {
|
||||
- let weights = weights.weights();
|
||||
- let record = Self::load_weights_record(&weights, device)?;
|
||||
- let model = ResNet::<B>::resnet34(weights.num_classes, device).load_record(record);
|
||||
-
|
||||
- Ok(model)
|
||||
- }
|
||||
-
|
||||
/// ResNet-50 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385).
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -151,29 +98,6 @@ impl<B: Backend> ResNet<B> {
|
||||
ResNetConfig::new(RESNET50_BLOCKS, num_classes, 4).init(device)
|
||||
}
|
||||
|
||||
- /// ResNet-50 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
|
||||
- /// with pre-trained weights.
|
||||
- ///
|
||||
- /// # Arguments
|
||||
- ///
|
||||
- /// * `weights`: Pre-trained weights to load.
|
||||
- /// * `device` - Device to create the module on.
|
||||
- ///
|
||||
- /// # Returns
|
||||
- ///
|
||||
- /// A ResNet-50 module with pre-trained weights.
|
||||
- #[cfg(feature = "pretrained")]
|
||||
- pub fn resnet50_pretrained(
|
||||
- weights: weights::ResNet50,
|
||||
- device: &Device<B>,
|
||||
- ) -> Result<Self, RecorderError> {
|
||||
- let weights = weights.weights();
|
||||
- let record = Self::load_weights_record(&weights, device)?;
|
||||
- let model = ResNet::<B>::resnet50(weights.num_classes, device).load_record(record);
|
||||
-
|
||||
- Ok(model)
|
||||
- }
|
||||
-
|
||||
/// ResNet-101 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385).
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -188,29 +112,6 @@ impl<B: Backend> ResNet<B> {
|
||||
ResNetConfig::new(RESNET101_BLOCKS, num_classes, 4).init(device)
|
||||
}
|
||||
|
||||
- /// ResNet-101 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
|
||||
- /// with pre-trained weights.
|
||||
- ///
|
||||
- /// # Arguments
|
||||
- ///
|
||||
- /// * `weights`: Pre-trained weights to load.
|
||||
- /// * `device` - Device to create the module on.
|
||||
- ///
|
||||
- /// # Returns
|
||||
- ///
|
||||
- /// A ResNet-101 module with pre-trained weights.
|
||||
- #[cfg(feature = "pretrained")]
|
||||
- pub fn resnet101_pretrained(
|
||||
- weights: weights::ResNet101,
|
||||
- device: &Device<B>,
|
||||
- ) -> Result<Self, RecorderError> {
|
||||
- let weights = weights.weights();
|
||||
- let record = Self::load_weights_record(&weights, device)?;
|
||||
- let model = ResNet::<B>::resnet101(weights.num_classes, device).load_record(record);
|
||||
-
|
||||
- Ok(model)
|
||||
- }
|
||||
-
|
||||
/// ResNet-152 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385).
|
||||
///
|
||||
/// # Arguments
|
||||
@@ -225,29 +126,6 @@ impl<B: Backend> ResNet<B> {
|
||||
ResNetConfig::new(RESNET152_BLOCKS, num_classes, 4).init(device)
|
||||
}
|
||||
|
||||
- /// ResNet-152 from [`Deep Residual Learning for Image Recognition`](https://arxiv.org/abs/1512.03385)
|
||||
- /// with pre-trained weights.
|
||||
- ///
|
||||
- /// # Arguments
|
||||
- ///
|
||||
- /// * `weights`: Pre-trained weights to load.
|
||||
- /// * `device` - Device to create the module on.
|
||||
- ///
|
||||
- /// # Returns
|
||||
- ///
|
||||
- /// A ResNet-152 module with pre-trained weights.
|
||||
- #[cfg(feature = "pretrained")]
|
||||
- pub fn resnet152_pretrained(
|
||||
- weights: weights::ResNet152,
|
||||
- device: &Device<B>,
|
||||
- ) -> Result<Self, RecorderError> {
|
||||
- let weights = weights.weights();
|
||||
- let record = Self::load_weights_record(&weights, device)?;
|
||||
- let model = ResNet::<B>::resnet152(weights.num_classes, device).load_record(record);
|
||||
-
|
||||
- Ok(model)
|
||||
- }
|
||||
-
|
||||
/// Re-initialize the last layer with the specified number of output classes.
|
||||
pub fn with_classes(mut self, num_classes: usize) -> Self {
|
||||
let [d_input, _d_output] = self.fc.weight.dims();
|
||||
@@ -256,32 +134,6 @@ impl<B: Backend> ResNet<B> {
|
||||
}
|
||||
}
|
||||
|
||||
-#[cfg(feature = "pretrained")]
|
||||
-impl<B: Backend> ResNet<B> {
|
||||
- /// Load specified pre-trained PyTorch weights as a record.
|
||||
- fn load_weights_record(
|
||||
- weights: &weights::Weights,
|
||||
- device: &Device<B>,
|
||||
- ) -> Result<ResNetRecord<B>, RecorderError> {
|
||||
- // Download torch weights
|
||||
- let torch_weights = weights.download().map_err(|err| {
|
||||
- RecorderError::Unknown(format!("Could not download weights.\nError: {err}"))
|
||||
- })?;
|
||||
-
|
||||
- // Load weights from torch state_dict
|
||||
- let load_args = LoadArgs::new(torch_weights)
|
||||
- // Map *.downsample.0.* -> *.downsample.conv.*
|
||||
- .with_key_remap("(.+)\\.downsample\\.0\\.(.+)", "$1.downsample.conv.$2")
|
||||
- // Map *.downsample.1.* -> *.downsample.bn.*
|
||||
- .with_key_remap("(.+)\\.downsample\\.1\\.(.+)", "$1.downsample.bn.$2")
|
||||
- // Map layer[i].[j].* -> layer[i].blocks.[j].*
|
||||
- .with_key_remap("(layer[1-4])\\.([0-9]+)\\.(.+)", "$1.blocks.$2.$3");
|
||||
- let record = PyTorchFileRecorder::<FullPrecisionSettings>::new().load(load_args, device)?;
|
||||
-
|
||||
- Ok(record)
|
||||
- }
|
||||
-}
|
||||
-
|
||||
/// [ResNet](ResNet) configuration.
|
||||
struct ResNetConfig {
|
||||
conv1: Conv2dConfig,
|
||||
"#;
|
||||
|
||||
fn run<F>(name: &str, mut configure: F)
|
||||
where
|
||||
F: FnMut(&mut Command) -> &mut Command,
|
||||
{
|
||||
let mut command = Command::new(name);
|
||||
let configured = configure(&mut command);
|
||||
println!("Executing {:?}", configured);
|
||||
if !configured.status().unwrap().success() {
|
||||
panic!("failed to execute {:?}", configured);
|
||||
}
|
||||
println!("Command {:?} finished successfully", configured);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// Checkout ResNet code from models repo
|
||||
let models_dir = Path::new(MODELS_DIR);
|
||||
if !models_dir.join(".git").exists() {
|
||||
run("git", |command| {
|
||||
command
|
||||
.arg("clone")
|
||||
.arg("--depth=1")
|
||||
.arg("--no-checkout")
|
||||
.arg(MODELS_REPO)
|
||||
.arg(MODELS_DIR)
|
||||
});
|
||||
|
||||
run("git", |command| {
|
||||
command
|
||||
.current_dir(models_dir)
|
||||
.arg("sparse-checkout")
|
||||
.arg("set")
|
||||
.arg("resnet-burn")
|
||||
});
|
||||
|
||||
run("git", |command| {
|
||||
command.current_dir(models_dir).arg("checkout")
|
||||
});
|
||||
|
||||
let patch_file = models_dir.join("benchmark.patch");
|
||||
|
||||
fs::write(&patch_file, PATCH).expect("should write to file successfully");
|
||||
|
||||
// Apply patch
|
||||
run("git", |command| {
|
||||
command
|
||||
.current_dir(models_dir)
|
||||
.arg("apply")
|
||||
.arg(patch_file.to_str().unwrap())
|
||||
});
|
||||
}
|
||||
|
||||
// Copy contents to output dir
|
||||
let out_dir = env::var("OUT_DIR").unwrap();
|
||||
let source_path = models_dir.join("resnet-burn").join("resnet").join("src");
|
||||
let dest_path = Path::new(&out_dir);
|
||||
|
||||
for file in fs::read_dir(source_path).unwrap() {
|
||||
let source_file = file.unwrap().path();
|
||||
let dest_file = dest_path.join(source_file.file_name().unwrap());
|
||||
fs::copy(source_file, dest_file).expect("should copy file successfully");
|
||||
}
|
||||
|
||||
// Delete cloned repository contents
|
||||
fs::remove_dir_all(models_dir.join(".git")).unwrap();
|
||||
fs::remove_dir_all(models_dir).unwrap();
|
||||
}
|
|
@ -98,6 +98,8 @@ enum BenchmarkValues {
|
|||
Unary,
|
||||
#[strum(to_string = "max-pool2d")]
|
||||
MaxPool2d,
|
||||
#[strum(to_string = "resnet50")]
|
||||
Resnet50,
|
||||
#[strum(to_string = "load-record")]
|
||||
LoadRecord,
|
||||
#[strum(to_string = "autodiff")]
|
||||
|
|
Loading…
Reference in New Issue