Add ColPali (#2524)

* add colpali

* cleanup

* fix clippy
This commit is contained in:
Akshay Ballal 2024-10-01 11:48:39 +02:00 committed by GitHub
parent 6110ad8d4f
commit 888d886dd8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 394 additions and 1 deletions

View File

@ -36,6 +36,7 @@ serde_json = { workspace = true }
symphonia = { version = "0.5.3", features = ["all"], optional = true }
tokenizers = { workspace = true, features = ["onig"] }
cpal = { version = "0.15.2", optional = true }
pdf2image = { version = "0.1.2" , optional = true}
[dev-dependencies]
anyhow = { workspace = true }
@ -117,3 +118,7 @@ required-features = ["depth_anything_v2"]
[[example]]
name = "silero-vad"
required-features = ["onnx"]
[[example]]
name = "colpali"
required-features = ["pdf2image"]

View File

@ -0,0 +1,18 @@
# Colpali
[HuggingFace Model Card](https://huggingface.co/vidore/colpali-v1.2-merged)
```
wget https://arxiv.org/pdf/1706.03762.pdf
cargo run --features cuda,pdf2image --release --example colpali -- --prompt "What is Positional Encoding" --pdf "1706.03762.pdf"
```
```
Prompt: what is position encoding?
top 3 page numbers that contain similarity to the prompt
-----------------------------------
Page: 6
Page: 11
Page: 15
-----------------------------------
```

View File

@ -0,0 +1,268 @@
use anyhow::{Error as E, Result};
use candle::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::colpali::Model;
use candle_transformers::models::{colpali, paligemma};
use clap::Parser;
use hf_hub::{api::sync::Api, Repo, RepoType};
use image::DynamicImage;
use pdf2image::{RenderOptionsBuilder, PDF};
use tokenizers::Tokenizer;
struct PageRetriever {
model: Model,
config: paligemma::Config,
pdf: PDF,
device: Device,
tokenizer: Tokenizer,
range: pdf2image::Pages,
batch_size: usize,
top_k: usize,
}
impl PageRetriever {
fn new(
model: Model,
config: paligemma::Config,
pdf: PDF,
tokenizer: Tokenizer,
device: &Device,
range: Option<pdf2image::Pages>,
batch_size: usize,
top_k: usize,
) -> Self {
let page_count = pdf.page_count();
Self {
model,
config,
pdf,
device: device.clone(),
tokenizer,
range: range.unwrap_or_else(|| pdf2image::Pages::Range(1..=page_count)),
batch_size,
top_k,
}
}
fn get_images_from_pdf(&self) -> Result<Vec<DynamicImage>> {
let pages = self
.pdf
.render(self.range.clone(), RenderOptionsBuilder::default().build()?)?;
Ok(pages)
}
fn tokenize_batch(&self, prompts: Vec<&str>) -> Result<Tensor> {
let tokens = self.tokenizer.encode_batch(prompts, true).map_err(E::msg)?;
let token_ids = tokens
.iter()
.map(|tokens| {
let tokens = tokens.get_ids().to_vec();
Tensor::new(tokens.as_slice(), &self.device)
})
.collect::<candle::Result<Vec<_>>>()?;
let input = Tensor::stack(&token_ids, 0)?;
Ok(input)
}
fn images_to_tensor(
&self,
pages: &[DynamicImage],
image_size: usize,
) -> anyhow::Result<Tensor> {
let mut images = vec![];
for page in pages.iter() {
let img = page.resize_to_fill(
image_size as u32,
image_size as u32,
image::imageops::FilterType::Triangle,
);
let img = img.to_rgb8();
let img = img.into_raw();
let img = Tensor::from_vec(img, (image_size, image_size, 3), &Device::Cpu)?
.permute((2, 0, 1))?
.to_dtype(DType::F32)?
.affine(2. / 255., -1.)?;
images.push(img);
}
let images = Tensor::stack(&images, 0)?;
Ok(images)
}
fn retrieve(&mut self, prompt: &str) -> Result<Vec<usize>> {
let dtype = if self.device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let dummy_prompt: &str = "Describe the image";
let input = self.tokenize_batch(vec![prompt])?;
let dummy_input = self.tokenize_batch(vec![dummy_prompt])?;
let pages = self.get_images_from_pdf()?;
let mut all_scores = Vec::new();
for batch in pages.chunks(self.batch_size) {
let page_images = self
.images_to_tensor(batch, self.config.vision_config.image_size)?
.to_device(&self.device)?
.to_dtype(dtype)?;
let dummy_input = dummy_input.repeat((page_images.dims()[0], 0))?;
let image_embeddings = self.model.forward_images(&page_images, &dummy_input)?;
let text_embeddings = self.model.forward_text(&input)?;
let scores = text_embeddings
.unsqueeze(1)?
.broadcast_matmul(&image_embeddings.unsqueeze(0)?.transpose(3, 2)?)?
.max(3)?
.sum(2)?;
let batch_scores: Vec<f32> = scores
.to_dtype(DType::F32)?
.to_vec2()?
.into_iter()
.flatten()
.collect();
all_scores.extend(batch_scores);
}
let mut indices: Vec<usize> = (0..all_scores.len()).collect();
indices.sort_by(|a, b| all_scores[*b].partial_cmp(&all_scores[*a]).unwrap());
let top_k_indices = indices[0..self.top_k].to_vec();
Ok(top_k_indices)
}
}
#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,
/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,
#[arg(long)]
prompt: String,
/// number of top pages to show.
#[arg(long, default_value_t = 3)]
top_k: usize,
#[arg(long)]
model_id: Option<String>,
#[arg(long, default_value = "main")]
revision: String,
#[arg(long)]
tokenizer_file: Option<String>,
#[arg(long)]
weight_files: Option<String>,
#[arg(long)]
pdf: String,
#[arg(long)]
start: Option<u32>,
#[arg(long)]
end: Option<u32>,
}
fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;
let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};
println!(
"avx: {}, neon: {}, simd128: {}, f16c: {}",
candle::utils::with_avx(),
candle::utils::with_neon(),
candle::utils::with_simd128(),
candle::utils::with_f16c()
);
let api = Api::new()?;
let model_id = match &args.model_id {
Some(model_id) => model_id.to_string(),
None => "vidore/colpali-v1.2-merged".to_string(),
};
let repo = api.repo(Repo::with_revision(
model_id,
RepoType::Model,
args.revision,
));
let tokenizer_filename = match args.tokenizer_file {
Some(file) => std::path::PathBuf::from(file),
None => api
.repo(Repo::with_revision(
"vidore/colpali".to_string(),
RepoType::Model,
"main".to_string(),
))
.get("tokenizer.json")?,
};
let filenames = match args.weight_files {
Some(files) => files
.split(',')
.map(std::path::PathBuf::from)
.collect::<Vec<_>>(),
None => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
};
let start = std::time::Instant::now();
let config: paligemma::Config = paligemma::Config::paligemma_3b_448();
println!("retrieved the files in {:?}", start.elapsed());
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let device = candle_examples::device(false)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = colpali::Model::new(&config, vb)?;
let pdf = PDF::from_file(args.pdf)?;
// check if start and end given in arg
let range = if let (Some(start), Some(end)) = (args.start, args.end) {
pdf2image::Pages::Range(start..=end)
} else {
pdf2image::Pages::Range(1..=pdf.page_count()) // can use pdf2image::Pages::All but there is a bug in the library which causes the first page to rendered twice.
};
let mut retriever =
PageRetriever::new(model, config, pdf, tokenizer, &device, Some(range), 4, 3);
let top_k_indices = retriever.retrieve(&args.prompt)?;
println!("Prompt: {}", args.prompt);
println!(
"top {} page numbers that contain similarity to the prompt",
retriever.top_k
);
println!("-----------------------------------");
for index in top_k_indices {
println!("Page: {:?}", index + 1);
}
println!("-----------------------------------");
Ok(())
}

View File

@ -0,0 +1,42 @@
use candle::{Module, Result, Tensor};
use candle_nn::VarBuilder;
use super::paligemma;
use candle_nn::{linear, Linear};
pub struct Model {
pub model: paligemma::Model,
pub custom_text_projection: Linear,
}
impl Model {
pub fn new(config: &paligemma::Config, vb: VarBuilder) -> Result<Self> {
let model = paligemma::Model::new(config, vb.pp("model"))?;
let custom_text_projection = linear(
config.text_config.hidden_size,
128,
vb.pp("custom_text_proj"),
)?;
Ok(Self {
model,
custom_text_projection,
})
}
pub fn forward_images(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self
.model
.setup_without_projection(pixel_values, input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
pub fn forward_text(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let outputs = self.model.forward_without_projection(input_ids)?;
let outputs = self.custom_text_projection.forward(&outputs)?;
let outputs = outputs.broadcast_div(&outputs.sqr()?.sum_keepdim(2)?.sqrt()?)?;
Ok(outputs)
}
}

View File

@ -403,7 +403,6 @@ impl Model {
.apply(&self.norm)?
.apply(&self.lm_head)
}
pub fn forward_embeds(
&mut self,
xs: &Tensor,
@ -420,6 +419,21 @@ impl Model {
.apply(&self.lm_head)
}
// Forward the model and return the hidden states without the lm_head
pub fn forward_embeds_without_projection(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_, _, _) = xs.dims3()?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
Ok(xs)
}
pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()

View File

@ -7,6 +7,7 @@ pub mod blip_text;
pub mod chatglm;
pub mod clip;
pub mod codegeex4_9b;
pub mod colpali;
pub mod convmixer;
pub mod convnext;
pub mod dac;

View File

@ -33,6 +33,29 @@ impl Config {
projection_dim: 2048,
}
}
pub fn paligemma_3b_448() -> Self {
Self {
vision_config: siglip::VisionConfig::paligemma_3b_448(),
text_config: gemma::Config {
hidden_size: 2048,
intermediate_size: 16384,
num_attention_heads: 8,
num_hidden_layers: 18,
num_key_value_heads: 1,
// Default values.
rope_theta: 10000.,
head_dim: 256,
hidden_act: Some(candle_nn::Activation::GeluPytorchTanh),
hidden_activation: None,
attention_bias: false,
max_position_embeddings: 8192,
rms_norm_eps: 1e-6,
vocab_size: 257216,
},
projection_dim: 2048,
}
}
}
#[derive(Clone, Debug)]
@ -102,6 +125,28 @@ impl Model {
self.language_model.forward(input_ids, pos)
}
pub fn forward_without_projection(&mut self, input_ids: &Tensor) -> Result<Tensor> {
self.clear_kv_cache();
let input_embeds = self.language_model.embed_tokens().forward(input_ids)?;
self.language_model
.forward_embeds_without_projection(&input_embeds, None, 0)
}
pub fn setup_without_projection(
&mut self,
pixel_values: &Tensor,
input_ids: &Tensor,
) -> Result<Tensor> {
self.clear_kv_cache();
let image_features = self
.vision_tower
.forward(pixel_values)?
.apply(&self.multi_modal_projector)?;
let image_features = crate::models::clip::div_l2_norm(&image_features)?;
let text_features = self.language_model.embed_tokens().forward(input_ids)?;
let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
self.language_model
.forward_embeds_without_projection(&input_embeds, None, 0)
}
pub fn clear_kv_cache(&mut self) {
self.pos = 0;
self.language_model.clear_kv_cache()