Move more models to candle-transformers (#796)

* Move dinov2.

* Move efficientnet.

* Move the quantized llama model.

* Move segment-anything.
This commit is contained in:
Laurent Mazare 2023-09-10 10:20:18 +01:00 committed by GitHub
parent d3f05eae8c
commit 35f72514f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 773 additions and 759 deletions

View File

@ -9,285 +9,10 @@ extern crate accelerate_src;
use clap::Parser;
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::dinov2;
const IMG_SIZE: usize = 518;
const PATCH_SIZE: usize = 14;
const NUM_CLASSES: usize = 1000;
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias {
candle_nn::linear(in_dim, out_dim, vb)
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug)]
struct Attention {
qkv: Linear,
proj: Linear,
num_heads: usize,
scale: f64,
}
impl Attention {
fn new(
vb: VarBuilder,
dim: usize,
num_heads: usize,
qkv_bias: bool,
proj_bias: bool,
) -> Result<Self> {
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
let scale = 1. / ((dim / num_heads) as f64).sqrt();
Ok(Self {
qkv,
proj,
num_heads,
scale,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, n, c) = xs.dims3()?;
let qkv = self
.qkv
.forward(xs)?
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
.transpose(1, 2)? // 02134
.transpose(0, 1)? // 20134
.transpose(2, 3)?; // 20314
let q = (qkv.i(0)? * self.scale)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
self.proj.forward(&attn)
}
}
#[derive(Debug)]
struct LayerScale {
gamma: Tensor,
}
impl LayerScale {
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
let gamma = vb.get(dim, "gamma")?;
Ok(Self { gamma })
}
}
impl Module for LayerScale {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&self.gamma)
}
}
#[derive(Debug)]
struct Mlp {
fc1: Linear,
fc2: Linear,
}
impl Mlp {
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
let out_features = in_features;
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?.gelu()?;
self.fc2.forward(&xs)
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
ls1: LayerScale,
norm2: LayerNorm,
mlp: Mlp,
ls2: LayerScale,
}
impl Block {
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
Ok(Self {
norm1,
attn,
ls1,
norm2,
mlp,
ls2,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self
.ls1
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.ls2
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
xs + residual
}
}
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
patch_size: (usize, usize),
num_patches: usize,
}
impl PatchEmbed {
fn new(
vb: VarBuilder,
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
) -> Result<Self> {
let config = candle_nn::Conv2dConfig {
stride: patch_size,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
let num_patches = (img_size / patch_size) * (img_size / patch_size);
Ok(Self {
proj,
patch_size: (patch_size, patch_size),
num_patches,
})
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _c, h, w) = xs.dims4()?;
let (patch_h, patch_w) = self.patch_size;
if (h % patch_h) != 0 {
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
}
if (w % patch_w) != 0 {
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
}
let xs = self.proj.forward(xs)?;
let (b, c, h, w) = xs.dims4()?;
// flatten embeddings.
xs.reshape((b, c, h * w))?.transpose(1, 2)
}
}
#[derive(Debug)]
pub struct DinoVisionTransformer {
patch_embed: PatchEmbed,
cls_token: Tensor,
pos_embed: Tensor,
blocks: Vec<Block>,
norm: LayerNorm,
head: Linear,
}
impl DinoVisionTransformer {
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
let patch_embed =
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
let num_tokens = 1;
let pos_embed = vb.get(
(1, patch_embed.num_patches + num_tokens, embed_dim),
"pos_embed",
)?;
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
cls_token,
pos_embed,
blocks,
norm,
head,
})
}
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let npatch = xs.dim(1)? - 1;
let n = self.pos_embed.dim(1)? - 1;
let sqrt_n = (n as f64).sqrt();
if npatch == n && w == h {
return Ok(xs.clone());
}
let class_pos_embed = self.pos_embed.i((.., ..1))?;
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
let dim = xs.dim(D::Minus1)?;
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
let patch_pos_embed = patch_pos_embed
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
.transpose(2, 3)?
.transpose(1, 2)?;
// This uses bicubic interpolation in the original implementation.
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
let el_count = patch_pos_embed.shape().elem_count();
let patch_pos_embed =
patch_pos_embed
.transpose(1, 2)?
.transpose(2, 3)?
.reshape((1, el_count / dim, dim))?;
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
}
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _nc, w, h) = xs.dims4()?;
let xs = self.patch_embed.forward(xs)?;
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
}
impl Module for DinoVisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
for blk in self.blocks.iter() {
xs = blk.forward(&xs)?
}
let xs = self.norm.forward(&xs)?;
let xs_norm_clstoken = xs.i((.., 0))?;
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
self.head.forward(&xs)
}
}
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
DinoVisionTransformer::new(vb, 12, 384, 6)
}
#[derive(Parser)]
struct Args {
#[arg(long)]
@ -320,7 +45,7 @@ pub fn main() -> anyhow::Result<()> {
let weights = unsafe { candle::safetensors::MmapedFile::new(model_file)? };
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let model = vit_small(vb)?;
let model = dinov2::vit_small(vb)?;
println!("model built");
let logits = model.forward(&image.unsqueeze(0)?)?;
let prs = candle_nn::ops::softmax(&logits, D::Minus1)?

View File

@ -8,340 +8,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle::{DType, IndexOp, D};
use candle_nn::{Module, VarBuilder};
use candle_transformers::models::efficientnet::{EfficientNet, MBConvConfig};
use clap::{Parser, ValueEnum};
use candle::{DType, IndexOp, Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};
// Based on the Python version from torchvision.
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
#[derive(Debug, Clone, Copy)]
pub struct MBConvConfig {
expand_ratio: f64,
kernel: usize,
stride: usize,
input_channels: usize,
out_channels: usize,
num_layers: usize,
}
fn make_divisible(v: f64, divisor: usize) -> usize {
let min_value = divisor;
let new_v = usize::max(
min_value,
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
);
if (new_v as f64) < 0.9 * v {
new_v + divisor
} else {
new_v
}
}
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
let bneck_conf = |e, k, s, i, o, n| {
let input_channels = make_divisible(i as f64 * width_mult, 8);
let out_channels = make_divisible(o as f64 * width_mult, 8);
let num_layers = (n as f64 * depth_mult).ceil() as usize;
MBConvConfig {
expand_ratio: e,
kernel: k,
stride: s,
input_channels,
out_channels,
num_layers,
}
};
vec![
bneck_conf(1., 3, 1, 32, 16, 1),
bneck_conf(6., 3, 2, 16, 24, 2),
bneck_conf(6., 5, 2, 24, 40, 2),
bneck_conf(6., 3, 2, 40, 80, 3),
bneck_conf(6., 5, 1, 80, 112, 3),
bneck_conf(6., 5, 2, 112, 192, 4),
bneck_conf(6., 3, 1, 192, 320, 1),
]
}
impl MBConvConfig {
fn b0() -> Vec<Self> {
bneck_confs(1.0, 1.0)
}
fn b1() -> Vec<Self> {
bneck_confs(1.0, 1.1)
}
fn b2() -> Vec<Self> {
bneck_confs(1.1, 1.2)
}
fn b3() -> Vec<Self> {
bneck_confs(1.2, 1.4)
}
fn b4() -> Vec<Self> {
bneck_confs(1.4, 1.8)
}
fn b5() -> Vec<Self> {
bneck_confs(1.6, 2.2)
}
fn b6() -> Vec<Self> {
bneck_confs(1.8, 2.6)
}
fn b7() -> Vec<Self> {
bneck_confs(2.0, 3.1)
}
}
/// Conv2D with same padding.
#[derive(Debug)]
struct Conv2DSame {
conv2d: nn::Conv2d,
s: usize,
k: usize,
}
impl Conv2DSame {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
bias: bool,
) -> Result<Self> {
let conv_config = nn::Conv2dConfig {
stride,
groups,
..Default::default()
};
let conv2d = if bias {
nn::conv2d(i, o, k, conv_config, vb)?
} else {
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
};
Ok(Self {
conv2d,
s: stride,
k,
})
}
}
impl Module for Conv2DSame {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let s = self.s;
let k = self.k;
let (_, _, ih, iw) = xs.dims4()?;
let oh = (ih + s - 1) / s;
let ow = (iw + s - 1) / s;
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
if pad_h > 0 || pad_w > 0 {
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
self.conv2d.forward(&xs)
} else {
self.conv2d.forward(xs)
}
}
}
#[derive(Debug)]
struct ConvNormActivation {
conv2d: Conv2DSame,
bn2d: nn::BatchNorm,
activation: bool,
}
impl ConvNormActivation {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
) -> Result<Self> {
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
Ok(Self {
conv2d,
bn2d,
activation: true,
})
}
fn no_activation(self) -> Self {
Self {
activation: false,
..self
}
}
}
impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv2d.forward(xs)?;
let xs = self.bn2d.forward(&xs)?;
if self.activation {
swish(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
struct SqueezeExcitation {
fc1: Conv2DSame,
fc2: Conv2DSame,
}
impl SqueezeExcitation {
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for SqueezeExcitation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
// equivalent to adaptive_avg_pool2d([1, 1])
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = self.fc1.forward(&xs)?;
let xs = swish(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = nn::ops::sigmoid(&xs)?;
residual.broadcast_mul(&xs)
}
}
#[derive(Debug)]
struct MBConv {
expand_cna: Option<ConvNormActivation>,
depthwise_cna: ConvNormActivation,
squeeze_excitation: SqueezeExcitation,
project_cna: ConvNormActivation,
config: MBConvConfig,
}
impl MBConv {
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
let vb = vb.pp("block");
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
let expand_cna = if exp != c.input_channels {
Some(ConvNormActivation::new(
vb.pp("0"),
c.input_channels,
exp,
1,
1,
1,
)?)
} else {
None
};
let start_index = if expand_cna.is_some() { 1 } else { 0 };
let depthwise_cna =
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
let squeeze_channels = usize::max(1, c.input_channels / 4);
let squeeze_excitation =
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
let project_cna =
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
.no_activation();
Ok(Self {
expand_cna,
depthwise_cna,
squeeze_excitation,
project_cna,
config: c,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let use_res_connect =
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
let ys = match &self.expand_cna {
Some(expand_cna) => expand_cna.forward(xs)?,
None => xs.clone(),
};
let ys = self.depthwise_cna.forward(&ys)?;
let ys = self.squeeze_excitation.forward(&ys)?;
let ys = self.project_cna.forward(&ys)?;
if use_res_connect {
ys + xs
} else {
Ok(ys)
}
}
}
fn swish(s: &Tensor) -> Result<Tensor> {
s * nn::ops::sigmoid(s)?
}
#[derive(Debug)]
struct EfficientNet {
init_cna: ConvNormActivation,
blocks: Vec<MBConv>,
final_cna: ConvNormActivation,
classifier: nn::Linear,
}
impl EfficientNet {
fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
let last_out_c = configs.last().unwrap().out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();
let mut blocks = vec![];
for (index, cnf) in configs.into_iter().enumerate() {
let f_p = f_p.pp(index + 1);
for r_index in 0..cnf.num_layers {
let cnf = if r_index == 0 {
cnf
} else {
MBConvConfig {
input_channels: cnf.out_channels,
stride: 1,
..cnf
}
};
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
}
}
let final_cna =
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
Ok(Self {
init_cna,
blocks,
final_cna,
classifier,
})
}
}
impl Module for EfficientNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.init_cna.forward(xs)?;
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
let xs = self.final_cna.forward(&xs)?;
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
self.classifier.forward(&xs)
}
}
#[derive(Clone, Copy, Debug, ValueEnum)]
enum Which {
B0,

View File

@ -12,7 +12,7 @@ use candle::quantized::{ggml_file, gguf_file};
use candle::{Device, Tensor};
use candle_transformers::generation::LogitsProcessor;
mod model;
use candle_transformers::models::quantized_llama as model;
use model::ModelWeights;
const DEFAULT_PROMPT: &str = "My favorite theorem is ";

View File

@ -7,108 +7,11 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
pub mod model_image_encoder;
pub mod model_mask_decoder;
pub mod model_prompt_encoder;
pub mod model_sam;
pub mod model_tiny_vit;
pub mod model_transformer;
use candle::{DType, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use candle::DType;
use candle_nn::VarBuilder;
use candle_transformers::models::segment_anything::sam;
use clap::Parser;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb)?
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
#[derive(Debug)]
pub struct LayerNorm2d {
weight: Tensor,
bias: Tensor,
num_channels: usize,
eps: f64,
}
impl LayerNorm2d {
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(num_channels, "weight")?;
let bias = vb.get(num_channels, "bias")?;
Ok(Self {
weight,
bias,
num_channels,
eps,
})
}
}
impl Module for LayerNorm2d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let u = xs.mean_keepdim(1)?;
let xs = xs.broadcast_sub(&u)?;
let s = xs.sqr()?.mean_keepdim(1)?;
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
}
}
#[derive(Debug)]
pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
activation: candle_nn::Activation,
span: tracing::Span,
}
impl MlpBlock {
pub fn new(
embedding_dim: usize,
mlp_dim: usize,
activation: candle_nn::Activation,
vb: VarBuilder,
) -> Result<Self> {
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
Ok(Self {
lin1,
lin2,
activation,
span,
})
}
}
impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}
#[derive(Parser)]
struct Args {
#[arg(long)]
@ -173,7 +76,7 @@ pub fn main() -> anyhow::Result<()> {
let (_c, h, w) = image.dims3()?;
(image, h, w)
} else {
let (image, h, w) = candle_examples::load_image(&args.image, Some(model_sam::IMAGE_SIZE))?;
let (image, h, w) = candle_examples::load_image(&args.image, Some(sam::IMAGE_SIZE))?;
(image.to_device(&device)?, h, w)
};
println!("loaded image {image:?}");
@ -195,9 +98,9 @@ pub fn main() -> anyhow::Result<()> {
let weights = weights.deserialize()?;
let vb = VarBuilder::from_safetensors(vec![weights], DType::F32, &device);
let sam = if args.use_tiny {
model_sam::Sam::new_tiny(vb)? // tiny vit_t
sam::Sam::new_tiny(vb)? // tiny vit_t
} else {
model_sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
sam::Sam::new(768, 12, 12, &[2, 5, 8, 11], vb)? // sam_vit_b
};
if args.generate_masks {

View File

@ -4,7 +4,7 @@ extern crate intel_mkl_src;
#[cfg(feature = "accelerate")]
extern crate accelerate_src;
use candle_examples::object_detection::{non_maximum_suppression, Bbox};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox};
mod darknet;
use anyhow::Result;

View File

@ -8,8 +8,8 @@ mod model;
use model::{Multiples, YoloV8, YoloV8Pose};
use candle::{DType, Device, IndexOp, Result, Tensor};
use candle_examples::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use candle_nn::{Module, VarBuilder};
use candle_transformers::object_detection::{non_maximum_suppression, Bbox, KeyPoint};
use clap::{Parser, ValueEnum};
use image::DynamicImage;

View File

@ -1,6 +1,5 @@
pub mod coco_classes;
pub mod imagenet;
pub mod object_detection;
use candle::{Device, Result, Tensor};

View File

@ -16,6 +16,7 @@ candle-nn = { path = "../candle-nn", version = "0.2.1" }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
rand = { workspace = true }
rayon = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tracing = { workspace = true }

View File

@ -1,4 +1,5 @@
pub mod generation;
pub mod models;
pub mod object_detection;
pub mod pipelines;
pub mod utils;

View File

@ -0,0 +1,279 @@
use candle::{IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, LayerNorm, Linear, Module, VarBuilder};
const IMG_SIZE: usize = 518;
const PATCH_SIZE: usize = 14;
const NUM_CLASSES: usize = 1000;
fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
if bias {
candle_nn::linear(in_dim, out_dim, vb)
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)
}
}
#[derive(Debug)]
struct Attention {
qkv: Linear,
proj: Linear,
num_heads: usize,
scale: f64,
}
impl Attention {
fn new(
vb: VarBuilder,
dim: usize,
num_heads: usize,
qkv_bias: bool,
proj_bias: bool,
) -> Result<Self> {
let qkv = linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = linear(vb.pp("proj"), dim, dim, proj_bias)?;
let scale = 1. / ((dim / num_heads) as f64).sqrt();
Ok(Self {
qkv,
proj,
num_heads,
scale,
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (b, n, c) = xs.dims3()?;
let qkv = self
.qkv
.forward(xs)?
.reshape((b, n, 3, self.num_heads, c / self.num_heads))?
.transpose(1, 2)? // 02134
.transpose(0, 1)? // 20134
.transpose(2, 3)?; // 20314
let q = (qkv.i(0)? * self.scale)?;
let k = qkv.i(1)?;
let v = qkv.i(2)?;
let attn = candle_nn::ops::softmax(&q.matmul(&k.t()?)?, D::Minus1)?;
let attn = attn.matmul(&v)?.transpose(1, 2)?.reshape((b, n, c))?;
self.proj.forward(&attn)
}
}
#[derive(Debug)]
struct LayerScale {
gamma: Tensor,
}
impl LayerScale {
fn new(vb: VarBuilder, dim: usize) -> Result<Self> {
let gamma = vb.get(dim, "gamma")?;
Ok(Self { gamma })
}
}
impl Module for LayerScale {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
xs.broadcast_mul(&self.gamma)
}
}
#[derive(Debug)]
struct Mlp {
fc1: Linear,
fc2: Linear,
}
impl Mlp {
fn new(vb: VarBuilder, in_features: usize, hidden_features: usize, bias: bool) -> Result<Self> {
let out_features = in_features;
let fc1 = linear(vb.pp("fc1"), in_features, hidden_features, bias)?;
let fc2 = linear(vb.pp("fc2"), hidden_features, out_features, bias)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.fc1.forward(xs)?.gelu()?;
self.fc2.forward(&xs)
}
}
#[derive(Debug)]
struct Block {
norm1: LayerNorm,
attn: Attention,
ls1: LayerScale,
norm2: LayerNorm,
mlp: Mlp,
ls2: LayerScale,
}
impl Block {
fn new(vb: VarBuilder, dim: usize, num_heads: usize) -> Result<Self> {
let norm1 = layer_norm(dim, 1e-5, vb.pp("norm1"))?;
let attn = Attention::new(vb.pp("attn"), dim, num_heads, true, true)?;
let ls1 = LayerScale::new(vb.pp("ls1"), dim)?;
let norm2 = layer_norm(dim, 1e-5, vb.pp("norm2"))?;
let mlp = Mlp::new(vb.pp("mlp"), dim, dim * 4, true)?;
let ls2 = LayerScale::new(vb.pp("ls2"), dim)?;
Ok(Self {
norm1,
attn,
ls1,
norm2,
mlp,
ls2,
})
}
}
impl Module for Block {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
let xs = self
.ls1
.forward(&self.attn.forward(&self.norm1.forward(xs)?)?)?;
let xs = (xs + residual)?;
let residual = &xs;
let xs = self
.ls2
.forward(&self.mlp.forward(&self.norm2.forward(&xs)?)?)?;
xs + residual
}
}
#[derive(Debug)]
struct PatchEmbed {
proj: candle_nn::Conv2d,
patch_size: (usize, usize),
num_patches: usize,
}
impl PatchEmbed {
fn new(
vb: VarBuilder,
img_size: usize,
patch_size: usize,
in_chans: usize,
embed_dim: usize,
) -> Result<Self> {
let config = candle_nn::Conv2dConfig {
stride: patch_size,
..Default::default()
};
let proj = candle_nn::conv2d(in_chans, embed_dim, patch_size, config, vb.pp("proj"))?;
let num_patches = (img_size / patch_size) * (img_size / patch_size);
Ok(Self {
proj,
patch_size: (patch_size, patch_size),
num_patches,
})
}
}
impl Module for PatchEmbed {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _c, h, w) = xs.dims4()?;
let (patch_h, patch_w) = self.patch_size;
if (h % patch_h) != 0 {
candle::bail!("image height {h} is not a multiple of patch height {patch_h}")
}
if (w % patch_w) != 0 {
candle::bail!("image width {w} is not a multiple of patch width {patch_w}")
}
let xs = self.proj.forward(xs)?;
let (b, c, h, w) = xs.dims4()?;
// flatten embeddings.
xs.reshape((b, c, h * w))?.transpose(1, 2)
}
}
#[derive(Debug)]
pub struct DinoVisionTransformer {
patch_embed: PatchEmbed,
cls_token: Tensor,
pos_embed: Tensor,
blocks: Vec<Block>,
norm: LayerNorm,
head: Linear,
}
impl DinoVisionTransformer {
pub fn new(vb: VarBuilder, depth: usize, embed_dim: usize, num_heads: usize) -> Result<Self> {
let patch_embed =
PatchEmbed::new(vb.pp("patch_embed"), IMG_SIZE, PATCH_SIZE, 3, embed_dim)?;
let cls_token = vb.get((1, 1, embed_dim), "cls_token")?;
let num_tokens = 1;
let pos_embed = vb.get(
(1, patch_embed.num_patches + num_tokens, embed_dim),
"pos_embed",
)?;
let head = linear(vb.pp("head"), 2 * embed_dim, NUM_CLASSES, true)?;
let norm = layer_norm(embed_dim, 1e-5, vb.pp("norm"))?;
let vb_b = vb.pp("blocks");
let blocks = (0..depth)
.map(|i| Block::new(vb_b.pp(&i.to_string()), embed_dim, num_heads))
.collect::<Result<Vec<_>>>()?;
Ok(Self {
patch_embed,
cls_token,
pos_embed,
blocks,
norm,
head,
})
}
fn interpolate_pos_encoding(&self, xs: &Tensor, w: usize, h: usize) -> Result<Tensor> {
let npatch = xs.dim(1)? - 1;
let n = self.pos_embed.dim(1)? - 1;
let sqrt_n = (n as f64).sqrt();
if npatch == n && w == h {
return Ok(xs.clone());
}
let class_pos_embed = self.pos_embed.i((.., ..1))?;
let patch_pos_embed = self.pos_embed.i((.., 1..))?;
let dim = xs.dim(D::Minus1)?;
let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
let patch_pos_embed = patch_pos_embed
.reshape((1, sqrt_n as usize, sqrt_n as usize, dim))?
.transpose(2, 3)?
.transpose(1, 2)?;
// This uses bicubic interpolation in the original implementation.
let patch_pos_embed = patch_pos_embed.upsample_nearest2d(h0 as usize, w0 as usize)?;
let el_count = patch_pos_embed.shape().elem_count();
let patch_pos_embed =
patch_pos_embed
.transpose(1, 2)?
.transpose(2, 3)?
.reshape((1, el_count / dim, dim))?;
Tensor::cat(&[&class_pos_embed, &patch_pos_embed], 1)
}
fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Result<Tensor> {
let (_b, _nc, w, h) = xs.dims4()?;
let xs = self.patch_embed.forward(xs)?;
let xs = Tensor::cat(&[&self.cls_token, &xs], 1)?;
&xs + &self.interpolate_pos_encoding(&xs, w, h)?
}
}
impl Module for DinoVisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.prepare_tokens_with_mask(xs)?;
for blk in self.blocks.iter() {
xs = blk.forward(&xs)?
}
let xs = self.norm.forward(&xs)?;
let xs_norm_clstoken = xs.i((.., 0))?;
let xs_norm_patchtokens = xs.i((.., 1..))?.mean(1)?;
let xs = Tensor::cat(&[xs_norm_clstoken, xs_norm_patchtokens], D::Minus1)?;
self.head.forward(&xs)
}
}
pub fn vit_small(vb: VarBuilder) -> Result<DinoVisionTransformer> {
DinoVisionTransformer::new(vb, 12, 384, 6)
}

View File

@ -0,0 +1,331 @@
use candle::{Result, Tensor, D};
use candle_nn as nn;
use nn::{Module, VarBuilder};
// Based on the Python version from torchvision.
// https://github.com/pytorch/vision/blob/0d75d9e5516f446c9c0ef93bd4ed9fea13992d06/torchvision/models/efficientnet.py#L47
#[derive(Debug, Clone, Copy)]
pub struct MBConvConfig {
expand_ratio: f64,
kernel: usize,
stride: usize,
input_channels: usize,
out_channels: usize,
num_layers: usize,
}
fn make_divisible(v: f64, divisor: usize) -> usize {
let min_value = divisor;
let new_v = usize::max(
min_value,
(v + divisor as f64 * 0.5) as usize / divisor * divisor,
);
if (new_v as f64) < 0.9 * v {
new_v + divisor
} else {
new_v
}
}
fn bneck_confs(width_mult: f64, depth_mult: f64) -> Vec<MBConvConfig> {
let bneck_conf = |e, k, s, i, o, n| {
let input_channels = make_divisible(i as f64 * width_mult, 8);
let out_channels = make_divisible(o as f64 * width_mult, 8);
let num_layers = (n as f64 * depth_mult).ceil() as usize;
MBConvConfig {
expand_ratio: e,
kernel: k,
stride: s,
input_channels,
out_channels,
num_layers,
}
};
vec![
bneck_conf(1., 3, 1, 32, 16, 1),
bneck_conf(6., 3, 2, 16, 24, 2),
bneck_conf(6., 5, 2, 24, 40, 2),
bneck_conf(6., 3, 2, 40, 80, 3),
bneck_conf(6., 5, 1, 80, 112, 3),
bneck_conf(6., 5, 2, 112, 192, 4),
bneck_conf(6., 3, 1, 192, 320, 1),
]
}
impl MBConvConfig {
pub fn b0() -> Vec<Self> {
bneck_confs(1.0, 1.0)
}
pub fn b1() -> Vec<Self> {
bneck_confs(1.0, 1.1)
}
pub fn b2() -> Vec<Self> {
bneck_confs(1.1, 1.2)
}
pub fn b3() -> Vec<Self> {
bneck_confs(1.2, 1.4)
}
pub fn b4() -> Vec<Self> {
bneck_confs(1.4, 1.8)
}
pub fn b5() -> Vec<Self> {
bneck_confs(1.6, 2.2)
}
pub fn b6() -> Vec<Self> {
bneck_confs(1.8, 2.6)
}
pub fn b7() -> Vec<Self> {
bneck_confs(2.0, 3.1)
}
}
/// Conv2D with same padding.
#[derive(Debug)]
struct Conv2DSame {
conv2d: nn::Conv2d,
s: usize,
k: usize,
}
impl Conv2DSame {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
bias: bool,
) -> Result<Self> {
let conv_config = nn::Conv2dConfig {
stride,
groups,
..Default::default()
};
let conv2d = if bias {
nn::conv2d(i, o, k, conv_config, vb)?
} else {
nn::conv2d_no_bias(i, o, k, conv_config, vb)?
};
Ok(Self {
conv2d,
s: stride,
k,
})
}
}
impl Module for Conv2DSame {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let s = self.s;
let k = self.k;
let (_, _, ih, iw) = xs.dims4()?;
let oh = (ih + s - 1) / s;
let ow = (iw + s - 1) / s;
let pad_h = usize::max((oh - 1) * s + k - ih, 0);
let pad_w = usize::max((ow - 1) * s + k - iw, 0);
if pad_h > 0 || pad_w > 0 {
let xs = xs.pad_with_zeros(2, pad_h / 2, pad_h - pad_h / 2)?;
let xs = xs.pad_with_zeros(3, pad_w / 2, pad_w - pad_w / 2)?;
self.conv2d.forward(&xs)
} else {
self.conv2d.forward(xs)
}
}
}
#[derive(Debug)]
struct ConvNormActivation {
conv2d: Conv2DSame,
bn2d: nn::BatchNorm,
activation: bool,
}
impl ConvNormActivation {
fn new(
vb: VarBuilder,
i: usize,
o: usize,
k: usize,
stride: usize,
groups: usize,
) -> Result<Self> {
let conv2d = Conv2DSame::new(vb.pp("0"), i, o, k, stride, groups, false)?;
let bn2d = nn::batch_norm(o, 1e-3, vb.pp("1"))?;
Ok(Self {
conv2d,
bn2d,
activation: true,
})
}
fn no_activation(self) -> Self {
Self {
activation: false,
..self
}
}
}
impl Module for ConvNormActivation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let xs = self.conv2d.forward(xs)?;
let xs = self.bn2d.forward(&xs)?;
if self.activation {
swish(&xs)
} else {
Ok(xs)
}
}
}
#[derive(Debug)]
struct SqueezeExcitation {
fc1: Conv2DSame,
fc2: Conv2DSame,
}
impl SqueezeExcitation {
fn new(vb: VarBuilder, in_channels: usize, squeeze_channels: usize) -> Result<Self> {
let fc1 = Conv2DSame::new(vb.pp("fc1"), in_channels, squeeze_channels, 1, 1, 1, true)?;
let fc2 = Conv2DSame::new(vb.pp("fc2"), squeeze_channels, in_channels, 1, 1, 1, true)?;
Ok(Self { fc1, fc2 })
}
}
impl Module for SqueezeExcitation {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let residual = xs;
// equivalent to adaptive_avg_pool2d([1, 1])
let xs = xs.mean_keepdim(D::Minus2)?.mean_keepdim(D::Minus1)?;
let xs = self.fc1.forward(&xs)?;
let xs = swish(&xs)?;
let xs = self.fc2.forward(&xs)?;
let xs = nn::ops::sigmoid(&xs)?;
residual.broadcast_mul(&xs)
}
}
#[derive(Debug)]
struct MBConv {
expand_cna: Option<ConvNormActivation>,
depthwise_cna: ConvNormActivation,
squeeze_excitation: SqueezeExcitation,
project_cna: ConvNormActivation,
config: MBConvConfig,
}
impl MBConv {
fn new(vb: VarBuilder, c: MBConvConfig) -> Result<Self> {
let vb = vb.pp("block");
let exp = make_divisible(c.input_channels as f64 * c.expand_ratio, 8);
let expand_cna = if exp != c.input_channels {
Some(ConvNormActivation::new(
vb.pp("0"),
c.input_channels,
exp,
1,
1,
1,
)?)
} else {
None
};
let start_index = if expand_cna.is_some() { 1 } else { 0 };
let depthwise_cna =
ConvNormActivation::new(vb.pp(start_index), exp, exp, c.kernel, c.stride, exp)?;
let squeeze_channels = usize::max(1, c.input_channels / 4);
let squeeze_excitation =
SqueezeExcitation::new(vb.pp(start_index + 1), exp, squeeze_channels)?;
let project_cna =
ConvNormActivation::new(vb.pp(start_index + 2), exp, c.out_channels, 1, 1, 1)?
.no_activation();
Ok(Self {
expand_cna,
depthwise_cna,
squeeze_excitation,
project_cna,
config: c,
})
}
}
impl Module for MBConv {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let use_res_connect =
self.config.stride == 1 && self.config.input_channels == self.config.out_channels;
let ys = match &self.expand_cna {
Some(expand_cna) => expand_cna.forward(xs)?,
None => xs.clone(),
};
let ys = self.depthwise_cna.forward(&ys)?;
let ys = self.squeeze_excitation.forward(&ys)?;
let ys = self.project_cna.forward(&ys)?;
if use_res_connect {
ys + xs
} else {
Ok(ys)
}
}
}
fn swish(s: &Tensor) -> Result<Tensor> {
s * nn::ops::sigmoid(s)?
}
#[derive(Debug)]
pub struct EfficientNet {
init_cna: ConvNormActivation,
blocks: Vec<MBConv>,
final_cna: ConvNormActivation,
classifier: nn::Linear,
}
impl EfficientNet {
pub fn new(p: VarBuilder, configs: Vec<MBConvConfig>, nclasses: usize) -> Result<Self> {
let f_p = p.pp("features");
let first_in_c = configs[0].input_channels;
let last_out_c = configs.last().unwrap().out_channels;
let final_out_c = 4 * last_out_c;
let init_cna = ConvNormActivation::new(f_p.pp(0), 3, first_in_c, 3, 2, 1)?;
let nconfigs = configs.len();
let mut blocks = vec![];
for (index, cnf) in configs.into_iter().enumerate() {
let f_p = f_p.pp(index + 1);
for r_index in 0..cnf.num_layers {
let cnf = if r_index == 0 {
cnf
} else {
MBConvConfig {
input_channels: cnf.out_channels,
stride: 1,
..cnf
}
};
blocks.push(MBConv::new(f_p.pp(r_index), cnf)?)
}
}
let final_cna =
ConvNormActivation::new(f_p.pp(nconfigs + 1), last_out_c, final_out_c, 1, 1, 1)?;
let classifier = nn::linear(final_out_c, nclasses, p.pp("classifier.1"))?;
Ok(Self {
init_cna,
blocks,
final_cna,
classifier,
})
}
}
impl Module for EfficientNet {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let mut xs = self.init_cna.forward(xs)?;
for block in self.blocks.iter() {
xs = block.forward(&xs)?
}
let xs = self.final_cna.forward(&xs)?;
// Equivalent to adaptive_avg_pool2d([1, 1]) -> squeeze(-1) -> squeeze(-1)
let xs = xs.mean(D::Minus1)?.mean(D::Minus1)?;
self.classifier.forward(&xs)
}
}

View File

@ -1,5 +1,9 @@
pub mod bert;
pub mod bigcode;
pub mod dinov2;
pub mod efficientnet;
pub mod falcon;
pub mod llama;
pub mod quantized_llama;
pub mod segment_anything;
pub mod whisper;

View File

@ -100,8 +100,8 @@ impl candle::CustomOp3 for Add3 {
#[derive(Debug)]
struct Attention {
qkv: crate::Linear,
proj: crate::Linear,
qkv: super::Linear,
proj: super::Linear,
num_heads: usize,
scale: f64,
rel_pos_hw: Option<(Tensor, Tensor)>,
@ -124,8 +124,8 @@ impl Attention {
let span_matmul = tracing::span!(tracing::Level::TRACE, "attn-matmul");
let span_rel_pos = tracing::span!(tracing::Level::TRACE, "attn-rel-pos");
let span_softmax = tracing::span!(tracing::Level::TRACE, "attn-sm");
let qkv = crate::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = crate::linear(vb.pp("proj"), dim, dim, true)?;
let qkv = super::linear(vb.pp("qkv"), dim, dim * 3, qkv_bias)?;
let proj = super::linear(vb.pp("proj"), dim, dim, true)?;
let head_dim = dim / num_heads;
let scale = 1. / (head_dim as f64).sqrt();
let rel_pos_hw = if use_rel_pos {
@ -251,7 +251,7 @@ struct Block {
norm1: LayerNorm,
attn: Attention,
norm2: LayerNorm,
mlp: crate::MlpBlock,
mlp: super::MlpBlock,
window_size: usize,
span: tracing::Span,
}
@ -281,7 +281,7 @@ impl Block {
input_size_attn,
vb.pp("attn"),
)?;
let mlp = crate::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
let mlp = super::MlpBlock::new(dim, dim * 4, candle_nn::Activation::Gelu, vb.pp("mlp"))?;
let span = tracing::span!(tracing::Level::TRACE, "ie-block");
Ok(Self {
norm1,
@ -375,9 +375,9 @@ pub struct ImageEncoderViT {
patch_embed: PatchEmbed,
blocks: Vec<Block>,
neck_conv1: candle_nn::Conv2d,
neck_ln1: crate::LayerNorm2d,
neck_ln1: super::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d,
neck_ln2: super::LayerNorm2d,
pos_embed: Option<Tensor>,
span: tracing::Span,
}
@ -433,13 +433,13 @@ impl ImageEncoderViT {
Default::default(),
vb.pp("neck.0"),
)?;
let neck_ln1 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
let neck_ln1 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(out_chans, out_chans, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = crate::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
let neck_ln2 = super::LayerNorm2d::new(out_chans, 1e-6, vb.pp("neck.3"))?;
let pos_embed = if use_abs_pos {
let p = vb.get(
(1, img_size / patch_size, img_size / patch_size, embed_dim),

View File

@ -1,11 +1,11 @@
use candle::{IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use crate::model_transformer::TwoWayTransformer;
use super::transformer::TwoWayTransformer;
#[derive(Debug)]
struct MlpMaskDecoder {
layers: Vec<crate::Linear>,
layers: Vec<super::Linear>,
sigmoid_output: bool,
span: tracing::Span,
}
@ -28,7 +28,7 @@ impl MlpMaskDecoder {
} else {
hidden_dim
};
let layer = crate::linear(vb.pp(i), in_dim, out_dim, true)?;
let layer = super::linear(vb.pp(i), in_dim, out_dim, true)?;
layers.push(layer)
}
let span = tracing::span!(tracing::Level::TRACE, "mlp-mask-decoder");
@ -64,7 +64,7 @@ pub struct MaskDecoder {
mask_tokens: candle_nn::Embedding,
iou_prediction_head: MlpMaskDecoder,
output_upscaling_conv1: candle_nn::ConvTranspose2d,
output_upscaling_ln: crate::LayerNorm2d,
output_upscaling_ln: super::LayerNorm2d,
output_upscaling_conv2: candle_nn::ConvTranspose2d,
num_mask_tokens: usize,
output_hypernetworks_mlps: Vec<MlpMaskDecoder>,
@ -104,7 +104,7 @@ impl MaskDecoder {
vb.pp("output_upscaling.0"),
)?;
let output_upscaling_ln =
crate::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
super::LayerNorm2d::new(transformer_dim / 4, 1e-6, vb.pp("output_upscaling.1"))?;
let output_upscaling_conv2 = candle_nn::conv_transpose2d(
transformer_dim / 4,
transformer_dim / 8,

View File

@ -0,0 +1,100 @@
use candle::{Result, Tensor};
use candle_nn::{Module, VarBuilder};
pub mod image_encoder;
pub mod mask_decoder;
pub mod prompt_encoder;
pub mod sam;
pub mod tiny_vit;
pub mod transformer;
pub fn linear(vb: VarBuilder, in_dim: usize, out_dim: usize, bias: bool) -> Result<Linear> {
let inner = if bias {
candle_nn::linear(in_dim, out_dim, vb)?
} else {
candle_nn::linear_no_bias(in_dim, out_dim, vb)?
};
let span = tracing::span!(tracing::Level::TRACE, "linear");
Ok(Linear { inner, span })
}
#[derive(Debug)]
pub struct LayerNorm2d {
weight: Tensor,
bias: Tensor,
num_channels: usize,
eps: f64,
}
impl LayerNorm2d {
pub fn new(num_channels: usize, eps: f64, vb: VarBuilder) -> Result<Self> {
let weight = vb.get(num_channels, "weight")?;
let bias = vb.get(num_channels, "bias")?;
Ok(Self {
weight,
bias,
num_channels,
eps,
})
}
}
impl Module for LayerNorm2d {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let u = xs.mean_keepdim(1)?;
let xs = xs.broadcast_sub(&u)?;
let s = xs.sqr()?.mean_keepdim(1)?;
let xs = xs.broadcast_div(&(s + self.eps)?.sqrt()?)?;
xs.broadcast_mul(&self.weight.reshape((1, self.num_channels, 1, 1))?)?
.broadcast_add(&self.bias.reshape((1, self.num_channels, 1, 1))?)
}
}
#[derive(Debug)]
pub struct MlpBlock {
lin1: Linear,
lin2: Linear,
activation: candle_nn::Activation,
span: tracing::Span,
}
impl MlpBlock {
pub fn new(
embedding_dim: usize,
mlp_dim: usize,
activation: candle_nn::Activation,
vb: VarBuilder,
) -> Result<Self> {
let lin1 = linear(vb.pp("lin1"), embedding_dim, mlp_dim, true)?;
let lin2 = linear(vb.pp("lin2"), mlp_dim, embedding_dim, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp-block");
Ok(Self {
lin1,
lin2,
activation,
span,
})
}
}
impl Module for MlpBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.lin1)?
.apply(&self.activation)?
.apply(&self.lin2)
}
}
#[derive(Debug)]
pub struct Linear {
inner: candle_nn::Linear,
span: tracing::Span,
}
impl Module for Linear {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
self.inner.forward(x)
}
}

View File

@ -56,9 +56,9 @@ pub struct PromptEncoder {
point_embeddings: Vec<candle_nn::Embedding>,
not_a_point_embed: candle_nn::Embedding,
mask_downscaling_conv1: candle_nn::Conv2d,
mask_downscaling_ln1: crate::LayerNorm2d,
mask_downscaling_ln1: super::LayerNorm2d,
mask_downscaling_conv2: candle_nn::Conv2d,
mask_downscaling_ln2: crate::LayerNorm2d,
mask_downscaling_ln2: super::LayerNorm2d,
mask_downscaling_conv3: candle_nn::Conv2d,
no_mask_embed: candle_nn::Embedding,
image_embedding_size: (usize, usize),
@ -100,9 +100,9 @@ impl PromptEncoder {
vb.pp("mask_downscaling.6"),
)?;
let mask_downscaling_ln1 =
crate::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
super::LayerNorm2d::new(mask_in_chans / 4, 1e-6, vb.pp("mask_downscaling.1"))?;
let mask_downscaling_ln2 =
crate::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
super::LayerNorm2d::new(mask_in_chans, 1e-6, vb.pp("mask_downscaling.4"))?;
let mut point_embeddings = Vec::with_capacity(num_points_embeddings);
let vb_e = vb.pp("point_embeddings");
for i in 0..num_points_embeddings {

View File

@ -1,10 +1,10 @@
use candle::{DType, IndexOp, Result, Tensor};
use candle_nn::{Module, VarBuilder};
use crate::model_image_encoder::ImageEncoderViT;
use crate::model_mask_decoder::MaskDecoder;
use crate::model_prompt_encoder::PromptEncoder;
use crate::model_tiny_vit::{tiny_vit_5m, TinyViT};
use super::image_encoder::ImageEncoderViT;
use super::mask_decoder::MaskDecoder;
use super::prompt_encoder::PromptEncoder;
use super::tiny_vit::{tiny_vit_5m, TinyViT};
const PROMPT_EMBED_DIM: usize = 256;
pub const IMAGE_SIZE: usize = 1024;
@ -186,7 +186,7 @@ impl Sam {
img: &Tensor,
cb: CropBox,
point_grids: &[(f64, f64)],
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
// Crop the image and calculate embeddings.
let img = img.i((.., cb.y0..cb.y1, cb.x0..cb.x1))?;
let img = self.preprocess(&img)?.unsqueeze(0)?;
@ -259,7 +259,7 @@ impl Sam {
let min_max_x = min_max_indexes(&low_res_mask_per_x);
let min_max_y = min_max_indexes(&low_res_mask_per_y);
if let Some(((x0, x1), (y0, y1))) = min_max_x.zip(min_max_y) {
let bbox = candle_examples::object_detection::Bbox {
let bbox = crate::object_detection::Bbox {
xmin: x0 as f32,
ymin: y0 as f32,
xmax: x1 as f32,
@ -277,7 +277,7 @@ impl Sam {
let mut bboxes = vec![bboxes];
// Remove duplicates within this crop.
candle_examples::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
crate::object_detection::non_maximum_suppression(&mut bboxes, CROP_NMS_THRESH);
// TODO: Return to the original image frame.
Ok(bboxes.remove(0))
@ -290,7 +290,7 @@ impl Sam {
crop_n_layer: usize,
crop_overlap_ratio: f64,
crop_n_points_downscale_factor: usize,
) -> Result<Vec<candle_examples::object_detection::Bbox<Tensor>>> {
) -> Result<Vec<crate::object_detection::Bbox<Tensor>>> {
let (_c, h, w) = img.dims3()?;
let point_grids = build_all_layer_point_grids(
points_per_side,

View File

@ -215,16 +215,16 @@ impl Module for ConvLayer {
#[derive(Debug)]
struct Mlp {
norm: candle_nn::LayerNorm,
fc1: crate::Linear,
fc2: crate::Linear,
fc1: super::Linear,
fc2: super::Linear,
span: tracing::Span,
}
impl Mlp {
fn new(in_: usize, hidden: usize, vb: VarBuilder) -> Result<Self> {
let norm = candle_nn::layer_norm(in_, 1e-5, vb.pp("norm"))?;
let fc1 = crate::linear(vb.pp("fc1"), in_, hidden, true)?;
let fc2 = crate::linear(vb.pp("fc2"), hidden, in_, true)?;
let fc1 = super::linear(vb.pp("fc1"), in_, hidden, true)?;
let fc2 = super::linear(vb.pp("fc2"), hidden, in_, true)?;
let span = tracing::span!(tracing::Level::TRACE, "mlp");
Ok(Self {
norm,
@ -248,8 +248,8 @@ impl Module for Mlp {
#[derive(Debug)]
struct Attention {
norm: candle_nn::LayerNorm,
qkv: crate::Linear,
proj: crate::Linear,
qkv: super::Linear,
proj: super::Linear,
ab: Tensor,
key_dim: usize,
num_heads: usize,
@ -275,8 +275,8 @@ impl Attention {
let nh_kd = key_dim * num_heads;
let h = dh + nh_kd * 2;
let norm = candle_nn::layer_norm(dim, 1e-5, vb.pp("norm"))?;
let qkv = crate::linear(vb.pp("qkv"), dim, h, true)?;
let proj = crate::linear(vb.pp("proj"), dh, dim, true)?;
let qkv = super::linear(vb.pp("qkv"), dim, h, true)?;
let proj = super::linear(vb.pp("proj"), dh, dim, true)?;
let points = (0..resolution.0)
.flat_map(|x| (0..resolution.1).map(move |y| (x as i64, y as i64)))
@ -526,9 +526,9 @@ pub struct TinyViT {
// norm_head: candle_nn::LayerNorm,
// head: candle_nn::Linear,
neck_conv1: candle_nn::Conv2d,
neck_ln1: crate::LayerNorm2d,
neck_ln1: super::LayerNorm2d,
neck_conv2: candle_nn::Conv2d,
neck_ln2: crate::LayerNorm2d,
neck_ln2: super::LayerNorm2d,
span: tracing::Span,
span_neck: tracing::Span,
}
@ -578,13 +578,13 @@ impl TinyViT {
// let head = candle_nn::linear(last_embed_dim, num_classes, vb.pp("head"))?;
let neck_conv1 =
candle_nn::conv2d_no_bias(last_embed_dim, 256, 1, Default::default(), vb.pp("neck.0"))?;
let neck_ln1 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
let neck_ln1 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.1"))?;
let cfg = candle_nn::Conv2dConfig {
padding: 1,
..Default::default()
};
let neck_conv2 = candle_nn::conv2d_no_bias(256, 256, 3, cfg, vb.pp("neck.2"))?;
let neck_ln2 = crate::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
let neck_ln2 = super::LayerNorm2d::new(256, 1e-6, vb.pp("neck.3"))?;
let span = tracing::span!(tracing::Level::TRACE, "tiny-vit");
let span_neck = tracing::span!(tracing::Level::TRACE, "neck");

View File

@ -68,7 +68,7 @@ struct TwoWayAttentionBlock {
norm1: LayerNorm,
cross_attn_token_to_image: Attention,
norm2: LayerNorm,
mlp: crate::MlpBlock,
mlp: super::MlpBlock,
norm3: LayerNorm,
norm4: LayerNorm,
cross_attn_image_to_token: Attention,
@ -100,7 +100,7 @@ impl TwoWayAttentionBlock {
2,
vb.pp("cross_attn_image_to_token"),
)?;
let mlp = crate::MlpBlock::new(
let mlp = super::MlpBlock::new(
embedding_dim,
mlp_dim,
candle_nn::Activation::Relu,