Fix dataset clash (#255)

This commit is contained in:
Nathaniel Simard 2023-03-25 11:50:06 -04:00 committed by GitHub
parent 8d03fc2e90
commit ed24db6d3e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 12 additions and 7 deletions

View File

@ -1,6 +1,8 @@
use crate::InMemDataset;
use dirs::home_dir;
use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::Hasher;
use std::process::Command;
use thiserror::Error;
@ -83,13 +85,20 @@ impl HuggingfaceDatasetLoader {
}
pub fn load_file(self) -> Result<String, DownloaderError> {
let path_file = format!("{}/{}-{}", cache_dir(), self.name, self.split);
let mut hasher = DefaultHasher::new();
hasher.write(format!("{:?}", self.extractors).as_bytes());
hasher.write(format!("{:?}", self.config).as_bytes());
hasher.write(format!("{:?}", self.config_named).as_bytes());
let hash = hasher.finish();
let base_file = format!("{}/{}-{}", cache_dir(), self.name, hash);
let path_file = format!("{}-{}", base_file, self.split);
if !std::path::Path::new(&path_file).exists() {
download(
self.name.clone(),
vec![self.split],
self.name,
base_file,
self.extractors,
self.config,
self.config_named,
@ -280,7 +289,6 @@ class ImageFieldExtractor(Extractor):
def download(
name: str,
keys: List[str],
download_dir: str,
download_file: str,
extractors: List[Extractor],
*config,
@ -289,7 +297,7 @@ def download(
dataset_all = load_dataset(name, *config, **kwargs)
for key in keys:
dataset = dataset_all[key]
dataset_file = os.path.join(download_dir, f"{download_file}-{key}")
dataset_file = f"{download_file}-{key}"
print(f"Saving dataset: {name} - {key}")
with open(dataset_file, "w") as file:
@ -359,8 +367,6 @@ def run():
extractors.append(RawFieldExtractor(field_name))
home = os.path.expanduser("~")
download_dir = str(os.path.join(home, DOWNLOAD_DIR))
os.makedirs(download_dir, exist_ok=True)
kwargs = {}
for config_named in args.config_named:
@ -369,7 +375,6 @@ def run():
download(
args.name,
args.split,
download_dir,
args.file,
extractors,
*args.config,