mirror of https://github.com/tracel-ai/burn.git
Dataset Improvements: Add Sqlite storage backend and HF importer improvements (#353)
This commit is contained in:
parent
18cb19cd03
commit
6fece7e4cb
|
@ -23,6 +23,7 @@ members = [
|
|||
[workspace.dependencies]
|
||||
bytemuck = "1.13"
|
||||
const-random = "0.1.15"
|
||||
csv = "1.2.1"
|
||||
dashmap = "5.4.0"
|
||||
dirs = "5.0.0"
|
||||
fake = "2.5.0"
|
||||
|
@ -36,14 +37,18 @@ pretty_assertions = "1.3"
|
|||
proc-macro2 = "1.0.56"
|
||||
protobuf-codegen = "3.2"
|
||||
quote = "1.0.26"
|
||||
r2d2 = "0.8.10"
|
||||
r2d2_sqlite = "0.21.0"
|
||||
rayon = "1.7.0"
|
||||
rstest = "0.17.0"
|
||||
sanitize-filename = "0.4.0"
|
||||
serde_rusqlite = "0.31.0"
|
||||
spin = {version = "0.9.8", features = ["mutex", "spin_mutex"]}
|
||||
strum = "0.24"
|
||||
strum_macros = "0.24"
|
||||
syn = "2.0"
|
||||
thiserror = "1.0.40"
|
||||
topological-sort = "0.2.2"
|
||||
|
||||
#
|
||||
# The following packages disable the "std" feature for no_std compatibility
|
||||
#
|
||||
|
|
|
@ -18,12 +18,20 @@ default = ["fake"]
|
|||
fake = ["dep:fake"]
|
||||
|
||||
[dependencies]
|
||||
derive-new = {workspace = true}
|
||||
dirs = {workspace = true}
|
||||
fake = {workspace = true, optional = true}
|
||||
image = {version = "0.24.6", features = ["png"]}
|
||||
r2d2 = {workspace = true}
|
||||
r2d2_sqlite = {workspace = true}
|
||||
rand = {workspace = true, features = ["std"]}
|
||||
sanitize-filename = {workspace = true}
|
||||
serde = {workspace = true, features = ["std", "derive"]}
|
||||
serde_json = {workspace = true, features = ["std"]}
|
||||
serde_rusqlite = {workspace = true}
|
||||
thiserror = {workspace = true}
|
||||
derive-new = {workspace = true}
|
||||
csv = {workspace = true}
|
||||
|
||||
[dev-dependencies]
|
||||
rayon = {workspace = true}
|
||||
rstest = {workspace = true}
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
use std::{
|
||||
fs::File,
|
||||
io::{BufRead, BufReader},
|
||||
path::Path,
|
||||
};
|
||||
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use crate::Dataset;
|
||||
|
||||
/// Dataset where all items are stored in ram.
|
||||
|
@ -30,10 +33,19 @@ where
|
|||
|
||||
impl<I> InMemDataset<I>
|
||||
where
|
||||
I: Clone + serde::de::DeserializeOwned,
|
||||
I: Clone + DeserializeOwned,
|
||||
{
|
||||
pub fn from_file(file: &str) -> Result<Self, std::io::Error> {
|
||||
let file = File::open(file)?;
|
||||
/// Create from a dataset. All items are loaded in memory.
|
||||
pub fn from_dataset(dataset: &impl Dataset<I>) -> Self {
|
||||
let items: Vec<I> = dataset.iter().collect();
|
||||
Self::new(items)
|
||||
}
|
||||
|
||||
/// Create from a json rows file (one json per line).
|
||||
///
|
||||
/// Supported field types: https://docs.rs/serde_json/latest/serde_json/value/enum.Value.html
|
||||
pub fn from_json_rows<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
|
||||
let file = File::open(path)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut items = Vec::new();
|
||||
|
||||
|
@ -46,12 +58,105 @@ where
|
|||
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
/// Create from a csv file.
|
||||
///
|
||||
/// The first line of the csv file must be the header. The header must contain the name of the fields in the struct.
|
||||
///
|
||||
/// The supported field types are: String, integer, float, and bool.
|
||||
///
|
||||
/// See: https://docs.rs/csv/latest/csv/tutorial/index.html#reading-with-serde
|
||||
pub fn from_csv<P: AsRef<Path>>(path: P) -> Result<Self, std::io::Error> {
|
||||
let file = File::open(path)?;
|
||||
let reader = BufReader::new(file);
|
||||
let mut rdr = csv::Reader::from_reader(reader);
|
||||
|
||||
let mut items = Vec::new();
|
||||
|
||||
for result in rdr.deserialize() {
|
||||
let item: I = result?;
|
||||
items.push(item);
|
||||
}
|
||||
|
||||
let dataset = Self::new(items);
|
||||
|
||||
Ok(dataset)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
||||
use super::*;
|
||||
use crate::test_data;
|
||||
use crate::{test_data, SqliteDataset};
|
||||
|
||||
use rstest::{fixture, rstest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const DB_FILE: &str = "tests/data/sqlite-dataset.db";
|
||||
const JSON_FILE: &str = "tests/data/dataset.json";
|
||||
const CSV_FILE: &str = "tests/data/dataset.csv";
|
||||
|
||||
type SqlDs = SqliteDataset<Sample>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Sample {
|
||||
column_str: String,
|
||||
column_bytes: Vec<u8>,
|
||||
column_int: i64,
|
||||
column_bool: bool,
|
||||
column_float: f64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct SampleCvs {
|
||||
column_str: String,
|
||||
column_int: i64,
|
||||
column_bool: bool,
|
||||
column_float: f64,
|
||||
}
|
||||
|
||||
#[fixture]
|
||||
fn train_dataset() -> SqlDs {
|
||||
SqliteDataset::from_db_file(DB_FILE, "train")
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn from_dataset(train_dataset: SqlDs) {
|
||||
let dataset = InMemDataset::from_dataset(&train_dataset);
|
||||
|
||||
let non_existing_record_index: usize = 10;
|
||||
let record_index: usize = 0;
|
||||
|
||||
assert_eq!(train_dataset.get(non_existing_record_index), None);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn from_json_rows() {
|
||||
let dataset = InMemDataset::<Sample>::from_json_rows(JSON_FILE).unwrap();
|
||||
|
||||
let non_existing_record_index: usize = 10;
|
||||
let record_index: usize = 1;
|
||||
|
||||
assert_eq!(dataset.get(non_existing_record_index), None);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
|
||||
assert!(!dataset.get(record_index).unwrap().column_bool);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn from_csv_rows() {
|
||||
let dataset = InMemDataset::<SampleCvs>::from_csv(CSV_FILE).unwrap();
|
||||
|
||||
let non_existing_record_index: usize = 10;
|
||||
let record_index: usize = 1;
|
||||
|
||||
assert_eq!(dataset.get(non_existing_record_index), None);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_str, "HI2");
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_int, 1);
|
||||
assert!(!dataset.get(record_index).unwrap().column_bool);
|
||||
assert_eq!(dataset.get(record_index).unwrap().column_float, 1.0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn given_in_memory_dataset_when_iterate_should_iterate_though_all_items() {
|
||||
|
|
|
@ -3,9 +3,11 @@ mod base;
|
|||
mod fake;
|
||||
mod in_memory;
|
||||
mod iterator;
|
||||
mod sqlite;
|
||||
|
||||
#[cfg(feature = "fake")]
|
||||
pub use self::fake::*;
|
||||
pub use base::*;
|
||||
pub use in_memory::*;
|
||||
pub use iterator::*;
|
||||
pub use sqlite::*;
|
||||
|
|
|
@ -0,0 +1,221 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use crate::Dataset;
|
||||
|
||||
use r2d2::Pool;
|
||||
use r2d2_sqlite::{rusqlite::OpenFlags, SqliteConnectionManager};
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_rusqlite::*;
|
||||
|
||||
/// Dataset where all items are stored in a sqlite database.
|
||||
///
|
||||
/// Note: The database must have a table with the same name as the split.
|
||||
/// The table must have a primary key column named `row_id` which is used to index the rows.
|
||||
/// `row_id` starts with 1 (one) and `index` starts with 0 (zero) (`row_id` = `index` + 1).
|
||||
/// The column names must match the field names of the <I> struct. However, the field names
|
||||
/// can be a subset of column names and can be in any order.
|
||||
///
|
||||
/// Supported serialization field types: https://docs.rs/serde_rusqlite/latest/serde_rusqlite and
|
||||
/// Sqlite3 types: https://www.sqlite.org/datatype3.html
|
||||
///
|
||||
/// Item struct example:
|
||||
///
|
||||
/// ```rust
|
||||
///
|
||||
/// use serde::{Deserialize, Serialize};
|
||||
///
|
||||
/// #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
/// pub struct Sample {
|
||||
/// column_str: String, // column name: column_str with type TEXT
|
||||
/// column_bytes: Vec<u8>, // column name: column_bytes with type BLOB
|
||||
/// column_int: i64, // column name: column_int with type INTEGER
|
||||
/// column_bool: bool, // column name: column_bool with type INTEGER
|
||||
/// column_float: f64, // column name: column_float with type REAL
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Sqlite table example:
|
||||
///
|
||||
/// ```sql
|
||||
///
|
||||
/// CREATE TABLE train (
|
||||
/// column_str TEXT,
|
||||
/// column_bytes BLOB,
|
||||
/// column_int INTEGER,
|
||||
/// column_bool BOOLEAN,
|
||||
/// column_float FLOAT,
|
||||
/// row_id INTEGER NOT NULL,
|
||||
/// PRIMARY KEY (row_id)
|
||||
/// );
|
||||
///
|
||||
/// ```
|
||||
#[derive(Debug)]
|
||||
pub struct SqliteDataset<I> {
|
||||
db_file: String,
|
||||
split: String,
|
||||
conn_pool: Pool<SqliteConnectionManager>,
|
||||
columns: Vec<String>,
|
||||
len: usize,
|
||||
select_statement: String,
|
||||
phantom: PhantomData<I>,
|
||||
}
|
||||
|
||||
impl<I> SqliteDataset<I> {
|
||||
pub fn from_db_file(db_file: &str, split: &str) -> Self {
|
||||
// Create a connection pool
|
||||
let conn_pool = create_conn_pool(db_file);
|
||||
|
||||
// Create a select statement and save it
|
||||
let select_statement = format!("select * from {split} where row_id = ?");
|
||||
|
||||
// Save the column names and the number of rows
|
||||
let (columns, len) = fetch_columns_and_len(&conn_pool, &select_statement, split);
|
||||
|
||||
SqliteDataset {
|
||||
db_file: db_file.to_string(),
|
||||
split: split.to_string(),
|
||||
conn_pool,
|
||||
columns,
|
||||
len,
|
||||
select_statement,
|
||||
phantom: PhantomData::default(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the database file name.
|
||||
pub fn db_file(&self) -> &str {
|
||||
self.db_file.as_str()
|
||||
}
|
||||
|
||||
/// Get the split name.
|
||||
pub fn split(&self) -> &str {
|
||||
self.split.as_str()
|
||||
}
|
||||
}
|
||||
|
||||
impl<I> Dataset<I> for SqliteDataset<I>
|
||||
where
|
||||
I: Clone + Send + Sync + DeserializeOwned,
|
||||
{
|
||||
/// Get an item from the dataset.
|
||||
fn get(&self, index: usize) -> Option<I> {
|
||||
// Row ids start with 1 (one) and index starts with 0 (zero)
|
||||
let row_id = index + 1;
|
||||
|
||||
// Get a connection from the pool
|
||||
let connection = self.conn_pool.get().unwrap();
|
||||
let mut statement = connection.prepare(self.select_statement.as_str()).unwrap();
|
||||
|
||||
// Query the row with the given row_id and deserialize it into I using column names (fast option)
|
||||
let mut rows = statement
|
||||
.query_and_then([row_id], |row| {
|
||||
from_row_with_columns::<I>(row, &self.columns)
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
// Return the first row if found else None (error)
|
||||
rows.next().and_then(|res| match res {
|
||||
Ok(val) => Some(val),
|
||||
Err(_) => None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Return the number of rows in the dataset.
|
||||
fn len(&self) -> usize {
|
||||
self.len
|
||||
}
|
||||
}
|
||||
|
||||
/// Fetch the column names and the number of rows from the database.
|
||||
fn fetch_columns_and_len(
|
||||
conn_pool: &Pool<SqliteConnectionManager>,
|
||||
select_statement: &str,
|
||||
split: &str,
|
||||
) -> (Vec<String>, usize) {
|
||||
// Save the column names
|
||||
let connection = conn_pool.get().unwrap();
|
||||
let statement = connection.prepare(select_statement).unwrap();
|
||||
let columns = columns_from_statement(&statement);
|
||||
|
||||
// Count the number of rows and save it as len
|
||||
let mut statement = connection
|
||||
.prepare(format!("select count(*) from {split}").as_str())
|
||||
.unwrap();
|
||||
let len = statement
|
||||
.query_row([], |row| {
|
||||
let len: usize = row.get(0)?;
|
||||
Ok(len)
|
||||
})
|
||||
.unwrap();
|
||||
(columns, len)
|
||||
}
|
||||
|
||||
/// Create a connection pool and make sure the connections are read only
|
||||
fn create_conn_pool(db_file: &str) -> Pool<SqliteConnectionManager> {
|
||||
// Create a connection pool and make sure the connections are read only
|
||||
let manager =
|
||||
SqliteConnectionManager::file(db_file).with_flags(OpenFlags::SQLITE_OPEN_READ_ONLY);
|
||||
let conn_pool: Pool<SqliteConnectionManager> = Pool::new(manager).unwrap();
|
||||
conn_pool
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rayon::prelude::*;
|
||||
use rstest::{fixture, rstest};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::*;
|
||||
|
||||
type SqlDs = SqliteDataset<Sample>;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
pub struct Sample {
|
||||
column_str: String,
|
||||
column_bytes: Vec<u8>,
|
||||
column_int: i64,
|
||||
column_bool: bool,
|
||||
column_float: f64,
|
||||
}
|
||||
#[fixture]
|
||||
fn train_dataset() -> SqlDs {
|
||||
SqliteDataset::from_db_file("tests/data/sqlite-dataset.db", "train")
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn len(train_dataset: SqlDs) {
|
||||
assert_eq!(train_dataset.len(), 2);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn get_some(train_dataset: SqlDs) {
|
||||
let item = train_dataset.get(0).unwrap();
|
||||
assert_eq!(item.column_str, "HI1");
|
||||
assert_eq!(item.column_bytes, vec![55, 231, 159]);
|
||||
assert_eq!(item.column_int, 1);
|
||||
assert!(item.column_bool);
|
||||
assert_eq!(item.column_float, 1.0);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn get_none(train_dataset: SqlDs) {
|
||||
assert_eq!(train_dataset.get(10), None);
|
||||
}
|
||||
|
||||
#[rstest]
|
||||
pub fn multi_thread(train_dataset: SqlDs) {
|
||||
let indices: Vec<usize> = vec![0, 1, 1, 3, 4, 5, 6, 0, 8, 1];
|
||||
let results: Vec<Option<Sample>> =
|
||||
indices.par_iter().map(|&i| train_dataset.get(i)).collect();
|
||||
|
||||
let mut match_count = 0;
|
||||
for (_index, result) in indices.iter().zip(results.iter()) {
|
||||
match result {
|
||||
Some(_val) => match_count += 1,
|
||||
None => (),
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(match_count, 5);
|
||||
}
|
||||
}
|
|
@ -8,6 +8,7 @@ pub mod transform;
|
|||
|
||||
mod dataset;
|
||||
pub use dataset::*;
|
||||
pub use source::huggingface::downloader::*;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_data {
|
||||
|
|
|
@ -1,152 +0,0 @@
|
|||
import abc
|
||||
import json
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
from datasets import load_dataset
|
||||
from typing import List, Any, Tuple
|
||||
from tqdm import tqdm
|
||||
|
||||
from json import JSONEncoder
|
||||
|
||||
|
||||
class CustomEncoder(JSONEncoder):
|
||||
def default(self, obj):
|
||||
if isinstance(obj, np.ndarray):
|
||||
return obj.tolist()
|
||||
return JSONEncoder.default(self, obj)
|
||||
|
||||
|
||||
DOWNLOAD_DIR = ".cache/burn-dataset"
|
||||
|
||||
|
||||
class Extractor(abc.ABC):
|
||||
def extract(self, item: Any) -> Any:
|
||||
pass
|
||||
|
||||
@abc.abstractproperty
|
||||
def name(self) -> str:
|
||||
pass
|
||||
|
||||
|
||||
class RawFieldExtractor(Extractor):
|
||||
def __init__(self, field_name: str):
|
||||
self.field_name = field_name
|
||||
|
||||
def extract(self, item: Any) -> Any:
|
||||
return item[self.field_name]
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.field_name
|
||||
|
||||
|
||||
class ImageFieldExtractor(Extractor):
|
||||
def __init__(self, field_name: str):
|
||||
self.field_name = field_name
|
||||
|
||||
def extract(self, item: Any) -> Any:
|
||||
image = item[self.field_name]
|
||||
return np.array(image).tolist()
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.field_name
|
||||
|
||||
|
||||
def download(
|
||||
name: str,
|
||||
keys: List[str],
|
||||
download_file: str,
|
||||
extractors: List[Extractor],
|
||||
*config,
|
||||
**kwargs,
|
||||
):
|
||||
dataset_all = load_dataset(name, *config, **kwargs)
|
||||
for key in keys:
|
||||
dataset = dataset_all[key]
|
||||
dataset_file = f"{download_file}-{key}"
|
||||
print(f"Saving dataset: {name} - {key}")
|
||||
|
||||
with open(dataset_file, "w") as file:
|
||||
for item in tqdm(dataset):
|
||||
payload = {}
|
||||
for extactor in extractors:
|
||||
payload[extactor.name] = extactor.extract(item)
|
||||
|
||||
payload = json.dumps(payload, cls=CustomEncoder)
|
||||
line = f"{payload}\n"
|
||||
file.write(line)
|
||||
|
||||
|
||||
def config_named(value: str) -> Tuple[str, str]:
|
||||
try:
|
||||
key, value = value.split("=")
|
||||
return {key: value}
|
||||
except:
|
||||
raise argparse.ArgumentTypeError("config_named must be key=value")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Huggingface datasets downloader to use with burn-dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, help="Name of the dataset to download", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file", type=str, help="Base file name where the data is saved", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--split", type=str, help="Splits to downloads", nargs="+", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config", type=str, help="Config of the dataset", nargs="+", default=[]
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config-named",
|
||||
type=config_named,
|
||||
help="Named config of the dataset",
|
||||
nargs="+",
|
||||
default=[],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extract-image",
|
||||
type=str,
|
||||
help="Image field to extract",
|
||||
nargs="+",
|
||||
default=[],
|
||||
)
|
||||
parser.add_argument(
|
||||
"--extract-raw", type=str, help="Raw field to extract", nargs="+", default=[]
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def run():
|
||||
args = parse_args()
|
||||
extractors = []
|
||||
|
||||
for field_name in args.extract_image:
|
||||
extractors.append(ImageFieldExtractor(field_name))
|
||||
|
||||
for field_name in args.extract_raw:
|
||||
extractors.append(RawFieldExtractor(field_name))
|
||||
|
||||
kwargs = {}
|
||||
for config_named in args.config_named:
|
||||
kwargs = kwargs | config_named
|
||||
|
||||
download(
|
||||
args.name,
|
||||
args.split,
|
||||
args.file,
|
||||
extractors,
|
||||
*args.config,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
|
@ -1,240 +1,227 @@
|
|||
use crate::InMemDataset;
|
||||
use dirs::home_dir;
|
||||
use std::collections::hash_map::DefaultHasher;
|
||||
use std::fs;
|
||||
use std::hash::Hasher;
|
||||
use std::fs::{self, create_dir_all};
|
||||
use std::path::Path;
|
||||
use std::process::Command;
|
||||
|
||||
use crate::SqliteDataset;
|
||||
|
||||
use sanitize_filename::sanitize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use thiserror::Error;
|
||||
|
||||
const PYTHON: &str = "python3";
|
||||
|
||||
const PYTHON_SOURCE: &str = include_str!("dataset.py");
|
||||
const PYTHON_SOURCE: &str = include_str!("importer.py");
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum DownloaderError {
|
||||
pub enum ImporterError {
|
||||
#[error("unknown: `{0}`")]
|
||||
Unknown(String),
|
||||
#[error("fail to download python dependencies: `{0}`")]
|
||||
FailToDownloadPythonDependencies(String),
|
||||
}
|
||||
|
||||
/// Load datasets from [huggingface datasets](https://huggingface.co/datasets).
|
||||
/// Load a dataset from [huggingface datasets](https://huggingface.co/datasets).
|
||||
///
|
||||
/// The dataset with all splits is stored in a single sqlite database (see [SqliteDataset](SqliteDataset)).
|
||||
///
|
||||
/// # Example
|
||||
/// ```no_run
|
||||
/// use burn_dataset::HuggingfaceDatasetLoader;
|
||||
/// use burn_dataset::SqliteDataset;
|
||||
/// use serde::{Deserialize, Serialize};
|
||||
///
|
||||
/// #[derive(Deserialize, Debug, Clone)]
|
||||
/// struct MNISTItemRaw {
|
||||
/// pub image_bytes: Vec<u8>,
|
||||
/// pub label: usize,
|
||||
/// }
|
||||
///
|
||||
/// let train_ds:SqliteDataset<MNISTItemRaw> = HuggingfaceDatasetLoader::new("mnist")
|
||||
/// .dataset("train")
|
||||
/// .unwrap();
|
||||
pub struct HuggingfaceDatasetLoader {
|
||||
name: String,
|
||||
split: String,
|
||||
extractors: Vec<Extractor>,
|
||||
config: Vec<String>,
|
||||
config_named: Vec<(String, String)>,
|
||||
deps: Vec<String>,
|
||||
subset: Option<String>,
|
||||
base_dir: Option<String>,
|
||||
huggingface_token: Option<String>,
|
||||
huggingface_cache_dir: Option<String>,
|
||||
}
|
||||
|
||||
impl HuggingfaceDatasetLoader {
|
||||
/// Create a huggingface dataset loader.
|
||||
pub fn new(name: &str, split: &str) -> Self {
|
||||
pub fn new(name: &str) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
split: split.to_string(),
|
||||
extractors: Vec::new(),
|
||||
config: Vec::new(),
|
||||
config_named: Vec::new(),
|
||||
deps: Vec::new(),
|
||||
subset: None,
|
||||
base_dir: None,
|
||||
huggingface_token: None,
|
||||
huggingface_cache_dir: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn config(mut self, config: &str) -> Self {
|
||||
self.config.push(config.to_string());
|
||||
/// Create a huggingface dataset loader for a subset of the dataset.
|
||||
///
|
||||
/// The subset name must be one of the subsets listed in the dataset page.
|
||||
///
|
||||
/// If no subset names are listed, then do not use this method.
|
||||
pub fn with_subset(mut self, subset: &str) -> Self {
|
||||
self.subset = Some(subset.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn config_named(mut self, name: &str, config: &str) -> Self {
|
||||
self.config_named
|
||||
.push((name.to_string(), config.to_string()));
|
||||
/// Specify a base directory to store the dataset.
|
||||
///
|
||||
/// If not specified, the dataset will be stored in `~/.cache/burn-dataset`.
|
||||
pub fn with_base_dir(mut self, base_dir: &str) -> Self {
|
||||
self.base_dir = Some(base_dir.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn deps(mut self, deps: &[&str]) -> Self {
|
||||
self.deps
|
||||
.append(&mut deps.iter().copied().map(String::from).collect());
|
||||
/// Specify a huggingface token to download datasets behind authentication.
|
||||
///
|
||||
/// You can get a token from https://huggingface.co/settings/tokens
|
||||
pub fn with_huggingface_token(mut self, huggingface_token: &str) -> Self {
|
||||
self.huggingface_token = Some(huggingface_token.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn dep(mut self, dep: &str) -> Self {
|
||||
self.deps.push(dep.to_string());
|
||||
/// Specify a huggingface cache directory to store the downloaded datasets.
|
||||
///
|
||||
/// If not specified, the dataset will be stored in `~/.cache/huggingface/datasets`.
|
||||
pub fn with_huggingface_cache_dir(mut self, huggingface_cache_dir: &str) -> Self {
|
||||
self.huggingface_cache_dir = Some(huggingface_cache_dir.to_string());
|
||||
self
|
||||
}
|
||||
|
||||
pub fn extract_image(mut self, field_name: &str) -> Self {
|
||||
self.extractors
|
||||
.push(Extractor::Image(field_name.to_string()));
|
||||
self
|
||||
}
|
||||
|
||||
pub fn extract_number(self, field_name: &str) -> Self {
|
||||
self.extract_raw(field_name)
|
||||
}
|
||||
|
||||
pub fn extract_string(self, field_name: &str) -> Self {
|
||||
self.extract_raw(field_name)
|
||||
}
|
||||
|
||||
pub fn load_in_memory<I: serde::de::DeserializeOwned + Clone>(
|
||||
/// Load the dataset.
|
||||
pub fn dataset<I: DeserializeOwned + Clone>(
|
||||
self,
|
||||
) -> Result<InMemDataset<I>, DownloaderError> {
|
||||
let path_file = self.load_file()?;
|
||||
let dataset = InMemDataset::from_file(path_file.as_str()).unwrap();
|
||||
|
||||
split: &str,
|
||||
) -> Result<SqliteDataset<I>, ImporterError> {
|
||||
let db_file = self.db_file()?;
|
||||
let dataset = SqliteDataset::from_db_file(db_file.as_str(), split);
|
||||
Ok(dataset)
|
||||
}
|
||||
|
||||
pub fn load_file(self) -> Result<String, DownloaderError> {
|
||||
let mut hasher = DefaultHasher::new();
|
||||
hasher.write(format!("{:?}", self.extractors).as_bytes());
|
||||
hasher.write(format!("{:?}", self.config).as_bytes());
|
||||
hasher.write(format!("{:?}", self.config_named).as_bytes());
|
||||
let hash = hasher.finish();
|
||||
/// Get the path to the sqlite database file.
|
||||
///
|
||||
/// If the database file does not exist, it will be downloaded and imported.
|
||||
pub fn db_file(self) -> Result<String, ImporterError> {
|
||||
// determine (and create if needed) the base directory
|
||||
let base_dir = base_dir(self.base_dir);
|
||||
|
||||
let base_file = format!("{}/{}-{}", cache_dir(), self.name, hash);
|
||||
let path_file = format!("{}-{}", base_file, self.split);
|
||||
//sanitize the name and subset
|
||||
let name = sanitize(self.name.as_str());
|
||||
|
||||
if !std::path::Path::new(&path_file).exists() {
|
||||
download(
|
||||
self.name.clone(),
|
||||
vec![self.split],
|
||||
base_file,
|
||||
self.extractors,
|
||||
self.config,
|
||||
self.config_named,
|
||||
&self.deps,
|
||||
// create the db file path
|
||||
let db_file = if let Some(subset) = self.subset.clone() {
|
||||
format!("{}/{}-{}.db", base_dir, name, sanitize(subset.as_str()))
|
||||
} else {
|
||||
format!("{}/{}.db", base_dir, name)
|
||||
};
|
||||
|
||||
// import the dataset if needed
|
||||
if !Path::new(&db_file).exists() {
|
||||
import(
|
||||
self.name,
|
||||
self.subset,
|
||||
db_file.clone(),
|
||||
base_dir,
|
||||
self.huggingface_token,
|
||||
self.huggingface_cache_dir,
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(path_file)
|
||||
}
|
||||
|
||||
fn extract_raw(mut self, field_name: &str) -> Self {
|
||||
self.extractors.push(Extractor::Raw(field_name.to_string()));
|
||||
self
|
||||
Ok(db_file)
|
||||
}
|
||||
}
|
||||
|
||||
fn download(
|
||||
/// Import a dataset from huggingface. The transformed dataset is stored as sqlite database.
|
||||
fn import(
|
||||
name: String,
|
||||
splits: Vec<String>,
|
||||
subset: Option<String>,
|
||||
base_file: String,
|
||||
extractors: Vec<Extractor>,
|
||||
config: Vec<String>,
|
||||
config_named: Vec<(String, String)>,
|
||||
deps: &[String],
|
||||
) -> Result<(), DownloaderError> {
|
||||
download_python_deps(deps)?;
|
||||
base_dir: String,
|
||||
huggingface_token: Option<String>,
|
||||
huggingface_cache_dir: Option<String>,
|
||||
) -> Result<(), ImporterError> {
|
||||
install_python_deps()?;
|
||||
|
||||
let mut command = Command::new(PYTHON);
|
||||
|
||||
command.arg(dataset_downloader_file_path());
|
||||
|
||||
command.arg("--file");
|
||||
command.arg(base_file);
|
||||
command.arg(importer_script_path(base_dir));
|
||||
|
||||
command.arg("--name");
|
||||
command.arg(name);
|
||||
|
||||
command.arg("--split");
|
||||
for split in splits {
|
||||
command.arg(split);
|
||||
command.arg("--file");
|
||||
command.arg(base_file);
|
||||
|
||||
if let Some(subset) = subset {
|
||||
command.arg("--subset");
|
||||
command.arg(subset);
|
||||
}
|
||||
|
||||
let mut extracted_raw = Vec::new();
|
||||
let mut extracted_images = Vec::new();
|
||||
|
||||
for extractor in extractors {
|
||||
match extractor {
|
||||
Extractor::Raw(field) => extracted_raw.push(field),
|
||||
Extractor::Image(field) => extracted_images.push(field),
|
||||
};
|
||||
if let Some(huggingface_token) = huggingface_token {
|
||||
command.arg("--token");
|
||||
command.arg(huggingface_token);
|
||||
}
|
||||
|
||||
if !extracted_raw.is_empty() {
|
||||
command.arg("--extract-raw");
|
||||
for field in extracted_raw {
|
||||
command.arg(field);
|
||||
}
|
||||
}
|
||||
|
||||
if !extracted_images.is_empty() {
|
||||
command.arg("--extract-image");
|
||||
for field in extracted_images {
|
||||
command.arg(field);
|
||||
}
|
||||
}
|
||||
|
||||
if !config.is_empty() {
|
||||
command.arg("--config");
|
||||
for config in config {
|
||||
command.arg(config);
|
||||
}
|
||||
}
|
||||
if !config_named.is_empty() {
|
||||
command.arg("--config-named");
|
||||
for (key, value) in config_named {
|
||||
command.arg(format!("{key}={value}"));
|
||||
}
|
||||
if let Some(huggingface_cache_dir) = huggingface_cache_dir {
|
||||
command.arg("--cache_dir");
|
||||
command.arg(huggingface_cache_dir);
|
||||
}
|
||||
|
||||
let mut handle = command.spawn().unwrap();
|
||||
handle
|
||||
.wait()
|
||||
.map_err(|err| DownloaderError::Unknown(format!("{err:?}")))?;
|
||||
.map_err(|err| ImporterError::Unknown(format!("{err:?}")))?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn cache_dir() -> String {
|
||||
let home_dir = home_dir().unwrap();
|
||||
let home_dir = home_dir.to_str().map(|s| s.to_string());
|
||||
let home_dir = home_dir.unwrap();
|
||||
let cache_dir = format!("{home_dir}/.cache/burn-dataset");
|
||||
std::fs::create_dir_all(&cache_dir).ok();
|
||||
cache_dir
|
||||
/// Determine the base directory to store the dataset.
|
||||
fn base_dir(base_dir: Option<String>) -> String {
|
||||
let base_dir = if let Some(base_dir) = base_dir {
|
||||
base_dir
|
||||
} else {
|
||||
let home_dir = home_dir().unwrap();
|
||||
let home_dir = home_dir.to_str().map(|s| s.to_string());
|
||||
let home_dir = home_dir.unwrap();
|
||||
let cache_dir = format!("{home_dir}/.cache/burn-dataset");
|
||||
cache_dir
|
||||
};
|
||||
|
||||
create_dir_all(&base_dir).ok();
|
||||
|
||||
base_dir
|
||||
}
|
||||
|
||||
fn dataset_downloader_file_path() -> String {
|
||||
let path_dir = cache_dir();
|
||||
let path_file = format!("{path_dir}/dataset.py");
|
||||
|
||||
fn importer_script_path(base_dir: String) -> String {
|
||||
let path_file = format!("{base_dir}/importer.py");
|
||||
fs::write(path_file.as_str(), PYTHON_SOURCE).expect("Write python dataset downloader");
|
||||
path_file
|
||||
}
|
||||
|
||||
fn download_python_deps(deps: &[String]) -> Result<(), DownloaderError> {
|
||||
fn install_python_deps() -> Result<(), ImporterError> {
|
||||
let mut command = Command::new(PYTHON);
|
||||
command.args([
|
||||
"-m",
|
||||
"pip",
|
||||
"--quiet",
|
||||
"install",
|
||||
"pyarrow",
|
||||
"sqlalchemy",
|
||||
"Pillow",
|
||||
"soundfile",
|
||||
"datasets",
|
||||
]);
|
||||
|
||||
command
|
||||
.args(["-m", "pip", "install", "datasets"])
|
||||
.args(deps);
|
||||
|
||||
command
|
||||
.spawn()
|
||||
.map_err(|err| {
|
||||
DownloaderError::FailToDownloadPythonDependencies(format!(
|
||||
"{} | error: {}",
|
||||
deps.to_vec().join(", "),
|
||||
err
|
||||
))
|
||||
})?
|
||||
.wait()
|
||||
.map_err(|err| {
|
||||
DownloaderError::FailToDownloadPythonDependencies(format!(
|
||||
"{} | error: {}",
|
||||
deps.to_vec().join(", "),
|
||||
err
|
||||
))
|
||||
})?;
|
||||
// Spawn the process and wait for it to complete.
|
||||
let mut handle = command.spawn().unwrap();
|
||||
handle.wait().map_err(|err| {
|
||||
ImporterError::FailToDownloadPythonDependencies(format!(" error: {}", err))
|
||||
})?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Extractor {
|
||||
Raw(String),
|
||||
Image(String),
|
||||
}
|
||||
|
|
|
@ -0,0 +1,176 @@
|
|||
import argparse
|
||||
|
||||
import pyarrow as pa
|
||||
from datasets import Audio, Image, load_dataset
|
||||
from sqlalchemy import Column, Integer, Table, create_engine, event, inspect
|
||||
from sqlalchemy.types import LargeBinary
|
||||
|
||||
|
||||
def download_and_export(name: str, subset: str, db_file: str, token: str, cache_dir: str):
|
||||
"""
|
||||
Download a dataset from using HuggingFace dataset and export it to a sqlite database.
|
||||
"""
|
||||
|
||||
# TODO For media columns (Image and Audio) sometimes when decode=False,
|
||||
# bytes can be none {'bytes': None, 'path': 'healthy_train.265.jpg'}
|
||||
# We should handle this case, but unfortunately we did not come across this case yet to test it.
|
||||
|
||||
print("*"*80)
|
||||
print("Starting huggingface dataset download and export")
|
||||
print(f"Dataset Name: {name}")
|
||||
print(f"Subset Name: {subset}")
|
||||
print(f"Sqlite database file: {db_file}")
|
||||
if cache_dir is None:
|
||||
print(f"Custom cache dir: {cache_dir}")
|
||||
print("*"*80)
|
||||
|
||||
# Load the dataset
|
||||
dataset_all = load_dataset(
|
||||
name, subset, cache_dir=cache_dir, use_auth_token=token)
|
||||
|
||||
print(f"Dataset: {dataset_all}")
|
||||
|
||||
# Create the database connection descriptor (sqlite)
|
||||
engine = create_engine(f"sqlite:///{db_file}")
|
||||
|
||||
# Set some sqlite pragmas to speed up the database
|
||||
event.listen(engine, 'connect', set_sqlite_pragma)
|
||||
|
||||
# Add an row_id column to each table as primary key (datasets does not have API for this)
|
||||
event.listen(Table, 'before_create', add_pk_column)
|
||||
|
||||
# Export each split in the dataset
|
||||
for key in dataset_all.keys():
|
||||
dataset = dataset_all[key]
|
||||
|
||||
# Disable decoding for audio and image fields
|
||||
dataset = disable_decoding(dataset)
|
||||
|
||||
# Flatten the dataset
|
||||
dataset = dataset.flatten()
|
||||
|
||||
# Rename columns to remove dots from the names
|
||||
dataset = rename_colums(dataset)
|
||||
|
||||
print(f"Saving dataset: {name} - {key}")
|
||||
print(f"Dataset features: {dataset.features}")
|
||||
|
||||
# Save the dataset to a sqlite database
|
||||
dataset.to_sql(
|
||||
key, # table name
|
||||
engine,
|
||||
# don't save the index, use row_id instead (index is not unique)
|
||||
index=False,
|
||||
dtype=blob_columns(dataset), # save binary columns as blob
|
||||
)
|
||||
|
||||
# Print the schema of the database so we can reference the columns in the rust code
|
||||
print_table_info(engine)
|
||||
|
||||
|
||||
def disable_decoding(dataset):
|
||||
"""
|
||||
Disable decoding for audio and image fields. The fields will be saved as raw file bytes.
|
||||
"""
|
||||
for k, v in dataset.features.items():
|
||||
if isinstance(v, Audio):
|
||||
dataset = dataset.cast_column(k, Audio(decode=False))
|
||||
elif isinstance(v, Image):
|
||||
dataset = dataset.cast_column(k, Image(decode=False))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def rename_colums(dataset):
|
||||
"""
|
||||
Rename columns to remove dots from the names. Dots appear in the column names because of the flattening.
|
||||
Dots are not allowed in column names in rust and sql (unless quoted). So we replace them with underscores.
|
||||
This way there is an easy name mapping between the rust and sql columns.
|
||||
"""
|
||||
|
||||
for name in dataset.features.keys():
|
||||
if '.' in name:
|
||||
dataset = dataset.rename_column(name, name.replace('.', '_'))
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def blob_columns(dataset):
|
||||
"""
|
||||
Make sure all binary columns are blob columns in the database because
|
||||
`to_sql` exports binary values as TEXT instead of BLOB.
|
||||
"""
|
||||
type_mapping = {}
|
||||
for name, value in dataset.features.items():
|
||||
if pa.types.is_binary(value.pa_type):
|
||||
type_mapping[name] = LargeBinary
|
||||
return type_mapping
|
||||
|
||||
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
"""
|
||||
Set some sqlite pragmas to speed up the database
|
||||
"""
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA synchronous = OFF")
|
||||
cursor.execute("PRAGMA journal_mode = OFF")
|
||||
cursor.close()
|
||||
|
||||
|
||||
def add_pk_column(target, connection, **kw):
|
||||
"""
|
||||
Add an id column to each table.
|
||||
"""
|
||||
target.append_column(Column("row_id", Integer, primary_key=True))
|
||||
|
||||
|
||||
def print_table_info(engine):
|
||||
"""
|
||||
Print the schema of the database so we can reference the columns in the rust code
|
||||
"""
|
||||
print(f"Printing table schema for sqlite3 db ({engine})")
|
||||
inspector = inspect(engine)
|
||||
for table_name in inspector.get_table_names():
|
||||
print(f"Table: {table_name}")
|
||||
for column in inspector.get_columns(table_name):
|
||||
print(f"Column: {column['name']} - {column['type']}")
|
||||
print("")
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Huggingface datasets downloader to use with burn-dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--name", type=str, help="Name of the dataset to download", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--file", type=str, help="Base file name where the data is saved", required=True
|
||||
)
|
||||
parser.add_argument(
|
||||
"--subset", type=str, help="Subset name", required=False, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token", type=str, help="HuggingFace authentication token", required=False, default=None
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir", type=str, help="Cache directory", required=False, default=None
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def run():
|
||||
args = parse_args()
|
||||
|
||||
download_and_export(
|
||||
args.name,
|
||||
args.subset,
|
||||
args.file,
|
||||
args.token,
|
||||
args.cache_dir,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run()
|
|
@ -1,15 +1,58 @@
|
|||
use crate::source::huggingface::downloader::HuggingfaceDatasetLoader;
|
||||
use crate::{Dataset, InMemDataset};
|
||||
use crate::transform::{Mapper, MapperDataset};
|
||||
use crate::{Dataset, SqliteDataset};
|
||||
|
||||
use image;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const WIDTH: usize = 28;
|
||||
const HEIGHT: usize = 28;
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct MNISTItem {
|
||||
pub image: [[f32; 28]; 28],
|
||||
pub image: [[f32; WIDTH]; HEIGHT],
|
||||
pub label: usize,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct MNISTItemRaw {
|
||||
pub image_bytes: Vec<u8>,
|
||||
pub label: usize,
|
||||
}
|
||||
|
||||
struct BytesToImage;
|
||||
|
||||
impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
|
||||
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
|
||||
fn map(&self, item: &MNISTItemRaw) -> MNISTItem {
|
||||
let image = image::load_from_memory(&item.image_bytes).unwrap();
|
||||
let image = image.as_luma8().unwrap();
|
||||
|
||||
// Ensure the image dimensions are correct.
|
||||
debug_assert_eq!(image.dimensions(), (WIDTH as u32, HEIGHT as u32));
|
||||
|
||||
// Convert the image to a 2D array of floats.
|
||||
let mut image_array = [[0f32; WIDTH]; HEIGHT];
|
||||
for (i, pixel) in image.as_raw().iter().enumerate() {
|
||||
let x = i % WIDTH;
|
||||
let y = i / HEIGHT;
|
||||
image_array[y][x] = *pixel as f32;
|
||||
}
|
||||
|
||||
MNISTItem {
|
||||
image: image_array,
|
||||
label: item.label,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MappedDataset = MapperDataset<SqliteDataset<MNISTItemRaw>, BytesToImage, MNISTItemRaw>;
|
||||
|
||||
/// MNIST dataset from Huggingface.
|
||||
///
|
||||
/// The data is downloaded from Huggingface and stored in a SQLite database.
|
||||
pub struct MNISTDataset {
|
||||
dataset: InMemDataset<MNISTItem>,
|
||||
dataset: MappedDataset,
|
||||
}
|
||||
|
||||
impl Dataset<MNISTItem> for MNISTDataset {
|
||||
|
@ -32,13 +75,12 @@ impl MNISTDataset {
|
|||
}
|
||||
|
||||
fn new(split: &str) -> Self {
|
||||
let dataset = HuggingfaceDatasetLoader::new("mnist", split)
|
||||
.extract_image("image")
|
||||
.extract_number("label")
|
||||
.deps(&["pillow", "numpy"])
|
||||
.load_in_memory()
|
||||
let dataset = HuggingfaceDatasetLoader::new("mnist")
|
||||
.dataset(split)
|
||||
.unwrap();
|
||||
|
||||
let dataset = MapperDataset::new(dataset, BytesToImage);
|
||||
|
||||
Self { dataset }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
pub mod downloader;
|
||||
mod mnist;
|
||||
|
||||
pub use downloader::*;
|
||||
pub use mnist::*;
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
column_str,column_int,column_bool,column_float
|
||||
HI1,1,true,1.0
|
||||
HI2,1,false,1.0
|
|
|
@ -0,0 +1,2 @@
|
|||
{"column_str":"HI1","column_bytes":[1,2,3,3],"column_int":1,"column_bool":true,"column_float":1.0}
|
||||
{"column_str":"HI2","column_bytes":[1,2,3,3],"column_int":1,"column_bool":false,"column_float":1.0}
|
Binary file not shown.
|
@ -44,7 +44,7 @@ burn-tensor = {path = "../burn-tensor", version = "0.8.0", default-features = fa
|
|||
|
||||
|
||||
matrixmultiply = {version = "0.3.6", default-features = false}
|
||||
rayon = {version= "1.7.0", optional = true}
|
||||
rayon = {workspace = true, optional = true}
|
||||
|
||||
blas-src = {version = "0.8.0", default-features = false, optional = true}# no-std compatible
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn::data::dataset::{
|
||||
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, InMemDataset,
|
||||
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset,
|
||||
};
|
||||
|
||||
#[derive(new, Clone, Debug)]
|
||||
|
@ -20,7 +20,7 @@ pub struct AgNewsItem {
|
|||
}
|
||||
|
||||
pub struct AgNewsDataset {
|
||||
dataset: InMemDataset<AgNewsItem>,
|
||||
dataset: SqliteDataset<AgNewsItem>,
|
||||
}
|
||||
|
||||
impl Dataset<TextClassificationItem> for AgNewsDataset {
|
||||
|
@ -37,19 +37,16 @@ impl Dataset<TextClassificationItem> for AgNewsDataset {
|
|||
|
||||
impl AgNewsDataset {
|
||||
pub fn train() -> Self {
|
||||
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "train")
|
||||
.extract_string("text")
|
||||
.extract_number("label")
|
||||
.load_in_memory()
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
Self::new("train")
|
||||
}
|
||||
|
||||
pub fn test() -> Self {
|
||||
let dataset: InMemDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news", "test")
|
||||
.extract_string("text")
|
||||
.extract_number("label")
|
||||
.load_in_memory()
|
||||
Self::new("test")
|
||||
}
|
||||
|
||||
pub fn new(split: &str) -> Self {
|
||||
let dataset: SqliteDataset<AgNewsItem> = HuggingfaceDatasetLoader::new("ag_news")
|
||||
.dataset(split)
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
}
|
||||
|
@ -80,7 +77,7 @@ pub struct DbPediaItem {
|
|||
}
|
||||
|
||||
pub struct DbPediaDataset {
|
||||
dataset: InMemDataset<DbPediaItem>,
|
||||
dataset: SqliteDataset<DbPediaItem>,
|
||||
}
|
||||
|
||||
impl Dataset<TextClassificationItem> for DbPediaDataset {
|
||||
|
@ -100,24 +97,16 @@ impl Dataset<TextClassificationItem> for DbPediaDataset {
|
|||
|
||||
impl DbPediaDataset {
|
||||
pub fn train() -> Self {
|
||||
let dataset: InMemDataset<DbPediaItem> =
|
||||
HuggingfaceDatasetLoader::new("dbpedia_14", "train")
|
||||
.extract_string("title")
|
||||
.extract_string("content")
|
||||
.extract_number("label")
|
||||
.load_in_memory()
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
Self::new("train")
|
||||
}
|
||||
|
||||
pub fn test() -> Self {
|
||||
let dataset: InMemDataset<DbPediaItem> =
|
||||
HuggingfaceDatasetLoader::new("dbpedia_14", "test")
|
||||
.extract_string("title")
|
||||
.extract_string("content")
|
||||
.extract_number("label")
|
||||
.load_in_memory()
|
||||
.unwrap();
|
||||
Self::new("test")
|
||||
}
|
||||
pub fn new(split: &str) -> Self {
|
||||
let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new("dbpedia_14")
|
||||
.dataset(split)
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn::data::dataset::{
|
||||
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, InMemDataset,
|
||||
source::huggingface::downloader::HuggingfaceDatasetLoader, Dataset, SqliteDataset,
|
||||
};
|
||||
|
||||
#[derive(new, Clone, Debug)]
|
||||
|
@ -13,7 +13,7 @@ pub struct DbPediaItem {
|
|||
}
|
||||
|
||||
pub struct DbPediaDataset {
|
||||
dataset: InMemDataset<DbPediaItem>,
|
||||
dataset: SqliteDataset<DbPediaItem>,
|
||||
}
|
||||
|
||||
impl Dataset<TextGenerationItem> for DbPediaDataset {
|
||||
|
@ -30,20 +30,16 @@ impl Dataset<TextGenerationItem> for DbPediaDataset {
|
|||
|
||||
impl DbPediaDataset {
|
||||
pub fn train() -> Self {
|
||||
let dataset: InMemDataset<DbPediaItem> =
|
||||
HuggingfaceDatasetLoader::new("dbpedia_14", "train")
|
||||
.extract_string("content")
|
||||
.load_in_memory()
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
Self::new("train")
|
||||
}
|
||||
|
||||
pub fn test() -> Self {
|
||||
let dataset: InMemDataset<DbPediaItem> =
|
||||
HuggingfaceDatasetLoader::new("dbpedia_14", "test")
|
||||
.extract_string("content")
|
||||
.load_in_memory()
|
||||
.unwrap();
|
||||
Self::new("test")
|
||||
}
|
||||
pub fn new(split: &str) -> Self {
|
||||
let dataset: SqliteDataset<DbPediaItem> = HuggingfaceDatasetLoader::new("dbpedia_14")
|
||||
.dataset(split)
|
||||
.unwrap();
|
||||
Self { dataset }
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue