refactor: burn dataset (#293)

This commit is contained in:
Nathaniel Simard 2023-04-11 18:12:28 -04:00 committed by GitHub
parent 04bcf9550a
commit d4ce825725
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 182 additions and 115 deletions

View File

@ -14,9 +14,12 @@ where
I: Send + Sync + Clone + std::fmt::Debug + 'static,
O: Send + Sync + Clone + std::fmt::Debug + 'static,
{
pub fn new(batcher: Arc<dyn Batcher<I, O>>) -> Self {
pub fn new<B>(batcher: B) -> Self
where
B: Batcher<I, O> + 'static,
{
Self {
batcher,
batcher: Arc::new(batcher),
strategy: None,
num_threads: None,
shuffle: None,
@ -38,10 +41,13 @@ where
self
}
pub fn build(self, dataset: Arc<dyn Dataset<I>>) -> Arc<dyn DataLoader<O>> {
let dataset = match self.shuffle {
pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<O>>
where
D: Dataset<I> + 'static,
{
let dataset: Arc<dyn Dataset<I>> = match self.shuffle {
Some(seed) => Arc::new(ShuffledDataset::with_seed(dataset, seed)),
None => dataset,
None => Arc::new(dataset),
};
let strategy = match self.strategy {
Some(strategy) => strategy,

View File

@ -24,5 +24,6 @@ rand = {workspace = true, features = ["std"]}
serde = {workspace = true, features = ["std", "derive"]}
serde_json = {workspace = true, features = ["std"]}
thiserror = {workspace = true}
derive-new = {workspace = true}
[dev-dependencies]

View File

@ -1,5 +1,8 @@
use std::sync::Arc;
use crate::DatasetIterator;
/// The dataset trait defines a basic collection of items with a predefined size.
pub trait Dataset<I>: Send + Sync {
fn get(&self, index: usize) -> Option<I>;
fn len(&self) -> usize;
@ -13,3 +16,49 @@ pub trait Dataset<I>: Send + Sync {
DatasetIterator::new(self)
}
}
impl<D, I> Dataset<I> for Arc<D>
where
D: Dataset<I>,
{
fn get(&self, index: usize) -> Option<I> {
self.as_ref().get(index)
}
fn len(&self) -> usize {
self.as_ref().len()
}
}
impl<I> Dataset<I> for Arc<dyn Dataset<I>> {
fn get(&self, index: usize) -> Option<I> {
self.as_ref().get(index)
}
fn len(&self) -> usize {
self.as_ref().len()
}
}
impl<D, I> Dataset<I> for Box<D>
where
D: Dataset<I>,
{
fn get(&self, index: usize) -> Option<I> {
self.as_ref().get(index)
}
fn len(&self) -> usize {
self.as_ref().len()
}
}
impl<I> Dataset<I> for Box<dyn Dataset<I>> {
fn get(&self, index: usize) -> Option<I> {
self.as_ref().get(index)
}
fn len(&self) -> usize {
self.as_ref().len()
}
}

View File

@ -1,6 +1,7 @@
use crate::{Dataset, DatasetIterator, InMemDataset};
use fake::{Dummy, Fake, Faker};
/// Dataset filled with fake items generated from the [fake](fake) crate.
pub struct FakeDataset<I> {
dataset: InMemDataset<I>,
}

View File

@ -5,6 +5,7 @@ use std::{
use crate::Dataset;
/// Dataset where all items are stored in ram.
pub struct InMemDataset<I> {
items: Vec<I>,
}

View File

@ -1,6 +1,7 @@
use crate::dataset::Dataset;
use std::iter::Iterator;
/// Dataset iterator.
pub struct DatasetIterator<'a, I> {
current: usize,
dataset: &'a dyn Dataset<I>,

View File

@ -1,3 +1,6 @@
#[macro_use]
extern crate derive_new;
extern crate dirs;
pub mod source;

View File

@ -1,17 +1,14 @@
use crate::Dataset;
pub struct ComposedDataset<I> {
datasets: Vec<Box<dyn Dataset<I>>>,
/// Compose multiple datasets together to create a bigger one.
#[derive(new)]
pub struct ComposedDataset<D> {
datasets: Vec<D>,
}
impl<I> ComposedDataset<I> {
pub fn new(datasets: Vec<Box<dyn Dataset<I>>>) -> Self {
Self { datasets }
}
}
impl<I> Dataset<I> for ComposedDataset<I>
impl<D, I> Dataset<I> for ComposedDataset<D>
where
D: Dataset<I>,
I: Clone,
{
fn get(&self, index: usize) -> Option<I> {

View File

@ -1,22 +1,22 @@
use crate::Dataset;
use std::marker::PhantomData;
pub trait Mapper<I, O> {
/// Basic mapper trait to be used with the [mapper dataset](MapperDataset).
pub trait Mapper<I, O>: Send + Sync {
fn map(&self, item: &I) -> O;
}
pub struct MapperDataset<M, I> {
dataset: Box<dyn Dataset<I>>,
/// Dataset mapping each element in an inner dataset to another element type lazily.
#[derive(new)]
pub struct MapperDataset<D, M, I> {
dataset: D,
mapper: M,
input: PhantomData<I>,
}
impl<M, I> MapperDataset<M, I> {
pub fn new(dataset: Box<dyn Dataset<I>>, mapper: M) -> Self {
Self { dataset, mapper }
}
}
impl<M, I, O> Dataset<O> for MapperDataset<M, I>
impl<D, M, I, O> Dataset<O> for MapperDataset<D, M, I>
where
D: Dataset<I>,
M: Mapper<I, O> + Send + Sync,
I: Send + Sync,
O: Send + Sync,
@ -38,7 +38,8 @@ mod tests {
#[test]
pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() {
struct StringToFirstChar {}
struct StringToFirstChar;
impl Mapper<String, String> for StringToFirstChar {
fn map(&self, item: &String) -> String {
let mut item = item.clone();
@ -46,9 +47,10 @@ mod tests {
item
}
}
let items_original = test_data::string_items();
let dataset = InMemDataset::new(items_original);
let dataset = MapperDataset::new(Box::new(dataset), StringToFirstChar {});
let dataset = MapperDataset::new(dataset, StringToFirstChar);
let items: Vec<String> = dataset.iter().collect();

View File

@ -1,21 +1,22 @@
use crate::Dataset;
use std::sync::Arc;
use std::{marker::PhantomData, sync::Arc};
pub struct PartialDataset<I> {
dataset: Arc<dyn Dataset<I>>,
/// Only use a fraction of an existing dataset lazily.
#[derive(new)]
pub struct PartialDataset<D, I> {
dataset: D,
start_index: usize,
end_index: usize,
input: PhantomData<I>,
}
impl<I> PartialDataset<I> {
pub fn new(dataset: Arc<dyn Dataset<I>>, start_index: usize, end_index: usize) -> Self {
Self {
dataset,
start_index,
end_index,
}
}
pub fn split(dataset: Arc<dyn Dataset<I>>, num: usize) -> Vec<PartialDataset<I>> {
impl<D, I> PartialDataset<D, I>
where
D: Dataset<I>,
{
pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
let dataset = Arc::new(dataset); // cheap cloning.
let mut current = 0;
let mut datasets = Vec::with_capacity(num);
@ -39,8 +40,9 @@ impl<I> PartialDataset<I> {
}
}
impl<I> Dataset<I> for PartialDataset<I>
impl<D, I> Dataset<I> for PartialDataset<D, I>
where
D: Dataset<I>,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
@ -64,20 +66,18 @@ mod tests {
#[test]
fn test_start_from_beginning() {
let dataset_original = Arc::new(FakeDataset::<String>::new(27));
let dataset_partial = PartialDataset::new(dataset_original.clone(), 0, 10);
let dataset_original = FakeDataset::<String>::new(27);
let mut items_original_1 = HashSet::new();
let mut items_original_2 = HashSet::new();
let mut items_partial = HashSet::new();
dataset_original.iter().enumerate().for_each(|(i, item)| {
match i >= 10 {
true => items_original_2.insert(item),
false => items_original_1.insert(item),
};
});
for (i, item) in dataset_original.iter().enumerate() {
if i >= 10 {
items_original_2.insert(item);
} else {
items_original_1.insert(item);
}
}
let dataset_partial = PartialDataset::new(dataset_original, 0, 10);
for item in dataset_partial.iter() {
items_partial.insert(item);
@ -92,21 +92,19 @@ mod tests {
#[test]
fn test_start_inside() {
let dataset_original = Arc::new(FakeDataset::<String>::new(27));
let dataset_partial = PartialDataset::new(dataset_original.clone(), 10, 20);
let dataset_original = FakeDataset::<String>::new(27);
let mut items_original_1 = HashSet::new();
let mut items_original_2 = HashSet::new();
let mut items_partial = HashSet::new();
for (i, item) in dataset_original.iter().enumerate() {
if !(10..20).contains(&i) {
items_original_2.insert(item);
} else {
items_original_1.insert(item);
}
}
dataset_original.iter().enumerate().for_each(|(i, item)| {
match !(10..20).contains(&i) {
true => items_original_2.insert(item),
false => items_original_1.insert(item),
};
});
let dataset_partial = PartialDataset::new(dataset_original, 10, 20);
for item in dataset_partial.iter() {
items_partial.insert(item);
}
@ -120,16 +118,15 @@ mod tests {
#[test]
fn test_split_contains_all_items_without_duplicates() {
let dataset_original = Arc::new(FakeDataset::<String>::new(27));
let dataset_partials = PartialDataset::split(dataset_original.clone(), 4);
let dataset_original = FakeDataset::<String>::new(27);
let mut items_original = Vec::new();
let mut items_partial = Vec::new();
for item in dataset_original.iter() {
items_original.push(item);
}
let dataset_partials = PartialDataset::split(dataset_original, 4);
for dataset in dataset_partials {
for item in dataset.iter() {
items_partial.push(item);

View File

@ -1,32 +1,43 @@
use crate::Dataset;
use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng};
use std::sync::Arc;
use std::marker::PhantomData;
pub struct ShuffledDataset<I> {
dataset: Arc<dyn Dataset<I>>,
/// Shuffled a dataset, consider using [sampler dataset](crate::transform::SamplerDataset) is you
/// want a probability distribution that is computed lazily.
pub struct ShuffledDataset<D, I> {
dataset: D,
indexes: Vec<usize>,
input: PhantomData<I>,
}
impl<I> ShuffledDataset<I> {
pub fn new(dataset: Arc<dyn Dataset<I>>, rng: &mut StdRng) -> Self {
impl<D, I> ShuffledDataset<D, I>
where
D: Dataset<I>,
{
pub fn new(dataset: D, rng: &mut StdRng) -> Self {
let mut indexes = Vec::with_capacity(dataset.len());
for i in 0..dataset.len() {
indexes.push(i);
}
indexes.shuffle(rng);
Self { dataset, indexes }
Self {
dataset,
indexes,
input: PhantomData::default(),
}
}
pub fn with_seed(dataset: Arc<dyn Dataset<I>>, seed: u64) -> Self {
pub fn with_seed(dataset: D, seed: u64) -> Self {
let mut rng = StdRng::seed_from_u64(seed);
Self::new(dataset, &mut rng)
}
}
impl<I> Dataset<I> for ShuffledDataset<I>
impl<D, I> Dataset<I> for ShuffledDataset<D, I>
where
I: Clone,
D: Dataset<I>,
I: Clone + Send + Sync,
{
fn get(&self, index: usize) -> Option<I> {
let index = match self.indexes.get(index) {

View File

@ -1,32 +1,44 @@
use crate::Dataset;
use rand::{distributions::Uniform, rngs::StdRng, Rng, SeedableRng};
use std::sync::Mutex;
use std::{marker::PhantomData, sync::Mutex};
pub struct SamplerDataset<I> {
dataset: Box<dyn Dataset<I>>,
/// Sample items from a dataset with replacement.
///
/// This is an efficient way of modeling a dataset as a probability distribution of a fixed size.
pub struct SamplerDataset<D, I> {
dataset: D,
size: usize,
rng: Mutex<StdRng>,
input: PhantomData<I>,
}
impl<I> SamplerDataset<I> {
pub fn from_dataset<D: Dataset<I> + 'static>(dataset: D, size: usize) -> Self {
Self::new(Box::new(dataset), size)
}
pub fn new(dataset: Box<dyn Dataset<I>>, size: usize) -> Self {
impl<D, I> SamplerDataset<D, I>
where
D: Dataset<I>,
I: Send + Sync,
{
pub fn new(dataset: D, size: usize) -> Self {
let rng = Mutex::new(StdRng::from_entropy());
Self { dataset, size, rng }
Self {
dataset,
size,
rng,
input: PhantomData::default(),
}
}
fn index(&self) -> usize {
let distribution = Uniform::new(0, self.dataset.len());
let mut rng = self.rng.lock().unwrap();
rng.sample(distribution)
rng.sample(Uniform::new(0, self.dataset.len()))
}
}
impl<I> Dataset<I> for SamplerDataset<I> {
impl<D, I> Dataset<I> for SamplerDataset<D, I>
where
D: Dataset<I>,
I: Send + Sync,
{
fn get(&self, _index: usize) -> Option<I> {
self.dataset.get(self.index())
}

View File

@ -1,5 +1,3 @@
use std::sync::Arc;
use crate::data::MNISTBatcher;
use crate::model::Model;
@ -43,18 +41,19 @@ pub fn run<B: ADBackend>(device: B::Device) {
B::seed(config.seed);
// Data
let batcher_train = Arc::new(MNISTBatcher::<B>::new(device.clone()));
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend>::new(device.clone()));
let batcher_train = MNISTBatcher::<B>::new(device.clone());
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(Arc::new(MNISTDataset::train()));
.build(MNISTDataset::train());
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(Arc::new(MNISTDataset::test()));
.build(MNISTDataset::test());
// Model
let learner = LearnerBuilder::new(ARTIFACT_DIR)

View File

@ -37,25 +37,21 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
config: ExperimentConfig,
artifact_dir: &str,
) {
let dataset_train = Arc::new(SamplerDataset::new(Box::new(dataset_train), 50_000));
let dataset_test = Arc::new(SamplerDataset::new(Box::new(dataset_test), 5_000));
let n_classes = D::num_classes();
let tokenizer = Arc::new(BertCasedTokenizer::default());
let batcher_train = Arc::new(TextClassificationBatcher::<B>::new(
let batcher_train = TextClassificationBatcher::<B>::new(
tokenizer.clone(),
device.clone(),
config.max_seq_length,
));
let batcher_test = Arc::new(TextClassificationBatcher::<B::InnerBackend>::new(
);
let batcher_test = TextClassificationBatcher::<B::InnerBackend>::new(
tokenizer.clone(),
device.clone(),
config.max_seq_length,
));
);
let model = TextClassificationModelConfig::new(
config.transformer.clone(),
n_classes,
D::num_classes(),
tokenizer.vocab_size(),
config.max_seq_length,
)
@ -64,12 +60,12 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.num_workers(4)
.build(dataset_train);
.build(SamplerDataset::new(dataset_train, 50_000));
let dataloader_test = DataLoaderBuilder::new(batcher_test)
.batch_size(config.batch_size)
.num_workers(4)
.build(dataset_test);
.build(SamplerDataset::new(dataset_test, 5_000));
let optim = config.optimizer.init();
let lr_scheduler = NoamLRSchedulerConfig::new(0.25)

View File

@ -40,18 +40,9 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
config: ExperimentConfig,
artifact_dir: &str,
) {
let dataset_train = Arc::new(SamplerDataset::new(Box::new(dataset_train), 10_000));
let dataset_test = Arc::new(SamplerDataset::new(Box::new(dataset_test), 1000));
let tokenizer = Arc::new(Gpt2Tokenizer::default());
let batcher_train = Arc::new(TextGenerationBatcher::new(
tokenizer.clone(),
config.max_seq_length,
));
let batcher_test = Arc::new(TextGenerationBatcher::new(
tokenizer.clone(),
config.max_seq_length,
));
let batcher_train = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length);
let batcher_test = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length);
let model = TextGenerationModelConfig::new(
config.transformer.clone(),
@ -64,12 +55,12 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.num_workers(4)
.build(dataset_train);
.build(SamplerDataset::new(dataset_train, 10_000));
let dataloader_test = DataLoaderBuilder::new(batcher_test)
.batch_size(config.batch_size)
.num_workers(4)
.build(dataset_test);
.build(SamplerDataset::new(dataset_test, 1000));
let accum = 6; // Effective batch size = 6 * 6 = 32.
let optim = config.optimizer.init();