mirror of https://github.com/vllm-project/vllm
Don't download both safetensor and bin files. (#2480)
This commit is contained in:
parent
18473cf498
commit
7e1081139d
|
@ -1,12 +1,13 @@
|
|||
"""Utilities for downloading and initializing model weights."""
|
||||
import filelock
|
||||
import glob
|
||||
import fnmatch
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Any, Iterator, List, Optional, Tuple
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from huggingface_hub import snapshot_download, HfFileSystem
|
||||
import numpy as np
|
||||
from safetensors.torch import load_file, save_file, safe_open
|
||||
import torch
|
||||
|
@ -149,6 +150,20 @@ def prepare_hf_model_weights(
|
|||
allow_patterns += ["*.pt"]
|
||||
|
||||
if not is_local:
|
||||
# Before we download we look at that is available:
|
||||
fs = HfFileSystem()
|
||||
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||
|
||||
# depending on what is available we download different things
|
||||
for pattern in allow_patterns:
|
||||
matching = fnmatch.filter(file_list, pattern)
|
||||
if len(matching) > 0:
|
||||
allow_patterns = [pattern]
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
|
||||
logger.info(f"Downloading model weights {allow_patterns}")
|
||||
# Use file lock to prevent multiple processes from
|
||||
# downloading the same model weights at the same time.
|
||||
with get_lock(model_name_or_path, cache_dir):
|
||||
|
@ -163,8 +178,6 @@ def prepare_hf_model_weights(
|
|||
for pattern in allow_patterns:
|
||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||
if len(hf_weights_files) > 0:
|
||||
if pattern == "*.safetensors":
|
||||
use_safetensors = True
|
||||
break
|
||||
if not use_safetensors:
|
||||
# Exclude files that are not needed for inference.
|
||||
|
|
Loading…
Reference in New Issue