mirror of https://github.com/tracel-ai/burn.git
refactor: burn dataset (#293)
This commit is contained in:
parent
04bcf9550a
commit
d4ce825725
|
@ -14,9 +14,12 @@ where
|
||||||
I: Send + Sync + Clone + std::fmt::Debug + 'static,
|
I: Send + Sync + Clone + std::fmt::Debug + 'static,
|
||||||
O: Send + Sync + Clone + std::fmt::Debug + 'static,
|
O: Send + Sync + Clone + std::fmt::Debug + 'static,
|
||||||
{
|
{
|
||||||
pub fn new(batcher: Arc<dyn Batcher<I, O>>) -> Self {
|
pub fn new<B>(batcher: B) -> Self
|
||||||
|
where
|
||||||
|
B: Batcher<I, O> + 'static,
|
||||||
|
{
|
||||||
Self {
|
Self {
|
||||||
batcher,
|
batcher: Arc::new(batcher),
|
||||||
strategy: None,
|
strategy: None,
|
||||||
num_threads: None,
|
num_threads: None,
|
||||||
shuffle: None,
|
shuffle: None,
|
||||||
|
@ -38,10 +41,13 @@ where
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn build(self, dataset: Arc<dyn Dataset<I>>) -> Arc<dyn DataLoader<O>> {
|
pub fn build<D>(self, dataset: D) -> Arc<dyn DataLoader<O>>
|
||||||
let dataset = match self.shuffle {
|
where
|
||||||
|
D: Dataset<I> + 'static,
|
||||||
|
{
|
||||||
|
let dataset: Arc<dyn Dataset<I>> = match self.shuffle {
|
||||||
Some(seed) => Arc::new(ShuffledDataset::with_seed(dataset, seed)),
|
Some(seed) => Arc::new(ShuffledDataset::with_seed(dataset, seed)),
|
||||||
None => dataset,
|
None => Arc::new(dataset),
|
||||||
};
|
};
|
||||||
let strategy = match self.strategy {
|
let strategy = match self.strategy {
|
||||||
Some(strategy) => strategy,
|
Some(strategy) => strategy,
|
||||||
|
|
|
@ -24,5 +24,6 @@ rand = {workspace = true, features = ["std"]}
|
||||||
serde = {workspace = true, features = ["std", "derive"]}
|
serde = {workspace = true, features = ["std", "derive"]}
|
||||||
serde_json = {workspace = true, features = ["std"]}
|
serde_json = {workspace = true, features = ["std"]}
|
||||||
thiserror = {workspace = true}
|
thiserror = {workspace = true}
|
||||||
|
derive-new = {workspace = true}
|
||||||
|
|
||||||
[dev-dependencies]
|
[dev-dependencies]
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::DatasetIterator;
|
use crate::DatasetIterator;
|
||||||
|
|
||||||
|
/// The dataset trait defines a basic collection of items with a predefined size.
|
||||||
pub trait Dataset<I>: Send + Sync {
|
pub trait Dataset<I>: Send + Sync {
|
||||||
fn get(&self, index: usize) -> Option<I>;
|
fn get(&self, index: usize) -> Option<I>;
|
||||||
fn len(&self) -> usize;
|
fn len(&self) -> usize;
|
||||||
|
@ -13,3 +16,49 @@ pub trait Dataset<I>: Send + Sync {
|
||||||
DatasetIterator::new(self)
|
DatasetIterator::new(self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<D, I> Dataset<I> for Arc<D>
|
||||||
|
where
|
||||||
|
D: Dataset<I>,
|
||||||
|
{
|
||||||
|
fn get(&self, index: usize) -> Option<I> {
|
||||||
|
self.as_ref().get(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.as_ref().len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> Dataset<I> for Arc<dyn Dataset<I>> {
|
||||||
|
fn get(&self, index: usize) -> Option<I> {
|
||||||
|
self.as_ref().get(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.as_ref().len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<D, I> Dataset<I> for Box<D>
|
||||||
|
where
|
||||||
|
D: Dataset<I>,
|
||||||
|
{
|
||||||
|
fn get(&self, index: usize) -> Option<I> {
|
||||||
|
self.as_ref().get(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.as_ref().len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<I> Dataset<I> for Box<dyn Dataset<I>> {
|
||||||
|
fn get(&self, index: usize) -> Option<I> {
|
||||||
|
self.as_ref().get(index)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn len(&self) -> usize {
|
||||||
|
self.as_ref().len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::{Dataset, DatasetIterator, InMemDataset};
|
use crate::{Dataset, DatasetIterator, InMemDataset};
|
||||||
use fake::{Dummy, Fake, Faker};
|
use fake::{Dummy, Fake, Faker};
|
||||||
|
|
||||||
|
/// Dataset filled with fake items generated from the [fake](fake) crate.
|
||||||
pub struct FakeDataset<I> {
|
pub struct FakeDataset<I> {
|
||||||
dataset: InMemDataset<I>,
|
dataset: InMemDataset<I>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ use std::{
|
||||||
|
|
||||||
use crate::Dataset;
|
use crate::Dataset;
|
||||||
|
|
||||||
|
/// Dataset where all items are stored in ram.
|
||||||
pub struct InMemDataset<I> {
|
pub struct InMemDataset<I> {
|
||||||
items: Vec<I>,
|
items: Vec<I>,
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::dataset::Dataset;
|
use crate::dataset::Dataset;
|
||||||
use std::iter::Iterator;
|
use std::iter::Iterator;
|
||||||
|
|
||||||
|
/// Dataset iterator.
|
||||||
pub struct DatasetIterator<'a, I> {
|
pub struct DatasetIterator<'a, I> {
|
||||||
current: usize,
|
current: usize,
|
||||||
dataset: &'a dyn Dataset<I>,
|
dataset: &'a dyn Dataset<I>,
|
||||||
|
|
|
@ -1,3 +1,6 @@
|
||||||
|
#[macro_use]
|
||||||
|
extern crate derive_new;
|
||||||
|
|
||||||
extern crate dirs;
|
extern crate dirs;
|
||||||
|
|
||||||
pub mod source;
|
pub mod source;
|
||||||
|
|
|
@ -1,17 +1,14 @@
|
||||||
use crate::Dataset;
|
use crate::Dataset;
|
||||||
|
|
||||||
pub struct ComposedDataset<I> {
|
/// Compose multiple datasets together to create a bigger one.
|
||||||
datasets: Vec<Box<dyn Dataset<I>>>,
|
#[derive(new)]
|
||||||
|
pub struct ComposedDataset<D> {
|
||||||
|
datasets: Vec<D>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> ComposedDataset<I> {
|
impl<D, I> Dataset<I> for ComposedDataset<D>
|
||||||
pub fn new(datasets: Vec<Box<dyn Dataset<I>>>) -> Self {
|
|
||||||
Self { datasets }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<I> Dataset<I> for ComposedDataset<I>
|
|
||||||
where
|
where
|
||||||
|
D: Dataset<I>,
|
||||||
I: Clone,
|
I: Clone,
|
||||||
{
|
{
|
||||||
fn get(&self, index: usize) -> Option<I> {
|
fn get(&self, index: usize) -> Option<I> {
|
||||||
|
|
|
@ -1,22 +1,22 @@
|
||||||
use crate::Dataset;
|
use crate::Dataset;
|
||||||
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
pub trait Mapper<I, O> {
|
/// Basic mapper trait to be used with the [mapper dataset](MapperDataset).
|
||||||
|
pub trait Mapper<I, O>: Send + Sync {
|
||||||
fn map(&self, item: &I) -> O;
|
fn map(&self, item: &I) -> O;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct MapperDataset<M, I> {
|
/// Dataset mapping each element in an inner dataset to another element type lazily.
|
||||||
dataset: Box<dyn Dataset<I>>,
|
#[derive(new)]
|
||||||
|
pub struct MapperDataset<D, M, I> {
|
||||||
|
dataset: D,
|
||||||
mapper: M,
|
mapper: M,
|
||||||
|
input: PhantomData<I>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<M, I> MapperDataset<M, I> {
|
impl<D, M, I, O> Dataset<O> for MapperDataset<D, M, I>
|
||||||
pub fn new(dataset: Box<dyn Dataset<I>>, mapper: M) -> Self {
|
|
||||||
Self { dataset, mapper }
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
impl<M, I, O> Dataset<O> for MapperDataset<M, I>
|
|
||||||
where
|
where
|
||||||
|
D: Dataset<I>,
|
||||||
M: Mapper<I, O> + Send + Sync,
|
M: Mapper<I, O> + Send + Sync,
|
||||||
I: Send + Sync,
|
I: Send + Sync,
|
||||||
O: Send + Sync,
|
O: Send + Sync,
|
||||||
|
@ -38,7 +38,8 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() {
|
pub fn given_mapper_dataset_when_iterate_should_iterate_though_all_map_items() {
|
||||||
struct StringToFirstChar {}
|
struct StringToFirstChar;
|
||||||
|
|
||||||
impl Mapper<String, String> for StringToFirstChar {
|
impl Mapper<String, String> for StringToFirstChar {
|
||||||
fn map(&self, item: &String) -> String {
|
fn map(&self, item: &String) -> String {
|
||||||
let mut item = item.clone();
|
let mut item = item.clone();
|
||||||
|
@ -46,9 +47,10 @@ mod tests {
|
||||||
item
|
item
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let items_original = test_data::string_items();
|
let items_original = test_data::string_items();
|
||||||
let dataset = InMemDataset::new(items_original);
|
let dataset = InMemDataset::new(items_original);
|
||||||
let dataset = MapperDataset::new(Box::new(dataset), StringToFirstChar {});
|
let dataset = MapperDataset::new(dataset, StringToFirstChar);
|
||||||
|
|
||||||
let items: Vec<String> = dataset.iter().collect();
|
let items: Vec<String> = dataset.iter().collect();
|
||||||
|
|
||||||
|
|
|
@ -1,21 +1,22 @@
|
||||||
use crate::Dataset;
|
use crate::Dataset;
|
||||||
use std::sync::Arc;
|
use std::{marker::PhantomData, sync::Arc};
|
||||||
|
|
||||||
pub struct PartialDataset<I> {
|
/// Only use a fraction of an existing dataset lazily.
|
||||||
dataset: Arc<dyn Dataset<I>>,
|
#[derive(new)]
|
||||||
|
pub struct PartialDataset<D, I> {
|
||||||
|
dataset: D,
|
||||||
start_index: usize,
|
start_index: usize,
|
||||||
end_index: usize,
|
end_index: usize,
|
||||||
|
input: PhantomData<I>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> PartialDataset<I> {
|
impl<D, I> PartialDataset<D, I>
|
||||||
pub fn new(dataset: Arc<dyn Dataset<I>>, start_index: usize, end_index: usize) -> Self {
|
where
|
||||||
Self {
|
D: Dataset<I>,
|
||||||
dataset,
|
{
|
||||||
start_index,
|
pub fn split(dataset: D, num: usize) -> Vec<PartialDataset<Arc<D>, I>> {
|
||||||
end_index,
|
let dataset = Arc::new(dataset); // cheap cloning.
|
||||||
}
|
|
||||||
}
|
|
||||||
pub fn split(dataset: Arc<dyn Dataset<I>>, num: usize) -> Vec<PartialDataset<I>> {
|
|
||||||
let mut current = 0;
|
let mut current = 0;
|
||||||
let mut datasets = Vec::with_capacity(num);
|
let mut datasets = Vec::with_capacity(num);
|
||||||
|
|
||||||
|
@ -39,8 +40,9 @@ impl<I> PartialDataset<I> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> Dataset<I> for PartialDataset<I>
|
impl<D, I> Dataset<I> for PartialDataset<D, I>
|
||||||
where
|
where
|
||||||
|
D: Dataset<I>,
|
||||||
I: Clone + Send + Sync,
|
I: Clone + Send + Sync,
|
||||||
{
|
{
|
||||||
fn get(&self, index: usize) -> Option<I> {
|
fn get(&self, index: usize) -> Option<I> {
|
||||||
|
@ -64,20 +66,18 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_start_from_beginning() {
|
fn test_start_from_beginning() {
|
||||||
let dataset_original = Arc::new(FakeDataset::<String>::new(27));
|
let dataset_original = FakeDataset::<String>::new(27);
|
||||||
let dataset_partial = PartialDataset::new(dataset_original.clone(), 0, 10);
|
|
||||||
|
|
||||||
let mut items_original_1 = HashSet::new();
|
let mut items_original_1 = HashSet::new();
|
||||||
let mut items_original_2 = HashSet::new();
|
let mut items_original_2 = HashSet::new();
|
||||||
let mut items_partial = HashSet::new();
|
let mut items_partial = HashSet::new();
|
||||||
|
dataset_original.iter().enumerate().for_each(|(i, item)| {
|
||||||
|
match i >= 10 {
|
||||||
|
true => items_original_2.insert(item),
|
||||||
|
false => items_original_1.insert(item),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
|
||||||
for (i, item) in dataset_original.iter().enumerate() {
|
let dataset_partial = PartialDataset::new(dataset_original, 0, 10);
|
||||||
if i >= 10 {
|
|
||||||
items_original_2.insert(item);
|
|
||||||
} else {
|
|
||||||
items_original_1.insert(item);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for item in dataset_partial.iter() {
|
for item in dataset_partial.iter() {
|
||||||
items_partial.insert(item);
|
items_partial.insert(item);
|
||||||
|
@ -92,21 +92,19 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_start_inside() {
|
fn test_start_inside() {
|
||||||
let dataset_original = Arc::new(FakeDataset::<String>::new(27));
|
let dataset_original = FakeDataset::<String>::new(27);
|
||||||
let dataset_partial = PartialDataset::new(dataset_original.clone(), 10, 20);
|
|
||||||
|
|
||||||
let mut items_original_1 = HashSet::new();
|
let mut items_original_1 = HashSet::new();
|
||||||
let mut items_original_2 = HashSet::new();
|
let mut items_original_2 = HashSet::new();
|
||||||
let mut items_partial = HashSet::new();
|
let mut items_partial = HashSet::new();
|
||||||
|
|
||||||
for (i, item) in dataset_original.iter().enumerate() {
|
dataset_original.iter().enumerate().for_each(|(i, item)| {
|
||||||
if !(10..20).contains(&i) {
|
match !(10..20).contains(&i) {
|
||||||
items_original_2.insert(item);
|
true => items_original_2.insert(item),
|
||||||
} else {
|
false => items_original_1.insert(item),
|
||||||
items_original_1.insert(item);
|
};
|
||||||
}
|
});
|
||||||
}
|
|
||||||
|
|
||||||
|
let dataset_partial = PartialDataset::new(dataset_original, 10, 20);
|
||||||
for item in dataset_partial.iter() {
|
for item in dataset_partial.iter() {
|
||||||
items_partial.insert(item);
|
items_partial.insert(item);
|
||||||
}
|
}
|
||||||
|
@ -120,16 +118,15 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_split_contains_all_items_without_duplicates() {
|
fn test_split_contains_all_items_without_duplicates() {
|
||||||
let dataset_original = Arc::new(FakeDataset::<String>::new(27));
|
let dataset_original = FakeDataset::<String>::new(27);
|
||||||
let dataset_partials = PartialDataset::split(dataset_original.clone(), 4);
|
|
||||||
|
|
||||||
let mut items_original = Vec::new();
|
let mut items_original = Vec::new();
|
||||||
let mut items_partial = Vec::new();
|
let mut items_partial = Vec::new();
|
||||||
|
|
||||||
for item in dataset_original.iter() {
|
for item in dataset_original.iter() {
|
||||||
items_original.push(item);
|
items_original.push(item);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let dataset_partials = PartialDataset::split(dataset_original, 4);
|
||||||
|
|
||||||
for dataset in dataset_partials {
|
for dataset in dataset_partials {
|
||||||
for item in dataset.iter() {
|
for item in dataset.iter() {
|
||||||
items_partial.push(item);
|
items_partial.push(item);
|
||||||
|
|
|
@ -1,32 +1,43 @@
|
||||||
use crate::Dataset;
|
use crate::Dataset;
|
||||||
use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng};
|
use rand::{prelude::SliceRandom, rngs::StdRng, SeedableRng};
|
||||||
use std::sync::Arc;
|
use std::marker::PhantomData;
|
||||||
|
|
||||||
pub struct ShuffledDataset<I> {
|
/// Shuffled a dataset, consider using [sampler dataset](crate::transform::SamplerDataset) is you
|
||||||
dataset: Arc<dyn Dataset<I>>,
|
/// want a probability distribution that is computed lazily.
|
||||||
|
pub struct ShuffledDataset<D, I> {
|
||||||
|
dataset: D,
|
||||||
indexes: Vec<usize>,
|
indexes: Vec<usize>,
|
||||||
|
input: PhantomData<I>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> ShuffledDataset<I> {
|
impl<D, I> ShuffledDataset<D, I>
|
||||||
pub fn new(dataset: Arc<dyn Dataset<I>>, rng: &mut StdRng) -> Self {
|
where
|
||||||
|
D: Dataset<I>,
|
||||||
|
{
|
||||||
|
pub fn new(dataset: D, rng: &mut StdRng) -> Self {
|
||||||
let mut indexes = Vec::with_capacity(dataset.len());
|
let mut indexes = Vec::with_capacity(dataset.len());
|
||||||
for i in 0..dataset.len() {
|
for i in 0..dataset.len() {
|
||||||
indexes.push(i);
|
indexes.push(i);
|
||||||
}
|
}
|
||||||
indexes.shuffle(rng);
|
indexes.shuffle(rng);
|
||||||
|
|
||||||
Self { dataset, indexes }
|
Self {
|
||||||
|
dataset,
|
||||||
|
indexes,
|
||||||
|
input: PhantomData::default(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_seed(dataset: Arc<dyn Dataset<I>>, seed: u64) -> Self {
|
pub fn with_seed(dataset: D, seed: u64) -> Self {
|
||||||
let mut rng = StdRng::seed_from_u64(seed);
|
let mut rng = StdRng::seed_from_u64(seed);
|
||||||
Self::new(dataset, &mut rng)
|
Self::new(dataset, &mut rng)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> Dataset<I> for ShuffledDataset<I>
|
impl<D, I> Dataset<I> for ShuffledDataset<D, I>
|
||||||
where
|
where
|
||||||
I: Clone,
|
D: Dataset<I>,
|
||||||
|
I: Clone + Send + Sync,
|
||||||
{
|
{
|
||||||
fn get(&self, index: usize) -> Option<I> {
|
fn get(&self, index: usize) -> Option<I> {
|
||||||
let index = match self.indexes.get(index) {
|
let index = match self.indexes.get(index) {
|
||||||
|
|
|
@ -1,32 +1,44 @@
|
||||||
use crate::Dataset;
|
use crate::Dataset;
|
||||||
use rand::{distributions::Uniform, rngs::StdRng, Rng, SeedableRng};
|
use rand::{distributions::Uniform, rngs::StdRng, Rng, SeedableRng};
|
||||||
use std::sync::Mutex;
|
use std::{marker::PhantomData, sync::Mutex};
|
||||||
|
|
||||||
pub struct SamplerDataset<I> {
|
/// Sample items from a dataset with replacement.
|
||||||
dataset: Box<dyn Dataset<I>>,
|
///
|
||||||
|
/// This is an efficient way of modeling a dataset as a probability distribution of a fixed size.
|
||||||
|
pub struct SamplerDataset<D, I> {
|
||||||
|
dataset: D,
|
||||||
size: usize,
|
size: usize,
|
||||||
rng: Mutex<StdRng>,
|
rng: Mutex<StdRng>,
|
||||||
|
input: PhantomData<I>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> SamplerDataset<I> {
|
impl<D, I> SamplerDataset<D, I>
|
||||||
pub fn from_dataset<D: Dataset<I> + 'static>(dataset: D, size: usize) -> Self {
|
where
|
||||||
Self::new(Box::new(dataset), size)
|
D: Dataset<I>,
|
||||||
}
|
I: Send + Sync,
|
||||||
|
{
|
||||||
pub fn new(dataset: Box<dyn Dataset<I>>, size: usize) -> Self {
|
pub fn new(dataset: D, size: usize) -> Self {
|
||||||
let rng = Mutex::new(StdRng::from_entropy());
|
let rng = Mutex::new(StdRng::from_entropy());
|
||||||
|
|
||||||
Self { dataset, size, rng }
|
Self {
|
||||||
|
dataset,
|
||||||
|
size,
|
||||||
|
rng,
|
||||||
|
input: PhantomData::default(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn index(&self) -> usize {
|
fn index(&self) -> usize {
|
||||||
let distribution = Uniform::new(0, self.dataset.len());
|
|
||||||
let mut rng = self.rng.lock().unwrap();
|
let mut rng = self.rng.lock().unwrap();
|
||||||
rng.sample(distribution)
|
rng.sample(Uniform::new(0, self.dataset.len()))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<I> Dataset<I> for SamplerDataset<I> {
|
impl<D, I> Dataset<I> for SamplerDataset<D, I>
|
||||||
|
where
|
||||||
|
D: Dataset<I>,
|
||||||
|
I: Send + Sync,
|
||||||
|
{
|
||||||
fn get(&self, _index: usize) -> Option<I> {
|
fn get(&self, _index: usize) -> Option<I> {
|
||||||
self.dataset.get(self.index())
|
self.dataset.get(self.index())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
use std::sync::Arc;
|
|
||||||
|
|
||||||
use crate::data::MNISTBatcher;
|
use crate::data::MNISTBatcher;
|
||||||
use crate::model::Model;
|
use crate::model::Model;
|
||||||
|
|
||||||
|
@ -43,18 +41,19 @@ pub fn run<B: ADBackend>(device: B::Device) {
|
||||||
B::seed(config.seed);
|
B::seed(config.seed);
|
||||||
|
|
||||||
// Data
|
// Data
|
||||||
let batcher_train = Arc::new(MNISTBatcher::<B>::new(device.clone()));
|
let batcher_train = MNISTBatcher::<B>::new(device.clone());
|
||||||
let batcher_valid = Arc::new(MNISTBatcher::<B::InnerBackend>::new(device.clone()));
|
let batcher_valid = MNISTBatcher::<B::InnerBackend>::new(device.clone());
|
||||||
|
|
||||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(Arc::new(MNISTDataset::train()));
|
.build(MNISTDataset::train());
|
||||||
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
let dataloader_test = DataLoaderBuilder::new(batcher_valid)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.shuffle(config.seed)
|
.shuffle(config.seed)
|
||||||
.num_workers(config.num_workers)
|
.num_workers(config.num_workers)
|
||||||
.build(Arc::new(MNISTDataset::test()));
|
.build(MNISTDataset::test());
|
||||||
|
|
||||||
// Model
|
// Model
|
||||||
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
||||||
|
|
|
@ -37,25 +37,21 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
||||||
config: ExperimentConfig,
|
config: ExperimentConfig,
|
||||||
artifact_dir: &str,
|
artifact_dir: &str,
|
||||||
) {
|
) {
|
||||||
let dataset_train = Arc::new(SamplerDataset::new(Box::new(dataset_train), 50_000));
|
|
||||||
let dataset_test = Arc::new(SamplerDataset::new(Box::new(dataset_test), 5_000));
|
|
||||||
let n_classes = D::num_classes();
|
|
||||||
|
|
||||||
let tokenizer = Arc::new(BertCasedTokenizer::default());
|
let tokenizer = Arc::new(BertCasedTokenizer::default());
|
||||||
let batcher_train = Arc::new(TextClassificationBatcher::<B>::new(
|
let batcher_train = TextClassificationBatcher::<B>::new(
|
||||||
tokenizer.clone(),
|
tokenizer.clone(),
|
||||||
device.clone(),
|
device.clone(),
|
||||||
config.max_seq_length,
|
config.max_seq_length,
|
||||||
));
|
);
|
||||||
let batcher_test = Arc::new(TextClassificationBatcher::<B::InnerBackend>::new(
|
let batcher_test = TextClassificationBatcher::<B::InnerBackend>::new(
|
||||||
tokenizer.clone(),
|
tokenizer.clone(),
|
||||||
device.clone(),
|
device.clone(),
|
||||||
config.max_seq_length,
|
config.max_seq_length,
|
||||||
));
|
);
|
||||||
|
|
||||||
let model = TextClassificationModelConfig::new(
|
let model = TextClassificationModelConfig::new(
|
||||||
config.transformer.clone(),
|
config.transformer.clone(),
|
||||||
n_classes,
|
D::num_classes(),
|
||||||
tokenizer.vocab_size(),
|
tokenizer.vocab_size(),
|
||||||
config.max_seq_length,
|
config.max_seq_length,
|
||||||
)
|
)
|
||||||
|
@ -64,12 +60,12 @@ pub fn train<B: ADBackend, D: TextClassificationDataset + 'static>(
|
||||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.num_workers(4)
|
.num_workers(4)
|
||||||
.build(dataset_train);
|
.build(SamplerDataset::new(dataset_train, 50_000));
|
||||||
|
|
||||||
let dataloader_test = DataLoaderBuilder::new(batcher_test)
|
let dataloader_test = DataLoaderBuilder::new(batcher_test)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.num_workers(4)
|
.num_workers(4)
|
||||||
.build(dataset_test);
|
.build(SamplerDataset::new(dataset_test, 5_000));
|
||||||
|
|
||||||
let optim = config.optimizer.init();
|
let optim = config.optimizer.init();
|
||||||
let lr_scheduler = NoamLRSchedulerConfig::new(0.25)
|
let lr_scheduler = NoamLRSchedulerConfig::new(0.25)
|
||||||
|
|
|
@ -40,18 +40,9 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
||||||
config: ExperimentConfig,
|
config: ExperimentConfig,
|
||||||
artifact_dir: &str,
|
artifact_dir: &str,
|
||||||
) {
|
) {
|
||||||
let dataset_train = Arc::new(SamplerDataset::new(Box::new(dataset_train), 10_000));
|
|
||||||
let dataset_test = Arc::new(SamplerDataset::new(Box::new(dataset_test), 1000));
|
|
||||||
|
|
||||||
let tokenizer = Arc::new(Gpt2Tokenizer::default());
|
let tokenizer = Arc::new(Gpt2Tokenizer::default());
|
||||||
let batcher_train = Arc::new(TextGenerationBatcher::new(
|
let batcher_train = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length);
|
||||||
tokenizer.clone(),
|
let batcher_test = TextGenerationBatcher::new(tokenizer.clone(), config.max_seq_length);
|
||||||
config.max_seq_length,
|
|
||||||
));
|
|
||||||
let batcher_test = Arc::new(TextGenerationBatcher::new(
|
|
||||||
tokenizer.clone(),
|
|
||||||
config.max_seq_length,
|
|
||||||
));
|
|
||||||
|
|
||||||
let model = TextGenerationModelConfig::new(
|
let model = TextGenerationModelConfig::new(
|
||||||
config.transformer.clone(),
|
config.transformer.clone(),
|
||||||
|
@ -64,12 +55,12 @@ pub fn train<B: ADBackend, D: Dataset<TextGenerationItem> + 'static>(
|
||||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.num_workers(4)
|
.num_workers(4)
|
||||||
.build(dataset_train);
|
.build(SamplerDataset::new(dataset_train, 10_000));
|
||||||
|
|
||||||
let dataloader_test = DataLoaderBuilder::new(batcher_test)
|
let dataloader_test = DataLoaderBuilder::new(batcher_test)
|
||||||
.batch_size(config.batch_size)
|
.batch_size(config.batch_size)
|
||||||
.num_workers(4)
|
.num_workers(4)
|
||||||
.build(dataset_test);
|
.build(SamplerDataset::new(dataset_test, 1000));
|
||||||
|
|
||||||
let accum = 6; // Effective batch size = 6 * 6 = 32.
|
let accum = 6; // Effective batch size = 6 * 6 = 32.
|
||||||
let optim = config.optimizer.init();
|
let optim = config.optimizer.init();
|
||||||
|
|
Loading…
Reference in New Issue