Don't download both safetensor and bin files. (#2480)

This commit is contained in:
Nikola Borisov 2024-01-18 11:05:53 -08:00 committed by GitHub
parent 18473cf498
commit 7e1081139d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 16 additions and 3 deletions

View File

@ -1,12 +1,13 @@
"""Utilities for downloading and initializing model weights."""
import filelock
import glob
import fnmatch
import json
import os
from collections import defaultdict
from typing import Any, Iterator, List, Optional, Tuple
from huggingface_hub import snapshot_download
from huggingface_hub import snapshot_download, HfFileSystem
import numpy as np
from safetensors.torch import load_file, save_file, safe_open
import torch
@ -149,6 +150,20 @@ def prepare_hf_model_weights(
allow_patterns += ["*.pt"]
if not is_local:
# Before we download we look at that is available:
fs = HfFileSystem()
file_list = fs.ls(model_name_or_path, detail=False, revision=revision)
# depending on what is available we download different things
for pattern in allow_patterns:
matching = fnmatch.filter(file_list, pattern)
if len(matching) > 0:
allow_patterns = [pattern]
if pattern == "*.safetensors":
use_safetensors = True
break
logger.info(f"Downloading model weights {allow_patterns}")
# Use file lock to prevent multiple processes from
# downloading the same model weights at the same time.
with get_lock(model_name_or_path, cache_dir):
@ -163,8 +178,6 @@ def prepare_hf_model_weights(
for pattern in allow_patterns:
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
if len(hf_weights_files) > 0:
if pattern == "*.safetensors":
use_safetensors = True
break
if not use_safetensors:
# Exclude files that are not needed for inference.