mirror of https://github.com/tracel-ai/burn.git
Fix dataset clash (#255)
This commit is contained in:
parent
8d03fc2e90
commit
ed24db6d3e
|
@ -1,6 +1,8 @@
|
|||
use crate::InMemDataset;
|
||||
use dirs::home_dir;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::fs;
|
||||
use std::hash::Hasher;
|
||||
use std::process::Command;
|
||||
use thiserror::Error;
|
||||
|
||||
|
@ -83,13 +85,20 @@ impl HuggingfaceDatasetLoader {
|
|||
}
|
||||
|
||||
pub fn load_file(self) -> Result<String, DownloaderError> {
|
||||
let path_file = format!("{}/{}-{}", cache_dir(), self.name, self.split);
|
||||
let mut hasher = DefaultHasher::new();
|
||||
hasher.write(format!("{:?}", self.extractors).as_bytes());
|
||||
hasher.write(format!("{:?}", self.config).as_bytes());
|
||||
hasher.write(format!("{:?}", self.config_named).as_bytes());
|
||||
let hash = hasher.finish();
|
||||
|
||||
let base_file = format!("{}/{}-{}", cache_dir(), self.name, hash);
|
||||
let path_file = format!("{}-{}", base_file, self.split);
|
||||
|
||||
if !std::path::Path::new(&path_file).exists() {
|
||||
download(
|
||||
self.name.clone(),
|
||||
vec![self.split],
|
||||
self.name,
|
||||
base_file,
|
||||
self.extractors,
|
||||
self.config,
|
||||
self.config_named,
|
||||
|
@ -280,7 +289,6 @@ class ImageFieldExtractor(Extractor):
|
|||
def download(
|
||||
name: str,
|
||||
keys: List[str],
|
||||
download_dir: str,
|
||||
download_file: str,
|
||||
extractors: List[Extractor],
|
||||
*config,
|
||||
|
@ -289,7 +297,7 @@ def download(
|
|||
dataset_all = load_dataset(name, *config, **kwargs)
|
||||
for key in keys:
|
||||
dataset = dataset_all[key]
|
||||
dataset_file = os.path.join(download_dir, f"{download_file}-{key}")
|
||||
dataset_file = f"{download_file}-{key}"
|
||||
print(f"Saving dataset: {name} - {key}")
|
||||
|
||||
with open(dataset_file, "w") as file:
|
||||
|
@ -359,8 +367,6 @@ def run():
|
|||
extractors.append(RawFieldExtractor(field_name))
|
||||
|
||||
home = os.path.expanduser("~")
|
||||
download_dir = str(os.path.join(home, DOWNLOAD_DIR))
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
|
||||
kwargs = {}
|
||||
for config_named in args.config_named:
|
||||
|
@ -369,7 +375,6 @@ def run():
|
|||
download(
|
||||
args.name,
|
||||
args.split,
|
||||
download_dir,
|
||||
args.file,
|
||||
extractors,
|
||||
*args.config,
|
||||
|
|
Loading…
Reference in New Issue