mirror of https://github.com/tracel-ai/burn.git
Add vision/mnist dataset (#1176)
This commit is contained in:
parent
0368409eb3
commit
b9bd42959b
|
@ -82,6 +82,7 @@ strum_macros = "0.25.3"
|
|||
syn = { version = "2.0", features = ["full", "extra-traits"] }
|
||||
tempfile = "3.8.1"
|
||||
thiserror = "1.0.50"
|
||||
tokio = { version = "1.35.1", features = ["rt", "macros"] }
|
||||
tracing-appender = "0.2.3"
|
||||
tracing-core = "0.1.32"
|
||||
tracing-subscriber = "0.3.18"
|
||||
|
|
|
@ -11,7 +11,7 @@ used as input to our previously defined model.
|
|||
|
||||
```rust , ignore
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem},
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||
};
|
||||
|
||||
|
|
|
@ -165,16 +165,17 @@ Tensor {
|
|||
```
|
||||
|
||||
While the previous example is somewhat trivial, the upcoming
|
||||
[basic workflow section](./basic-workflow/README.md) will walk you through a much more relevant
|
||||
example for deep learning applications.
|
||||
[basic workflow section](./basic-workflow/) will walk you through a much more relevant example for
|
||||
deep learning applications.
|
||||
|
||||
## Running examples
|
||||
|
||||
Burn uses a [Python library by HuggingFace](https://huggingface.co/docs/datasets/index) to download
|
||||
datasets. Therefore, in order to run examples, you will need to install Python. Follow the
|
||||
instructions on the [official website](https://www.python.org/downloads/) to install Python on your
|
||||
computer.
|
||||
|
||||
Many Burn examples are available in the
|
||||
Many additional Burn examples available in the
|
||||
[examples](https://github.com/tracel-ai/burn/tree/main/examples) directory. To run one, please refer
|
||||
to the example's README.md for the specific command to execute.
|
||||
|
||||
Note that some examples use the
|
||||
[`datasets` library by HuggingFace](https://huggingface.co/docs/datasets/index) to download the
|
||||
datasets required in the examples. This is a Python library, which means that you will need to
|
||||
install Python before running these examples. This requirement will be clearly indicated in the
|
||||
example's README when applicable.
|
||||
|
|
|
@ -42,6 +42,7 @@ std = [
|
|||
dataset = ["burn-dataset"]
|
||||
sqlite = ["burn-dataset?/sqlite"]
|
||||
sqlite-bundled = ["burn-dataset?/sqlite-bundled"]
|
||||
vision = ["burn-dataset?/vision"]
|
||||
|
||||
wasm-sync = ["burn-tensor/wasm-sync", "burn-common/wasm-sync"]
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ fake = ["dep:fake"]
|
|||
sqlite = ["__sqlite-shared", "dep:rusqlite"]
|
||||
sqlite-bundled = ["__sqlite-shared", "rusqlite/bundled"]
|
||||
|
||||
vision = ["dep:flate2", "dep:indicatif", "dep:reqwest", "dep:tokio"]
|
||||
|
||||
# internal
|
||||
__sqlite-shared = [
|
||||
"dep:r2d2",
|
||||
|
@ -34,12 +36,15 @@ csv = { workspace = true }
|
|||
derive-new = { workspace = true }
|
||||
dirs = { workspace = true }
|
||||
fake = { workspace = true, optional = true }
|
||||
flate2 = { workspace = true, optional = true }
|
||||
gix-tempfile = { workspace = true, optional = true }
|
||||
hound = { version = "3.5.1", optional = true }
|
||||
image = { version = "0.24.7", features = ["png"], optional = true }
|
||||
indicatif = { workspace = true, optional = true }
|
||||
r2d2 = { workspace = true, optional = true }
|
||||
r2d2_sqlite = { workspace = true, optional = true }
|
||||
rand = { workspace = true, features = ["std"] }
|
||||
reqwest = { workspace = true, optional = true }
|
||||
rmp-serde = { workspace = true }
|
||||
rusqlite = { workspace = true, optional = true }
|
||||
sanitize-filename = { workspace = true }
|
||||
|
@ -50,6 +55,7 @@ strum = { workspace = true }
|
|||
strum_macros = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
tokio = { workspace = true, optional = true }
|
||||
|
||||
[dev-dependencies]
|
||||
rayon = { workspace = true }
|
||||
|
|
|
@ -19,6 +19,10 @@ pub mod transform;
|
|||
#[cfg(feature = "audio")]
|
||||
pub mod audio;
|
||||
|
||||
/// Vision datasets.
|
||||
#[cfg(feature = "vision")]
|
||||
pub mod vision;
|
||||
|
||||
mod dataset;
|
||||
pub use dataset::*;
|
||||
#[cfg(any(feature = "sqlite", feature = "sqlite-bundled"))]
|
||||
|
|
|
@ -1,92 +0,0 @@
|
|||
use crate::source::huggingface::downloader::HuggingfaceDatasetLoader;
|
||||
use crate::transform::{Mapper, MapperDataset};
|
||||
use crate::{Dataset, SqliteDataset};
|
||||
|
||||
use image;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const WIDTH: usize = 28;
|
||||
const HEIGHT: usize = 28;
|
||||
|
||||
/// MNIST item.
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct MNISTItem {
|
||||
/// Image as a 2D array of floats.
|
||||
pub image: [[f32; WIDTH]; HEIGHT],
|
||||
|
||||
/// Label of the image.
|
||||
pub label: usize,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct MNISTItemRaw {
|
||||
pub image_bytes: Vec<u8>,
|
||||
pub label: usize,
|
||||
}
|
||||
|
||||
struct BytesToImage;
|
||||
|
||||
impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
|
||||
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
|
||||
fn map(&self, item: &MNISTItemRaw) -> MNISTItem {
|
||||
let image = image::load_from_memory(&item.image_bytes).unwrap();
|
||||
let image = image.as_luma8().unwrap();
|
||||
|
||||
// Ensure the image dimensions are correct.
|
||||
debug_assert_eq!(image.dimensions(), (WIDTH as u32, HEIGHT as u32));
|
||||
|
||||
// Convert the image to a 2D array of floats.
|
||||
let mut image_array = [[0f32; WIDTH]; HEIGHT];
|
||||
for (i, pixel) in image.as_raw().iter().enumerate() {
|
||||
let x = i % WIDTH;
|
||||
let y = i / HEIGHT;
|
||||
image_array[y][x] = *pixel as f32;
|
||||
}
|
||||
|
||||
MNISTItem {
|
||||
image: image_array,
|
||||
label: item.label,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MappedDataset = MapperDataset<SqliteDataset<MNISTItemRaw>, BytesToImage, MNISTItemRaw>;
|
||||
|
||||
/// MNIST dataset from Huggingface.
|
||||
///
|
||||
/// The data is downloaded from Huggingface and stored in a SQLite database.
|
||||
pub struct MNISTDataset {
|
||||
dataset: MappedDataset,
|
||||
}
|
||||
|
||||
impl Dataset<MNISTItem> for MNISTDataset {
|
||||
fn get(&self, index: usize) -> Option<MNISTItem> {
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl MNISTDataset {
|
||||
/// Creates a new train dataset.
|
||||
pub fn train() -> Self {
|
||||
Self::new("train")
|
||||
}
|
||||
|
||||
/// Creates a new test dataset.
|
||||
pub fn test() -> Self {
|
||||
Self::new("test")
|
||||
}
|
||||
|
||||
fn new(split: &str) -> Self {
|
||||
let dataset = HuggingfaceDatasetLoader::new("mnist")
|
||||
.dataset(split)
|
||||
.unwrap();
|
||||
|
||||
let dataset = MapperDataset::new(dataset, BytesToImage);
|
||||
|
||||
Self { dataset }
|
||||
}
|
||||
}
|
|
@ -1,5 +1,3 @@
|
|||
pub(crate) mod downloader;
|
||||
mod mnist;
|
||||
|
||||
pub use downloader::*;
|
||||
pub use mnist::*;
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
use indicatif::{ProgressBar, ProgressState, ProgressStyle};
|
||||
use reqwest::Client;
|
||||
use std::cmp::min;
|
||||
use std::io::Write;
|
||||
|
||||
/// Download the file at the specified url to a bytes vector.
|
||||
/// File download progress is reported with the help of a [progress bar](indicatif).
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
pub async fn download_file_as_bytes(url: &str, message: &str) -> Vec<u8> {
|
||||
// Get file from web
|
||||
let mut response = Client::new().get(url).send().await.unwrap();
|
||||
let total_size = response.content_length().unwrap();
|
||||
|
||||
// Pretty progress bar
|
||||
let pb = ProgressBar::new(total_size);
|
||||
let msg = message.to_owned();
|
||||
pb.set_style(
|
||||
ProgressStyle::with_template(
|
||||
"{msg}\n {wide_bar:.cyan/blue} {bytes}/{total_bytes} ({eta})",
|
||||
)
|
||||
.unwrap()
|
||||
.with_key(
|
||||
"eta",
|
||||
|state: &ProgressState, w: &mut dyn std::fmt::Write| {
|
||||
write!(w, "{:.1}s", state.eta().as_secs_f64()).unwrap()
|
||||
},
|
||||
)
|
||||
.progress_chars("▬ "),
|
||||
);
|
||||
pb.set_message(msg.clone());
|
||||
|
||||
// Read stream into bytes
|
||||
let mut downloaded: u64 = 0;
|
||||
let mut bytes: Vec<u8> = Vec::with_capacity(total_size as usize);
|
||||
while let Some(chunk) = response.chunk().await.unwrap() {
|
||||
let num_bytes = bytes.write(&chunk).unwrap();
|
||||
let new = min(downloaded + (num_bytes as u64), total_size);
|
||||
downloaded = new;
|
||||
pb.set_position(new);
|
||||
}
|
||||
pb.finish_with_message(msg);
|
||||
|
||||
bytes
|
||||
}
|
|
@ -0,0 +1,220 @@
|
|||
use std::fs::{create_dir_all, File};
|
||||
use std::io::{Read, Seek, SeekFrom};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use flate2::read::GzDecoder;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::{
|
||||
transform::{Mapper, MapperDataset},
|
||||
Dataset, InMemDataset,
|
||||
};
|
||||
|
||||
// CVDF mirror of http://yann.lecun.com/exdb/mnist/
|
||||
const URL: &str = "https://storage.googleapis.com/cvdf-datasets/mnist/";
|
||||
const TRAIN_IMAGES: &str = "train-images-idx3-ubyte";
|
||||
const TRAIN_LABELS: &str = "train-labels-idx1-ubyte";
|
||||
const TEST_IMAGES: &str = "t10k-images-idx3-ubyte";
|
||||
const TEST_LABELS: &str = "t10k-labels-idx1-ubyte";
|
||||
|
||||
const WIDTH: usize = 28;
|
||||
const HEIGHT: usize = 28;
|
||||
|
||||
/// MNIST item.
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
pub struct MNISTItem {
|
||||
/// Image as a 2D array of floats.
|
||||
pub image: [[f32; WIDTH]; HEIGHT],
|
||||
|
||||
/// Label of the image.
|
||||
pub label: u8,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug, Clone)]
|
||||
struct MNISTItemRaw {
|
||||
pub image_bytes: Vec<u8>,
|
||||
pub label: u8,
|
||||
}
|
||||
|
||||
struct BytesToImage;
|
||||
|
||||
impl Mapper<MNISTItemRaw, MNISTItem> for BytesToImage {
|
||||
/// Convert a raw MNIST item (image bytes) to a MNIST item (2D array image).
|
||||
fn map(&self, item: &MNISTItemRaw) -> MNISTItem {
|
||||
// Ensure the image dimensions are correct.
|
||||
debug_assert_eq!(item.image_bytes.len(), WIDTH * HEIGHT);
|
||||
|
||||
// Convert the image to a 2D array of floats.
|
||||
let mut image_array = [[0f32; WIDTH]; HEIGHT];
|
||||
for (i, pixel) in item.image_bytes.iter().enumerate() {
|
||||
let x = i % WIDTH;
|
||||
let y = i / HEIGHT;
|
||||
image_array[y][x] = *pixel as f32;
|
||||
}
|
||||
|
||||
MNISTItem {
|
||||
image: image_array,
|
||||
label: item.label,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type MappedDataset = MapperDataset<InMemDataset<MNISTItemRaw>, BytesToImage, MNISTItemRaw>;
|
||||
|
||||
/// The MNIST dataset consists of 70,000 28x28 black-and-white images in 10 classes (one for each digits), with 7,000
|
||||
/// images per class. There are 60,000 training images and 10,000 test images.
|
||||
///
|
||||
/// The data is downloaded from the web from the [CVDF mirror](https://github.com/cvdfoundation/mnist).
|
||||
pub struct MNISTDataset {
|
||||
dataset: MappedDataset,
|
||||
}
|
||||
|
||||
impl Dataset<MNISTItem> for MNISTDataset {
|
||||
fn get(&self, index: usize) -> Option<MNISTItem> {
|
||||
self.dataset.get(index)
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl MNISTDataset {
|
||||
/// Creates a new train dataset.
|
||||
pub fn train() -> Self {
|
||||
Self::new("train")
|
||||
}
|
||||
|
||||
/// Creates a new test dataset.
|
||||
pub fn test() -> Self {
|
||||
Self::new("test")
|
||||
}
|
||||
|
||||
fn new(split: &str) -> Self {
|
||||
// Download dataset
|
||||
let root = MNISTDataset::download(split);
|
||||
|
||||
// MNIST is tiny so we can load it in-memory
|
||||
// Train images (u8): 28 * 28 * 60000 = 47.04Mb
|
||||
// Test images (u8): 28 * 28 * 10000 = 7.84Mb
|
||||
let images = MNISTDataset::read_images(&root, split);
|
||||
let labels = MNISTDataset::read_labels(&root, split);
|
||||
|
||||
// Collect as vector of MNISTItemRaw
|
||||
let items: Vec<_> = images
|
||||
.into_iter()
|
||||
.zip(labels)
|
||||
.map(|(image_bytes, label)| MNISTItemRaw { image_bytes, label })
|
||||
.collect();
|
||||
|
||||
let dataset = InMemDataset::new(items);
|
||||
let dataset = MapperDataset::new(dataset, BytesToImage);
|
||||
|
||||
Self { dataset }
|
||||
}
|
||||
|
||||
/// Download the MNIST dataset files from the web.
|
||||
/// Panics if the download cannot be completed or the content of the file cannot be written to disk.
|
||||
fn download(split: &str) -> PathBuf {
|
||||
// Dataset files are stored un the burn-dataset cache directory
|
||||
let cache_dir = dirs::home_dir()
|
||||
.expect("Could not get home directory")
|
||||
.join(".cache")
|
||||
.join("burn-dataset");
|
||||
let split_dir = cache_dir.join("mnist").join(split);
|
||||
|
||||
if !split_dir.exists() {
|
||||
create_dir_all(&split_dir).expect("Failed to create base directory");
|
||||
}
|
||||
|
||||
// Download split files
|
||||
match split {
|
||||
"train" => {
|
||||
MNISTDataset::download_file(TRAIN_IMAGES, &split_dir);
|
||||
MNISTDataset::download_file(TRAIN_LABELS, &split_dir);
|
||||
}
|
||||
"test" => {
|
||||
MNISTDataset::download_file(TEST_IMAGES, &split_dir);
|
||||
MNISTDataset::download_file(TEST_LABELS, &split_dir);
|
||||
}
|
||||
_ => panic!("Invalid split specified {}", split),
|
||||
};
|
||||
|
||||
split_dir
|
||||
}
|
||||
|
||||
/// Download a file from the MNIST dataset URL to the destination directory.
|
||||
/// File download progress is reported with the help of a [progress bar](indicatif).
|
||||
fn download_file<P: AsRef<Path>>(name: &str, dest_dir: &P) -> PathBuf {
|
||||
// Output file name
|
||||
let file_name = dest_dir.as_ref().join(name);
|
||||
|
||||
if !file_name.exists() {
|
||||
// Download gzip file
|
||||
let bytes = super::downloader::download_file_as_bytes(&format!("{URL}{name}.gz"), name);
|
||||
|
||||
// Create file to write the downloaded content to
|
||||
let mut output_file = File::create(&file_name).unwrap();
|
||||
|
||||
// Decode gzip file content and write to disk
|
||||
let mut gz_buffer = GzDecoder::new(&bytes[..]);
|
||||
std::io::copy(&mut gz_buffer, &mut output_file).unwrap();
|
||||
}
|
||||
|
||||
file_name
|
||||
}
|
||||
|
||||
/// Read images at the provided path for the specified split.
|
||||
/// Each image is a vector of bytes.
|
||||
fn read_images<P: AsRef<Path>>(root: &P, split: &str) -> Vec<Vec<u8>> {
|
||||
let file_name = if split == "train" {
|
||||
TRAIN_IMAGES
|
||||
} else {
|
||||
TEST_IMAGES
|
||||
};
|
||||
let file_name = root.as_ref().join(file_name);
|
||||
|
||||
// Read number of images from 16-byte header metadata
|
||||
let mut f = File::open(file_name).unwrap();
|
||||
let mut buf = [0u8; 4];
|
||||
let _ = f.seek(SeekFrom::Start(4)).unwrap();
|
||||
f.read_exact(&mut buf)
|
||||
.expect("Should be able to read image file header");
|
||||
let size = u32::from_be_bytes(buf);
|
||||
|
||||
let mut buf_images: Vec<u8> = vec![0u8; WIDTH * HEIGHT * (size as usize)];
|
||||
let _ = f.seek(SeekFrom::Start(16)).unwrap();
|
||||
f.read_exact(&mut buf_images)
|
||||
.expect("Should be able to read image file header");
|
||||
|
||||
buf_images
|
||||
.chunks(WIDTH * HEIGHT)
|
||||
.map(|chunk| chunk.to_vec())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Read labels at the provided path for the specified split.
|
||||
fn read_labels<P: AsRef<Path>>(root: &P, split: &str) -> Vec<u8> {
|
||||
let file_name = if split == "train" {
|
||||
TRAIN_LABELS
|
||||
} else {
|
||||
TEST_LABELS
|
||||
};
|
||||
let file_name = root.as_ref().join(file_name);
|
||||
|
||||
// Read number of labels from 8-byte header metadata
|
||||
let mut f = File::open(file_name).unwrap();
|
||||
let mut buf = [0u8; 4];
|
||||
let _ = f.seek(SeekFrom::Start(4)).unwrap();
|
||||
f.read_exact(&mut buf)
|
||||
.expect("Should be able to read label file header");
|
||||
let size = u32::from_be_bytes(buf);
|
||||
|
||||
let mut buf_labels: Vec<u8> = vec![0u8; size as usize];
|
||||
let _ = f.seek(SeekFrom::Start(8)).unwrap();
|
||||
f.read_exact(&mut buf_labels)
|
||||
.expect("Should be able to read labels from file");
|
||||
|
||||
buf_labels
|
||||
}
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
mod downloader;
|
||||
mod mnist;
|
||||
|
||||
pub use mnist::*;
|
|
@ -33,6 +33,8 @@ dataset = ["burn-core/dataset"]
|
|||
sqlite = ["burn-core/sqlite"]
|
||||
sqlite-bundled = ["burn-core/sqlite-bundled"]
|
||||
|
||||
vision = ["burn-core/vision"]
|
||||
|
||||
# Backends
|
||||
autodiff = ["burn-core/autodiff"]
|
||||
fusion = ["burn-core/fusion"]
|
||||
|
|
|
@ -61,6 +61,7 @@
|
|||
//! - `audio`: Enables audio datasets (SpeechCommandsDataset)
|
||||
//! - `sqlite`: Stores datasets in SQLite database
|
||||
//! - `sqlite_bundled`: Use bundled version of SQLite
|
||||
//! - `vision`: Enables vision datasets (MNISTDataset)
|
||||
//! - Backends
|
||||
//! - `wgpu`: Makes available the WGPU backend
|
||||
//! - `candle`: Makes available the Candle backend
|
||||
|
|
|
@ -8,7 +8,7 @@ publish = false
|
|||
version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../../burn", features=["autodiff", "wgpu", "train", "dataset"], default-features=false}
|
||||
burn = {path = "../../burn", features=["autodiff", "wgpu", "train", "dataset", "vision"], default-features=false}
|
||||
guide = {path = "../guide", default-features=false}
|
||||
|
||||
# Serialization
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use burn::data::dataset::source::huggingface::MNISTDataset;
|
||||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::train::renderer::{MetricState, MetricsRenderer, TrainingProgress};
|
||||
use burn::train::LearnerBuilder;
|
||||
use burn::{
|
||||
|
|
|
@ -7,7 +7,7 @@ publish = false
|
|||
version.workspace = true
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../../burn", features=["autodiff", "wgpu"]}
|
||||
burn = {path = "../../burn", features=["autodiff", "wgpu", "vision"]}
|
||||
guide = {path = "../guide"}
|
||||
|
||||
# Serialization
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn::data::dataset::source::huggingface::MNISTDataset;
|
||||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::DataLoaderBuilder,
|
||||
|
|
|
@ -10,7 +10,7 @@ version.workspace = true
|
|||
default = ["burn/default"]
|
||||
|
||||
[dependencies]
|
||||
burn = {path = "../../burn", features = ["wgpu", "train"]}
|
||||
burn = {path = "../../burn", features = ["wgpu", "train", "vision"]}
|
||||
|
||||
# Serialization
|
||||
log = {workspace = true}
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# Basic Workflow: From Training to Inference
|
||||
|
||||
This example corresponds to the [book's guide](https://burn.dev/book/basic-workflow/).
|
||||
|
||||
## Example Usage
|
||||
|
||||
```sh
|
||||
cargo run --example guide
|
||||
```
|
|
@ -18,7 +18,7 @@ fn main() {
|
|||
guide::inference::infer::<MyBackend>(
|
||||
artifact_dir,
|
||||
device,
|
||||
burn::data::dataset::source::huggingface::MNISTDataset::test()
|
||||
burn::data::dataset::vision::MNISTDataset::test()
|
||||
.get(42)
|
||||
.unwrap(),
|
||||
);
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem},
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use crate::{data::MNISTBatcher, training::TrainingConfig};
|
||||
use burn::data::dataset::source::huggingface::MNISTItem;
|
||||
use burn::data::dataset::vision::MNISTItem;
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::dataloader::batcher::Batcher,
|
||||
|
|
|
@ -2,7 +2,7 @@ use crate::{
|
|||
data::{MNISTBatch, MNISTBatcher},
|
||||
model::{Model, ModelConfig},
|
||||
};
|
||||
use burn::data::dataset::source::huggingface::MNISTDataset;
|
||||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
ClassificationOutput, LearnerBuilder, TrainOutput, TrainStep, ValidStep,
|
||||
|
|
|
@ -7,7 +7,7 @@ publish = false
|
|||
version.workspace = true
|
||||
|
||||
[features]
|
||||
default = ["burn/dataset", "burn/sqlite-bundled"]
|
||||
default = ["burn/dataset", "burn/vision"]
|
||||
ndarray = ["burn/ndarray"]
|
||||
ndarray-blas-accelerate = ["burn/ndarray", "burn/accelerate"]
|
||||
ndarray-blas-netlib = ["burn/ndarray", "burn/blas-netlib"]
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::source::huggingface::MNISTItem},
|
||||
data::{dataloader::batcher::Batcher, dataset::vision::MNISTItem},
|
||||
tensor::{backend::Backend, Data, ElementConversion, Int, Tensor},
|
||||
};
|
||||
|
||||
|
|
|
@ -10,7 +10,7 @@ use burn::train::metric::{CpuMemory, CpuTemperature, CpuUse};
|
|||
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
|
||||
use burn::{
|
||||
config::Config,
|
||||
data::{dataloader::DataLoaderBuilder, dataset::source::huggingface::MNISTDataset},
|
||||
data::{dataloader::DataLoaderBuilder, dataset::vision::MNISTDataset},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{AccuracyMetric, LossMetric},
|
||||
|
|
|
@ -12,7 +12,7 @@ default = ["embedded-model"]
|
|||
embedded-model = []
|
||||
|
||||
[dependencies]
|
||||
burn = { path = "../../burn", features = ["ndarray", "dataset", "sqlite-bundled"] }
|
||||
burn = { path = "../../burn", features = ["ndarray", "dataset", "vision"] }
|
||||
serde = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
|
|
|
@ -3,7 +3,7 @@ use std::env::args;
|
|||
use burn::backend::ndarray::NdArray;
|
||||
use burn::tensor::Tensor;
|
||||
|
||||
use burn::data::dataset::source::huggingface::MNISTDataset;
|
||||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::data::dataset::Dataset;
|
||||
|
||||
use onnx_inference::mnist::Model;
|
||||
|
@ -49,7 +49,7 @@ fn main() {
|
|||
let output = model.forward(input);
|
||||
|
||||
// Get the index of the maximum value
|
||||
let arg_max = output.argmax(1).into_scalar() as usize;
|
||||
let arg_max = output.argmax(1).into_scalar() as u8;
|
||||
|
||||
// Check if the index matches the label
|
||||
assert!(arg_max == item.label);
|
||||
|
|
|
@ -10,7 +10,7 @@ version = "0.12.0"
|
|||
burn = { path = "../../burn", features = [
|
||||
"ndarray",
|
||||
"dataset",
|
||||
"sqlite-bundled",
|
||||
"vision",
|
||||
] }
|
||||
|
||||
model = { path = "./model" }
|
||||
|
|
|
@ -5,7 +5,7 @@ use burn::backend::ndarray::NdArray;
|
|||
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
|
||||
use burn::tensor::Tensor;
|
||||
|
||||
use burn::data::dataset::source::huggingface::MNISTDataset;
|
||||
use burn::data::dataset::vision::MNISTDataset;
|
||||
use burn::data::dataset::Dataset;
|
||||
|
||||
use model::Model;
|
||||
|
@ -57,7 +57,7 @@ fn main() {
|
|||
let output = model.forward(input);
|
||||
|
||||
// Get the index of the maximum value
|
||||
let arg_max = output.argmax(1).into_scalar() as usize;
|
||||
let arg_max = output.argmax(1).into_scalar() as u8;
|
||||
|
||||
// Check if the index matches the label
|
||||
assert!(arg_max == item.label);
|
||||
|
|
|
@ -7,6 +7,11 @@ from HuggingFace hub. The dataset is also available as part of toy regression da
|
|||
- Create a data pipeline from a raw dataset to a batched fast DataLoader with min-max feature scaling.
|
||||
- Define a Simple NN model for regression using Burn Modules.
|
||||
|
||||
> **Note**
|
||||
> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)
|
||||
> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)
|
||||
> installed on your computer.
|
||||
|
||||
The example can be run like so:
|
||||
|
||||
```bash
|
||||
|
|
|
@ -3,6 +3,11 @@
|
|||
This project provides an example implementation for training and inferencing text classification
|
||||
models on AG News and DbPedia datasets using the Rust-based Burn Deep Learning Library.
|
||||
|
||||
> **Note**
|
||||
> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)
|
||||
> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)
|
||||
> installed on your computer.
|
||||
|
||||
## Dataset Details
|
||||
|
||||
- AG News: The AG News dataset is a collection of news articles from more than 2000 news sources.
|
||||
|
|
|
@ -1,5 +1,10 @@
|
|||
# Text Generation
|
||||
|
||||
> **Note**
|
||||
> This example makes use of the HuggingFace [`datasets`](https://huggingface.co/docs/datasets/index)
|
||||
> library to download the datasets. Make sure you have [Python](https://www.python.org/downloads/)
|
||||
> installed on your computer.
|
||||
|
||||
The example can be run like so:
|
||||
|
||||
## CUDA users
|
||||
|
|
Loading…
Reference in New Issue