mirror of https://github.com/vllm-project/vllm
Don't download both safetensor and bin files. (#2480)
This commit is contained in:
parent
18473cf498
commit
7e1081139d
|
@ -1,12 +1,13 @@
|
||||||
"""Utilities for downloading and initializing model weights."""
|
"""Utilities for downloading and initializing model weights."""
|
||||||
import filelock
|
import filelock
|
||||||
import glob
|
import glob
|
||||||
|
import fnmatch
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Any, Iterator, List, Optional, Tuple
|
from typing import Any, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download, HfFileSystem
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from safetensors.torch import load_file, save_file, safe_open
|
from safetensors.torch import load_file, save_file, safe_open
|
||||||
import torch
|
import torch
|
||||||
|
@ -149,6 +150,20 @@ def prepare_hf_model_weights(
|
||||||
allow_patterns += ["*.pt"]
|
allow_patterns += ["*.pt"]
|
||||||
|
|
||||||
if not is_local:
|
if not is_local:
|
||||||
|
# Before we download we look at that is available:
|
||||||
|
fs = HfFileSystem()
|
||||||
|
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
|
||||||
|
|
||||||
|
# depending on what is available we download different things
|
||||||
|
for pattern in allow_patterns:
|
||||||
|
matching = fnmatch.filter(file_list, pattern)
|
||||||
|
if len(matching) > 0:
|
||||||
|
allow_patterns = [pattern]
|
||||||
|
if pattern == "*.safetensors":
|
||||||
|
use_safetensors = True
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(f"Downloading model weights {allow_patterns}")
|
||||||
# Use file lock to prevent multiple processes from
|
# Use file lock to prevent multiple processes from
|
||||||
# downloading the same model weights at the same time.
|
# downloading the same model weights at the same time.
|
||||||
with get_lock(model_name_or_path, cache_dir):
|
with get_lock(model_name_or_path, cache_dir):
|
||||||
|
@ -163,8 +178,6 @@ def prepare_hf_model_weights(
|
||||||
for pattern in allow_patterns:
|
for pattern in allow_patterns:
|
||||||
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
||||||
if len(hf_weights_files) > 0:
|
if len(hf_weights_files) > 0:
|
||||||
if pattern == "*.safetensors":
|
|
||||||
use_safetensors = True
|
|
||||||
break
|
break
|
||||||
if not use_safetensors:
|
if not use_safetensors:
|
||||||
# Exclude files that are not needed for inference.
|
# Exclude files that are not needed for inference.
|
||||||
|
|
Loading…
Reference in New Issue