Fix clippy lints + minor cleanups. (#1957)

* Fix clippy lints + minor cleanups.

* fmt.

* Derive clone.
This commit is contained in:
Laurent Mazare 2024-03-28 14:17:46 +01:00 committed by GitHub
parent b0340d72ec
commit cdc8b57b5c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 100 deletions

View File

@ -10,13 +10,11 @@ use self::{
vision_model::ClipVisionTransformer,
};
use candle::{Result, Tensor, D};
use candle_nn::Module;
use tracing::warn;
pub mod text_model;
pub mod vision_model;
#[derive(Clone, Debug)]
pub struct ClipModel {
text_model: ClipTextTransformer,
vision_model: ClipVisionTransformer,
@ -25,6 +23,7 @@ pub struct ClipModel {
logit_scale: Tensor,
}
#[derive(Clone, Debug)]
pub enum EncoderConfig {
Text(text_model::ClipTextConfig),
Vision(vision_model::ClipVisionConfig),
@ -67,6 +66,7 @@ impl EncoderConfig {
}
}
#[derive(Clone, Debug)]
pub struct ClipConfig {
pub text_config: text_model::ClipTextConfig,
pub vision_config: vision_model::ClipVisionConfig,
@ -111,7 +111,6 @@ impl ClipModel {
let logit_scale = if vs.contains_tensor("logit_scale") {
vs.get(&[], "logit_scale")?
} else {
warn!("Creating logit_scale tensor, results may vary.");
Tensor::new(&[c.logit_scale_init_value], vs.device())?
};
@ -125,38 +124,26 @@ impl ClipModel {
}
pub fn get_text_features(&self, input_ids: &Tensor) -> Result<Tensor> {
let text_outputs = self.text_model.forward(input_ids)?;
let text_features = self.text_projection.forward(&text_outputs)?;
Ok(text_features)
input_ids
.apply(&self.text_model)?
.apply(&self.text_projection)
}
pub fn get_image_features(&self, pixel_values: &Tensor) -> Result<Tensor> {
let image_features = self.vision_model.forward(pixel_values)?;
let image_features = self.visual_projection.forward(&image_features)?;
Ok(image_features)
pixel_values
.apply(&self.vision_model)?
.apply(&self.visual_projection)
}
pub fn forward(&self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<(Tensor, Tensor)> {
let image_features = self.get_image_features(pixel_values)?;
let text_features = self.get_text_features(input_ids)?;
let image_features_normalized = div_l2_norm(&image_features)?;
let text_features_normalized = div_l2_norm(&text_features)?;
let logits_per_text = text_features_normalized.matmul(&image_features_normalized.t()?)?;
let logit_scale = &self.logit_scale.exp()?;
let logit_scale = self.logit_scale.exp()?;
let logits_per_text = logits_per_text.broadcast_mul(&logit_scale)?;
let logits_per_image = logits_per_text.t()?;
Ok((logits_per_text, logits_per_image))
}
}

View File

@ -59,7 +59,7 @@ impl ClipTextConfig {
// ClipTextEmbeddings mostly based on the existing implementation in the stable diffision model.
// TODO rewrite to be more similar to https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L142
#[derive(Debug)]
#[derive(Clone, Debug)]
struct ClipTextEmbeddings {
token_embedding: candle_nn::Embedding,
position_embedding: candle_nn::Embedding,
@ -70,16 +70,13 @@ impl ClipTextEmbeddings {
fn new(vs: candle_nn::VarBuilder, c: &ClipTextConfig) -> Result<Self> {
let token_embedding =
candle_nn::embedding(c.vocab_size, c.embed_dim, vs.pp("token_embedding"))?;
let position_embedding: nn::Embedding = candle_nn::embedding(
c.max_position_embeddings,
c.embed_dim,
vs.pp("position_embedding"),
)?;
let position_ids =
Tensor::arange(0u32, c.max_position_embeddings as u32, vs.device())?.unsqueeze(0)?;
Ok(ClipTextEmbeddings {
token_embedding,
position_embedding,
@ -91,20 +88,14 @@ impl ClipTextEmbeddings {
impl Module for ClipTextEmbeddings {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let seq_length = input_ids.dim(D::Minus1)?;
let inputs_embeds = &self.token_embedding.forward(input_ids)?;
let postion_ids = &self.position_ids.narrow(1, 0, seq_length)?;
let position_embedding = &self.position_embedding.forward(&postion_ids)?;
let inputs_embeds = inputs_embeds.broadcast_add(&position_embedding)?;
Ok(inputs_embeds)
let inputs_embeds = self.token_embedding.forward(input_ids)?;
let position_ids = self.position_ids.narrow(1, 0, seq_length)?;
let position_embedding = self.position_embedding.forward(&position_ids)?;
inputs_embeds.broadcast_add(&position_embedding)
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct ClipAttention {
k_proj: candle_nn::Linear,
v_proj: candle_nn::Linear,
@ -166,15 +157,10 @@ impl ClipAttention {
let src_len = key_states.dim(1)?;
let attn_weights = if let Some(causal_attention_mask) = causal_attention_mask {
let attn_reshape =
attn_weights.reshape((bsz, self.num_attention_heads, seq_len, src_len))?;
let attn_weights = attn_reshape.broadcast_add(causal_attention_mask)?;
let attn_weights =
attn_weights.reshape((bsz * self.num_attention_heads, seq_len, src_len))?;
attn_weights
.reshape((bsz, self.num_attention_heads, seq_len, src_len))?
.broadcast_add(causal_attention_mask)?
.reshape((bsz * self.num_attention_heads, seq_len, src_len))?
} else {
attn_weights
};
@ -190,7 +176,7 @@ impl ClipAttention {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct ClipMlp {
fc1: candle_nn::Linear,
fc2: candle_nn::Linear,
@ -217,7 +203,7 @@ impl ClipMlp {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
struct ClipEncoderLayer {
self_attn: ClipAttention,
layer_norm1: candle_nn::LayerNorm,
@ -253,7 +239,7 @@ impl ClipEncoderLayer {
}
}
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct ClipEncoder {
layers: Vec<ClipEncoderLayer>,
}
@ -271,7 +257,6 @@ impl ClipEncoder {
pub fn forward(&self, xs: &Tensor, causal_attention_mask: Option<&Tensor>) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
}
@ -280,7 +265,7 @@ impl ClipEncoder {
}
/// A CLIP transformer based model.
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct ClipTextTransformer {
embeddings: ClipTextEmbeddings,
encoder: ClipEncoder,
@ -292,7 +277,6 @@ impl ClipTextTransformer {
let embeddings = ClipTextEmbeddings::new(vs.pp("embeddings"), c)?;
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Text(c.clone()))?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("final_layer_norm"))?;
Ok(ClipTextTransformer {
embeddings,
encoder,
@ -325,7 +309,6 @@ impl ClipTextTransformer {
pub fn forward_with_mask(&self, input_ids: &Tensor, mask_after: usize) -> Result<Tensor> {
let (bsz, seq_len) = input_ids.dims2()?;
let input_ids = self.embeddings.forward(input_ids)?;
let causal_attention_mask =
Self::build_causal_attention_mask(bsz, seq_len, mask_after, input_ids.device())?;
let input_ids = self
@ -338,18 +321,13 @@ impl ClipTextTransformer {
impl Module for ClipTextTransformer {
fn forward(&self, input_ids: &Tensor) -> Result<Tensor> {
let output = self.forward_with_mask(input_ids, usize::MAX)?;
let sequence_max_indices = input_ids.argmax(D::Minus1)?.to_dtype(DType::I64)?;
let mut indices: Vec<Tensor> = Vec::new();
let mut indices = Vec::new();
for (batch_idx, &seq_idx) in sequence_max_indices.to_vec1::<i64>()?.iter().enumerate() {
let index = output.i((batch_idx, seq_idx as usize))?.unsqueeze(0)?;
indices.push(index);
}
let pooled_output = Tensor::cat(&indices, 0)?;
Ok(pooled_output)
Tensor::cat(&indices, 0)
}
}

View File

@ -10,7 +10,6 @@ use candle::{IndexOp, Result, Shape, Tensor, D};
use candle_nn as nn;
use candle_nn::Module;
use nn::Conv2dConfig;
use tracing::warn;
use super::{
text_model::{Activation, ClipEncoder},
@ -50,7 +49,7 @@ impl ClipVisionConfig {
}
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
#[derive(Debug)]
#[derive(Clone, Debug)]
struct ClipVisionEmbeddings {
patch_embedding: candle_nn::Conv2d,
position_ids: Tensor,
@ -64,14 +63,11 @@ impl ClipVisionEmbeddings {
let class_embedding = if vs.contains_tensor("class_embedding") {
vs.get(c.embed_dim, "class_embedding")?
} else {
warn!("class_embedding not found in the. Initializing a new one.");
Tensor::randn(0.0 as f32, 1.0 as f32, &[c.embed_dim], vs.device())?
Tensor::randn(0f32, 1f32, c.embed_dim, vs.device())?
};
let num_patches = (c.image_size / c.patch_size).pow(2);
let num_positions = num_patches + 1;
let position_ids = Tensor::arange(0, num_positions as i64, vs.device())?;
let conv2dconfig = Conv2dConfig {
@ -80,7 +76,6 @@ impl ClipVisionEmbeddings {
};
let position_embedding =
candle_nn::embedding(num_positions, c.embed_dim, vs.pp("position_embedding"))?;
let patch_embedding = candle_nn::conv2d_no_bias(
c.num_channels,
c.embed_dim,
@ -88,7 +83,6 @@ impl ClipVisionEmbeddings {
conv2dconfig,
vs.pp("patch_embedding"),
)?;
Ok(Self {
patch_embedding,
position_ids,
@ -101,31 +95,21 @@ impl ClipVisionEmbeddings {
impl Module for ClipVisionEmbeddings {
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
let batch_size = pixel_values.shape().dims();
let patch_embeds = self.patch_embedding.forward(&pixel_values)?;
let patch_embeds = patch_embeds.flatten_from(2)?;
let patch_embeds = patch_embeds.transpose(1, 2)?;
let class_embedding = self.class_embedding.clone();
let shape = Shape::from(vec![batch_size[0], 1, class_embedding.dim(D::Minus1)?]);
let class_embeds = class_embedding.expand(shape)?;
let patch_embeds = self
.patch_embedding
.forward(pixel_values)?
.flatten_from(2)?
.transpose(1, 2)?;
let shape = Shape::from((batch_size[0], 1, self.class_embedding.dim(D::Minus1)?));
let class_embeds = self.class_embedding.expand(shape)?;
let embeddings = Tensor::cat(&[class_embeds, patch_embeds], 1)?;
let position_embedding = self.position_embedding.forward(&self.position_ids)?;
let embeddings = embeddings.broadcast_add(&position_embedding)?;
Ok(embeddings)
embeddings.broadcast_add(&position_embedding)
}
}
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L743
#[derive(Debug)]
#[derive(Clone, Debug)]
pub struct ClipVisionTransformer {
embeddings: ClipVisionEmbeddings,
encoder: ClipEncoder,
@ -136,13 +120,9 @@ pub struct ClipVisionTransformer {
impl ClipVisionTransformer {
pub fn new(vs: candle_nn::VarBuilder, c: &ClipVisionConfig) -> Result<Self> {
let embeddings = ClipVisionEmbeddings::new(vs.pp("embeddings"), c)?;
let pre_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("pre_layrnorm"))?;
let encoder = ClipEncoder::new(vs.pp("encoder"), &EncoderConfig::Vision(c.clone()))?;
let final_layer_norm = candle_nn::layer_norm(c.embed_dim, 1e-5, vs.pp("post_layernorm"))?;
Ok(Self {
embeddings,
encoder,
@ -154,18 +134,14 @@ impl ClipVisionTransformer {
impl Module for ClipVisionTransformer {
fn forward(&self, pixel_values: &Tensor) -> Result<Tensor> {
let hidden_states = self.embeddings.forward(pixel_values)?;
let hidden_states = self.pre_layer_norm.forward(&hidden_states)?;
let hidden_states = pixel_values
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let encoder_outputs = self.encoder.forward(&hidden_states, None)?;
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L787
// pooled_output = encoder_outputs[:, 0, :]
let pooled_output = encoder_outputs.i((.., 0, ..))?;
let output = self.final_layer_norm.forward(&pooled_output)?;
Ok(output)
self.final_layer_norm.forward(&pooled_output)
}
}

View File

@ -3,6 +3,7 @@ pub mod bigcode;
pub mod blip;
pub mod blip_text;
pub mod chatglm;
pub mod clip;
pub mod convmixer;
pub mod convnext;
pub mod dinov2;
@ -12,7 +13,6 @@ pub mod efficientvit;
pub mod encodec;
pub mod falcon;
pub mod gemma;
pub mod clip;
pub mod jina_bert;
pub mod llama;
pub mod llama2_c;