diff --git a/burn-dataset/src/source/huggingface/downloader.rs b/burn-dataset/src/source/huggingface/downloader.rs index 6bb6e20a5..4c03cfbd9 100644 --- a/burn-dataset/src/source/huggingface/downloader.rs +++ b/burn-dataset/src/source/huggingface/downloader.rs @@ -1,6 +1,8 @@ use crate::InMemDataset; use dirs::home_dir; +use std::collections::hash_map::DefaultHasher; use std::fs; +use std::hash::Hasher; use std::process::Command; use thiserror::Error; @@ -83,13 +85,20 @@ impl HuggingfaceDatasetLoader { } pub fn load_file(self) -> Result { - let path_file = format!("{}/{}-{}", cache_dir(), self.name, self.split); + let mut hasher = DefaultHasher::new(); + hasher.write(format!("{:?}", self.extractors).as_bytes()); + hasher.write(format!("{:?}", self.config).as_bytes()); + hasher.write(format!("{:?}", self.config_named).as_bytes()); + let hash = hasher.finish(); + + let base_file = format!("{}/{}-{}", cache_dir(), self.name, hash); + let path_file = format!("{}-{}", base_file, self.split); if !std::path::Path::new(&path_file).exists() { download( self.name.clone(), vec![self.split], - self.name, + base_file, self.extractors, self.config, self.config_named, @@ -280,7 +289,6 @@ class ImageFieldExtractor(Extractor): def download( name: str, keys: List[str], - download_dir: str, download_file: str, extractors: List[Extractor], *config, @@ -289,7 +297,7 @@ def download( dataset_all = load_dataset(name, *config, **kwargs) for key in keys: dataset = dataset_all[key] - dataset_file = os.path.join(download_dir, f"{download_file}-{key}") + dataset_file = f"{download_file}-{key}" print(f"Saving dataset: {name} - {key}") with open(dataset_file, "w") as file: @@ -359,8 +367,6 @@ def run(): extractors.append(RawFieldExtractor(field_name)) home = os.path.expanduser("~") - download_dir = str(os.path.join(home, DOWNLOAD_DIR)) - os.makedirs(download_dir, exist_ok=True) kwargs = {} for config_named in args.config_named: @@ -369,7 +375,6 @@ def run(): download( args.name, args.split, - download_dir, args.file, extractors, *args.config,