Don't download both safetensor and bin files. (#2480)

This commit is contained in:
Nikola Borisov 2024-01-18 11:05:53 -08:00 committed by GitHub
parent 18473cf498
commit 7e1081139d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 16 additions and 3 deletions

View File

@ -1,12 +1,13 @@
"""Utilities for downloading and initializing model weights.""" """Utilities for downloading and initializing model weights."""
import filelock import filelock
import glob import glob
import fnmatch
import json import json
import os import os
from collections import defaultdict from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download, HfFileSystem
import numpy as np import numpy as np
from safetensors.torch import load_file, save_file, safe_open from safetensors.torch import load_file, save_file, safe_open
import torch import torch
@ -149,6 +150,20 @@ def prepare_hf_model_weights(
allow_patterns += ["*.pt"] allow_patterns += ["*.pt"]
if not is_local: if not is_local:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
if pattern == "*.safetensors":
use_safetensors = True
break
logger.info(f"Downloading model weights {allow_patterns}")
# Use file lock to prevent multiple processes from # Use file lock to prevent multiple processes from
# downloading the same model weights at the same time. # downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir): with get_lock(model_name_or_path, cache_dir):
@ -163,8 +178,6 @@ def prepare_hf_model_weights(
for pattern in allow_patterns: for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern)) hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0: if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break break
if not use_safetensors: if not use_safetensors:
# Exclude files that are not needed for inference. # Exclude files that are not needed for inference.