mirror of https://github.com/vllm-project/vllm
[Model] Initialize support for InternVL2 series models (#6514)
Co-authored-by: Roger Wang <ywang@roblox.com>
This commit is contained in:
parent
3eeb148f46
commit
7cbd9ec7a9
|
@ -200,6 +200,10 @@ Vision Language Models
|
|||
- Fuyu
|
||||
- :code:`adept/fuyu-8b` etc.
|
||||
-
|
||||
* - :code:`InternVLChatModel`
|
||||
- InternVL2
|
||||
- :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc.
|
||||
-
|
||||
* - :code:`LlavaForConditionalGeneration`
|
||||
- LLaVA-1.5
|
||||
- :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc.
|
||||
|
|
|
@ -106,6 +106,20 @@ def run_minicpmv(question):
|
|||
return llm, prompt
|
||||
|
||||
|
||||
# InternVL
|
||||
def run_internvl(question):
|
||||
# Generally, InternVL can use chatml template for conversation
|
||||
TEMPLATE = "<|im_start|>User\n{prompt}<|im_end|>\n<|im_start|>Assistant\n"
|
||||
prompt = f"<image>\n{question}\n"
|
||||
prompt = TEMPLATE.format(prompt=prompt)
|
||||
llm = LLM(
|
||||
model="OpenGVLab/InternVL2-4B",
|
||||
trust_remote_code=True,
|
||||
max_num_seqs=5,
|
||||
)
|
||||
return llm, prompt
|
||||
|
||||
|
||||
# BLIP-2
|
||||
def run_blip2(question):
|
||||
|
||||
|
@ -125,6 +139,7 @@ model_example_map = {
|
|||
"chameleon": run_chameleon,
|
||||
"minicpmv": run_minicpmv,
|
||||
"blip-2": run_blip2,
|
||||
"internvl_chat": run_internvl,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -42,6 +42,7 @@ chat_completion_from_url = client.chat.completions.create(
|
|||
],
|
||||
}],
|
||||
model=model,
|
||||
max_tokens=64,
|
||||
)
|
||||
|
||||
result = chat_completion_from_url.choices[0].message.content
|
||||
|
@ -78,6 +79,7 @@ chat_completion_from_base64 = client.chat.completions.create(
|
|||
],
|
||||
}],
|
||||
model=model,
|
||||
max_tokens=64,
|
||||
)
|
||||
|
||||
result = chat_completion_from_base64.choices[0].message.content
|
||||
|
|
|
@ -16,6 +16,7 @@ ray
|
|||
sentence-transformers # required for embedding
|
||||
sparseml==1.8.0 # required for compressed-tensors
|
||||
compressed-tensors==0.4.0 # required for compressed-tensors
|
||||
timm # required for internvl test
|
||||
|
||||
# Benchmarking
|
||||
aiohttp
|
||||
|
|
|
@ -0,0 +1,201 @@
|
|||
import types
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from PIL.Image import Image
|
||||
|
||||
from vllm.model_executor.models.internvl import (IMG_CONTEXT, IMG_END,
|
||||
IMG_START,
|
||||
image_to_pixel_values)
|
||||
from vllm.multimodal.utils import rescale_image_size
|
||||
from vllm.utils import is_cpu
|
||||
|
||||
from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
|
||||
from .utils import check_logprobs_close
|
||||
|
||||
pytestmark = pytest.mark.vlm
|
||||
|
||||
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
|
||||
"stop_sign":
|
||||
"<|im_start|>User\n<image>\nWhat's the content in the center of the image?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
|
||||
"cherry_blossom":
|
||||
"<|im_start|>User\n<image>\nWhat is the season?<|im_end|>\n<|im_start|>Assistant\n", # noqa: E501
|
||||
})
|
||||
|
||||
# we use snapshot_download to prevent conflicts between
|
||||
# dynamic_module and trust_remote_code for hf_runner
|
||||
models = [
|
||||
snapshot_download("OpenGVLab/InternVL2-1B"),
|
||||
snapshot_download("OpenGVLab/InternVL2-2B"),
|
||||
# snapshot_download("OpenGVLab/InternVL2-4B"), # broken
|
||||
]
|
||||
|
||||
|
||||
class InternVLProcessor:
|
||||
"""A simple processor for InternVL2 HF model which misses a processor."""
|
||||
|
||||
def __init__(self, hf_runner: HfRunner):
|
||||
self.num_image_token = hf_runner.model.num_image_token
|
||||
self.tokenizer = hf_runner.tokenizer
|
||||
self.dtype = hf_runner.model.dtype
|
||||
|
||||
def __call__(self, text: str, images: Image, **kwargs):
|
||||
pixel_values = image_to_pixel_values(images).to(self.dtype)
|
||||
num_patches_list = [pixel_values.shape[0]]
|
||||
for num_patches in num_patches_list:
|
||||
context_tokens = IMG_CONTEXT * self.num_image_token * num_patches
|
||||
image_tokens = IMG_START + context_tokens + IMG_END
|
||||
text = text.replace('<image>', image_tokens, 1)
|
||||
prompt = self.tokenizer(text, return_tensors="pt")
|
||||
prompt.update({"pixel_values": pixel_values})
|
||||
return prompt
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/modeling_internvl_chat.py
|
||||
def generate(
|
||||
self,
|
||||
pixel_values: torch.FloatTensor,
|
||||
input_ids: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.LongTensor] = None,
|
||||
**generate_kwargs,
|
||||
) -> torch.LongTensor:
|
||||
"""Generate method for InternVL2 model without fixed use_cache."""
|
||||
assert self.img_context_token_id is not None
|
||||
vit_embeds = self.extract_feature(pixel_values)
|
||||
input_embeds = self.language_model.get_input_embeddings()(input_ids)
|
||||
B, N, C = input_embeds.shape
|
||||
input_embeds = input_embeds.reshape(B * N, C)
|
||||
|
||||
input_ids = input_ids.reshape(B * N)
|
||||
selected = (input_ids == self.img_context_token_id)
|
||||
assert selected.sum() != 0
|
||||
input_embeds[selected] = vit_embeds.reshape(-1, C).to(input_embeds.device)
|
||||
|
||||
input_embeds = input_embeds.reshape(B, N, C)
|
||||
|
||||
outputs = self.language_model.generate(
|
||||
inputs_embeds=input_embeds,
|
||||
attention_mask=attention_mask,
|
||||
**generate_kwargs,
|
||||
)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def run_test(
|
||||
hf_runner: Type[HfRunner],
|
||||
vllm_runner: Type[VllmRunner],
|
||||
image_assets: _ImageAssets,
|
||||
model: str,
|
||||
*,
|
||||
size_factors: List[float],
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
tensor_parallel_size: int,
|
||||
distributed_executor_backend: Optional[str] = None,
|
||||
):
|
||||
"""Inference result should be the same between hf and vllm.
|
||||
|
||||
All the image fixtures for the test is under tests/images.
|
||||
For huggingface runner, we provide the PIL images as input.
|
||||
For vllm runner, we provide MultiModalDataDict objects
|
||||
and corresponding vision language config as input.
|
||||
Note, the text input is also adjusted to abide by vllm contract.
|
||||
The text output is sanitized to be able to compare with hf.
|
||||
"""
|
||||
images = [asset.pil_image for asset in image_assets]
|
||||
|
||||
inputs_per_image = [(
|
||||
[prompt for _ in size_factors],
|
||||
[rescale_image_size(image, factor) for factor in size_factors],
|
||||
) for image, prompt in zip(images, HF_IMAGE_PROMPTS)]
|
||||
|
||||
# NOTE: take care of the order. run vLLM first, and then run HF.
|
||||
# vLLM needs a fresh new process without cuda initialization.
|
||||
# if we run HF first, the cuda initialization will be done and it
|
||||
# will hurt multiprocessing backend with fork method (the default method).
|
||||
|
||||
# max_model_len should be greater than image_feature_size
|
||||
with vllm_runner(model,
|
||||
max_model_len=4096,
|
||||
dtype=dtype,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
distributed_executor_backend=distributed_executor_backend,
|
||||
enforce_eager=True) as vllm_model:
|
||||
vllm_outputs_per_image = [
|
||||
vllm_model.generate_greedy_logprobs(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=images)
|
||||
for prompts, images in inputs_per_image
|
||||
]
|
||||
|
||||
with hf_runner(model, dtype=dtype) as hf_model:
|
||||
img_context_token_id = hf_model.tokenizer.convert_tokens_to_ids(
|
||||
"<IMG_CONTEXT>")
|
||||
hf_model.model.img_context_token_id = img_context_token_id
|
||||
hf_model.processor = InternVLProcessor(hf_model)
|
||||
hf_model.model.get_output_embeddings = lambda: \
|
||||
hf_model.model.language_model.get_output_embeddings()
|
||||
hf_model.model.generate = types.MethodType(generate, hf_model.model)
|
||||
eos_token_id = hf_model.tokenizer.eos_token_id
|
||||
hf_outputs_per_image = [
|
||||
hf_model.generate_greedy_logprobs_limit(prompts,
|
||||
max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
images=hf_images,
|
||||
eos_token_id=eos_token_id)
|
||||
for prompts, hf_images in inputs_per_image
|
||||
]
|
||||
|
||||
for hf_outputs, vllm_outputs in zip(hf_outputs_per_image,
|
||||
vllm_outputs_per_image):
|
||||
# TODO: Check whether using original CLIPVisionModel can improve
|
||||
# consistency against HF
|
||||
check_logprobs_close(
|
||||
outputs_0_lst=hf_outputs,
|
||||
outputs_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
)
|
||||
|
||||
|
||||
target_dtype = "half"
|
||||
if is_cpu():
|
||||
target_dtype = "bfloat16"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", models)
|
||||
@pytest.mark.parametrize(
|
||||
"size_factors",
|
||||
[
|
||||
# No image
|
||||
[],
|
||||
# Single-scale
|
||||
[1.0],
|
||||
# Single-scale, batched
|
||||
[1.0, 1.0, 1.0],
|
||||
# Multi-scale
|
||||
[0.25, 0.5, 1.0],
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("dtype", [target_dtype])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
@pytest.mark.parametrize("num_logprobs", [5])
|
||||
@torch.inference_mode()
|
||||
def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
|
||||
dtype: str, max_tokens: int, num_logprobs: int) -> None:
|
||||
run_test(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
image_assets,
|
||||
model,
|
||||
size_factors=size_factors,
|
||||
dtype=dtype,
|
||||
max_tokens=max_tokens,
|
||||
num_logprobs=num_logprobs,
|
||||
tensor_parallel_size=1,
|
||||
)
|
|
@ -107,7 +107,7 @@ def _image_token_str(model_config: ModelConfig,
|
|||
return None
|
||||
if model_type.startswith("llava"):
|
||||
return tokenizer.decode(model_config.hf_config.image_token_index)
|
||||
if model_type == "chameleon":
|
||||
if model_type in ("chameleon", "internvl_chat"):
|
||||
return "<image>"
|
||||
raise TypeError(f"Unknown model type: {model_type}")
|
||||
|
||||
|
|
|
@ -37,6 +37,7 @@ _GENERATION_MODELS = {
|
|||
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
||||
"InternLMForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"),
|
||||
"InternVLChatModel": ("internvl", "InternVLChatModel"),
|
||||
"JAISLMHeadModel": ("jais", "JAISLMHeadModel"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"LlavaForConditionalGeneration":
|
||||
|
|
|
@ -0,0 +1,270 @@
|
|||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||
# --------------------------------------------------------
|
||||
# InternVL
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
|
||||
NORM2FN = {
|
||||
'rms_norm': RMSNorm,
|
||||
'layer_norm': nn.LayerNorm,
|
||||
}
|
||||
|
||||
|
||||
class InternVisionEmbeddings(nn.Module):
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(torch.randn(1, 1, self.embed_dim))
|
||||
|
||||
self.patch_embedding = nn.Conv2d(in_channels=3,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size)**2
|
||||
self.num_positions = self.num_patches + 1
|
||||
|
||||
self.position_embedding = nn.Parameter(
|
||||
torch.randn(1, self.num_positions, self.embed_dim))
|
||||
|
||||
def _get_pos_embed(self, pos_embed, H, W):
|
||||
target_dtype = pos_embed.dtype
|
||||
pos_embed = pos_embed.float().reshape(
|
||||
1, self.image_size // self.patch_size,
|
||||
self.image_size // self.patch_size, -1).permute(0, 3, 1, 2)
|
||||
pos_embed = F.interpolate(pos_embed,
|
||||
size=(H, W),
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pos_embed = pos_embed.reshape(1, -1, H * W).permute(0, 2,
|
||||
1).to(target_dtype)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(
|
||||
target_dtype)) # shape = [*, channel, width, height]
|
||||
batch_size, _, height, width = patch_embeds.shape
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1,
|
||||
-1).to(target_dtype)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
position_embedding = torch.cat([
|
||||
self.position_embedding[:, :1, :],
|
||||
self._get_pos_embed(self.position_embedding[:, 1:, :], height,
|
||||
width)
|
||||
],
|
||||
dim=1)
|
||||
embeddings = embeddings + position_embedding.to(target_dtype)
|
||||
return embeddings
|
||||
|
||||
|
||||
class InternAttention(nn.Module):
|
||||
"""Multi-headed attention from 'Attention Is All You Need' paper"""
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
if self.head_dim * self.num_heads != self.embed_dim:
|
||||
raise ValueError(
|
||||
f'embed_dim must be divisible by num_heads '
|
||||
f'(got `embed_dim`: {self.embed_dim} and `num_heads`:'
|
||||
f' {self.num_heads}).')
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(self.embed_dim,
|
||||
3 * self.embed_dim,
|
||||
bias=config.qkv_bias)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
if self.qk_normalization:
|
||||
self.q_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = RMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads,
|
||||
C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
if self.qk_normalization:
|
||||
B_, H_, N_, D_ = q.shape
|
||||
q = self.q_norm(q.transpose(1, 2).flatten(-2, -1)).view(
|
||||
B_, N_, H_, D_).transpose(1, 2)
|
||||
k = self.k_norm(k.transpose(1, 2).flatten(-2, -1)).view(
|
||||
B_, N_, H_, D_).transpose(1, 2)
|
||||
|
||||
x = F.scaled_dot_product_attention(q, k, v, scale=self.scale)
|
||||
x = x.transpose(1, 2).reshape(B, N, C)
|
||||
|
||||
x = self.proj(x)
|
||||
return x
|
||||
|
||||
|
||||
class InternMLP(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.activation_fn = get_act_fn(config.hidden_act)
|
||||
self.fc1 = ColumnParallelLinear(config.hidden_size,
|
||||
config.intermediate_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
self.fc2 = RowParallelLinear(config.intermediate_size,
|
||||
config.hidden_size,
|
||||
bias=True,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states, _ = self.fc1(hidden_states)
|
||||
hidden_states = self.activation_fn(hidden_states)
|
||||
hidden_states, _ = self.fc2(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternVisionEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
|
||||
self.attn = InternAttention(config)
|
||||
self.mlp = InternMLP(config, quant_config=quant_config)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
self.norm2 = NORM2FN[self.norm_type](self.embed_dim,
|
||||
eps=config.layer_norm_eps)
|
||||
|
||||
self.ls1 = nn.Parameter(config.initializer_factor *
|
||||
torch.ones(self.embed_dim))
|
||||
self.ls2 = nn.Parameter(config.initializer_factor *
|
||||
torch.ones(self.embed_dim))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
):
|
||||
hidden_states = hidden_states + self.attn(
|
||||
self.norm1(hidden_states)) * self.ls1
|
||||
|
||||
hidden_states = hidden_states + self.mlp(
|
||||
self.norm2(hidden_states)) * self.ls2
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternVisionEncoder(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
if num_hidden_layers_override is None:
|
||||
num_hidden_layers = config.num_hidden_layers
|
||||
else:
|
||||
num_hidden_layers = num_hidden_layers_override
|
||||
self.layers = nn.ModuleList([
|
||||
InternVisionEncoderLayer(config=config, quant_config=quant_config)
|
||||
for _ in range(num_hidden_layers)
|
||||
])
|
||||
|
||||
def forward(self, inputs_embeds: torch.Tensor):
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
for encoder_layer in self.layers:
|
||||
hidden_states = encoder_layer(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternVisionModel(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
num_hidden_layers_override: Optional[int] = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
self.embeddings = InternVisionEmbeddings(config)
|
||||
self.encoder = InternVisionEncoder(
|
||||
config=config,
|
||||
quant_config=quant_config,
|
||||
num_hidden_layers_override=num_hidden_layers_override)
|
||||
|
||||
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
||||
pos_emb = self.embeddings.position_embedding
|
||||
_, num_positions, embed_dim = pos_emb.shape
|
||||
cls_emb = pos_emb[:, :1, :]
|
||||
pos_emb = pos_emb[:, 1:, :].reshape(1, old_size // patch_size,
|
||||
old_size // patch_size,
|
||||
-1).permute(0, 3, 1, 2)
|
||||
pos_emb = F.interpolate(pos_emb.float(),
|
||||
size=new_size // patch_size,
|
||||
mode='bicubic',
|
||||
align_corners=False)
|
||||
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim,
|
||||
-1).permute(0, 2, 1)
|
||||
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
||||
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
||||
self.embeddings.image_size = new_size
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.Tensor] = None,
|
||||
pixel_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.FloatTensor:
|
||||
if pixel_values is None and pixel_embeds is None:
|
||||
raise ValueError(
|
||||
'You have to specify pixel_values or pixel_embeds')
|
||||
|
||||
if pixel_embeds is not None:
|
||||
hidden_states = pixel_embeds
|
||||
elif pixel_values is not None:
|
||||
if pixel_values.ndim == 4:
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
else:
|
||||
raise ValueError(
|
||||
f'wrong pixel_values size: {pixel_values.shape}')
|
||||
|
||||
encoder_outputs = self.encoder(inputs_embeds=hidden_states)
|
||||
|
||||
return encoder_outputs
|
|
@ -219,14 +219,22 @@ class InternLM2Model(nn.Module):
|
|||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.tok_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: IntermediateTensors = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.tok_embeddings(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
|
|
|
@ -0,0 +1,471 @@
|
|||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_internvl_chat.py
|
||||
# --------------------------------------------------------
|
||||
# InternVL
|
||||
# Copyright (c) 2023 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import CacheConfig, MultiModalConfig
|
||||
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.models.intern_vit import InternVisionModel
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY, BatchedTensors
|
||||
from vllm.multimodal.base import MultiModalInputs
|
||||
from vllm.multimodal.image import cached_get_tokenizer
|
||||
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||
|
||||
from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip,
|
||||
get_clip_num_patches)
|
||||
from .interfaces import SupportsVision
|
||||
from .utils import merge_vision_embeddings
|
||||
|
||||
IMG_START = '<img>'
|
||||
IMG_END = '</img>'
|
||||
IMG_CONTEXT = '<IMG_CONTEXT>'
|
||||
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
MAX_IMAGE_FEATURE_SIZE_WIDTH = 3000
|
||||
MAX_IMAGE_FEATURE_SIZE_HEIGHT = 500
|
||||
|
||||
|
||||
class InternVLImagePixelInputs(TypedDict):
|
||||
type: Literal["pixel_values"]
|
||||
data: BatchedTensors
|
||||
"""
|
||||
Shape: `(batch_size, 1 + num_patches, num_channels, height, width)`
|
||||
|
||||
Note that `num_patches` may be different for each batch, in which case
|
||||
the data is passed as a list instead of a batched tensor.
|
||||
"""
|
||||
|
||||
|
||||
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def build_transform(input_size):
|
||||
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
||||
transform = T.Compose([
|
||||
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
||||
T.Resize((input_size, input_size),
|
||||
interpolation=T.InterpolationMode.BICUBIC),
|
||||
T.ToTensor(),
|
||||
T.Normalize(mean=MEAN, std=STD)
|
||||
])
|
||||
return transform
|
||||
|
||||
|
||||
# copied from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
||||
image_size):
|
||||
best_ratio_diff = float('inf')
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
|
||||
def calculate_num_blocks(orig_width: int,
|
||||
orig_height: int,
|
||||
min_num=1,
|
||||
max_num=6,
|
||||
image_size=448):
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1) for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
|
||||
target_ratios, orig_width,
|
||||
orig_height, image_size)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
return blocks, target_width, target_height
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def dynamic_preprocess(image,
|
||||
min_num=1,
|
||||
max_num=6,
|
||||
image_size=448,
|
||||
use_thumbnail=False):
|
||||
orig_width, orig_height = image.size
|
||||
|
||||
blocks, target_width, target_height = calculate_num_blocks(
|
||||
orig_width, orig_height, min_num, max_num, image_size)
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = ((i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
|
||||
# adapted from https://huggingface.co/OpenGVLab/InternVL2-1B
|
||||
def image_to_pixel_values(image: Image.Image, input_size=448, max_num=6):
|
||||
transform = build_transform(input_size=input_size)
|
||||
images = dynamic_preprocess(image,
|
||||
image_size=input_size,
|
||||
use_thumbnail=True,
|
||||
max_num=max_num)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values
|
||||
|
||||
|
||||
def get_internvl_num_patches(image_size: int, patch_size: int,
|
||||
downsample_ratio: float):
|
||||
return int(
|
||||
get_clip_num_patches(image_size=image_size, patch_size=patch_size) *
|
||||
(downsample_ratio**2))
|
||||
|
||||
|
||||
def get_max_internvl_image_tokens(ctx: InputContext):
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
image_size = vision_config.image_size
|
||||
patch_size = vision_config.patch_size
|
||||
downsample_ratio = hf_config.downsample_ratio
|
||||
num_patches = get_internvl_num_patches(image_size, patch_size,
|
||||
downsample_ratio)
|
||||
return num_patches * 7
|
||||
|
||||
|
||||
def input_processor_for_internvl(ctx: InputContext, llm_inputs: LLMInputs):
|
||||
multi_modal_data = llm_inputs.get("multi_modal_data")
|
||||
if multi_modal_data is None or "image" not in multi_modal_data:
|
||||
return llm_inputs
|
||||
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
|
||||
image_data = multi_modal_data["image"]
|
||||
if isinstance(image_data, Image.Image):
|
||||
width, height = image_data.size
|
||||
num_blocks, _, _ = calculate_num_blocks(width, height)
|
||||
elif isinstance(image_data, torch.Tensor):
|
||||
raise NotImplementedError("Embeddings input is not supported yet")
|
||||
else:
|
||||
raise TypeError(f"Invalid image type: {type(image_data)}")
|
||||
|
||||
image_size = vision_config.image_size
|
||||
patch_size = vision_config.patch_size
|
||||
downsample_ratio = hf_config.downsample_ratio
|
||||
num_patches = get_internvl_num_patches(image_size, patch_size,
|
||||
downsample_ratio)
|
||||
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
|
||||
prompt = llm_inputs["prompt"]
|
||||
prompt_token_ids = llm_inputs["prompt_token_ids"]
|
||||
if prompt is None:
|
||||
prompt = tokenizer.decode(prompt_token_ids)
|
||||
image_prompt = IMG_START + IMG_CONTEXT * (num_blocks +
|
||||
1) * num_patches + IMG_END
|
||||
new_prompt = prompt.replace('<image>', image_prompt, 1)
|
||||
new_prompt_token_ids = tokenizer.encode(new_prompt)
|
||||
|
||||
return LLMInputs(prompt=prompt,
|
||||
prompt_token_ids=new_prompt_token_ids,
|
||||
multi_modal_data=multi_modal_data)
|
||||
|
||||
|
||||
def input_mapper_for_internvl(ctx: InputContext, data: object):
|
||||
if isinstance(data, Image.Image):
|
||||
data = image_to_pixel_values(data)
|
||||
model_config = ctx.model_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
image_token_id = tokenizer.encode(IMG_CONTEXT,
|
||||
add_special_tokens=False,
|
||||
return_tensors="pt")[0]
|
||||
|
||||
return MultiModalInputs({
|
||||
"pixel_values": data,
|
||||
"image_token_id": image_token_id
|
||||
})
|
||||
|
||||
|
||||
def dummy_data_for_internvl(ctx: InputContext, seq_len: int):
|
||||
|
||||
image_feature_size = get_max_internvl_image_tokens(ctx)
|
||||
model_config = ctx.model_config
|
||||
hf_config = ctx.get_hf_config(PretrainedConfig)
|
||||
vision_config = hf_config.vision_config
|
||||
tokenizer = cached_get_tokenizer(model_config.tokenizer,
|
||||
trust_remote_code=True)
|
||||
|
||||
seq_data = dummy_seq_data_for_clip(
|
||||
vision_config,
|
||||
seq_len,
|
||||
image_token_id=tokenizer.encode(IMG_CONTEXT,
|
||||
add_special_tokens=False)[0],
|
||||
image_feature_size_override=image_feature_size,
|
||||
)
|
||||
mm_data = dummy_image_for_clip(
|
||||
vision_config,
|
||||
image_width_override=MAX_IMAGE_FEATURE_SIZE_WIDTH,
|
||||
image_height_override=MAX_IMAGE_FEATURE_SIZE_HEIGHT,
|
||||
)
|
||||
|
||||
return seq_data, mm_data
|
||||
|
||||
|
||||
@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_internvl)
|
||||
@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens)
|
||||
@INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl)
|
||||
@INPUT_REGISTRY.register_input_processor(input_processor_for_internvl)
|
||||
class InternVLChatModel(nn.Module, SupportsVision):
|
||||
|
||||
def __init__(self,
|
||||
config: PretrainedConfig,
|
||||
multimodal_config: MultiModalConfig,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self.multimodal_config = multimodal_config
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
self.select_layer = config.select_layer
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size)**2 * (config.downsample_ratio**2))
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = config.ps_version
|
||||
|
||||
vision_feature_layer = self.select_layer
|
||||
if vision_feature_layer < 0:
|
||||
num_hidden_layers = config.vision_config.num_hidden_layers \
|
||||
+ vision_feature_layer + 1
|
||||
else:
|
||||
num_hidden_layers = vision_feature_layer + 1
|
||||
self.vision_model = InternVisionModel(
|
||||
config.vision_config, num_hidden_layers_override=num_hidden_layers)
|
||||
|
||||
llm_class = ModelRegistry.load_model_cls(
|
||||
config.text_config.architectures[0])
|
||||
self.language_model = llm_class(config.text_config, cache_config,
|
||||
quant_config)
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.text_config.hidden_size
|
||||
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio)**2),
|
||||
nn.Linear(vit_hidden_size * int(1 / self.downsample_ratio)**2,
|
||||
llm_hidden_size), nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size))
|
||||
|
||||
self.img_context_token_id = None
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
x = x.view(n, int(h * scale_factor), int(w * scale_factor),
|
||||
int(c / (scale_factor * scale_factor)))
|
||||
if self.ps_version == 'v1':
|
||||
pass
|
||||
else:
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
vit_embeds = self.vision_model(pixel_values=pixel_values)
|
||||
vit_embeds = vit_embeds[:, 1:, :]
|
||||
|
||||
h = w = int(vit_embeds.shape[1]**0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(vit_embeds,
|
||||
scale_factor=self.downsample_ratio)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1,
|
||||
vit_embeds.shape[-1])
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
return vit_embeds
|
||||
|
||||
def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor:
|
||||
if list(data.shape[1:]) != [2]:
|
||||
raise ValueError(
|
||||
f"The expected image sizes shape is batch dimension plus "
|
||||
f"{[2]}. You supplied {data.shape}.")
|
||||
|
||||
return data
|
||||
|
||||
def _validate_pixel_values(
|
||||
self, data: Union[torch.Tensor, List[torch.Tensor]]
|
||||
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
h = w = self.config.vision_config.image_size
|
||||
expected_dims = (3, h, w)
|
||||
|
||||
def _validate_shape(d: torch.Tensor):
|
||||
actual_dims = tuple(d.shape)
|
||||
|
||||
if actual_dims != expected_dims:
|
||||
expected_expr = ("num_patches", *map(str, expected_dims))
|
||||
raise ValueError(
|
||||
"The expected shape of pixel values in each batch element "
|
||||
f"is {expected_expr}. You supplied {tuple(d.shape)}.")
|
||||
|
||||
for d in data:
|
||||
_validate_shape(d)
|
||||
|
||||
return data
|
||||
|
||||
def _parse_and_validate_image_input(
|
||||
self, **kwargs: object) -> Optional[InternVLImagePixelInputs]:
|
||||
pixel_values = kwargs.pop("pixel_values", None)
|
||||
image_token_id = kwargs.pop("image_token_id", None)
|
||||
|
||||
if pixel_values is None:
|
||||
return None
|
||||
|
||||
self.img_context_token_id = image_token_id[0]
|
||||
|
||||
if not isinstance(pixel_values, (torch.Tensor, list)):
|
||||
raise ValueError("Incorrect type of pixel values. "
|
||||
f"Got type: {type(pixel_values)}")
|
||||
|
||||
return InternVLImagePixelInputs(
|
||||
type="pixel_values",
|
||||
data=self._validate_pixel_values(pixel_values),
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
**kwargs: object,
|
||||
) -> SamplerOutput:
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is not None:
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
vit_embeds = self.extract_feature(image_input["data"])
|
||||
inputs_embeds = merge_vision_embeddings(input_ids, inputs_embeds,
|
||||
vit_embeds,
|
||||
self.img_context_token_id)
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
kv_caches,
|
||||
attn_metadata,
|
||||
None,
|
||||
inputs_embeds=inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(self, hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||
return self.language_model.compute_logits(hidden_states,
|
||||
sampling_metadata)
|
||||
|
||||
def sample(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
return self.language_model.sample(logits, sampling_metadata)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
(".gate_up_proj", ".w1", 0),
|
||||
(".gate_up_proj", ".w3", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.config.text_config.tie_word_embeddings \
|
||||
and "lm_head.weight" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
# We only do sharding for language model
|
||||
# and not vision model for now.
|
||||
if "vision_embed_tokens" in name and self.vision_embed_tokens:
|
||||
continue
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if "wqkv" in name:
|
||||
config = self.config.text_config
|
||||
kv_groups = (config.num_attention_heads //
|
||||
config.num_key_value_heads)
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
loaded_weight = loaded_weight.view(-1, 2 + kv_groups,
|
||||
head_dim,
|
||||
loaded_weight.shape[-1])
|
||||
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1],
|
||||
dim=1)
|
||||
wq = wq.reshape(-1, wq.shape[-1])
|
||||
wk = wk.reshape(-1, wk.shape[-1])
|
||||
wv = wv.reshape(-1, wv.shape[-1])
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, wq, 'q')
|
||||
weight_loader(param, wk, 'k')
|
||||
weight_loader(param, wv, 'v')
|
||||
continue
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
|
@ -243,14 +243,22 @@ class Qwen2Model(nn.Module):
|
|||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.embed_tokens(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
if inputs_embeds is not None:
|
||||
hidden_states = inputs_embeds
|
||||
else:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
layer = self.layers[i]
|
||||
|
|
|
@ -6,9 +6,10 @@ from transformers import GenerationConfig, PretrainedConfig
|
|||
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||
JAISConfig, MedusaConfig,
|
||||
MLPSpeculatorConfig, MPTConfig,
|
||||
NemotronConfig, RWConfig)
|
||||
InternVLChatConfig, JAISConfig,
|
||||
MedusaConfig, MLPSpeculatorConfig,
|
||||
MPTConfig, NemotronConfig,
|
||||
RWConfig)
|
||||
|
||||
if VLLM_USE_MODELSCOPE:
|
||||
from modelscope import AutoConfig
|
||||
|
@ -26,6 +27,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
|||
"jais": JAISConfig,
|
||||
"mlp_speculator": MLPSpeculatorConfig,
|
||||
"medusa": MedusaConfig,
|
||||
"internvl_chat": InternVLChatConfig,
|
||||
"nemotron": NemotronConfig,
|
||||
}
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||
from vllm.transformers_utils.configs.falcon import RWConfig
|
||||
from vllm.transformers_utils.configs.internvl import InternVLChatConfig
|
||||
from vllm.transformers_utils.configs.jais import JAISConfig
|
||||
from vllm.transformers_utils.configs.medusa import MedusaConfig
|
||||
from vllm.transformers_utils.configs.mlp_speculator import MLPSpeculatorConfig
|
||||
|
@ -15,6 +16,7 @@ __all__ = [
|
|||
"DbrxConfig",
|
||||
"MPTConfig",
|
||||
"RWConfig",
|
||||
"InternVLChatConfig",
|
||||
"JAISConfig",
|
||||
"MedusaConfig",
|
||||
"MLPSpeculatorConfig",
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# Adapted from
|
||||
# https://huggingface.co/OpenGVLab/InternVL2-1B/blob/main/configuration_internvl_chat.py
|
||||
# --------------------------------------------------------
|
||||
# InternVL
|
||||
# Copyright (c) 2024 OpenGVLab
|
||||
# Licensed under The MIT License [see LICENSE for details]
|
||||
# --------------------------------------------------------
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class InternVLChatConfig(PretrainedConfig):
|
||||
model_type = 'internvl_chat'
|
||||
is_composition = True
|
||||
|
||||
def __init__(self,
|
||||
vision_config=None,
|
||||
llm_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
select_layer=-1,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
ps_version='v1',
|
||||
min_dynamic_patch=1,
|
||||
max_dynamic_patch=6,
|
||||
**kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {}
|
||||
|
||||
if llm_config is None:
|
||||
llm_config = {}
|
||||
|
||||
self.vision_config = PretrainedConfig(**vision_config)
|
||||
self.text_config = PretrainedConfig(**llm_config)
|
||||
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.ps_version = ps_version # pixel shuffle version
|
||||
self.min_dynamic_patch = min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
Loading…
Reference in New Issue