Add LLaVA support (#2234)

* first commit

* llava

* clippy and fmt

* some fixes

* minor fixes

* remove useless file

* refactor: Remove llava/constants.rs and update llava/mod.rs

* modify variable name

* modify code after clippy

* Minor tweaks.

---------

Co-authored-by: laurent <laurent.mazare@gmail.com>
This commit is contained in:
chenwanqq 2024-06-03 17:54:09 +08:00 committed by GitHub
parent 03344d3c19
commit cd4d941ed1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 1567 additions and 0 deletions

View File

@ -0,0 +1,4 @@
pub const DEFAULT_IMAGE_TOKEN: &str = "<image>";
pub const DEFAULT_IM_START_TOKEN: &str = "<im_start>";
pub const DEFAULT_IM_END_TOKEN: &str = "<im_end>";
pub const IMAGE_PLACEHOLDER: &str = "<image-placeholder>";

View File

@ -0,0 +1,114 @@
pub enum SeparatorStyle {
Two,
Mpt,
}
pub struct Conversation {
pub system: String,
pub roles: Vec<String>,
pub messages: Vec<(String, Option<String>)>,
pub offset: i32,
pub sep_style: SeparatorStyle,
pub sep: String,
pub sep2: Option<String>,
pub version: String,
}
impl Conversation {
pub fn new(
system: &str,
roles: &[String],
offset: i32,
sep_style: SeparatorStyle,
sep: &str,
sep2: Option<&str>,
version: &str,
) -> Self {
Conversation {
system: system.to_string(),
roles: roles.to_vec(),
messages: Vec::new(),
offset,
sep_style,
sep: sep.to_string(),
sep2: sep2.map(|s| s.to_string()),
version: version.to_string(),
}
}
pub fn conv_chatml_direct() -> Self {
Conversation::new(
"<|im_start|>system\nAnswer the questions.",
&[
"<|im_start|>user\n".to_string(),
"<|im_start|>assistant\n".to_string(),
],
0,
SeparatorStyle::Mpt,
"<|im_end|>",
None,
"mpt",
)
}
pub fn conv_llava_v1() -> Self {
Conversation::new(
"A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.",
&[
"USER".to_string(),
"ASSISTANT".to_string(),
],
0,
SeparatorStyle::Two,
" ",
Some("</s>"),
"v1"
)
}
pub fn append_message(&mut self, role: String, message: Option<&str>) {
self.messages.push((role, message.map(|s| s.to_string())))
}
pub fn append_user_message(&mut self, message: Option<&str>) {
self.append_message(self.roles[0].clone(), message);
}
pub fn append_assistant_message(&mut self, message: Option<&str>) {
self.append_message(self.roles[1].clone(), message);
}
pub fn get_prompt(&self) -> String {
match self.sep_style {
SeparatorStyle::Mpt => {
let mut ret = String::new();
ret.push_str(&self.system);
ret.push_str(&self.sep);
for (role, message) in &self.messages {
ret.push_str(role);
if let Some(message) = message {
ret.push_str(message);
};
ret.push_str(&self.sep);
}
ret
}
SeparatorStyle::Two => {
let seps = [self.sep.clone(), self.sep2.clone().unwrap()];
let mut ret = String::new();
ret.push_str(&self.system);
ret.push_str(&seps[0]);
for (i, (role, message)) in self.messages.iter().enumerate() {
ret.push_str(role);
if let Some(message) = message {
ret.push_str(": "); // strictly follow the python implementation, otherwise it will cause some minor difference between tokens ^_^
ret.push_str(message);
ret.push_str(&seps[i % 2]);
} else {
ret.push(':')
}
}
ret
}
}
}
}

View File

@ -0,0 +1,317 @@
use std::cmp::min;
use candle::{bail, DType, Device, Result, Tensor};
use candle_transformers::models::llava::{
config::{HFPreProcessorConfig, LLaVAConfig},
utils::select_best_resolution,
};
use hf_hub::api::sync::Api;
use image::{imageops::overlay, DynamicImage, GenericImageView, Rgb, RgbImage};
use serde::{Deserialize, Serialize};
//This struct is mainly for LLaVA aplications, hence it's not completely compatible with python transformer CLIPImageProcessor few several preprocess that LLaVA used, including "openai/clip-vit-large-patch14-336" and "openai/clip-vit-large-patch14".
#[derive(Serialize, Deserialize, Debug)]
pub struct ImageProcessor {
#[serde(default = "default_size")]
pub size: u32, // this is not the same as python transformer
#[serde(default = "default_do_resize")]
pub do_resize: bool,
//resample: u32 // 3 for PIL bicubic, equivalent to rust CatmullRom. Hence below we use CatmullRom
#[serde(default = "default_do_center_crop")]
pub do_center_crop: bool,
#[serde(default = "default_crop_size")]
pub crop_size: u32, // this is not the same as python transformer
#[serde(default = "default_do_rescale")]
pub do_rescale: bool,
#[serde(default = "default_rescale_factor")]
pub rescale_factor: f32,
#[serde(default = "default_do_normalize")]
pub do_normalize: bool,
#[serde(default = "default_image_mean")]
pub image_mean: Vec<f32>,
#[serde(default = "default_image_std")]
pub image_std: Vec<f32>,
}
fn default_size() -> u32 {
224
}
fn default_do_resize() -> bool {
true
}
fn default_do_center_crop() -> bool {
true
}
fn default_crop_size() -> u32 {
224
}
fn default_do_rescale() -> bool {
true
}
fn default_rescale_factor() -> f32 {
1.0 / 255.0
}
fn default_do_normalize() -> bool {
true
}
fn default_image_mean() -> Vec<f32> {
vec![0.48145466, 0.4578275, 0.40821073]
}
fn default_image_std() -> Vec<f32> {
vec![0.26862954, 0.2613026, 0.2757771]
}
impl ImageProcessor {
pub fn from_pretrained(clip_id: &str) -> Result<Self> {
let api = Api::new().map_err(|e| candle::Error::Msg(e.to_string()))?;
let api = api.model(clip_id.to_string());
let config_filename = api
.get("preprocessor_config.json")
.map_err(|e| candle::Error::Msg(e.to_string()))?;
let image_processor =
serde_json::from_slice(&std::fs::read(config_filename).map_err(candle::Error::Io)?)
.map_err(|e| candle::Error::Msg(e.to_string()))?;
Ok(image_processor)
}
pub fn from_hf_preprocessor_config(hf_preprocessor_config: &HFPreProcessorConfig) -> Self {
Self {
size: hf_preprocessor_config.size["shortest_edge"] as u32,
do_resize: hf_preprocessor_config.do_resize,
do_center_crop: hf_preprocessor_config.do_center_crop,
crop_size: hf_preprocessor_config.crop_size["height"] as u32,
do_rescale: hf_preprocessor_config.do_rescale,
rescale_factor: hf_preprocessor_config.rescale_factor,
do_normalize: hf_preprocessor_config.do_normalize,
image_mean: hf_preprocessor_config.image_mean.clone(),
image_std: hf_preprocessor_config.image_std.clone(),
}
}
///shortest edge to self.resize, other edge is resized to maintain aspect ratio
pub fn resize(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let size = self.size;
if width == size && height == size {
image.clone()
} else {
let (new_width, new_height) = if width < height {
(
size,
(((size * height) as f32) / width as f32).ceil() as u32,
)
} else {
(
(((size * width) as f32) / height as f32).ceil() as u32,
size,
)
};
image.resize(
new_width,
new_height,
image::imageops::FilterType::CatmullRom,
)
}
}
pub fn center_crop(&self, image: &DynamicImage) -> DynamicImage {
let (width, height) = image.dimensions();
let crop_size = self.crop_size;
let (left, top) = calculate_middle((width, height), (crop_size, crop_size));
image.crop_imm(left, top, crop_size, crop_size)
}
pub fn to_tensor(&self, image: &DynamicImage) -> Result<Tensor> {
let img = image.to_rgb8().into_raw();
let (width, height) = image.dimensions();
Tensor::from_vec(img, (height as usize, width as usize, 3), &Device::Cpu)?
.to_dtype(DType::F32) // only for internal compute
}
pub fn rescale(&self, tensor: &Tensor) -> Result<Tensor> {
let rescale_factor = self.rescale_factor as f64;
tensor.affine(rescale_factor, 0.0)
}
pub fn normalize(&self, tensor: &Tensor) -> Result<Tensor> {
let image_mean = self.image_mean.clone();
let image_std = self.image_std.clone();
let mean = Tensor::from_vec(image_mean, (3,), &Device::Cpu)?;
let std = Tensor::from_vec(image_std, (3,), &Device::Cpu)?;
tensor.broadcast_sub(&mean)?.broadcast_div(&std)
}
pub fn to_channel_dimension_format(&self, tensor: &Tensor) -> Result<Tensor> {
tensor.permute((2, 0, 1))
}
pub fn preprocess(&self, image: &DynamicImage) -> Result<Tensor> {
let image = if self.do_resize {
self.resize(image)
} else {
image.clone()
};
let image = if self.do_center_crop {
self.center_crop(&image)
} else {
image
};
let tensor = self.to_tensor(&image)?;
let tensor = if self.do_rescale {
self.rescale(&tensor)?
} else {
tensor
};
let tensor = if self.do_normalize {
self.normalize(&tensor)?
} else {
tensor
};
self.to_channel_dimension_format(&tensor)
}
}
pub fn calculate_middle(image_size: (u32, u32), center_size: (u32, u32)) -> (u32, u32) {
let (width, height) = image_size;
let (center_width, center_height) = center_size;
let left = if width <= center_width {
0
} else {
((width as f32 - center_width as f32) / 2.0).ceil() as u32
};
let top = if height <= center_height {
0
} else {
((height as f32 - center_height as f32) / 2.0).ceil() as u32
};
(left, top)
}
pub fn process_image(
image: &DynamicImage,
processor: &ImageProcessor,
llava_config: &LLaVAConfig,
) -> candle::Result<Tensor> {
if llava_config.image_aspect_ratio == *"square" {
processor.preprocess(image)?.unsqueeze(0)
} else if llava_config.image_aspect_ratio == *"anyres" {
process_anyres_image(image, processor, &llava_config.image_grid_pinpoints)
} else if llava_config.image_aspect_ratio == *"pad" {
process_pad_image(image, processor)
} else {
bail!("Invalid image aspect ratio")
}
}
fn process_pad_image(image: &DynamicImage, processor: &ImageProcessor) -> Result<Tensor> {
let mean_color = processor
.image_mean
.iter()
.map(|x| ((*x) * 255.0) as u8)
.collect::<Vec<u8>>();
let mean_color = Rgb::from([mean_color[0], mean_color[1], mean_color[2]]);
let image_padded = expand2square(image, mean_color);
processor.preprocess(&image_padded)
}
fn process_anyres_image(
image: &DynamicImage,
processor: &ImageProcessor,
grid_pinpoints: &[(u32, u32)],
) -> Result<Tensor> {
let original_size = image.dimensions();
let best_resolution = select_best_resolution(original_size, grid_pinpoints);
let image_padded = resize_and_pad_image(image, best_resolution);
let image_original_resize = image.resize_exact(
processor.size,
processor.size,
image::imageops::FilterType::CatmullRom,
);
let mut patches = vec![image_original_resize];
for patch in divide_to_patches(&image_padded, processor.crop_size) {
patches.push(patch);
}
let tensors = patches
.iter()
.map(|patch| processor.preprocess(patch))
.collect::<Result<Vec<Tensor>>>()?;
Tensor::stack(&tensors, 0)
}
fn expand2square(image: &DynamicImage, background_color: Rgb<u8>) -> DynamicImage {
let (width, height) = image.dimensions();
match width.cmp(&height) {
std::cmp::Ordering::Less => {
let mut new_image =
DynamicImage::from(RgbImage::from_pixel(height, height, background_color));
overlay(&mut new_image, image, ((height - width) / 2) as i64, 0);
new_image
}
std::cmp::Ordering::Equal => image.clone(),
std::cmp::Ordering::Greater => {
let mut new_image =
DynamicImage::from(RgbImage::from_pixel(width, width, background_color));
overlay(&mut new_image, image, 0, ((width - height) / 2) as i64);
new_image
}
}
}
fn resize_and_pad_image(image: &DynamicImage, target_resolution: (u32, u32)) -> DynamicImage {
let (original_width, original_height) = image.dimensions();
let original_width_f = original_width as f32;
let original_height_f = original_height as f32;
let (target_width, target_height) = target_resolution;
let target_width_f = target_width as f32;
let target_height_f = target_height as f32;
let scale_w = target_width_f / original_width_f;
let scale_h = target_height_f / original_height_f;
let (new_width, new_height) = if scale_w < scale_h {
(
target_width,
min((original_height_f * scale_w).ceil() as u32, target_height),
)
} else {
(
min((original_width_f * scale_h).ceil() as u32, target_width),
target_height,
)
};
let resized_image = image.resize_exact(
new_width,
new_height,
image::imageops::FilterType::CatmullRom,
);
let mut new_image = DynamicImage::new_rgb8(target_width, target_height);
let (paste_x, paste_y) =
calculate_middle((target_width, target_height), (new_width, new_height));
overlay(
&mut new_image,
&resized_image,
paste_x.into(),
paste_y.into(),
);
new_image
}
fn divide_to_patches(image: &DynamicImage, patch_size: u32) -> Vec<DynamicImage> {
let (width, height) = image.dimensions();
let mut patches = Vec::new();
for y in (0..height).step_by(patch_size as usize) {
for x in (0..width).step_by(patch_size as usize) {
let patch = image.crop_imm(x, y, patch_size, patch_size);
patches.push(patch);
}
}
patches
}

View File

@ -0,0 +1,316 @@
pub mod constants;
pub mod conversation;
pub mod image_processor;
use candle_transformers::generation::{LogitsProcessor, Sampling};
use candle_transformers::models::llama::Cache;
use anyhow::{bail, Error as E, Result};
use candle::{DType, Device, IndexOp, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::llava::config::{
HFGenerationConfig, HFLLaVAConfig, HFPreProcessorConfig,
};
use candle_transformers::models::llava::{config::LLaVAConfig, LLaVA};
use clap::Parser;
use constants::*;
use conversation::Conversation;
use hf_hub::api::sync::Api;
use image_processor::{process_image, ImageProcessor};
use std::io::Write;
use tokenizers::Tokenizer;
#[derive(Parser, Debug)]
#[command(author, version, about,long_about=None)]
struct Args {
#[arg(long, default_value = "llava-hf/llava-v1.6-vicuna-7b-hf")]
model_path: String,
#[arg(long, default_value = "tokenizer/tokenizer.json")]
tokenizer_path: String,
#[arg(long)]
model_base: Option<String>,
#[arg(long)]
image_file: String, // Required
#[arg(long)]
conv_mode: Option<String>,
#[arg(long, default_value_t = 0.2)]
temperature: f32,
#[arg(long, default_value_t = 512)]
max_new_tokens: usize,
#[arg(long, action)]
hf: bool,
#[arg(long, action)]
cpu: bool,
#[arg(long, action)]
no_kv_cache: bool,
#[arg(long)]
prompt: String,
/// The seed to use when generating random samples. Copy from candle llama. Not exist in python llava.
#[arg(long, default_value_t = 299792458)]
seed: u64,
}
//from https://github.com/huggingface/candle/blob/main/candle-examples/examples/clip/main.rs
fn load_image<T: AsRef<std::path::Path>>(
path: T,
processor: &ImageProcessor,
llava_config: &LLaVAConfig,
dtype: DType,
) -> Result<((u32, u32), Tensor)> {
let img = image::io::Reader::open(path)?.decode()?;
let img_tensor = process_image(&img, processor, llava_config)?;
Ok(((img.width(), img.height()), img_tensor.to_dtype(dtype)?))
}
fn get_model_name_from_path(model_path: &str) -> String {
let model_paths: Vec<String> = model_path
.trim_matches('/')
.split('/')
.map(|s| s.to_string())
.collect();
if model_paths.last().unwrap().starts_with("checkpoint-") {
format!(
"{}_{}",
model_paths[model_paths.len() - 2],
model_paths.last().unwrap()
)
} else {
model_paths.last().unwrap().to_string()
}
}
fn duplicate_vec<T>(vec: &[T], n: usize) -> Vec<T>
where
T: Clone,
{
let mut res = Vec::new();
for _ in 0..n {
res.extend(vec.to_owned());
}
res
}
fn insert_separator<T>(x: Vec<Vec<T>>, sep: Vec<T>) -> Vec<Vec<T>>
where
T: Clone,
{
let sep = vec![sep];
let sep = duplicate_vec(&sep, x.len());
let mut res = x
.iter()
.zip(sep.iter())
.flat_map(|(x, y)| vec![x.clone(), y.clone()])
.collect::<Vec<Vec<T>>>();
res.pop();
res
}
fn tokenizer_image_token(
prompt: &str,
tokenizer: &Tokenizer,
image_token_index: i64,
llava_config: &LLaVAConfig,
) -> Result<Tensor> {
let prompt_chunks = prompt
.split("<image>")
.map(|s| {
tokenizer
.encode(s, true)
.unwrap()
.get_ids()
.to_vec()
.iter()
.map(|x| *x as i64)
.collect()
})
.collect::<Vec<Vec<i64>>>();
let mut input_ids = Vec::new();
let mut offset = 0;
if !prompt_chunks.is_empty()
&& !prompt_chunks[0].is_empty()
&& prompt_chunks[0][0] == llava_config.bos_token_id as i64
{
offset = 1;
input_ids.push(prompt_chunks[0][0]);
}
for x in insert_separator(
prompt_chunks,
duplicate_vec(&[image_token_index], offset + 1),
)
.iter()
{
input_ids.extend(x[1..].to_vec())
}
let input_len = input_ids.len();
Tensor::from_vec(input_ids, (1, input_len), &Device::Cpu).map_err(E::msg)
}
fn main() -> Result<()> {
let mut args = Args::parse();
let device = candle_examples::device(args.cpu)?;
println!("Start loading model");
let api = Api::new()?;
let api = api.model(args.model_path.clone());
let (llava_config, tokenizer, clip_vision_config, image_processor) = if args.hf {
let config_filename = api.get("config.json")?;
let hf_llava_config: HFLLaVAConfig =
serde_json::from_slice(&std::fs::read(config_filename)?)?;
let generation_config_filename = api.get("generation_config.json")?;
let generation_config: HFGenerationConfig =
serde_json::from_slice(&std::fs::read(generation_config_filename)?)?;
let preprocessor_config_filename = api.get("preprocessor_config.json")?;
let preprocessor_config: HFPreProcessorConfig =
serde_json::from_slice(&std::fs::read(preprocessor_config_filename)?)?;
let llava_config =
hf_llava_config.to_llava_config(&generation_config, &preprocessor_config);
let tokenizer_filename = api.get("tokenizer.json")?;
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?;
let clip_vision_config = hf_llava_config.to_clip_vision_config();
(
llava_config,
tokenizer,
Some(clip_vision_config),
ImageProcessor::from_hf_preprocessor_config(&preprocessor_config),
)
} else {
let config_filename = api.get("config.json")?;
let llava_config: LLaVAConfig = serde_json::from_slice(&std::fs::read(config_filename)?)?;
let tokenizer = Tokenizer::from_file(&args.tokenizer_path)
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.tokenizer_path, e)))?;
(
llava_config.clone(),
tokenizer,
None,
ImageProcessor::from_pretrained(&llava_config.mm_vision_tower.unwrap())?,
)
};
let llama_config = llava_config.to_llama_config();
let dtype: DType = match llava_config.torch_dtype.as_str() {
"float16" => DType::F16,
"bfloat16" => DType::BF16,
_ => bail!("unsupported dtype"),
};
let eos_token_id = llava_config.eos_token_id;
println!("setting kv cache");
let mut cache = Cache::new(!args.no_kv_cache, dtype, &llama_config, &device)?;
println!("loading model weights");
let weight_filenames =
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&weight_filenames, dtype, &device)? };
let llava: LLaVA = LLaVA::load(vb, &llava_config, clip_vision_config)?;
println!("generating conv template");
let image_token_se = format!(
"{}{}{}",
DEFAULT_IM_START_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_END_TOKEN
);
let qs = if args.prompt.contains(IMAGE_PLACEHOLDER) {
if llava_config.mm_use_im_start_end {
args.prompt.replace(IMAGE_PLACEHOLDER, &image_token_se)
} else {
args.prompt.replace(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN)
}
} else if llava_config.mm_use_im_start_end {
format!("{}\n{}", image_token_se, args.prompt)
} else {
format!("{}\n{}", DEFAULT_IMAGE_TOKEN, args.prompt)
};
let model_name = get_model_name_from_path(&args.model_path).to_lowercase();
let conv_mode = if model_name.contains("llama-2") {
"llava_llama_2"
} else if model_name.contains("mistral") {
"mistral_instruct"
} else if model_name.contains("v1.6-34b") {
"chatml_direct"
} else if model_name.contains("v1") {
"llava_v1"
} else if model_name.contains("mpt") {
"mpt"
} else {
"llava_v0"
};
if args.conv_mode.is_some() && args.conv_mode.as_deref() != Some(conv_mode) {
println!(
"Warning: the model is trained with {}, but you are using {}",
conv_mode,
args.conv_mode.as_deref().unwrap()
);
} else {
args.conv_mode = Some(conv_mode.to_string());
}
let mut conv = match args.conv_mode {
Some(conv_mode) => match conv_mode.as_str() {
"chatml_direct" => Conversation::conv_chatml_direct(),
"llava_v1" => Conversation::conv_llava_v1(),
_ => todo!("not implement yet"),
},
None => bail!("conv_mode is required"),
};
conv.append_user_message(Some(&qs));
conv.append_assistant_message(None);
let prompt = conv.get_prompt();
println!("loading image");
let (image_size, image_tensor) =
load_image(&args.image_file, &image_processor, &llava_config, dtype)
.map_err(|e| E::msg(format!("Error loading {}: {}", &args.image_file, e)))?;
let image_tensor = image_tensor.to_device(&device)?;
let mut logits_processor = {
let temperature = f64::from(args.temperature);
let sampling = if temperature <= 0. {
Sampling::ArgMax
} else {
Sampling::All { temperature }
};
LogitsProcessor::from_sampling(args.seed, sampling)
};
// get input tokens
let tokens = tokenizer_image_token(
&prompt,
&tokenizer,
llava_config.image_token_index as i64,
&llava_config,
)?;
let mut input_embeds =
llava.prepare_inputs_labels_for_multimodal(&tokens, &[image_tensor], &[image_size])?;
//inference loop, based on https://github.com/huggingface/candle/blob/main/candle-examples/examples/llama/main.rs
let mut tokenizer = candle_examples::token_output_stream::TokenOutputStream::new(tokenizer);
let mut index_pos = 0;
for index in 0..args.max_new_tokens {
let (_, input_embeds_len, _) = input_embeds.dims3()?;
let (context_size, context_index) = if cache.use_kv_cache && index > 0 {
(1, index_pos)
} else {
(input_embeds_len, 0)
};
let input = input_embeds.i((.., input_embeds_len.saturating_sub(context_size).., ..))?;
let logits = llava.forward(&input, context_index, &mut cache)?; //[1,32000]
let logits = logits.squeeze(0)?;
let (_, input_len, _) = input.dims3()?;
index_pos += input_len;
let next_token = logits_processor.sample(&logits)?;
let next_token_tensor = Tensor::from_vec(vec![next_token], 1, &device)?;
let next_embeds = llava.llama.embed(&next_token_tensor)?.unsqueeze(0)?;
input_embeds = Tensor::cat(&[input_embeds, next_embeds], 1)?;
if next_token == eos_token_id as u32 {
break;
}
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{t}");
std::io::stdout().flush()?;
}
}
if let Some(rest) = tokenizer.decode_rest().map_err(E::msg)? {
print!("{rest}");
}
Ok(())
}

View File

@ -0,0 +1,40 @@
# candle-llava
LLaVA (Large Language-and-Vision Assistant) is an end-to-end trained large
multimodal model. This example is from [candle-llava](https://github.com/chenwanqq/candle-llava)
The code is based on [https://github.com/haotian-liu/LLaVA](https://github.com/haotian-liu/LLaVA), Hence the llava-hf version of config may perform differently.
## model zoo
* [liuhaotian/LLaVA](https://huggingface.co/liuhaotian)
* [llava-hf](https://huggingface.co/llava-hf)
Right now this has been tested on `liuhaotian/llava-v1.6-vicuna-7b` and
`llava-hf/llava-v1.6-vicuna-7b-hf`. Memory usage might have room for optimization.
## Tokenizer Setup
The llava-hf models contain a `tokenizer.json` file so can be used directly with
the `-hf` command line flag.
For the original llava models, you can use the following code to generate the `tokenizer.json` file.
```bash
conda create -n llava python=3.10
pip install transformers protobuf
conda activate llava
python -c "from transformers import AutoTokenizer;tokenizer=AutoTokenizer.from_pretrained('liuhaotian/llava-v1.6-vicuna-7b');tokenizer.save_pretrained('tokenizer')"
```
Then the `tokenizer.json` file should be in `tokenizer/tokenizer.json` (which is the default path).
## eval
```bash
cargo run --example llava --features cuda -- --image-file "llava_logo.png" --prompt "is this a cat?" --hf # default args, use llava-hf/llava-v1.6-vicuna-7b-hf. image-file is required^_^
cargo run --example llava --features cuda -- --model-path liuhaotian/llava-v1.6-vicuna-7b --image-file "llava_logo.png" --prompt "is this a cat?" # use liuhaotian/llava-v1.6-vicuna-7b, tokenizer setup should be done
```
## Major Limitations
1. Currently only support llama-2/vicuna llm. Haven't supoort Mistral yet.
2. There are some ops like split, nonzero and where are not supported by candle.
3. Lack of quantization and LoRA support.

View File

@ -262,6 +262,20 @@ impl ClipEncoder {
}
Ok(xs)
}
// required by LLaVA
pub fn output_hidden_states(
&self,
xs: &Tensor,
causal_attention_mask: Option<&Tensor>,
) -> Result<Vec<Tensor>> {
let mut xs = xs.clone();
let mut hidden_states = Vec::new();
for layer in self.layers.iter() {
xs = layer.forward(&xs, causal_attention_mask)?;
hidden_states.push(xs.clone());
}
Ok(hidden_states)
}
}
/// A CLIP transformer based model.

View File

@ -46,6 +46,19 @@ impl ClipVisionConfig {
patch_size: 32,
}
}
pub fn clip_vit_large_patch14_336() -> Self {
Self {
embed_dim: 1024,
activation: Activation::QuickGelu,
intermediate_size: 4096,
num_hidden_layers: 24,
num_attention_heads: 16,
projection_dim: 768,
num_channels: 3,
image_size: 336,
patch_size: 14,
}
}
}
// https://github.com/huggingface/transformers/blob/f6fa0f0bf0796ac66f201f23bdb8585de1609add/src/transformers/models/clip/modeling_clip.py#L112
@ -130,6 +143,17 @@ impl ClipVisionTransformer {
pre_layer_norm,
})
}
// required by LLaVA
pub fn output_hidden_states(&self, pixel_values: &Tensor) -> Result<Vec<Tensor>> {
let hidden_states = pixel_values
.apply(&self.embeddings)?
.apply(&self.pre_layer_norm)?;
let mut result = self.encoder.output_hidden_states(&hidden_states, None)?;
let encoder_outputs = result.last().unwrap();
let pooled_output = encoder_outputs.i((.., 0, ..))?;
result.push(self.final_layer_norm.forward(&pooled_output)?.clone());
Ok(result)
}
}
impl Module for ClipVisionTransformer {

View File

@ -388,6 +388,28 @@ pub struct Llama {
}
impl Llama {
// required by LLaVA
pub fn embed(&self, x: &Tensor) -> Result<Tensor> {
self.wte.forward(x)
}
// required by LLaVA
pub fn forward_input_embed(
&self,
input_embed: &Tensor,
index_pos: usize,
cache: &mut Cache,
) -> Result<Tensor> {
let (_, seq_len, _) = input_embed.dims3()?;
let mut x = input_embed.clone();
for (block_idx, block) in self.blocks.iter().enumerate() {
x = block.forward(&x, index_pos, block_idx, cache)?;
}
let x = self.ln_f.forward(&x)?;
let x = x.i((.., seq_len - 1, ..))?.contiguous()?;
let logits = self.lm_head.forward(&x)?;
logits.to_dtype(DType::F32)
}
pub fn forward(&self, x: &Tensor, index_pos: usize, cache: &mut Cache) -> Result<Tensor> {
let (_b_sz, seq_len) = x.dims2()?;
let mut x = self.wte.forward(x)?;

View File

@ -0,0 +1,267 @@
use std::collections::HashMap;
use crate::models::{
clip::{text_model::Activation, vision_model::ClipVisionConfig},
llama::Config,
};
use serde::{Deserialize, Serialize};
// original config from liuhaotian/llava
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct LLaVAConfig {
pub architectures: Vec<String>,
pub bos_token_id: usize,
pub eos_token_id: usize,
pub hidden_size: usize,
#[serde(default = "default_image_aspect_ratio")]
pub image_aspect_ratio: String,
pub image_crop_resolution: usize,
pub image_grid_pinpoints: Vec<(u32, u32)>,
pub image_split_resolution: usize,
pub intermediate_size: usize,
pub max_position_embeddings: usize,
pub mm_hidden_size: usize,
#[serde(default = "default_mm_patch_merge_type")]
pub mm_patch_merge_type: String,
pub mm_projector_type: String,
pub mm_use_im_start_end: bool,
pub mm_vision_select_feature: String,
pub mm_vision_select_layer: isize,
pub mm_vision_tower: Option<String>,
pub model_type: String,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub num_key_value_heads: usize,
pub pad_token_id: usize,
pub rms_norm_eps: f32,
pub rope_theta: f32,
pub tokenizer_model_max_length: Option<usize>,
pub torch_dtype: String,
pub use_cache: bool,
pub vocab_size: usize,
#[serde(default = "default_image_token_index")]
pub image_token_index: isize,
#[serde(default = "default_hf")]
pub hf: bool,
}
fn default_hf() -> bool {
false
}
fn default_image_token_index() -> isize {
-200
}
fn default_mm_patch_merge_type() -> String {
"flat".to_string()
}
fn default_image_aspect_ratio() -> String {
"square".to_string()
}
impl LLaVAConfig {
pub fn to_llama_config(&self) -> Config {
Config {
hidden_size: self.hidden_size,
intermediate_size: self.intermediate_size,
vocab_size: self.vocab_size,
num_hidden_layers: self.num_hidden_layers,
num_attention_heads: self.num_attention_heads,
num_key_value_heads: self.num_key_value_heads,
rms_norm_eps: self.rms_norm_eps as f64,
rope_theta: self.rope_theta,
bos_token_id: Some(self.bos_token_id as u32),
eos_token_id: Some(self.eos_token_id as u32),
use_flash_attn: false,
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFLLaVATextConfig {
pub architectures: Vec<String>,
#[serde(default = "default_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_max_length")]
pub max_length: usize,
pub max_position_embeddings: usize,
pub model_type: String,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_num_hidden_layers")]
pub num_hidden_layers: usize,
#[serde(default = "default_num_key_value_heads")]
pub num_key_value_heads: usize,
pub pad_token_id: usize,
pub rms_norm_eps: f32,
#[serde(default = "default_rope_theta")]
pub rope_theta: f32,
pub torch_dtype: String,
#[serde(default = "default_use_cache")]
pub use_cache: bool,
pub vocab_size: usize,
}
fn default_num_hidden_layers() -> usize {
32
}
fn default_use_cache() -> bool {
true
}
fn default_hidden_size() -> usize {
4096
}
fn default_intermediate_size() -> usize {
11008
}
fn default_max_length() -> usize {
4096
}
fn default_num_attention_heads() -> usize {
32
}
fn default_num_key_value_heads() -> usize {
32
}
fn default_rope_theta() -> f32 {
10000.0
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFLLaVAVisionConfig {
pub hidden_size: usize,
pub image_size: usize,
pub intermediate_size: usize,
pub model_type: String,
pub num_attention_heads: usize,
pub num_hidden_layers: usize,
pub patch_size: usize,
pub projection_dim: usize,
pub vocab_size: usize,
}
// config from llava-v1.6-vicuna-7b-hf
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFLLaVAConfig {
pub architectures: Vec<String>,
pub ignore_index: isize,
pub image_grid_pinpoints: Vec<(u32, u32)>,
pub image_token_index: isize,
pub model_type: String,
pub projector_hidden_act: String,
pub text_config: HFLLaVATextConfig,
pub torch_dtype: String,
pub use_image_newline_parameter: bool,
pub vision_config: HFLLaVAVisionConfig,
pub vision_feature_layer: isize,
pub vision_feature_select_strategy: String,
pub vocab_size: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFGenerationConfig {
pub bos_token_id: usize,
pub eos_token_id: usize,
#[serde(default = "default_max_length")]
pub max_length: usize,
pub pad_token_id: usize,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct HFPreProcessorConfig {
pub aspect_ratio_setting: String,
pub crop_size: HashMap<String, usize>,
pub do_center_crop: bool,
pub do_convert_rgb: bool,
pub do_normalize: bool,
pub do_rescale: bool,
pub do_resize: bool,
pub image_mean: Vec<f32>,
pub image_std: Vec<f32>,
pub resample: u32,
pub rescale_factor: f32,
pub size: HashMap<String, f32>,
}
impl HFLLaVAConfig {
pub fn to_clip_vision_config(&self) -> ClipVisionConfig {
ClipVisionConfig {
embed_dim: self.vision_config.hidden_size,
activation: Activation::QuickGelu,
intermediate_size: self.vision_config.intermediate_size,
num_hidden_layers: self.vision_config.num_hidden_layers,
num_attention_heads: self.vision_config.num_attention_heads,
projection_dim: self.vision_config.projection_dim,
num_channels: 3,
image_size: self.vision_config.image_size,
patch_size: self.vision_config.patch_size,
}
}
fn map_projector_type(s: &str) -> String {
if s == "gelu" {
"mlp2x_gelu".to_string()
} else {
s.to_string()
}
}
fn map_select_feature(s: &str) -> String {
if s == "default" {
"patch".to_string()
} else {
"cls_patch".to_string()
}
}
pub fn to_llava_config(
&self,
generation_config: &HFGenerationConfig,
preprocessor_config: &HFPreProcessorConfig,
) -> LLaVAConfig {
LLaVAConfig {
hf: true,
architectures: self.architectures.clone(),
bos_token_id: generation_config.bos_token_id,
eos_token_id: generation_config.eos_token_id,
hidden_size: self.text_config.hidden_size,
image_aspect_ratio: preprocessor_config.aspect_ratio_setting.clone(),
image_crop_resolution: 224,
image_grid_pinpoints: self.image_grid_pinpoints.clone(),
image_split_resolution: 224,
intermediate_size: self.text_config.intermediate_size,
max_position_embeddings: self.text_config.max_position_embeddings,
mm_hidden_size: 1024,
mm_patch_merge_type: "spatial_unpad".to_string(),
mm_projector_type: Self::map_projector_type(&self.projector_hidden_act),
mm_use_im_start_end: false,
mm_vision_select_feature: Self::map_select_feature(
&self.vision_feature_select_strategy,
),
mm_vision_select_layer: self.vision_feature_layer,
mm_vision_tower: None,
model_type: self.model_type.clone(),
num_attention_heads: self.text_config.num_attention_heads,
num_hidden_layers: self.text_config.num_hidden_layers,
num_key_value_heads: self.text_config.num_key_value_heads,
pad_token_id: self.text_config.pad_token_id,
rms_norm_eps: self.text_config.rms_norm_eps,
rope_theta: self.text_config.rope_theta,
tokenizer_model_max_length: Some(4096),
torch_dtype: self.torch_dtype.clone(),
use_cache: self.text_config.use_cache,
vocab_size: self.vocab_size,
image_token_index: self.image_token_index,
}
}
}

View File

@ -0,0 +1,407 @@
pub mod config;
pub mod utils;
use crate::models::clip::vision_model::{ClipVisionConfig, ClipVisionTransformer};
use crate::models::llama::{Cache, Llama};
use crate::models::with_tracing::linear;
use candle::{bail, Device, IndexOp, Result, Tensor};
use candle_nn::{seq, Activation, Module, Sequential, VarBuilder};
use fancy_regex::Regex;
use utils::get_anyres_image_grid_shape;
use config::LLaVAConfig;
fn mlp_gelu_match(mm_projector_type: &str) -> Option<usize> {
let mlp_gelu_regex = Regex::new(r"^mlp(\d+)x_gelu$").unwrap();
if let Ok(Some(captures)) = mlp_gelu_regex.captures(mm_projector_type) {
if let Some(match_str) = captures.get(1) {
let match_str = match_str.as_str();
match_str.parse::<usize>().ok()
} else {
None
}
} else {
None
}
}
fn unpad_image(tensor: &Tensor, original_size: &(u32, u32)) -> Result<Tensor> {
assert_eq!(tensor.dims().len(), 3);
let (original_width, original_height) = *original_size;
let tensor_dims = tensor.dims();
let current_height = tensor_dims[1];
let current_width = tensor_dims[2];
let original_aspect_ratio = (original_width as f32) / (original_height as f32);
let current_aspect_ratio = (current_width as f32) / (current_height as f32);
if original_aspect_ratio > current_aspect_ratio {
let scale_factor = (current_width as f32) / (original_width as f32);
let new_height = (original_height as f32 * scale_factor).floor() as usize;
let padding = (current_height - new_height) / 2;
tensor.i((.., padding..current_width - padding, ..))
} else {
let scale_factor = (current_height as f32) / (original_height as f32);
let new_width = (original_width as f32 * scale_factor).floor() as usize;
let padding = (current_width - new_width) / 2;
tensor.i((.., .., padding..current_width - padding))
}
}
pub struct IdentityMap {}
impl Module for IdentityMap {
fn forward(&self, x: &Tensor) -> Result<Tensor> {
Ok(x.clone())
}
}
pub struct MMProjector {
pub modules: Sequential,
}
impl MMProjector {
pub fn load(vb: &VarBuilder, config: &LLaVAConfig) -> Result<Self> {
if config.mm_projector_type == "linear" {
let vb_prefix = if config.hf {
"multi_modal_projector.linear_1"
} else {
"model.mm_projector.0"
};
let linear = linear(config.mm_hidden_size, config.hidden_size, vb.pp(vb_prefix))?;
let modules = seq().add(linear);
Ok(Self { modules })
} else if let Some(mlp_depth) = mlp_gelu_match(&config.mm_projector_type) {
let modules = if config.hf {
let mut modules = seq().add(linear(
config.mm_hidden_size,
config.hidden_size,
vb.pp("multi_modal_projector.linear_1"),
)?);
for i in 1..mlp_depth {
modules = modules.add(Activation::Gelu).add(linear(
config.hidden_size,
config.hidden_size,
vb.pp(format!("multi_modal_projector.linear_{}", i + 1)),
)?);
}
modules
} else {
let mut modules = seq().add(linear(
config.mm_hidden_size,
config.hidden_size,
vb.pp("model.mm_projector.0"),
)?);
for i in 1..mlp_depth {
modules = modules.add(Activation::Gelu).add(linear(
config.hidden_size,
config.hidden_size,
vb.pp(format!("model.mm_projector.{}", i * 2)),
)?);
}
modules
};
Ok(Self { modules })
} else if config.mm_projector_type == "identity" {
Ok(Self {
modules: seq().add(IdentityMap {}),
})
} else {
bail!(
"Unsupported MM projector type: {}",
config.mm_projector_type
)
}
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.modules.forward(x)
}
}
pub struct ClipVisionTower {
model: ClipVisionTransformer,
select_layer: isize,
select_feature_method: String,
pub config: ClipVisionConfig,
}
impl ClipVisionTower {
pub fn new(
vb: VarBuilder,
select_layer: isize,
select_feature_method: &str,
config: &Option<ClipVisionConfig>,
) -> Result<Self> {
let config = if config.is_none() {
ClipVisionConfig::clip_vit_large_patch14_336()
} else {
config.clone().unwrap()
};
let select_layer = match select_layer {
-1 | -2 => select_layer,
_ => bail!("Unsupported select layer: {}", select_layer),
};
let model = ClipVisionTransformer::new(vb, &config)?;
Ok(Self {
model,
select_layer,
select_feature_method: select_feature_method.to_string(),
config,
})
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
let result = self.model.output_hidden_states(x)?;
let index = result.len() as isize + self.select_layer;
let result = result[index as usize].clone();
if self.select_feature_method == "cls_patch" {
Ok(result)
} else {
result.i((.., 1..))
}
}
pub fn num_patches_per_side(&self) -> usize {
self.config.image_size / self.config.patch_size
}
}
pub struct LLaVA {
pub clip_vision_tower: ClipVisionTower,
pub image_newline: Tensor,
pub mm_projector: MMProjector,
pub llama: Llama,
config: LLaVAConfig,
device: Device,
}
impl LLaVA {
pub fn load(
vb: VarBuilder,
config: &LLaVAConfig,
clip_vision_config: Option<ClipVisionConfig>,
) -> Result<Self> {
let device = vb.device().clone();
let llama_config = config.to_llama_config();
let mm_projector = MMProjector::load(&vb, config)?;
let (clip_vision_tower, image_newline, llama) = if config.hf {
(
ClipVisionTower::new(
vb.pp("vision_tower.vision_model"),
config.mm_vision_select_layer,
&config.mm_vision_select_feature,
&clip_vision_config,
)?,
vb.get(&[config.hidden_size], "image_newline")?
.to_device(&device)?,
Llama::load(vb.pp("language_model"), &llama_config)?,
)
} else {
(
ClipVisionTower::new(
vb.pp("model.vision_tower.vision_tower.vision_model"),
config.mm_vision_select_layer,
&config.mm_vision_select_feature,
&clip_vision_config,
)?,
vb.get(&[config.hidden_size], "model.image_newline")?
.to_device(&device)?,
Llama::load(vb, &llama_config)?,
)
};
Ok(Self {
clip_vision_tower,
image_newline,
mm_projector,
llama,
config: (*config).clone(),
device,
})
}
pub fn encode_images(&self, x: &Tensor) -> Result<Tensor> {
let image_features = self.clip_vision_tower.forward(x)?;
let image_features = self.mm_projector.forward(&image_features)?;
Ok(image_features)
}
// currently only for single image, 4 dim tensor
pub fn prepare_inputs_labels_for_multimodal(
&self,
input_ids: &Tensor,
images: &[Tensor],
image_sizes: &[(u32, u32)],
) -> Result<Tensor> {
//TODO: process of multiple images/ new line
// 576: 336(input size)/14(patch size)=24 24*24+1(class)=577 577-1=576
let concat_images = Tensor::cat(images, 0)?;
let image_features_together = self.encode_images(&concat_images)?;
let split_sizes = images
.iter()
.map(|x| x.shape().dims()[0])
.collect::<Vec<usize>>();
// can be replaced by split
let mut index_pos = 0;
let mut image_features = Vec::new();
for split_size in split_sizes.iter() {
image_features.push(image_features_together.i(index_pos..index_pos + (*split_size))?);
index_pos += *split_size;
}
let mm_patch_merge_type = &self.config.mm_patch_merge_type;
let image_aspect_ratio = &self.config.image_aspect_ratio;
let image_features = if mm_patch_merge_type == "flat" {
image_features
.iter()
.map(|x| x.flatten(0, 1).unwrap())
.collect::<Vec<Tensor>>()
} else if mm_patch_merge_type.starts_with("spatial") {
let mut new_image_features = Vec::new();
for (image_idx, image_feature) in image_features.iter().enumerate() {
let new_image_feature = if image_feature.dims()[0] > 1 {
let base_image_feature = image_feature.get(0).unwrap();
let patch_image_feature = image_feature.i(1..).unwrap();
let height = self.clip_vision_tower.num_patches_per_side();
let width = height;
assert_eq!(height * width, base_image_feature.dims()[0]);
let image_size = image_sizes[image_idx];
let new_image_feature = if image_aspect_ratio == "anyres" {
let (num_patch_width, num_patch_height) = get_anyres_image_grid_shape(
image_size,
&self.config.image_grid_pinpoints,
self.clip_vision_tower.config.image_size as u32,
);
patch_image_feature.reshape((
num_patch_height as usize,
num_patch_width as usize,
height,
width,
(),
))?
} else {
todo!("not implemented in original python LLaVA yet")
};
let new_image_feature = if mm_patch_merge_type.contains("unpad") {
let new_image_feature = new_image_feature
.permute((4, 0, 2, 1, 3))?
.flatten(1, 2)?
.flatten(2, 3)?;
let new_image_feature = unpad_image(&new_image_feature, &image_size)?;
let new_image_feature_dims = new_image_feature.dims();
let image_new_line = self
.image_newline
.reshape((self.config.hidden_size, 1, 1))?
.broadcast_as((
new_image_feature_dims[0],
new_image_feature_dims[1],
1,
))?;
let new_image_feature =
Tensor::cat(&[new_image_feature, image_new_line], 2)?;
new_image_feature.flatten(1, 2)?.transpose(0, 1)?
} else {
new_image_feature.permute((0, 2, 1, 3, 4))?.flatten(0, 3)?
};
Tensor::cat(&[base_image_feature, new_image_feature], 0)?
} else {
let new_image_feature = image_feature.get(0).unwrap();
if mm_patch_merge_type.contains("unpad") {
Tensor::cat(
&[
new_image_feature,
self.image_newline.clone().unsqueeze(0).unwrap(),
],
0,
)
.unwrap()
} else {
new_image_feature
}
};
new_image_features.push(new_image_feature);
}
new_image_features
} else {
bail!("Unexpected mm_patch_merge_type: {mm_patch_merge_type}")
};
// can easily be replaced by nonzero if it is implemented in candle
let input_ids_vec = input_ids.squeeze(0)?.to_vec1::<i64>()?;
let mut image_indices = {
let mut image_indices = vec![0_i64];
image_indices.extend(
input_ids_vec
.iter()
.enumerate()
.filter_map(|(i, x)| {
if *x == self.config.image_token_index as i64 {
Some(i as i64)
} else {
None
}
})
.collect::<Vec<i64>>(),
);
image_indices
};
if image_indices.len() == 1 {
//no image, only [0],
return self.llama.embed(input_ids);
}
let input_ids_noim = input_ids_vec
.iter()
.filter_map(|x| {
if *x != self.config.image_token_index as i64 {
Some(*x)
} else {
None
}
})
.collect::<Vec<i64>>();
let input_ids_noim_len = input_ids_noim.len();
image_indices.push((input_ids_noim_len) as i64);
let input_ids_noim = Tensor::from_vec(input_ids_noim, input_ids_noim_len, &self.device)?;
let cur_input_embeds = self.llama.embed(&input_ids_noim)?;
// can be replace by split if it is implemented in candle
let input_embed_no_ims = {
let mut input_embeds = Vec::new();
for i in 0..image_indices.len() - 1 {
let start = (image_indices[i]) as usize;
let end = image_indices[i + 1] as usize;
input_embeds.push(cur_input_embeds.i((start..end, ..))?)
}
input_embeds
};
let mut cur_new_input_embeds = Vec::new();
for (i, image_feature) in image_features.iter().enumerate() {
cur_new_input_embeds.push(input_embed_no_ims[i].clone());
cur_new_input_embeds.push(image_feature.clone());
}
cur_new_input_embeds.push(input_embed_no_ims[image_features.len()].clone());
let new_input_embeds = Tensor::cat(&cur_new_input_embeds, 0)?;
//trancate
let new_input_embeds =
if let Some(tokenizer_model_max_length) = self.config.tokenizer_model_max_length {
let (new_input_embeds_length, _) = new_input_embeds.shape().dims2()?;
if new_input_embeds_length > tokenizer_model_max_length {
new_input_embeds.i((..tokenizer_model_max_length, ..))?
} else {
new_input_embeds
}
} else {
new_input_embeds
};
new_input_embeds.unsqueeze(0)
}
pub fn forward(
&self,
input_embeds: &Tensor,
position_id: usize,
cache: &mut Cache,
) -> Result<Tensor> {
self.llama
.forward_input_embed(input_embeds, position_id, cache)
}
}

View File

@ -0,0 +1,41 @@
pub fn get_anyres_image_grid_shape(
image_size: (u32, u32),
grid_pinpoints: &[(u32, u32)],
patch_size: u32,
) -> (u32, u32) {
let (width, height) = select_best_resolution(image_size, grid_pinpoints);
(width / patch_size, height / patch_size)
}
pub fn select_best_resolution(
original_size: (u32, u32),
possible_resolutions: &[(u32, u32)],
) -> (u32, u32) {
let (original_width, original_height) = original_size;
let mut best_fit = (0, 0);
let original_width_f = original_width as f32;
let original_height_f = original_height as f32;
let mut max_effective_resolution = 0_u32;
let mut min_wasted_resolution = u32::MAX;
for (width, height) in possible_resolutions {
let width_f = *width as f32;
let height_f = *height as f32;
let scale = (width_f / original_width_f).min(height_f / original_height_f);
let (downscaled_width, downscaled_height) = (
(original_width_f * scale) as u32,
(original_height_f * scale) as u32,
);
let effective_resolution =
std::cmp::min((*width) * (*height), downscaled_width * downscaled_height);
let wasted_resolution = (*width) * (*height) - effective_resolution;
if effective_resolution > max_effective_resolution
|| (effective_resolution == max_effective_resolution
&& wasted_resolution < min_wasted_resolution)
{
best_fit = (*width, *height);
max_effective_resolution = effective_resolution;
min_wasted_resolution = wasted_resolution;
}
}
best_fit
}

View File

@ -17,6 +17,7 @@ pub mod jina_bert;
pub mod llama;
pub mod llama2_c;
pub mod llama2_c_weights;
pub mod llava;
pub mod mamba;
pub mod marian;
pub mod metavoice;