Fix clippy lints + minor cleanups. (#1957)
* Fix clippy lints + minor cleanups. * fmt. * Derive clone.
This commit is contained in:
parent
b0340d72ec
commit
cdc8b57b5c
|
@ -10,13 +10,11 @@ use self::{
|
|||
vision_model::ClipVisionTransformer,
|
||||
};
|
||||
use candle::{Result, Tensor, D};
|
||||
use candle_nn::Module;
|
||||
|
||||
use tracing::warn;
|
||||
|
||||
pub mod text_model;
|
||||
pub mod vision_model;
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipModel {
|
||||
text_model: ClipTextTransformer,
|
||||
vision_model: ClipVisionTransformer,
|
||||
|
@ -25,6 +23,7 @@ pub struct ClipModel {
|
|||
logit_scale: Tensor,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum EncoderConfig {
|
||||
Text(text_model::ClipTextConfig),
|
||||
Vision(vision_model::ClipVisionConfig),
|
||||
|
@ -67,6 +66,7 @@ impl EncoderConfig {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipConfig {
|
||||
pub text_config: text_model::ClipTextConfig,
|
||||
pub vision_config: vision_model::ClipVisionConfig,
|
||||
|
@ -111,7 +111,6 @@ impl ClipModel {
|
|||
let logit_scale = if vs.contains_tensor("logit_scale") {
|
||||
vs.get(&[], "logit_scale")?
|
||||
} else {
|
||||
warn!("Creating logit_scale tensor, results may vary.");
|
||||
Tensor::new(&[c.logit_scale_init_value], vs.device())?
|
||||
};
|
||||
|
||||
|
@ -125,38 +124,26 @@ impl ClipModel {
|
|||
}
|
||||
|
||||
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let text_outputs = self.text_model.forward(input_ids)?;
|
||||
|
||||
let text_features = self.text_projection.forward(&text_outputs)?;
|
||||
|
||||
Ok(text_features)
|
||||
input_ids
|
||||
.apply(&self.text_model)?
|
||||
.apply(&self.text_projection)
|
||||
}
|
||||
|
||||
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let image_features = self.vision_model.forward(pixel_values)?;
|
||||
|
||||
let image_features = self.visual_projection.forward(&image_features)?;
|
||||
|
||||
Ok(image_features)
|
||||
pixel_values
|
||||
.apply(&self.vision_model)?
|
||||
.apply(&self.visual_projection)
|
||||
}
|
||||
|
||||
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
|
||||
let image_features = self.get_image_features(pixel_values)?;
|
||||
|
||||
let text_features = self.get_text_features(input_ids)?;
|
||||
|
||||
let image_features_normalized = div_l2_norm(&image_features)?;
|
||||
|
||||
let text_features_normalized = div_l2_norm(&text_features)?;
|
||||
|
||||
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
|
||||
|
||||
let logit_scale = &self.logit_scale.exp()?;
|
||||
|
||||
let logit_scale = self.logit_scale.exp()?;
|
||||
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
|
||||
|
||||
let logits_per_image = logits_per_text.t()?;
|
||||
|
||||
Ok((logits_per_text, logits_per_image))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -59,7 +59,7 @@ impl ClipTextConfig {
|
|||
|
||||
// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
|
||||
// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipTextEmbeddings {
|
||||
token_embedding: candle_nn::Embedding,
|
||||
position_embedding: candle_nn::Embedding,
|
||||
|
@ -70,16 +70,13 @@ impl ClipTextEmbeddings {
|
|||
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
|
||||
let token_embedding =
|
||||
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
|
||||
|
||||
let position_embedding: nn::Embedding = candle_nn::embedding(
|
||||
c.max_position_embeddings,
|
||||
c.embed_dim,
|
||||
vs.pp("position_embedding"),
|
||||
)?;
|
||||
|
||||
let position_ids =
|
||||
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
|
||||
|
||||
Ok(ClipTextEmbeddings {
|
||||
token_embedding,
|
||||
position_embedding,
|
||||
|
@ -91,20 +88,14 @@ impl ClipTextEmbeddings {
|
|||
impl Module for ClipTextEmbeddings {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let seq_length = input_ids.dim(D::Minus1)?;
|
||||
|
||||
let inputs_embeds = &self.token_embedding.forward(input_ids)?;
|
||||
|
||||
let postion_ids = &self.position_ids.narrow(1, 0, seq_length)?;
|
||||
|
||||
let position_embedding = &self.position_embedding.forward(&postion_ids)?;
|
||||
|
||||
let inputs_embeds = inputs_embeds.broadcast_add(&position_embedding)?;
|
||||
|
||||
Ok(inputs_embeds)
|
||||
let inputs_embeds = self.token_embedding.forward(input_ids)?;
|
||||
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
|
||||
let position_embedding = self.position_embedding.forward(&position_ids)?;
|
||||
inputs_embeds.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipAttention {
|
||||
k_proj: candle_nn::Linear,
|
||||
v_proj: candle_nn::Linear,
|
||||
|
@ -166,15 +157,10 @@ impl ClipAttention {
|
|||
let src_len = key_states.dim(1)?;
|
||||
|
||||
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
|
||||
let attn_reshape =
|
||||
attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?;
|
||||
|
||||
let attn_weights = attn_reshape.broadcast_add(causal_attention_mask)?;
|
||||
|
||||
let attn_weights =
|
||||
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
|
||||
|
||||
attn_weights
|
||||
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
|
||||
.broadcast_add(causal_attention_mask)?
|
||||
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
|
||||
} else {
|
||||
attn_weights
|
||||
};
|
||||
|
@ -190,7 +176,7 @@ impl ClipAttention {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipMlp {
|
||||
fc1: candle_nn::Linear,
|
||||
fc2: candle_nn::Linear,
|
||||
|
@ -217,7 +203,7 @@ impl ClipMlp {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipEncoderLayer {
|
||||
self_attn: ClipAttention,
|
||||
layer_norm1: candle_nn::LayerNorm,
|
||||
|
@ -253,7 +239,7 @@ impl ClipEncoderLayer {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipEncoder {
|
||||
layers: Vec<ClipEncoderLayer>,
|
||||
}
|
||||
|
@ -271,7 +257,6 @@ impl ClipEncoder {
|
|||
|
||||
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
|
||||
let mut xs = xs.clone();
|
||||
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, causal_attention_mask)?;
|
||||
}
|
||||
|
@ -280,7 +265,7 @@ impl ClipEncoder {
|
|||
}
|
||||
|
||||
/// A CLIP transformer based model.
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipTextTransformer {
|
||||
embeddings: ClipTextEmbeddings,
|
||||
encoder: ClipEncoder,
|
||||
|
@ -292,7 +277,6 @@ impl ClipTextTransformer {
|
|||
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
|
||||
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
|
||||
|
||||
Ok(ClipTextTransformer {
|
||||
embeddings,
|
||||
encoder,
|
||||
|
@ -325,7 +309,6 @@ impl ClipTextTransformer {
|
|||
pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
|
||||
let (bsz, seq_len) = input_ids.dims2()?;
|
||||
let input_ids = self.embeddings.forward(input_ids)?;
|
||||
|
||||
let causal_attention_mask =
|
||||
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
|
||||
let input_ids = self
|
||||
|
@ -338,18 +321,13 @@ impl ClipTextTransformer {
|
|||
impl Module for ClipTextTransformer {
|
||||
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
|
||||
let output = self.forward_with_mask(input_ids, usize::MAX)?;
|
||||
|
||||
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
|
||||
|
||||
let mut indices: Vec<Tensor> = Vec::new();
|
||||
|
||||
let mut indices = Vec::new();
|
||||
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
|
||||
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
|
||||
indices.push(index);
|
||||
}
|
||||
|
||||
let pooled_output = Tensor::cat(&indices, 0)?;
|
||||
|
||||
Ok(pooled_output)
|
||||
Tensor::cat(&indices, 0)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,7 +10,6 @@ use candle::{IndexOp, Result, Shape, Tensor, D};
|
|||
use candle_nn as nn;
|
||||
use candle_nn::Module;
|
||||
use nn::Conv2dConfig;
|
||||
use tracing::warn;
|
||||
|
||||
use super::{
|
||||
text_model::{Activation, ClipEncoder},
|
||||
|
@ -50,7 +49,7 @@ impl ClipVisionConfig {
|
|||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
struct ClipVisionEmbeddings {
|
||||
patch_embedding: candle_nn::Conv2d,
|
||||
position_ids: Tensor,
|
||||
|
@ -64,14 +63,11 @@ impl ClipVisionEmbeddings {
|
|||
let class_embedding = if vs.contains_tensor("class_embedding") {
|
||||
vs.get(c.embed_dim, "class_embedding")?
|
||||
} else {
|
||||
warn!("class_embedding not found in the. Initializing a new one.");
|
||||
Tensor::randn(0.0 as f32, 1.0 as f32, &[c.embed_dim], vs.device())?
|
||||
Tensor::randn(0f32, 1f32, c.embed_dim, vs.device())?
|
||||
};
|
||||
|
||||
let num_patches = (c.image_size / c.patch_size).pow(2);
|
||||
|
||||
let num_positions = num_patches + 1;
|
||||
|
||||
let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
|
||||
|
||||
let conv2dconfig = Conv2dConfig {
|
||||
|
@ -80,7 +76,6 @@ impl ClipVisionEmbeddings {
|
|||
};
|
||||
let position_embedding =
|
||||
candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?;
|
||||
|
||||
let patch_embedding = candle_nn::conv2d_no_bias(
|
||||
c.num_channels,
|
||||
c.embed_dim,
|
||||
|
@ -88,7 +83,6 @@ impl ClipVisionEmbeddings {
|
|||
conv2dconfig,
|
||||
vs.pp("patch_embedding"),
|
||||
)?;
|
||||
|
||||
Ok(Self {
|
||||
patch_embedding,
|
||||
position_ids,
|
||||
|
@ -101,31 +95,21 @@ impl ClipVisionEmbeddings {
|
|||
impl Module for ClipVisionEmbeddings {
|
||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let batch_size = pixel_values.shape().dims();
|
||||
|
||||
let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
|
||||
|
||||
let patch_embeds = patch_embeds.flatten_from(2)?;
|
||||
|
||||
let patch_embeds = patch_embeds.transpose(1, 2)?;
|
||||
|
||||
let class_embedding = self.class_embedding.clone();
|
||||
|
||||
let shape = Shape::from(vec![batch_size[0], 1, class_embedding.dim(D::Minus1)?]);
|
||||
|
||||
let class_embeds = class_embedding.expand(shape)?;
|
||||
|
||||
let patch_embeds = self
|
||||
.patch_embedding
|
||||
.forward(pixel_values)?
|
||||
.flatten_from(2)?
|
||||
.transpose(1, 2)?;
|
||||
let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
|
||||
let class_embeds = self.class_embedding.expand(shape)?;
|
||||
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
|
||||
|
||||
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
|
||||
|
||||
let embeddings = embeddings.broadcast_add(&position_embedding)?;
|
||||
|
||||
Ok(embeddings)
|
||||
embeddings.broadcast_add(&position_embedding)
|
||||
}
|
||||
}
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
|
||||
#[derive(Debug)]
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct ClipVisionTransformer {
|
||||
embeddings: ClipVisionEmbeddings,
|
||||
encoder: ClipEncoder,
|
||||
|
@ -136,13 +120,9 @@ pub struct ClipVisionTransformer {
|
|||
impl ClipVisionTransformer {
|
||||
pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
|
||||
let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?;
|
||||
|
||||
let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
|
||||
|
||||
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?;
|
||||
|
||||
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?;
|
||||
|
||||
Ok(Self {
|
||||
embeddings,
|
||||
encoder,
|
||||
|
@ -154,18 +134,14 @@ impl ClipVisionTransformer {
|
|||
|
||||
impl Module for ClipVisionTransformer {
|
||||
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
|
||||
let hidden_states = self.embeddings.forward(pixel_values)?;
|
||||
|
||||
let hidden_states = self.pre_layer_norm.forward(&hidden_states)?;
|
||||
let hidden_states = pixel_values
|
||||
.apply(&self.embeddings)?
|
||||
.apply(&self.pre_layer_norm)?;
|
||||
|
||||
let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
|
||||
|
||||
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
|
||||
// pooled_output = encoder_outputs[:, 0, :]
|
||||
let pooled_output = encoder_outputs.i((.., 0, ..))?;
|
||||
|
||||
let output = self.final_layer_norm.forward(&pooled_output)?;
|
||||
|
||||
Ok(output)
|
||||
self.final_layer_norm.forward(&pooled_output)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,6 +3,7 @@ pub mod bigcode;
|
|||
pub mod blip;
|
||||
pub mod blip_text;
|
||||
pub mod chatglm;
|
||||
pub mod clip;
|
||||
pub mod convmixer;
|
||||
pub mod convnext;
|
||||
pub mod dinov2;
|
||||
|
@ -12,7 +13,6 @@ pub mod efficientvit;
|
|||
pub mod encodec;
|
||||
pub mod falcon;
|
||||
pub mod gemma;
|
||||
pub mod clip;
|
||||
pub mod jina_bert;
|
||||
pub mod llama;
|
||||
pub mod llama2_c;
|
||||
|
|
Loading…
Reference in New Issue