[Bugfix] Add phi3v resize for dynamic shape and fix torchvision requirement (#5772)

This commit is contained in:
Isotr0py 2024-06-24 12:11:53 +08:00 committed by GitHub
parent 5d4d90536f
commit edd5fe5fa2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 72 additions and 5 deletions

View File

@ -3,4 +3,5 @@
# Dependencies for x86_64 CPUs
torch == 2.3.1+cpu
torchvision == 0.18.1+cpu # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

View File

@ -5,5 +5,7 @@
ray >= 2.9
nvidia-ml-py # for pynvml package
torch == 2.3.0
# These must be updated alongside torch
torchvision == 0.18.0 # Required for phi3v processor, also see https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version
xformers == 0.0.26.post1 # Requires PyTorch 2.3.0
vllm-flash-attn == 2.5.9 # Requires PyTorch 2.3.0

View File

@ -14,7 +14,6 @@ peft
requests
ray
sentence-transformers # required for embedding
torchvision # required for the image processor of phi3v
# Benchmarking
aiohttp

View File

@ -22,6 +22,7 @@ assert len(HF_IMAGE_PROMPTS) == len(IMAGE_FILES)
def iter_phi3v_configs(model_name: str):
image_hw_to_feature_size = {
(1008, 1344): 1921,
(2016, 2688): 1933,
}
for (h, w), f in image_hw_to_feature_size.items():
@ -75,6 +76,9 @@ if is_cpu():
# TODO: Add test for `tensor_parallel_size` [ref: PR #3883]
# Since we use _attn_implementation="eager" for hf_runner, here is
# numeric difference for longer context and test can't pass
@pytest.mark.xfail(
reason="Inconsistent image processor being used due to lack "
"of support for dynamic image token replacement")
@pytest.mark.parametrize("model_and_config", model_and_vl_config)
@pytest.mark.parametrize("dtype", [target_dtype])
@pytest.mark.parametrize("max_tokens", [128])

View File

@ -13,14 +13,17 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Iterable, List, Literal, Optional, Tuple, TypedDict
from typing import Dict, Iterable, List, Literal, Optional, Tuple, TypedDict
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from transformers import CLIPVisionConfig, PretrainedConfig
from vllm.attention import AttentionMetadata
from vllm.config import CacheConfig, VisionLanguageConfig
from vllm.config import CacheConfig, ModelConfig, VisionLanguageConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
@ -32,9 +35,11 @@ from vllm.model_executor.models.llama import LlamaModel
from vllm.model_executor.models.vlm_base import VisionLanguageModelBase
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import get_dummy_image_data
from vllm.multimodal.image import ImagePixelData, get_dummy_image_data
from vllm.sequence import SamplerOutput
logger = init_logger(__name__)
_KEYS_TO_MODIFY_MAPPING = {
"model.vision_embed_tokens": "vision_embed_tokens",
}
@ -268,7 +273,63 @@ class Phi3VImagePixelInputs(TypedDict):
"""Shape: (batch_size, 2)"""
@MULTIMODAL_REGISTRY.register_image_pixel_input()
# FIXME(Isotr0py): Remove these after dynamic num_img_tokens is supported
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def calc_padded_size(width, height, padding_unit=336):
target_height = int(np.ceil(height / padding_unit) * padding_unit)
top_padding = int((target_height - height) / 2)
bottom_padding = target_height - height - top_padding
padded_width = width
padded_height = height + top_padding + bottom_padding
return padded_width, padded_height
# copied from https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py
def calc_hd_transform_size(width, height, hd_num=16):
transposed = False
if width < height:
width, height = height, width
transposed = True
ratio = width / height
scale = 1
while scale * np.ceil(scale / ratio) <= hd_num:
scale += 1
scale -= 1
new_width = int(scale * 336)
new_height = int(new_width / ratio)
padded_width, padded_height = calc_padded_size(new_width, new_height)
if transposed:
padded_width, padded_height = padded_height, padded_width
return padded_width, padded_height
def _image_processor(
data: ImagePixelData,
model_config: ModelConfig,
vlm_config: VisionLanguageConfig,
) -> Dict[str, torch.Tensor]:
image = data.image
if isinstance(image, Image.Image):
# Temporary patch before dynamic number of image tokens is supported
_, _, h, w = vlm_config.image_input_shape
if (w, h) != calc_hd_transform_size(image.width, image.height):
logger.warning(
"Dynamic image shape is currently not supported. "
"Resizing input image to (%d, %d).", w, h)
data.image = image.resize((w, h))
return MULTIMODAL_REGISTRY._get_plugin_for_data_type(ImagePixelData) \
._default_input_processor(data, model_config, vlm_config)
@MULTIMODAL_REGISTRY.register_image_pixel_input(_image_processor)
@MULTIMODAL_REGISTRY.register_dummy_data(get_dummy_image_data)
class Phi3VForCausalLM(VisionLanguageModelBase):