diff --git a/burn-core/src/data/dataloader/builder.rs b/burn-core/src/data/dataloader/builder.rs index 92d919ed3..0adfd2ed7 100644 --- a/burn-core/src/data/dataloader/builder.rs +++ b/burn-core/src/data/dataloader/builder.rs @@ -14,9 +14,12 @@ where I: Send + Sync + Clone + std::fmt::Debug + 'static, O: Send + Sync + Clone + std::fmt::Debug + 'static, { - pub fn new(batcher: Arc>) -> Self { + pub fn new(batcher: B) -> Self + where + B: Batcher + 'static, + { Self { - batcher, + batcher: Arc::new(batcher), strategy: None, num_threads: None, shuffle: None, @@ -38,10 +41,13 @@ where self } - pub fn build(self, dataset: Arc>) -> Arc> { - let dataset = match self.shuffle { + pub fn build(self, dataset: D) -> Arc> + where + D: Dataset + 'static, + { + let dataset: Arc> = match self.shuffle { Some(seed) => Arc::new(ShuffledDataset::with_seed(dataset, seed)), - None => dataset, + None => Arc::new(dataset), }; let strategy = match self.strategy { Some(strategy) => strategy, diff --git a/burn-dataset/Cargo.toml b/burn-dataset/Cargo.toml index 9e86241f8..54fe6de1d 100644 --- a/burn-dataset/Cargo.toml +++ b/burn-dataset/Cargo.toml @@ -24,5 +24,6 @@ rand = {workspace = true, features = ["std"]} serde = {workspace = true, features = ["std", "derive"]} serde_json = {workspace = true, features = ["std"]} thiserror = {workspace = true} +derive-new = {workspace = true} [dev-dependencies] diff --git a/burn-dataset/src/dataset/base.rs b/burn-dataset/src/dataset/base.rs index 063b0f02d..7149fb277 100644 --- a/burn-dataset/src/dataset/base.rs +++ b/burn-dataset/src/dataset/base.rs @@ -1,5 +1,8 @@ +use std::sync::Arc; + use crate::DatasetIterator; +/// The dataset trait defines a basic collection of items with a predefined size. pub trait Dataset: Send + Sync { fn get(&self, index: usize) -> Option; fn len(&self) -> usize; @@ -13,3 +16,49 @@ pub trait Dataset: Send + Sync { DatasetIterator::new(self) } } + +impl Dataset for Arc +where + D: Dataset, +{ + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } + + fn len(&self) -> usize { + self.as_ref().len() + } +} + +impl Dataset for Arc> { + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } + + fn len(&self) -> usize { + self.as_ref().len() + } +} + +impl Dataset for Box +where + D: Dataset, +{ + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } + + fn len(&self) -> usize { + self.as_ref().len() + } +} + +impl Dataset for Box> { + fn get(&self, index: usize) -> Option { + self.as_ref().get(index) + } + + fn len(&self) -> usize { + self.as_ref().len() + } +} diff --git a/burn-dataset/src/dataset/fake.rs b/burn-dataset/src/dataset/fake.rs index 4251a3c33..0d761e10d 100644 --- a/burn-dataset/src/dataset/fake.rs +++ b/burn-dataset/src/dataset/fake.rs @@ -1,6 +1,7 @@ use crate::{Dataset, DatasetIterator, InMemDataset}; use fake::{Dummy, Fake, Faker}; +/// Dataset filled with fake items generated from the [fake](fake) crate. pub struct FakeDataset { dataset: InMemDataset, } diff --git a/burn-dataset/src/dataset/in_memory.rs b/burn-dataset/src/dataset/in_memory.rs index ffeff20c5..9bb40424f 100644 --- a/burn-dataset/src/dataset/in_memory.rs +++ b/burn-dataset/src/dataset/in_memory.rs @@ -5,6 +5,7 @@ use std::{ use crate::Dataset; +/// Dataset where all items are stored in ram. pub struct InMemDataset { items: Vec, } diff --git a/burn-dataset/src/dataset/iterator.rs b/burn-dataset/src/dataset/iterator.rs index 8e937888a..4b8140521 100644 --- a/burn-dataset/src/dataset/iterator.rs +++ b/burn-dataset/src/dataset/iterator.rs @@ -1,6 +1,7 @@ use crate::dataset::Dataset; use std::iter::Iterator; +/// Dataset iterator. pub struct DatasetIterator<'a, I> { current: usize, dataset: &'a dyn Dataset, diff --git a/burn-dataset/src/lib.rs b/burn-dataset/src/lib.rs index d1fd81f0d..7268d18db 100644 --- a/burn-dataset/src/lib.rs +++ b/burn-dataset/src/lib.rs @@ -1,3 +1,6 @@ +#[macro_use] +extern crate derive_new; + extern crate dirs; pub mod source; diff --git a/burn-dataset/src/transform/composed.rs b/burn-dataset/src/transform/composed.rs index af3e25d2f..8f26bd597 100644 --- a/burn-dataset/src/transform/composed.rs +++ b/burn-dataset/src/transform/composed.rs @@ -1,17 +1,14 @@ use crate::Dataset; -pub struct ComposedDataset { - datasets: Vec>>, +/// Compose multiple datasets together to create a bigger one. +#[derive(new)] +pub struct ComposedDataset { + datasets: Vec, } -impl ComposedDataset { - pub fn new(datasets: Vec>>) -> Self { - Self { datasets } - } -} - -impl Dataset for ComposedDataset +impl Dataset for ComposedDataset where + D: Dataset, I: Clone, { fn get(&self, index: usize) -> Option { diff --git a/burn-dataset/src/transform/mapper.rs b/burn-dataset/src/transform/mapper.rs index 1e6ee6e0f..ec1822520 100644 --- a/burn-dataset/src/transform/mapper.rs +++ b/burn-dataset/src/transform/mapper.rs @@ -1,22 +1,22 @@ use crate::Dataset; +use std::marker::PhantomData; -pub trait Mapper { +/// Basic mapper trait to be used with the [mapper dataset](MapperDataset). +pub trait Mapper: Send + Sync { fn map(&self, item: &I) -> O; } -pub struct MapperDataset { - dataset: Box>, +/// Dataset mapping each element in an inner dataset to another element type lazily. +#[derive(new)] +pub struct MapperDataset { + dataset: D, mapper: M, + input: PhantomData, } -impl MapperDataset { - pub fn new(dataset: Box>, mapper: M) -> Self { - Self { dataset, mapper } - } -} - -impl Dataset for MapperDataset +impl Dataset for MapperDataset where + D: Dataset, M: Mapper + Send + Sync, I: Send + Sync, O: Send + Sync, @@ -38,7 +38,8 @@ mod tests { #[test] pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() { - struct StringToFirstChar {} + struct StringToFirstChar; + impl Mapper for StringToFirstChar { fn map(&self, item: &String) -> String { let mut item = item.clone(); @@ -46,9 +47,10 @@ mod tests { item } } + let items_original = test_data::string_items(); let dataset = InMemDataset::new(items_original); - let dataset = MapperDataset::new(Box::new(dataset), StringToFirstChar {}); + let dataset = MapperDataset::new(dataset, StringToFirstChar); let items: Vec = dataset.iter().collect(); diff --git a/burn-dataset/src/transform/partial.rs b/burn-dataset/src/transform/partial.rs index bfd9669c6..9b365444f 100644 --- a/burn-dataset/src/transform/partial.rs +++ b/burn-dataset/src/transform/partial.rs @@ -1,21 +1,22 @@ use crate::Dataset; -use std::sync::Arc; +use std::{marker::PhantomData, sync::Arc}; -pub struct PartialDataset { - dataset: Arc>, +/// Only use a fraction of an existing dataset lazily. +#[derive(new)] +pub struct PartialDataset { + dataset: D, start_index: usize, end_index: usize, + input: PhantomData, } -impl PartialDataset { - pub fn new(dataset: Arc>, start_index: usize, end_index: usize) -> Self { - Self { - dataset, - start_index, - end_index, - } - } - pub fn split(dataset: Arc>, num: usize) -> Vec> { +impl PartialDataset +where + D: Dataset, +{ + pub fn split(dataset: D, num: usize) -> Vec, I>> { + let dataset = Arc::new(dataset); // cheap cloning. + let mut current = 0; let mut datasets = Vec::with_capacity(num); @@ -39,8 +40,9 @@ impl PartialDataset { } } -impl Dataset for PartialDataset +impl Dataset for PartialDataset where + D: Dataset, I: Clone + Send + Sync, { fn get(&self, index: usize) -> Option { @@ -64,20 +66,18 @@ mod tests { #[test] fn test_start_from_beginning() { - let dataset_original = Arc::new(FakeDataset::::new(27)); - let dataset_partial = PartialDataset::new(dataset_original.clone(), 0, 10); - + let dataset_original = FakeDataset::::new(27); let mut items_original_1 = HashSet::new(); let mut items_original_2 = HashSet::new(); let mut items_partial = HashSet::new(); + dataset_original.iter().enumerate().for_each(|(i, item)| { + match i >= 10 { + true => items_original_2.insert(item), + false => items_original_1.insert(item), + }; + }); - for (i, item) in dataset_original.iter().enumerate() { - if i >= 10 { - items_original_2.insert(item); - } else { - items_original_1.insert(item); - } - } + let dataset_partial = PartialDataset::new(dataset_original, 0, 10); for item in dataset_partial.iter() { items_partial.insert(item); @@ -92,21 +92,19 @@ mod tests { #[test] fn test_start_inside() { - let dataset_original = Arc::new(FakeDataset::::new(27)); - let dataset_partial = PartialDataset::new(dataset_original.clone(), 10, 20); - + let dataset_original = FakeDataset::::new(27); let mut items_original_1 = HashSet::new(); let mut items_original_2 = HashSet::new(); let mut items_partial = HashSet::new(); - for (i, item) in dataset_original.iter().enumerate() { - if !(10..20).contains(&i) { - items_original_2.insert(item); - } else { - items_original_1.insert(item); - } - } + dataset_original.iter().enumerate().for_each(|(i, item)| { + match !(10..20).contains(&i) { + true => items_original_2.insert(item), + false => items_original_1.insert(item), + }; + }); + let dataset_partial = PartialDataset::new(dataset_original, 10, 20); for item in dataset_partial.iter() { items_partial.insert(item); } @@ -120,16 +118,15 @@ mod tests { #[test] fn test_split_contains_all_items_without_duplicates() { - let dataset_original = Arc::new(FakeDataset::::new(27)); - let dataset_partials = PartialDataset::split(dataset_original.clone(), 4); - + let dataset_original = FakeDataset::::new(27); let mut items_original = Vec::new(); let mut items_partial = Vec::new(); - for item in dataset_original.iter() { items_original.push(item); } + let dataset_partials = PartialDataset::split(dataset_original, 4); + for dataset in dataset_partials { for item in dataset.iter() { items_partial.push(item); diff --git a/burn-dataset/src/transform/random.rs b/burn-dataset/src/transform/random.rs index a0f6c2354..e6c13182d 100644 --- a/burn-dataset/src/transform/random.rs +++ b/burn-dataset/src/transform/random.rs @@ -1,32 +1,43 @@ use crate::Dataset; use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng}; -use std::sync::Arc; +use std::marker::PhantomData; -pub struct ShuffledDataset { - dataset: Arc>, +/// Shuffled a dataset, consider using [sampler dataset](crate::transform::SamplerDataset) is you +/// want a probability distribution that is computed lazily. +pub struct ShuffledDataset { + dataset: D, indexes: Vec, + input: PhantomData, } -impl ShuffledDataset { - pub fn new(dataset: Arc>, rng: &mut StdRng) -> Self { +impl ShuffledDataset +where + D: Dataset, +{ + pub fn new(dataset: D, rng: &mut StdRng) -> Self { let mut indexes = Vec::with_capacity(dataset.len()); for i in 0..dataset.len() { indexes.push(i); } indexes.shuffle(rng); - Self { dataset, indexes } + Self { + dataset, + indexes, + input: PhantomData::default(), + } } - pub fn with_seed(dataset: Arc>, seed: u64) -> Self { + pub fn with_seed(dataset: D, seed: u64) -> Self { let mut rng = StdRng::seed_from_u64(seed); Self::new(dataset, &mut rng) } } -impl Dataset for ShuffledDataset +impl Dataset for ShuffledDataset where - I: Clone, + D: Dataset, + I: Clone + Send + Sync, { fn get(&self, index: usize) -> Option { let index = match self.indexes.get(index) { diff --git a/burn-dataset/src/transform/sampler.rs b/burn-dataset/src/transform/sampler.rs index e768b8a06..ab20e910b 100644 --- a/burn-dataset/src/transform/sampler.rs +++ b/burn-dataset/src/transform/sampler.rs @@ -1,32 +1,44 @@ use crate::Dataset; use rand::{distributions::Uniform, rngs::StdRng, Rng, SeedableRng}; -use std::sync::Mutex; +use std::{marker::PhantomData, sync::Mutex}; -pub struct SamplerDataset { - dataset: Box>, +/// Sample items from a dataset with replacement. +/// +/// This is an efficient way of modeling a dataset as a probability distribution of a fixed size. +pub struct SamplerDataset { + dataset: D, size: usize, rng: Mutex, + input: PhantomData, } -impl SamplerDataset { - pub fn from_dataset + 'static>(dataset: D, size: usize) -> Self { - Self::new(Box::new(dataset), size) - } - - pub fn new(dataset: Box>, size: usize) -> Self { +impl SamplerDataset +where + D: Dataset, + I: Send + Sync, +{ + pub fn new(dataset: D, size: usize) -> Self { let rng = Mutex::new(StdRng::from_entropy()); - Self { dataset, size, rng } + Self { + dataset, + size, + rng, + input: PhantomData::default(), + } } fn index(&self) -> usize { - let distribution = Uniform::new(0, self.dataset.len()); let mut rng = self.rng.lock().unwrap(); - rng.sample(distribution) + rng.sample(Uniform::new(0, self.dataset.len())) } } -impl Dataset for SamplerDataset { +impl Dataset for SamplerDataset +where + D: Dataset, + I: Send + Sync, +{ fn get(&self, _index: usize) -> Option { self.dataset.get(self.index()) } diff --git a/examples/mnist/src/training.rs b/examples/mnist/src/training.rs index 0878c9fb7..9da6d78bb 100644 --- a/examples/mnist/src/training.rs +++ b/examples/mnist/src/training.rs @@ -1,5 +1,3 @@ -use std::sync::Arc; - use crate::data::MNISTBatcher; use crate::model::Model; @@ -43,18 +41,19 @@ pub fn run(device: B::Device) { B::seed(config.seed); // Data - let batcher_train = Arc::new(MNISTBatcher::::new(device.clone())); - let batcher_valid = Arc::new(MNISTBatcher::::new(device.clone())); + let batcher_train = MNISTBatcher::::new(device.clone()); + let batcher_valid = MNISTBatcher::::new(device.clone()); + let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) - .build(Arc::new(MNISTDataset::train())); + .build(MNISTDataset::train()); let dataloader_test = DataLoaderBuilder::new(batcher_valid) .batch_size(config.batch_size) .shuffle(config.seed) .num_workers(config.num_workers) - .build(Arc::new(MNISTDataset::test())); + .build(MNISTDataset::test()); // Model let learner = LearnerBuilder::new(ARTIFACT_DIR) diff --git a/examples/text-classification/src/training.rs b/examples/text-classification/src/training.rs index 7fbeaf487..67ac5f736 100644 --- a/examples/text-classification/src/training.rs +++ b/examples/text-classification/src/training.rs @@ -37,25 +37,21 @@ pub fn train( config: ExperimentConfig, artifact_dir: &str, ) { - let dataset_train = Arc::new(SamplerDataset::new(Box::new(dataset_train), 50_000)); - let dataset_test = Arc::new(SamplerDataset::new(Box::new(dataset_test), 5_000)); - let n_classes = D::num_classes(); - let tokenizer = Arc::new(BertCasedTokenizer::default()); - let batcher_train = Arc::new(TextClassificationBatcher::::new( + let batcher_train = TextClassificationBatcher::::new( tokenizer.clone(), device.clone(), config.max_seq_length, - )); - let batcher_test = Arc::new(TextClassificationBatcher::::new( + ); + let batcher_test = TextClassificationBatcher::::new( tokenizer.clone(), device.clone(), config.max_seq_length, - )); + ); let model = TextClassificationModelConfig::new( config.transformer.clone(), - n_classes, + D::num_classes(), tokenizer.vocab_size(), config.max_seq_length, ) @@ -64,12 +60,12 @@ pub fn train( let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) .num_workers(4) - .build(dataset_train); + .build(SamplerDataset::new(dataset_train, 50_000)); let dataloader_test = DataLoaderBuilder::new(batcher_test) .batch_size(config.batch_size) .num_workers(4) - .build(dataset_test); + .build(SamplerDataset::new(dataset_test, 5_000)); let optim = config.optimizer.init(); let lr_scheduler = NoamLRSchedulerConfig::new(0.25) diff --git a/examples/text-generation/src/training.rs b/examples/text-generation/src/training.rs index eeb279a32..2d8ef45f0 100644 --- a/examples/text-generation/src/training.rs +++ b/examples/text-generation/src/training.rs @@ -40,18 +40,9 @@ pub fn train + 'static>( config: ExperimentConfig, artifact_dir: &str, ) { - let dataset_train = Arc::new(SamplerDataset::new(Box::new(dataset_train), 10_000)); - let dataset_test = Arc::new(SamplerDataset::new(Box::new(dataset_test), 1000)); - let tokenizer = Arc::new(Gpt2Tokenizer::default()); - let batcher_train = Arc::new(TextGenerationBatcher::new( - tokenizer.clone(), - config.max_seq_length, - )); - let batcher_test = Arc::new(TextGenerationBatcher::new( - tokenizer.clone(), - config.max_seq_length, - )); + let batcher_train = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); + let batcher_test = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length); let model = TextGenerationModelConfig::new( config.transformer.clone(), @@ -64,12 +55,12 @@ pub fn train + 'static>( let dataloader_train = DataLoaderBuilder::new(batcher_train) .batch_size(config.batch_size) .num_workers(4) - .build(dataset_train); + .build(SamplerDataset::new(dataset_train, 10_000)); let dataloader_test = DataLoaderBuilder::new(batcher_test) .batch_size(config.batch_size) .num_workers(4) - .build(dataset_test); + .build(SamplerDataset::new(dataset_test, 1000)); let accum = 6; // Effective batch size = 6 * 6 = 32. let optim = config.optimizer.init();