diff --git a/candle-examples/Cargo.toml b/candle-examples/Cargo.toml index 5b90f140..fa5c620a 100644 --- a/candle-examples/Cargo.toml +++ b/candle-examples/Cargo.toml @@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] } image = { workspace = true } intel-mkl-src = { workspace = true, optional = true } num-traits = { workspace = true } +palette = { version = "0.7.6", optional = true } +enterpolation = { version = "0.2.1", optional = true} pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true } rayon = { workspace = true } rubato = { version = "0.15.0", optional = true } @@ -65,6 +67,7 @@ onnx = ["candle-onnx"] metal = ["candle/metal", "candle-nn/metal"] microphone = ["cpal"] encodec = ["cpal", "symphonia", "rubato"] +depth_anything_v2 = ["palette", "enterpolation"] [[example]] name = "llama_multiprocess" @@ -101,3 +104,7 @@ required-features = ["candle-datasets"] [[example]] name = "encodec" required-features = ["encodec"] + +[[example]] +name = "depth_anything_v2" +required-features = ["depth_anything_v2"] diff --git a/candle-examples/examples/depth_anything_v2/README.md b/candle-examples/examples/depth_anything_v2/README.md new file mode 100644 index 00000000..163b398b --- /dev/null +++ b/candle-examples/examples/depth_anything_v2/README.md @@ -0,0 +1,13 @@ +# candle-dinov2 + +[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which +builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer. + +This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it. + +## Running an example with color map and CUDA + +```bash +cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg +``` + diff --git a/candle-examples/examples/depth_anything_v2/color_map.rs b/candle-examples/examples/depth_anything_v2/color_map.rs new file mode 100644 index 00000000..94be326f --- /dev/null +++ b/candle-examples/examples/depth_anything_v2/color_map.rs @@ -0,0 +1,50 @@ +use enterpolation::linear::ConstEquidistantLinear; +use enterpolation::Generator; +use palette::LinSrgb; + +use candle::Tensor; + +pub struct SpectralRColormap { + gradient: ConstEquidistantLinear, +} + +impl SpectralRColormap { + pub(crate) fn new() -> Self { + // Define a colormap similar to 'Spectral_r' by specifying key colors. + // got the colors from ChatGPT-4o + let gradient = ConstEquidistantLinear::::equidistant_unchecked([ + LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue + LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue + LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan + LinSrgb::new(0.6706, 0.8667, 0.6431), // Green + LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow + LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange + LinSrgb::new(0.9922, 0.6824, 0.3804), // Red + LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red + LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple + ]); + Self { gradient } + } + + fn get_color(&self, value: f32) -> LinSrgb { + self.gradient.gen(value) + } + + pub fn gray2color(&self, gray: &Tensor) -> candle::Result { + println!("Gray: {:?}", gray.dims()); + let gray_values: Vec = gray.flatten_all()?.to_vec1()?; + let rgb_values: Vec = gray_values + .iter() + .map(|g| self.get_color(*g)) + .flat_map(|rgb| [rgb.red, rgb.green, rgb.blue]) + .collect(); + + let [.., height, width] = gray.dims() else { + candle::bail!("Not enough dims!") + }; + + let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?; + + color.permute((2, 0, 1)) + } +} diff --git a/candle-examples/examples/depth_anything_v2/main.rs b/candle-examples/examples/depth_anything_v2/main.rs new file mode 100644 index 00000000..ef337eba --- /dev/null +++ b/candle-examples/examples/depth_anything_v2/main.rs @@ -0,0 +1,187 @@ +//! Depth Anything V2 +//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2 + +#[cfg(feature = "accelerate")] +extern crate accelerate_src; +#[cfg(feature = "mkl")] +extern crate intel_mkl_src; + +use std::ffi::OsString; +use std::path::PathBuf; + +use clap::Parser; + +use candle::DType::{F32, U8}; +use candle::{DType, Device, Module, Result, Tensor}; +use candle_examples::{load_image, load_image_and_resize, save_image}; +use candle_nn::VarBuilder; +use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config}; +use candle_transformers::models::dinov2; + +use crate::color_map::SpectralRColormap; + +mod color_map; + +// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207 +const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406]; +const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225]; + +const DINO_IMG_SIZE: usize = 518; + +#[derive(Parser)] +struct Args { + #[arg(long)] + dinov2_model: Option, + + #[arg(long)] + depth_anything_v2_model: Option, + + #[arg(long)] + image: PathBuf, + + #[arg(long)] + output_dir: Option, + + #[arg(long)] + cpu: bool, + + #[arg(long)] + color_map: bool, +} + +pub fn main() -> anyhow::Result<()> { + let args = Args::parse(); + let device = candle_examples::device(args.cpu)?; + + let dinov2_model_file = match args.dinov2_model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("lmz/candle-dino-v2".into()); + api.get("dinov2_vits14.safetensors")? + } + Some(dinov2_model) => dinov2_model, + }; + println!("Using file {:?}", dinov2_model_file); + + let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? }; + let dinov2 = dinov2::vit_small(vb)?; + println!("DinoV2 model built"); + + let depth_anything_model_file = match args.depth_anything_v2_model { + None => { + let api = hf_hub::api::sync::Api::new()?; + let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into()); + api.get("depth_anything_v2_vits.safetensors")? + } + Some(depth_anything_model) => depth_anything_model, + }; + println!("Using file {:?}", depth_anything_model_file); + + let vb = unsafe { + VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)? + }; + + let config = DepthAnythingV2Config::vit_small(); + let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?; + + let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?; + + println!("Loaded image {image:?}"); + + let depth = depth_anything.forward(&image)?; + + println!("Got predictions {:?}", depth.shape()); + + let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?; + + let output_path = full_output_path(&args.image, &args.output_dir); + println!("Saving image to {}", output_path.to_string_lossy()); + save_image(&output_image, output_path)?; + + Ok(()) +} + +fn full_output_path(image_path: &PathBuf, output_dir: &Option) -> PathBuf { + let input_file_name = image_path.file_name().unwrap(); + let mut output_file_name = OsString::from("depth_"); + output_file_name.push(input_file_name); + let mut output_path = match output_dir { + None => image_path.parent().unwrap().to_path_buf(), + Some(output_path) => output_path.clone(), + }; + output_path.push(output_file_name); + + output_path +} + +fn load_and_prep_image( + image_path: &PathBuf, + device: &Device, +) -> anyhow::Result<(usize, usize, Tensor)> { + let (_original_image, original_height, original_width) = load_image(&image_path, None)?; + + let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)? + .unsqueeze(0)? + .to_dtype(F32)? + .to_device(&device)?; + + let max_pixel_val = Tensor::try_from(255.0f32)? + .to_device(&device)? + .broadcast_as(image.shape())?; + let image = (image / max_pixel_val)?; + let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?; + + Ok((original_height, original_width, image)) +} + +fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result { + let mean_tensor = + Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?; + let std_tensor = + Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?; + image.sub(&mean_tensor)?.div(&std_tensor) +} + +fn post_process_image( + image: &Tensor, + original_height: usize, + original_width: usize, + color_map: bool, +) -> Result { + let out = image.interpolate2d(original_height, original_width)?; + let out = scale_image(&out)?; + + let out = if color_map { + let spectral_r = SpectralRColormap::new(); + spectral_r.gray2color(&out)? + } else { + let rgb_slice = [&out, &out, &out]; + Tensor::cat(&rgb_slice, 0)?.squeeze(1)? + }; + + let max_pixel_val = Tensor::try_from(255.0f32)? + .to_device(out.device())? + .broadcast_as(out.shape())?; + let out = (out * max_pixel_val)?; + + out.to_dtype(U8) +} + +fn scale_image(depth: &Tensor) -> Result { + let flat_values: Vec = depth.flatten_all()?.to_vec1()?; + + let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap(); + let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap(); + + let min_val_tensor = Tensor::try_from(*min_val)? + .to_device(depth.device())? + .broadcast_as(depth.shape())?; + let depth = (depth - min_val_tensor)?; + + let range = max_val - min_val; + let range_tensor = Tensor::try_from(range)? + .to_device(depth.device())? + .broadcast_as(depth.shape())?; + + depth / range_tensor +} diff --git a/candle-nn/src/ops.rs b/candle-nn/src/ops.rs index 2a76ee5e..9a360c47 100644 --- a/candle-nn/src/ops.rs +++ b/candle-nn/src/ops.rs @@ -1,4 +1,4 @@ -use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D}; +use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D}; use rayon::prelude::*; /// Applies the softmax function to the input tensor, rescaling the element so that elements on @@ -926,3 +926,24 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result { n => candle::bail!("replication-pad with a size of {n} is not supported"), } } + +#[derive(Clone, Debug)] +pub struct Identity; + +impl Identity { + pub fn new() -> Identity { + Self + } +} + +impl Default for Identity { + fn default() -> Self { + Self + } +} + +impl Module for Identity { + fn forward(&self, xs: &Tensor) -> Result { + Ok(xs.clone()) + } +} diff --git a/candle-transformers/src/models/depth_anything_v2.rs b/candle-transformers/src/models/depth_anything_v2.rs new file mode 100644 index 00000000..9eee6d11 --- /dev/null +++ b/candle-transformers/src/models/depth_anything_v2.rs @@ -0,0 +1,553 @@ +use candle::D::Minus1; +use candle::{Module, Result, Tensor}; +use candle_nn::ops::Identity; +use candle_nn::{ + batch_norm, conv2d, conv2d_no_bias, conv_transpose2d, linear, seq, Activation, BatchNorm, + BatchNormConfig, Conv2d, Conv2dConfig, ConvTranspose2dConfig, Sequential, VarBuilder, +}; + +use crate::models::dinov2::DinoVisionTransformer; + +pub struct DepthAnythingV2Config { + out_channel_sizes: [usize; 4], + in_channel_size: usize, // embed_dim in the Dino model + num_features: usize, + use_batch_norm: bool, + use_class_token: bool, + layer_ids_vits: Vec, + input_image_size: usize, + target_patch_size: usize, +} + +impl DepthAnythingV2Config { + #[allow(clippy::too_many_arguments)] + pub fn new( + out_channel_sizes: [usize; 4], + in_channel_size: usize, + num_features: usize, + use_batch_norm: bool, + use_class_token: bool, + layer_ids_vits: Vec, + input_image_size: usize, + target_patch_size: usize, + ) -> Self { + Self { + out_channel_sizes, + in_channel_size, + num_features, + use_batch_norm, + use_class_token, + layer_ids_vits, + input_image_size, + target_patch_size, + } + } + + pub fn vit_small() -> Self { + Self { + out_channel_sizes: [48, 96, 192, 384], + in_channel_size: 384, + num_features: 64, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![2, 5, 8, 11], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_base() -> Self { + Self { + out_channel_sizes: [96, 192, 384, 768], + in_channel_size: 768, + num_features: 128, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![2, 5, 8, 11], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_large() -> Self { + Self { + out_channel_sizes: [256, 512, 1024, 1024], + in_channel_size: 1024, + num_features: 256, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![4, 11, 17, 23], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } + + pub fn vit_giant() -> Self { + Self { + out_channel_sizes: [1536, 1536, 1536, 1536], + in_channel_size: 1536, + num_features: 384, + use_batch_norm: false, + use_class_token: false, + layer_ids_vits: vec![9, 19, 29, 39], + input_image_size: 518, + target_patch_size: 518 / 14, + } + } +} + +pub struct ResidualConvUnit { + activation: Activation, + conv1: Conv2d, + conv2: Conv2d, + batch_norm1: Option, + batch_norm2: Option, +} + +impl ResidualConvUnit { + pub fn new( + conf: &DepthAnythingV2Config, + activation: Activation, + vb: VarBuilder, + ) -> Result { + const KERNEL_SIZE: usize = 3; + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + let conv1 = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("conv1"), + )?; + let conv2 = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("conv2"), + )?; + + let (batch_norm1, batch_norm2) = match conf.use_batch_norm { + true => { + let batch_norm_cfg = BatchNormConfig { + eps: 1e-05, + remove_mean: false, + affine: true, + momentum: 0.1, + }; + ( + Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn1"))?), + Some(batch_norm(conf.num_features, batch_norm_cfg, vb.pp("bn2"))?), + ) + } + false => (None, None), + }; + + Ok(Self { + activation, + conv1, + conv2, + batch_norm1, + batch_norm2, + }) + } +} + +impl Module for ResidualConvUnit { + fn forward(&self, xs: &Tensor) -> Result { + let out = self.activation.forward(xs)?; + let out = self.conv1.forward(&out)?; + let out = if let Some(batch_norm1) = &self.batch_norm1 { + batch_norm1.forward_train(&out)? + } else { + out + }; + + let out = self.activation.forward(&out)?; + let out = self.conv2.forward(&out)?; + let out = if let Some(batch_norm2) = &self.batch_norm2 { + batch_norm2.forward_train(&out)? + } else { + out + }; + + out + xs + } +} + +pub struct FeatureFusionBlock { + res_conv_unit1: ResidualConvUnit, + res_conv_unit2: ResidualConvUnit, + output_conv: Conv2d, + target_patch_size: usize, +} + +impl FeatureFusionBlock { + pub fn new( + conf: &DepthAnythingV2Config, + target_patch_size: usize, + activation: Activation, + vb: VarBuilder, + ) -> Result { + const KERNEL_SIZE: usize = 1; + let conv_cfg = Conv2dConfig { + padding: 0, + stride: 1, + dilation: 1, + groups: 1, + }; + let output_conv = conv2d( + conf.num_features, + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("out_conv"), + )?; + let res_conv_unit1 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit1"))?; + let res_conv_unit2 = ResidualConvUnit::new(conf, activation, vb.pp("resConfUnit2"))?; + + Ok(Self { + res_conv_unit1, + res_conv_unit2, + output_conv, + target_patch_size, + }) + } +} + +impl Module for FeatureFusionBlock { + fn forward(&self, xs: &Tensor) -> Result { + let out = self.res_conv_unit2.forward(xs)?; + let out = out.interpolate2d(self.target_patch_size, self.target_patch_size)?; + + self.output_conv.forward(&out) + } +} + +pub struct Scratch { + layer1_rn: Conv2d, + layer2_rn: Conv2d, + layer3_rn: Conv2d, + layer4_rn: Conv2d, + refine_net1: FeatureFusionBlock, + refine_net2: FeatureFusionBlock, + refine_net3: FeatureFusionBlock, + refine_net4: FeatureFusionBlock, + output_conv1: Conv2d, + output_conv2: Sequential, +} + +impl Scratch { + pub fn new(conf: &DepthAnythingV2Config, vb: VarBuilder) -> Result { + const KERNEL_SIZE: usize = 3; + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + + let layer1_rn = conv2d_no_bias( + conf.out_channel_sizes[0], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer1_rn"), + )?; + let layer2_rn = conv2d_no_bias( + conf.out_channel_sizes[1], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer2_rn"), + )?; + let layer3_rn = conv2d_no_bias( + conf.out_channel_sizes[2], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer3_rn"), + )?; + let layer4_rn = conv2d_no_bias( + conf.out_channel_sizes[3], + conf.num_features, + KERNEL_SIZE, + conv_cfg, + vb.pp("layer4_rn"), + )?; + + let refine_net1 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 8, + Activation::Relu, + vb.pp("refinenet1"), + )?; + let refine_net2 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 4, + Activation::Relu, + vb.pp("refinenet2"), + )?; + let refine_net3 = FeatureFusionBlock::new( + conf, + conf.target_patch_size * 2, + Activation::Relu, + vb.pp("refinenet3"), + )?; + let refine_net4 = FeatureFusionBlock::new( + conf, + conf.target_patch_size, + Activation::Relu, + vb.pp("refinenet4"), + )?; + + let conv_cfg = Conv2dConfig { + padding: 1, + stride: 1, + dilation: 1, + groups: 1, + }; + let output_conv1 = conv2d( + conf.num_features, + conf.num_features / 2, + KERNEL_SIZE, + conv_cfg, + vb.pp("output_conv1"), + )?; + + let output_conv2 = seq(); + const HEAD_FEATURES_2: usize = 32; + const OUT_CHANNELS_2: usize = 1; + const KERNEL_SIZE_2: usize = 1; + let output_conv2 = output_conv2.add(conv2d( + conf.num_features / 2, + HEAD_FEATURES_2, + KERNEL_SIZE, + conv_cfg, + vb.pp("output_conv2").pp("0"), + )?); + let output_conv2 = output_conv2 + .add(Activation::Relu) + .add(conv2d( + HEAD_FEATURES_2, + OUT_CHANNELS_2, + KERNEL_SIZE_2, + conv_cfg, + vb.pp("output_conv2").pp("2"), + )?) + .add(Activation::Relu); + + Ok(Self { + layer1_rn, + layer2_rn, + layer3_rn, + layer4_rn, + refine_net1, + refine_net2, + refine_net3, + refine_net4, + output_conv1, + output_conv2, + }) + } +} + +const NUM_CHANNELS: usize = 4; + +pub struct DPTHead<'a> { + conf: &'a DepthAnythingV2Config, + projections: Vec, + resize_layers: Vec>, + readout_projections: Vec, + scratch: Scratch, +} + +impl<'a> DPTHead<'a> { + pub fn new(conf: &'a DepthAnythingV2Config, vb: VarBuilder) -> Result { + let mut projections: Vec = Vec::with_capacity(conf.out_channel_sizes.len()); + for (conv_index, out_channel_size) in conf.out_channel_sizes.iter().enumerate() { + projections.push(conv2d( + conf.in_channel_size, + *out_channel_size, + 1, + Default::default(), + vb.pp("projects").pp(conv_index.to_string()), + )?); + } + + let resize_layers: Vec> = vec![ + Box::new(conv_transpose2d( + conf.out_channel_sizes[0], + conf.out_channel_sizes[0], + 4, + ConvTranspose2dConfig { + padding: 0, + stride: 4, + dilation: 1, + output_padding: 0, + }, + vb.pp("resize_layers").pp("0"), + )?), + Box::new(conv_transpose2d( + conf.out_channel_sizes[1], + conf.out_channel_sizes[1], + 2, + ConvTranspose2dConfig { + padding: 0, + stride: 2, + dilation: 1, + output_padding: 0, + }, + vb.pp("resize_layers").pp("1"), + )?), + Box::new(Identity::new()), + Box::new(conv2d( + conf.out_channel_sizes[3], + conf.out_channel_sizes[3], + 3, + Conv2dConfig { + padding: 1, + stride: 2, + dilation: 1, + groups: 1, + }, + vb.pp("resize_layers").pp("3"), + )?), + ]; + + let readout_projections = if conf.use_class_token { + let rop = Vec::with_capacity(NUM_CHANNELS); + for rop_index in 0..NUM_CHANNELS { + seq() + .add(linear( + 2 * conf.in_channel_size, + conf.in_channel_size, + vb.pp("readout_projects").pp(rop_index.to_string()), + )?) + .add(Activation::Gelu); + } + rop + } else { + vec![] + }; + + let scratch = Scratch::new(conf, vb.pp("scratch"))?; + + Ok(Self { + conf, + projections, + resize_layers, + readout_projections, + scratch, + }) + } +} + +impl Module for DPTHead<'_> { + fn forward(&self, xs: &Tensor) -> Result { + let mut out: Vec = Vec::with_capacity(NUM_CHANNELS); + for i in 0..NUM_CHANNELS { + let x = if self.conf.use_class_token { + let x = xs.get(i)?.get(0)?; + let class_token = xs.get(i)?.get(1)?; + let readout = class_token.unsqueeze(1)?.expand(x.shape())?; + let to_cat = [x, readout]; + let cat = Tensor::cat(&to_cat, Minus1)?; + self.readout_projections[i].forward(&cat)? + } else { + xs.get(i)? + }; + let x_dims = x.dims(); + + let x = x.permute((0, 2, 1))?.reshape(( + x_dims[0], + x_dims[x_dims.len() - 1], + self.conf.target_patch_size, + self.conf.target_patch_size, + ))?; + let x = self.projections[i].forward(&x)?; + + let x = self.resize_layers[i].forward(&x)?; + out.push(x); + } + + let layer_1_rn = self.scratch.layer1_rn.forward(&out[0])?; + let layer_2_rn = self.scratch.layer2_rn.forward(&out[1])?; + let layer_3_rn = self.scratch.layer3_rn.forward(&out[2])?; + let layer_4_rn = self.scratch.layer4_rn.forward(&out[3])?; + + let path4 = self.scratch.refine_net4.forward(&layer_4_rn)?; + + let res3_out = self + .scratch + .refine_net3 + .res_conv_unit1 + .forward(&layer_3_rn)?; + let res3_out = path4.add(&res3_out)?; + let path3 = self.scratch.refine_net3.forward(&res3_out)?; + + let res2_out = self + .scratch + .refine_net2 + .res_conv_unit1 + .forward(&layer_2_rn)?; + let res2_out = path3.add(&res2_out)?; + let path2 = self.scratch.refine_net2.forward(&res2_out)?; + + let res1_out = self + .scratch + .refine_net1 + .res_conv_unit1 + .forward(&layer_1_rn)?; + let res1_out = path2.add(&res1_out)?; + let path1 = self.scratch.refine_net1.forward(&res1_out)?; + + let out = self.scratch.output_conv1.forward(&path1)?; + + let out = out.interpolate2d(self.conf.input_image_size, self.conf.input_image_size)?; + + self.scratch.output_conv2.forward(&out) + } +} + +pub struct DepthAnythingV2<'a> { + pretrained: &'a DinoVisionTransformer, + depth_head: DPTHead<'a>, + conf: &'a DepthAnythingV2Config, +} + +impl<'a> DepthAnythingV2<'a> { + pub fn new( + pretrained: &'a DinoVisionTransformer, + conf: &'a DepthAnythingV2Config, + vb: VarBuilder, + ) -> Result { + let depth_head = DPTHead::new(conf, vb.pp("depth_head"))?; + + Ok(Self { + pretrained, + depth_head, + conf, + }) + } +} + +impl<'a> Module for DepthAnythingV2<'a> { + fn forward(&self, xs: &Tensor) -> Result { + let features = self.pretrained.get_intermediate_layers( + xs, + &self.conf.layer_ids_vits, + false, + false, + true, + )?; + let depth = self.depth_head.forward(&features)?; + + depth.relu() + } +} diff --git a/candle-transformers/src/models/dinov2.rs b/candle-transformers/src/models/dinov2.rs index 757aa88a..00e501ce 100644 --- a/candle-transformers/src/models/dinov2.rs +++ b/candle-transformers/src/models/dinov2.rs @@ -258,6 +258,84 @@ impl DinoVisionTransformer { let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?; &xs + &self.interpolate_pos_encoding(&xs, w, h)? } + + fn get_intermediate_layers_not_chunked( + &self, + xs: &Tensor, + blocks_to_take: &[usize], + ) -> Result> { + let mut xs = self.prepare_tokens_with_mask(xs)?; + let mut output = Vec::new(); + for (i, blk) in self.blocks.iter().enumerate() { + xs = blk.forward(&xs)?; + if blocks_to_take.contains(&i) { + output.push(xs.clone()); + } + } + if output.len() != blocks_to_take.len() { + candle::bail!( + "only {} / {} blocks found", + output.len(), + blocks_to_take.len() + ); + } + Ok(output) + } + + pub fn get_intermediate_layers( + &self, + xs: &Tensor, + blocks_to_take: &[usize], + reshape: bool, + return_class_token: bool, + norm: bool, + ) -> Result { + let outputs = self.get_intermediate_layers_not_chunked(xs, blocks_to_take)?; + let outputs = if norm { + outputs + .iter() + .map(|out| self.norm.forward(out)) + .collect::>>()? + } else { + outputs + }; + let class_tokens = outputs + .iter() + .map(|out| out.i((.., 0))) + .collect::>>()?; + let outputs = outputs + .iter() + .map(|out| out.i((.., 1..))) + .collect::>>()?; + + let outputs = if reshape { + let (b, _c, w, h) = xs.dims4()?; + let patch_size = self.patch_embed.patch_size.0; + let num_channels = outputs[0].elem_count() / (b * (w / patch_size) * (h / patch_size)); + outputs + .iter() + .map(|out| { + out.reshape((b, w / patch_size, h / patch_size, num_channels))? + .transpose(2, 3)? + .transpose(1, 2) + }) + .collect::>>()? + } else { + outputs + }; + + let outputs = if return_class_token { + outputs + .iter() + .zip(class_tokens.iter()) + .map(|(out, class_token)| Tensor::cat(&[out, class_token], D::Minus1)) + .collect::>>()? + } else { + outputs + }; + + Tensor::stack(&outputs[..], 0) + } } impl Module for DinoVisionTransformer { diff --git a/candle-transformers/src/models/mod.rs b/candle-transformers/src/models/mod.rs index 4628a3de..89ae0f8a 100644 --- a/candle-transformers/src/models/mod.rs +++ b/candle-transformers/src/models/mod.rs @@ -6,6 +6,7 @@ pub mod chatglm; pub mod clip; pub mod convmixer; pub mod convnext; +pub mod depth_anything_v2; pub mod dinov2; pub mod distilbert; pub mod efficientnet;