Dataset Improvements: Add Sqlite storage backend and HF importer improvements (#353)

This commit is contained in:
Dilshod Tadjibaev 2023-05-20 13:24:55 -05:00 committed by GitHub
parent 18cb19cd03
commit 6fece7e4cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 753 additions and 367 deletions

View File

@ -23,6 +23,7 @@ members = [
[workspace.dependencies]
bytemuck = "1.13"
const-random = "0.1.15"
csv = "1.2.1"
dashmap = "5.4.0"
dirs = "5.0.0"
fake = "2.5.0"
@ -36,14 +37,18 @@ pretty_assertions = "1.3"
proc-macro2 = "1.0.56"
protobuf-codegen = "3.2"
quote = "1.0.26"
r2d2 = "0.8.10"
r2d2_sqlite = "0.21.0"
rayon = "1.7.0"
rstest = "0.17.0"
sanitize-filename = "0.4.0"
serde_rusqlite = "0.31.0"
spin = {version = "0.9.8", features = ["mutex", "spin_mutex"]}
strum = "0.24"
strum_macros = "0.24"
syn = "2.0"
thiserror = "1.0.40"
topological-sort = "0.2.2"
#
# The following packages disable the "std" feature for no_std compatibility
#

View File

@ -18,12 +18,20 @@ default = ["fake"]
fake = ["dep:fake"]
[dependencies]
derive-new = {workspace = true}
dirs = {workspace = true}
fake = {workspace = true, optional = true}
image = {version = "0.24.6", features = ["png"]}
r2d2 = {workspace = true}
r2d2_sqlite = {workspace = true}
rand = {workspace = true, features = ["std"]}
sanitize-filename = {workspace = true}
serde = {workspace = true, features = ["std", "derive"]}
serde_json = {workspace = true, features = ["std"]}
serde_rusqlite = {workspace = true}
thiserror = {workspace = true}
derive-new = {workspace = true}
csv = {workspace = true}
[dev-dependencies]
rayon = {workspace = true}
rstest = {workspace = true}

View File

@ -1,8 +1,11 @@
use std::{
fs::File,
io::{BufRead, BufReader},
path::Path,
};
use serde::de::DeserializeOwned;
use crate::Dataset;
/// Dataset where all items are stored in ram.
@ -30,10 +33,19 @@ where
impl<I> InMemDataset<I>
where
I: Clone + serde::de::DeserializeOwned,
I: Clone + DeserializeOwned,
{
pub fn from_file(file: &str) -> Result<Self, std::io::Error> {
let file = File::open(file)?;
/// Create from a dataset. All items are loaded in memory.
pub fn from_dataset(dataset: &impl Dataset<I>) -> Self {
let items: Vec<I> = dataset.iter().collect();
Self::new(items)
}
/// Create from a json rows file (one json per line).
///
/// Supported field types: https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html
pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut items = Vec::new();
@ -46,12 +58,105 @@ where
Ok(dataset)
}
/// Create from a csv file.
///
/// The first line of the csv file must be the header. The header must contain the name of the fields in the struct.
///
/// The supported field types are: String, integer, float, and bool.
///
/// See: https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde
pub fn from_csv<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut rdr = csv::Reader::from_reader(reader);
let mut items = Vec::new();
for result in rdr.deserialize() {
let item: I = result?;
items.push(item);
}
let dataset = Self::new(items);
Ok(dataset)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_data;
use crate::{test_data, SqliteDataset};
use rstest::{fixture, rstest};
use serde::{Deserialize, Serialize};
const DB_FILE: &str = "tests/data/sqlite-dataset.db";
const JSON_FILE: &str = "tests/data/dataset.json";
const CSV_FILE: &str = "tests/data/dataset.csv";
type SqlDs = SqliteDataset<Sample>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Sample {
column_str: String,
column_bytes: Vec<u8>,
column_int: i64,
column_bool: bool,
column_float: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct SampleCvs {
column_str: String,
column_int: i64,
column_bool: bool,
column_float: f64,
}
#[fixture]
fn train_dataset() -> SqlDs {
SqliteDataset::from_db_file(DB_FILE, "train")
}
#[rstest]
pub fn from_dataset(train_dataset: SqlDs) {
let dataset = InMemDataset::from_dataset(&train_dataset);
let non_existing_record_index: usize = 10;
let record_index: usize = 0;
assert_eq!(train_dataset.get(non_existing_record_index), None);
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1");
}
#[test]
pub fn from_json_rows() {
let dataset = InMemDataset::<Sample>::from_json_rows(JSON_FILE).unwrap();
let non_existing_record_index: usize = 10;
let record_index: usize = 1;
assert_eq!(dataset.get(non_existing_record_index), None);
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
assert!(!dataset.get(record_index).unwrap().column_bool);
}
#[test]
pub fn from_csv_rows() {
let dataset = InMemDataset::<SampleCvs>::from_csv(CSV_FILE).unwrap();
let non_existing_record_index: usize = 10;
let record_index: usize = 1;
assert_eq!(dataset.get(non_existing_record_index), None);
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
assert!(!dataset.get(record_index).unwrap().column_bool);
assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
}
#[test]
pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() {

View File

@ -3,9 +3,11 @@ mod base;
mod fake;
mod in_memory;
mod iterator;
mod sqlite;
#[cfg(feature = "fake")]
pub use self::fake::*;
pub use base::*;
pub use in_memory::*;
pub use iterator::*;
pub use sqlite::*;

View File

@ -0,0 +1,221 @@
use std::marker::PhantomData;
use crate::Dataset;
use r2d2::Pool;
use r2d2_sqlite::{rusqlite::OpenFlags, SqliteConnectionManager};
use serde::de::DeserializeOwned;
use serde_rusqlite::*;
/// Dataset where all items are stored in a sqlite database.
///
/// Note: The database must have a table with the same name as the split.
/// The table must have a primary key column named `row_id` which is used to index the rows.
/// `row_id` starts with 1 (one) and `index` starts with 0 (zero) (`row_id` = `index` + 1).
/// The column names must match the field names of the <I> struct. However, the field names
/// can be a subset of column names and can be in any order.
///
/// Supported serialization field types: https://docs.rs/serde_rusqlite/latest/serde_rusqlite and
/// Sqlite3 types: https://www.sqlite.org/datatype3.html
///
/// Item struct example:
///
/// ```rust
///
/// use serde::{Deserialize, Serialize};
///
/// #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
/// pub struct Sample {
/// column_str: String, // column name: column_str with type TEXT
/// column_bytes: Vec<u8>, // column name: column_bytes with type BLOB
/// column_int: i64, // column name: column_int with type INTEGER
/// column_bool: bool, // column name: column_bool with type INTEGER
/// column_float: f64, // column name: column_float with type REAL
/// }
/// ```
///
/// Sqlite table example:
///
/// ```sql
///
/// CREATE TABLE train (
/// column_str TEXT,
/// column_bytes BLOB,
/// column_int INTEGER,
/// column_bool BOOLEAN,
/// column_float FLOAT,
/// row_id INTEGER NOT NULL,
/// PRIMARY KEY (row_id)
/// );
///
/// ```
#[derive(Debug)]
pub struct SqliteDataset<I> {
db_file: String,
split: String,
conn_pool: Pool<SqliteConnectionManager>,
columns: Vec<String>,
len: usize,
select_statement: String,
phantom: PhantomData<I>,
}
impl<I> SqliteDataset<I> {
pub fn from_db_file(db_file: &str, split: &str) -> Self {
// Create a connection pool
let conn_pool = create_conn_pool(db_file);
// Create a select statement and save it
let select_statement = format!("select * from {split} where row_id = ?");
// Save the column names and the number of rows
let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split);
SqliteDataset {
db_file: db_file.to_string(),
split: split.to_string(),
conn_pool,
columns,
len,
select_statement,
phantom: PhantomData::default(),
}
}
/// Get the database file name.
pub fn db_file(&self) -> &str {
self.db_file.as_str()
}
/// Get the split name.
pub fn split(&self) -> &str {
self.split.as_str()
}
}
impl<I> Dataset<I> for SqliteDataset<I>
where
I: Clone + Send + Sync + DeserializeOwned,
{
/// Get an item from the dataset.
fn get(&self, index: usize) -> Option<I> {
// Row ids start with 1 (one) and index starts with 0 (zero)
let row_id = index + 1;
// Get a connection from the pool
let connection = self.conn_pool.get().unwrap();
let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
// Query the row with the given row_id and deserialize it into I using column names (fast option)
let mut rows = statement
.query_and_then([row_id], |row| {
from_row_with_columns::<I>(row, &self.columns)
})
.unwrap();
// Return the first row if found else None (error)
rows.next().and_then(|res| match res {
Ok(val) => Some(val),
Err(_) => None,
})
}
/// Return the number of rows in the dataset.
fn len(&self) -> usize {
self.len
}
}
/// Fetch the column names and the number of rows from the database.
fn fetch_columns_and_len(
conn_pool: &Pool<SqliteConnectionManager>,
select_statement: &str,
split: &str,
) -> (Vec<String>, usize) {
// Save the column names
let connection = conn_pool.get().unwrap();
let statement = connection.prepare(select_statement).unwrap();
let columns = columns_from_statement(&statement);
// Count the number of rows and save it as len
let mut statement = connection
.prepare(format!("select count(*) from {split}").as_str())
.unwrap();
let len = statement
.query_row([], |row| {
let len: usize = row.get(0)?;
Ok(len)
})
.unwrap();
(columns, len)
}
/// Create a connection pool and make sure the connections are read only
fn create_conn_pool(db_file: &str) -> Pool<SqliteConnectionManager> {
// Create a connection pool and make sure the connections are read only
let manager =
SqliteConnectionManager::file(db_file).with_flags(OpenFlags::SQLITE_OPEN_READ_ONLY);
let conn_pool: Pool<SqliteConnectionManager> = Pool::new(manager).unwrap();
conn_pool
}
#[cfg(test)]
mod tests {
use rayon::prelude::*;
use rstest::{fixture, rstest};
use serde::{Deserialize, Serialize};
use super::*;
type SqlDs = SqliteDataset<Sample>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct Sample {
column_str: String,
column_bytes: Vec<u8>,
column_int: i64,
column_bool: bool,
column_float: f64,
}
#[fixture]
fn train_dataset() -> SqlDs {
SqliteDataset::from_db_file("tests/data/sqlite-dataset.db", "train")
}
#[rstest]
pub fn len(train_dataset: SqlDs) {
assert_eq!(train_dataset.len(), 2);
}
#[rstest]
pub fn get_some(train_dataset: SqlDs) {
let item = train_dataset.get(0).unwrap();
assert_eq!(item.column_str, "HI1");
assert_eq!(item.column_bytes, vec![55, 231, 159]);
assert_eq!(item.column_int, 1);
assert!(item.column_bool);
assert_eq!(item.column_float, 1.0);
}
#[rstest]
pub fn get_none(train_dataset: SqlDs) {
assert_eq!(train_dataset.get(10), None);
}
#[rstest]
pub fn multi_thread(train_dataset: SqlDs) {
let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
let results: Vec<Option<Sample>> =
indices.par_iter().map(|&i| train_dataset.get(i)).collect();
let mut match_count = 0;
for (_index, result) in indices.iter().zip(results.iter()) {
match result {
Some(_val) => match_count += 1,
None => (),
}
}
assert_eq!(match_count, 5);
}
}

View File

@ -8,6 +8,7 @@ pub mod transform;
mod dataset;
pub use dataset::*;
pub use source::huggingface::downloader::*;
#[cfg(test)]
mod test_data {

View File

@ -1,152 +0,0 @@
import abc
import json
import argparse
import numpy as np
from datasets import load_dataset
from typing import List, Any, Tuple
from tqdm import tqdm
from json import JSONEncoder
class CustomEncoder(JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return JSONEncoder.default(self, obj)
DOWNLOAD_DIR = ".cache/burn-dataset"
class Extractor(abc.ABC):
def extract(self, item: Any) -> Any:
pass
@abc.abstractproperty
def name(self) -> str:
pass
class RawFieldExtractor(Extractor):
def __init__(self, field_name: str):
self.field_name = field_name
def extract(self, item: Any) -> Any:
return item[self.field_name]
@property
def name(self) -> str:
return self.field_name
class ImageFieldExtractor(Extractor):
def __init__(self, field_name: str):
self.field_name = field_name
def extract(self, item: Any) -> Any:
image = item[self.field_name]
return np.array(image).tolist()
@property
def name(self) -> str:
return self.field_name
def download(
name: str,
keys: List[str],
download_file: str,
extractors: List[Extractor],
*config,
**kwargs,
):
dataset_all = load_dataset(name, *config, **kwargs)
for key in keys:
dataset = dataset_all[key]
dataset_file = f"{download_file}-{key}"
print(f"Saving dataset: {name} - {key}")
with open(dataset_file, "w") as file:
for item in tqdm(dataset):
payload = {}
for extactor in extractors:
payload[extactor.name] = extactor.extract(item)
payload = json.dumps(payload, cls=CustomEncoder)
line = f"{payload}\n"
file.write(line)
def config_named(value: str) -> Tuple[str, str]:
try:
key, value = value.split("=")
return {key: value}
except:
raise argparse.ArgumentTypeError("config_named must be key=value")
def parse_args():
parser = argparse.ArgumentParser(
description="Huggingface datasets downloader to use with burn-dataset"
)
parser.add_argument(
"--name", type=str, help="Name of the dataset to download", required=True
)
parser.add_argument(
"--file", type=str, help="Base file name where the data is saved", required=True
)
parser.add_argument(
"--split", type=str, help="Splits to downloads", nargs="+", required=True
)
parser.add_argument(
"--config", type=str, help="Config of the dataset", nargs="+", default=[]
)
parser.add_argument(
"--config-named",
type=config_named,
help="Named config of the dataset",
nargs="+",
default=[],
)
parser.add_argument(
"--extract-image",
type=str,
help="Image field to extract",
nargs="+",
default=[],
)
parser.add_argument(
"--extract-raw", type=str, help="Raw field to extract", nargs="+", default=[]
)
return parser.parse_args()
def run():
args = parse_args()
extractors = []
for field_name in args.extract_image:
extractors.append(ImageFieldExtractor(field_name))
for field_name in args.extract_raw:
extractors.append(RawFieldExtractor(field_name))
kwargs = {}
for config_named in args.config_named:
kwargs = kwargs | config_named
download(
args.name,
args.split,
args.file,
extractors,
*args.config,
**kwargs,
)
if __name__ == "__main__":
run()

View File

@ -1,240 +1,227 @@
use crate::InMemDataset;
use dirs::home_dir;
use std::collections::hash_map::DefaultHasher;
use std::fs;
use std::hash::Hasher;
use std::fs::{self, create_dir_all};
use std::path::Path;
use std::process::Command;
use crate::SqliteDataset;
use sanitize_filename::sanitize;
use serde::de::DeserializeOwned;
use thiserror::Error;
const PYTHON: &str = "python3";
const PYTHON_SOURCE: &str = include_str!("dataset.py");
const PYTHON_SOURCE: &str = include_str!("importer.py");
#[derive(Error, Debug)]
pub enum DownloaderError {
pub enum ImporterError {
#[error("unknown: `{0}`")]
Unknown(String),
#[error("fail to download python dependencies: `{0}`")]
FailToDownloadPythonDependencies(String),
}
/// Load datasets from [huggingface datasets](https://huggingface.co/datasets).
/// Load a dataset from [huggingface datasets](https://huggingface.co/datasets).
///
/// The dataset with all splits is stored in a single sqlite database (see [SqliteDataset](SqliteDataset)).
///
/// # Example
/// ```no_run
/// use burn_dataset::HuggingfaceDatasetLoader;
/// use burn_dataset::SqliteDataset;
/// use serde::{Deserialize, Serialize};
///
/// #[derive(Deserialize, Debug, Clone)]
/// struct MNISTItemRaw {
/// pub image_bytes: Vec<u8>,
/// pub label: usize,
/// }
///
/// let train_ds:SqliteDataset<MNISTItemRaw> = HuggingfaceDatasetLoader::new("mnist")
/// .dataset("train")
/// .unwrap();
pub struct HuggingfaceDatasetLoader {
name: String,
split: String,
extractors: Vec<Extractor>,
config: Vec<String>,
config_named: Vec<(String, String)>,
deps: Vec<String>,
subset: Option<String>,
base_dir: Option<String>,
huggingface_token: Option<String>,
huggingface_cache_dir: Option<String>,
}
impl HuggingfaceDatasetLoader {
/// Create a huggingface dataset loader.
pub fn new(name: &str, split: &str) -> Self {
pub fn new(name: &str) -> Self {
Self {
name: name.to_string(),
split: split.to_string(),
extractors: Vec::new(),
config: Vec::new(),
config_named: Vec::new(),
deps: Vec::new(),
subset: None,
base_dir: None,
huggingface_token: None,
huggingface_cache_dir: None,
}
}
pub fn config(mut self, config: &str) -> Self {
self.config.push(config.to_string());
/// Create a huggingface dataset loader for a subset of the dataset.
///
/// The subset name must be one of the subsets listed in the dataset page.
///
/// If no subset names are listed, then do not use this method.
pub fn with_subset(mut self, subset: &str) -> Self {
self.subset = Some(subset.to_string());
self
}
pub fn config_named(mut self, name: &str, config: &str) -> Self {
self.config_named
.push((name.to_string(), config.to_string()));
/// Specify a base directory to store the dataset.
///
/// If not specified, the dataset will be stored in `~/.cache/burn-dataset`.
pub fn with_base_dir(mut self, base_dir: &str) -> Self {
self.base_dir = Some(base_dir.to_string());
self
}
pub fn deps(mut self, deps: &[&str]) -> Self {
self.deps
.append(&mut deps.iter().copied().map(String::from).collect());
/// Specify a huggingface token to download datasets behind authentication.
///
/// You can get a token from https://huggingface.co/settings/tokens
pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
self.huggingface_token = Some(huggingface_token.to_string());
self
}
pub fn dep(mut self, dep: &str) -> Self {
self.deps.push(dep.to_string());
/// Specify a huggingface cache directory to store the downloaded datasets.
///
/// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`.
pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
self
}
pub fn extract_image(mut self, field_name: &str) -> Self {
self.extractors
.push(Extractor::Image(field_name.to_string()));
self
}
pub fn extract_number(self, field_name: &str) -> Self {
self.extract_raw(field_name)
}
pub fn extract_string(self, field_name: &str) -> Self {
self.extract_raw(field_name)
}
pub fn load_in_memory<I: serde::de::DeserializeOwned + Clone>(
/// Load the dataset.
pub fn dataset<I: DeserializeOwned + Clone>(
self,
) -> Result<InMemDataset<I>, DownloaderError> {
let path_file = self.load_file()?;
let dataset = InMemDataset::from_file(path_file.as_str()).unwrap();
split: &str,
) -> Result<SqliteDataset<I>, ImporterError> {
let db_file = self.db_file()?;
let dataset = SqliteDataset::from_db_file(db_file.as_str(), split);
Ok(dataset)
}
pub fn load_file(self) -> Result<String, DownloaderError> {
let mut hasher = DefaultHasher::new();
hasher.write(format!("{:?}", self.extractors).as_bytes());
hasher.write(format!("{:?}", self.config).as_bytes());
hasher.write(format!("{:?}", self.config_named).as_bytes());
let hash = hasher.finish();
/// Get the path to the sqlite database file.
///
/// If the database file does not exist, it will be downloaded and imported.
pub fn db_file(self) -> Result<String, ImporterError> {
// determine (and create if needed) the base directory
let base_dir = base_dir(self.base_dir);
let base_file = format!("{}/{}-{}", cache_dir(), self.name, hash);
let path_file = format!("{}-{}", base_file, self.split);
//sanitize the name and subset
let name = sanitize(self.name.as_str());
if !std::path::Path::new(&path_file).exists() {
download(
self.name.clone(),
vec![self.split],
base_file,
self.extractors,
self.config,
self.config_named,
&self.deps,
// create the db file path
let db_file = if let Some(subset) = self.subset.clone() {
format!("{}/{}-{}.db", base_dir, name, sanitize(subset.as_str()))
} else {
format!("{}/{}.db", base_dir, name)
};
// import the dataset if needed
if !Path::new(&db_file).exists() {
import(
self.name,
self.subset,
db_file.clone(),
base_dir,
self.huggingface_token,
self.huggingface_cache_dir,
)?;
}
Ok(path_file)
}
fn extract_raw(mut self, field_name: &str) -> Self {
self.extractors.push(Extractor::Raw(field_name.to_string()));
self
Ok(db_file)
}
}
fn download(
/// Import a dataset from huggingface. The transformed dataset is stored as sqlite database.
fn import(
name: String,
splits: Vec<String>,
subset: Option<String>,
base_file: String,
extractors: Vec<Extractor>,
config: Vec<String>,
config_named: Vec<(String, String)>,
deps: &[String],
) -> Result<(), DownloaderError> {
download_python_deps(deps)?;
base_dir: String,
huggingface_token: Option<String>,
huggingface_cache_dir: Option<String>,
) -> Result<(), ImporterError> {
install_python_deps()?;
let mut command = Command::new(PYTHON);
command.arg(dataset_downloader_file_path());
command.arg("--file");
command.arg(base_file);
command.arg(importer_script_path(base_dir));
command.arg("--name");
command.arg(name);
command.arg("--split");
for split in splits {
command.arg(split);
command.arg("--file");
command.arg(base_file);
if let Some(subset) = subset {
command.arg("--subset");
command.arg(subset);
}
let mut extracted_raw = Vec::new();
let mut extracted_images = Vec::new();
for extractor in extractors {
match extractor {
Extractor::Raw(field) => extracted_raw.push(field),
Extractor::Image(field) => extracted_images.push(field),
};
if let Some(huggingface_token) = huggingface_token {
command.arg("--token");
command.arg(huggingface_token);
}
if !extracted_raw.is_empty() {
command.arg("--extract-raw");
for field in extracted_raw {
command.arg(field);
}
}
if !extracted_images.is_empty() {
command.arg("--extract-image");
for field in extracted_images {
command.arg(field);
}
}
if !config.is_empty() {
command.arg("--config");
for config in config {
command.arg(config);
}
}
if !config_named.is_empty() {
command.arg("--config-named");
for (key, value) in config_named {
command.arg(format!("{key}={value}"));
}
if let Some(huggingface_cache_dir) = huggingface_cache_dir {
command.arg("--cache_dir");
command.arg(huggingface_cache_dir);
}
let mut handle = command.spawn().unwrap();
handle
.wait()
.map_err(|err| DownloaderError::Unknown(format!("{err:?}")))?;
.map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;
Ok(())
}
fn cache_dir() -> String {
let home_dir = home_dir().unwrap();
let home_dir = home_dir.to_str().map(|s| s.to_string());
let home_dir = home_dir.unwrap();
let cache_dir = format!("{home_dir}/.cache/burn-dataset");
std::fs::create_dir_all(&cache_dir).ok();
cache_dir
/// Determine the base directory to store the dataset.
fn base_dir(base_dir: Option<String>) -> String {
let base_dir = if let Some(base_dir) = base_dir {
base_dir
} else {
let home_dir = home_dir().unwrap();
let home_dir = home_dir.to_str().map(|s| s.to_string());
let home_dir = home_dir.unwrap();
let cache_dir = format!("{home_dir}/.cache/burn-dataset");
cache_dir
};
create_dir_all(&base_dir).ok();
base_dir
}
fn dataset_downloader_file_path() -> String {
let path_dir = cache_dir();
let path_file = format!("{path_dir}/dataset.py");
fn importer_script_path(base_dir: String) -> String {
let path_file = format!("{base_dir}/importer.py");
fs::write(path_file.as_str(), PYTHON_SOURCE).expect("Write python dataset downloader");
path_file
}
fn download_python_deps(deps: &[String]) -> Result<(), DownloaderError> {
fn install_python_deps() -> Result<(), ImporterError> {
let mut command = Command::new(PYTHON);
command.args([
"-m",
"pip",
"--quiet",
"install",
"pyarrow",
"sqlalchemy",
"Pillow",
"soundfile",
"datasets",
]);
command
.args(["-m", "pip", "install", "datasets"])
.args(deps);
command
.spawn()
.map_err(|err| {
DownloaderError::FailToDownloadPythonDependencies(format!(
"{} | error: {}",
deps.to_vec().join(", "),
err
))
})?
.wait()
.map_err(|err| {
DownloaderError::FailToDownloadPythonDependencies(format!(
"{} | error: {}",
deps.to_vec().join(", "),
err
))
})?;
// Spawn the process and wait for it to complete.
let mut handle = command.spawn().unwrap();
handle.wait().map_err(|err| {
ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err))
})?;
Ok(())
}
#[derive(Debug)]
enum Extractor {
Raw(String),
Image(String),
}

View File

@ -0,0 +1,176 @@
import argparse
import pyarrow as pa
from datasets import Audio, Image, load_dataset
from sqlalchemy import Column, Integer, Table, create_engine, event, inspect
from sqlalchemy.types import LargeBinary
def download_and_export(name: str, subset: str, db_file: str, token: str, cache_dir: str):
"""
Download a dataset from using HuggingFace dataset and export it to a sqlite database.
"""
# TODO For media columns (Image and Audio) sometimes when decode=False,
# bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'}
# We should handle this case, but unfortunately we did not come across this case yet to test it.
print("*"*80)
print("Starting huggingface dataset download and export")
print(f"Dataset Name: {name}")
print(f"Subset Name: {subset}")
print(f"Sqlite database file: {db_file}")
if cache_dir is None:
print(f"Custom cache dir: {cache_dir}")
print("*"*80)
# Load the dataset
dataset_all = load_dataset(
name, subset, cache_dir=cache_dir, use_auth_token=token)
print(f"Dataset: {dataset_all}")
# Create the database connection descriptor (sqlite)
engine = create_engine(f"sqlite:///{db_file}")
# Set some sqlite pragmas to speed up the database
event.listen(engine, 'connect', set_sqlite_pragma)
# Add an row_id column to each table as primary key (datasets does not have API for this)
event.listen(Table, 'before_create', add_pk_column)
# Export each split in the dataset
for key in dataset_all.keys():
dataset = dataset_all[key]
# Disable decoding for audio and image fields
dataset = disable_decoding(dataset)
# Flatten the dataset
dataset = dataset.flatten()
# Rename columns to remove dots from the names
dataset = rename_colums(dataset)
print(f"Saving dataset: {name} - {key}")
print(f"Dataset features: {dataset.features}")
# Save the dataset to a sqlite database
dataset.to_sql(
key, # table name
engine,
# don't save the index, use row_id instead (index is not unique)
index=False,
dtype=blob_columns(dataset), # save binary columns as blob
)
# Print the schema of the database so we can reference the columns in the rust code
print_table_info(engine)
def disable_decoding(dataset):
"""
Disable decoding for audio and image fields. The fields will be saved as raw file bytes.
"""
for k, v in dataset.features.items():
if isinstance(v, Audio):
dataset = dataset.cast_column(k, Audio(decode=False))
elif isinstance(v, Image):
dataset = dataset.cast_column(k, Image(decode=False))
return dataset
def rename_colums(dataset):
"""
Rename columns to remove dots from the names. Dots appear in the column names because of the flattening.
Dots are not allowed in column names in rust and sql (unless quoted). So we replace them with underscores.
This way there is an easy name mapping between the rust and sql columns.
"""
for name in dataset.features.keys():
if '.' in name:
dataset = dataset.rename_column(name, name.replace('.', '_'))
return dataset
def blob_columns(dataset):
"""
Make sure all binary columns are blob columns in the database because
`to_sql` exports binary values as TEXT instead of BLOB.
"""
type_mapping = {}
for name, value in dataset.features.items():
if pa.types.is_binary(value.pa_type):
type_mapping[name] = LargeBinary
return type_mapping
def set_sqlite_pragma(dbapi_connection, connection_record):
"""
Set some sqlite pragmas to speed up the database
"""
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA synchronous = OFF")
cursor.execute("PRAGMA journal_mode = OFF")
cursor.close()
def add_pk_column(target, connection, **kw):
"""
Add an id column to each table.
"""
target.append_column(Column("row_id", Integer, primary_key=True))
def print_table_info(engine):
"""
Print the schema of the database so we can reference the columns in the rust code
"""
print(f"Printing table schema for sqlite3 db ({engine})")
inspector = inspect(engine)
for table_name in inspector.get_table_names():
print(f"Table: {table_name}")
for column in inspector.get_columns(table_name):
print(f"Column: {column['name']} - {column['type']}")
print("")
def parse_args():
parser = argparse.ArgumentParser(
description="Huggingface datasets downloader to use with burn-dataset"
)
parser.add_argument(
"--name", type=str, help="Name of the dataset to download", required=True
)
parser.add_argument(
"--file", type=str, help="Base file name where the data is saved", required=True
)
parser.add_argument(
"--subset", type=str, help="Subset name", required=False, default=None
)
parser.add_argument(
"--token", type=str, help="HuggingFace authentication token", required=False, default=None
)
parser.add_argument(
"--cache_dir", type=str, help="Cache directory", required=False, default=None
)
return parser.parse_args()
def run():
args = parse_args()
download_and_export(
args.name,
args.subset,
args.file,
args.token,
args.cache_dir,
)
if __name__ == "__main__":
run()

View File

@ -1,15 +1,58 @@
use crate::source::huggingface::downloader::HuggingfaceDatasetLoader;
use crate::{Dataset, InMemDataset};
use crate::transform::{Mapper, MapperDataset};
use crate::{Dataset, SqliteDataset};
use image;
use serde::{Deserialize, Serialize};
const WIDTH: usize = 28;
const HEIGHT: usize = 28;
#[derive(Deserialize, Serialize, Debug, Clone)]
pub struct MNISTItem {
pub image: [[f32; 28]; 28],
pub image: [[f32; WIDTH]; HEIGHT],
pub label: usize,
}
#[derive(Deserialize, Debug, Clone)]
struct MNISTItemRaw {
pub image_bytes: Vec<u8>,
pub label: usize,
}
struct BytesToImage;
impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
fn map(&self, item: &MNISTItemRaw) -> MNISTItem {
let image = image::load_from_memory(&item.image_bytes).unwrap();
let image = image.as_luma8().unwrap();
// Ensure the image dimensions are correct.
debug_assert_eq!(image.dimensions(), (WIDTH as u32, HEIGHT as u32));
// Convert the image to a 2D array of floats.
let mut image_array = [[0f32; WIDTH]; HEIGHT];
for (i, pixel) in image.as_raw().iter().enumerate() {
let x = i % WIDTH;
let y = i / HEIGHT;
image_array[y][x] = *pixel as f32;
}
MNISTItem {
image: image_array,
label: item.label,
}
}
}
type MappedDataset = MapperDataset<SqliteDataset<MNISTItemRaw>, BytesToImage, MNISTItemRaw>;
/// MNIST dataset from Huggingface.
///
/// The data is downloaded from Huggingface and stored in a SQLite database.
pub struct MNISTDataset {
dataset: InMemDataset<MNISTItem>,
dataset: MappedDataset,
}
impl Dataset<MNISTItem> for MNISTDataset {
@ -32,13 +75,12 @@ impl MNISTDataset {
}
fn new(split: &str) -> Self {
let dataset = HuggingfaceDatasetLoader::new("mnist", split)
.extract_image("image")
.extract_number("label")
.deps(&["pillow", "numpy"])
.load_in_memory()
let dataset = HuggingfaceDatasetLoader::new("mnist")
.dataset(split)
.unwrap();
let dataset = MapperDataset::new(dataset, BytesToImage);
Self { dataset }
}
}

View File

@ -1,4 +1,5 @@
pub mod downloader;
mod mnist;
pub use downloader::*;
pub use mnist::*;

View File

@ -0,0 +1,3 @@
column_str,column_int,column_bool,column_float
HI1,1,true,1.0
HI2,1,false,1.0
1 column_str column_int column_bool column_float
2 HI1 1 true 1.0
3 HI2 1 false 1.0

View File

@ -0,0 +1,2 @@
{"column_str":"HI1","column_bytes":[1,2,3,3],"column_int":1,"column_bool":true,"column_float":1.0}
{"column_str":"HI2","column_bytes":[1,2,3,3],"column_int":1,"column_bool":false,"column_float":1.0}

Binary file not shown.

View File

@ -44,7 +44,7 @@ burn-tensor = {path = "../burn-tensor", version = "0.8.0", default-features = fa
matrixmultiply = {version = "0.3.6", default-features = false}
rayon = {version= "1.7.0", optional = true}
rayon = {workspace = true, optional = true}
blas-src = {version = "0.8.0", default-features = false, optional = true}# no-std compatible

View File

@ -1,5 +1,5 @@
use burn::data::dataset::{
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, InMemDataset,
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset,
};
#[derive(new, Clone, Debug)]
@ -20,7 +20,7 @@ pub struct AgNewsItem {
}
pub struct AgNewsDataset {
dataset: InMemDataset<AgNewsItem>,
dataset: SqliteDataset<AgNewsItem>,
}
impl Dataset<TextClassificationItem> for AgNewsDataset {
@ -37,19 +37,16 @@ impl Dataset<TextClassificationItem> for AgNewsDataset {
impl AgNewsDataset {
pub fn train() -> Self {
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "train")
.extract_string("text")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
Self::new("train")
}
pub fn test() -> Self {
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "test")
.extract_string("text")
.extract_number("label")
.load_in_memory()
Self::new("test")
}
pub fn new(split: &str) -> Self {
let dataset: SqliteDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news")
.dataset(split)
.unwrap();
Self { dataset }
}
@ -80,7 +77,7 @@ pub struct DbPediaItem {
}
pub struct DbPediaDataset {
dataset: InMemDataset<DbPediaItem>,
dataset: SqliteDataset<DbPediaItem>,
}
impl Dataset<TextClassificationItem> for DbPediaDataset {
@ -100,24 +97,16 @@ impl Dataset<TextClassificationItem> for DbPediaDataset {
impl DbPediaDataset {
pub fn train() -> Self {
let dataset: InMemDataset<DbPediaItem> =
HuggingfaceDatasetLoader::new("dbpedia_14", "train")
.extract_string("title")
.extract_string("content")
.extract_number("label")
.load_in_memory()
.unwrap();
Self { dataset }
Self::new("train")
}
pub fn test() -> Self {
let dataset: InMemDataset<DbPediaItem> =
HuggingfaceDatasetLoader::new("dbpedia_14", "test")
.extract_string("title")
.extract_string("content")
.extract_number("label")
.load_in_memory()
.unwrap();
Self::new("test")
}
pub fn new(split: &str) -> Self {
let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new("dbpedia_14")
.dataset(split)
.unwrap();
Self { dataset }
}
}

View File

@ -1,5 +1,5 @@
use burn::data::dataset::{
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, InMemDataset,
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset,
};
#[derive(new, Clone, Debug)]
@ -13,7 +13,7 @@ pub struct DbPediaItem {
}
pub struct DbPediaDataset {
dataset: InMemDataset<DbPediaItem>,
dataset: SqliteDataset<DbPediaItem>,
}
impl Dataset<TextGenerationItem> for DbPediaDataset {
@ -30,20 +30,16 @@ impl Dataset<TextGenerationItem> for DbPediaDataset {
impl DbPediaDataset {
pub fn train() -> Self {
let dataset: InMemDataset<DbPediaItem> =
HuggingfaceDatasetLoader::new("dbpedia_14", "train")
.extract_string("content")
.load_in_memory()
.unwrap();
Self { dataset }
Self::new("train")
}
pub fn test() -> Self {
let dataset: InMemDataset<DbPediaItem> =
HuggingfaceDatasetLoader::new("dbpedia_14", "test")
.extract_string("content")
.load_in_memory()
.unwrap();
Self::new("test")
}
pub fn new(split: &str) -> Self {
let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new("dbpedia_14")
.dataset(split)
.unwrap();
Self { dataset }
}
}