mirror of https://github.com/tracel-ai/burn.git
Add multi-label classification dataset and metric (#1572)
* Add multilabel classification dataset - Add MultiLabel annotation support - Refactor de/serialize annotation with AnnotationRaw - Add ImageFolderDataset::with_items methods * Fix custom-image-classification example deps * Add image_folder_dataset_multilabel test * Do not change class names order when provided * Add hamming score and multi-label classification output * Add new_classification_with_items test * Fix clippy suggestions * Implement default trait for hamming score * Remove de/serialization and use AnnotationRaw as type * Fix clippy * Fix metric backend phantom data
This commit is contained in:
parent
f5159b6d22
commit
f3e0aa6689
|
@ -57,11 +57,13 @@ impl TryFrom<PixelDepth> for f32 {
|
|||
}
|
||||
}
|
||||
|
||||
/// Image target for different tasks.
|
||||
/// Annotation type for different tasks.
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Annotation {
|
||||
/// Image-level label.
|
||||
Label(usize),
|
||||
/// Multiple image-level labels.
|
||||
MultiLabel(Vec<usize>),
|
||||
/// Object bounding boxes.
|
||||
BoundingBoxes(Vec<BoundingBox>),
|
||||
/// Segmentation mask.
|
||||
|
@ -97,14 +99,30 @@ pub struct ImageDatasetItem {
|
|||
pub annotation: Annotation,
|
||||
}
|
||||
|
||||
/// Raw annotation types.
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
enum AnnotationRaw {
|
||||
Label(String),
|
||||
MultiLabel(Vec<String>),
|
||||
// TODO: bounding boxes and segmentation mask
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize, Debug, Clone)]
|
||||
struct ImageDatasetItemRaw {
|
||||
/// Image path.
|
||||
pub image_path: PathBuf,
|
||||
image_path: PathBuf,
|
||||
|
||||
/// Image annotation.
|
||||
/// The annotation bytes can represent a string (category name) or path to annotation file.
|
||||
pub annotation: Vec<u8>,
|
||||
annotation: AnnotationRaw,
|
||||
}
|
||||
|
||||
impl ImageDatasetItemRaw {
|
||||
fn new<P: AsRef<Path>>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw {
|
||||
ImageDatasetItemRaw {
|
||||
image_path: image_path.as_ref().to_path_buf(),
|
||||
annotation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PathToImageDatasetItem {
|
||||
|
@ -112,15 +130,25 @@ struct PathToImageDatasetItem {
|
|||
}
|
||||
|
||||
/// Parse the image annotation to the corresponding type.
|
||||
fn parse_image_annotation(annotation: &[u8], classes: &HashMap<String, usize>) -> Annotation {
|
||||
fn parse_image_annotation(
|
||||
annotation: &AnnotationRaw,
|
||||
classes: &HashMap<String, usize>,
|
||||
) -> Annotation {
|
||||
// TODO: add support for other annotations
|
||||
// - [ ] Object bounding boxes
|
||||
// - [ ] Segmentation mask
|
||||
// For now, only image classification labels are supported.
|
||||
|
||||
// Map class string to label id
|
||||
let name = std::str::from_utf8(annotation).unwrap();
|
||||
Annotation::Label(*classes.get(name).unwrap())
|
||||
match annotation {
|
||||
AnnotationRaw::Label(name) => Annotation::Label(*classes.get(name).unwrap()),
|
||||
AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel(
|
||||
names
|
||||
.iter()
|
||||
.map(|name| *classes.get(name).unwrap())
|
||||
.collect(),
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
impl Mapper<ImageDatasetItemRaw, ImageDatasetItem> for PathToImageDatasetItem {
|
||||
|
@ -212,7 +240,7 @@ pub enum ImageLoaderError {
|
|||
type ImageDatasetMapper =
|
||||
MapperDataset<InMemDataset<ImageDatasetItemRaw>, PathToImageDatasetItem, ImageDatasetItemRaw>;
|
||||
|
||||
/// A generic dataset to load classification images from disk.
|
||||
/// A generic dataset to load images from disk.
|
||||
pub struct ImageFolderDataset {
|
||||
dataset: ImageDatasetMapper,
|
||||
}
|
||||
|
@ -259,18 +287,6 @@ impl ImageFolderDataset {
|
|||
P: AsRef<Path>,
|
||||
S: AsRef<str>,
|
||||
{
|
||||
/// Check if extension is supported.
|
||||
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
|
||||
let extension = extension.as_ref();
|
||||
if !SUPPORTED_FILES.contains(&extension) {
|
||||
Err(ImageLoaderError::InvalidFileExtensionError(
|
||||
extension.to_string(),
|
||||
))
|
||||
} else {
|
||||
Ok(extension.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Glob all images with extensions
|
||||
let walker = globwalk::GlobWalkerBuilder::from_patterns(
|
||||
root.as_ref(),
|
||||
|
@ -278,7 +294,7 @@ impl ImageFolderDataset {
|
|||
"*.{{{}}}", // "*.{ext1,ext2,ext3}
|
||||
extensions
|
||||
.iter()
|
||||
.map(check_extension)
|
||||
.map(Self::check_extension)
|
||||
.collect::<Result<Vec<_>, _>>()?
|
||||
.join(",")
|
||||
)],
|
||||
|
@ -312,21 +328,102 @@ impl ImageFolderDataset {
|
|||
|
||||
classes.insert(label.clone());
|
||||
|
||||
items.push(ImageDatasetItemRaw {
|
||||
image_path: image_path.to_path_buf(),
|
||||
annotation: label.into_bytes(),
|
||||
})
|
||||
items.push(ImageDatasetItemRaw::new(
|
||||
image_path,
|
||||
AnnotationRaw::Label(label),
|
||||
))
|
||||
}
|
||||
|
||||
// Sort class names
|
||||
let mut classes = classes.into_iter().collect::<Vec<_>>();
|
||||
classes.sort();
|
||||
|
||||
Self::with_items(items, &classes)
|
||||
}
|
||||
|
||||
/// Create an image classification dataset with the specified items.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `items` - List of dataset items, each item represented by a tuple `(image path, label)`.
|
||||
/// * `classes` - Dataset class names.
|
||||
///
|
||||
/// # Returns
|
||||
/// A new dataset instance.
|
||||
pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
|
||||
items: Vec<(P, String)>,
|
||||
classes: &[S],
|
||||
) -> Result<Self, ImageLoaderError> {
|
||||
// Parse items and check valid image extension types
|
||||
let items = items
|
||||
.into_iter()
|
||||
.map(|(path, label)| {
|
||||
// Map image path and label
|
||||
let path = path.as_ref();
|
||||
let label = AnnotationRaw::Label(label);
|
||||
|
||||
Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;
|
||||
|
||||
Ok(ImageDatasetItemRaw::new(path, label))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Self::with_items(items, classes)
|
||||
}
|
||||
|
||||
/// Create a multi-label image classification dataset with the specified items.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`.
|
||||
/// * `classes` - Dataset class names.
|
||||
///
|
||||
/// # Returns
|
||||
/// A new dataset instance.
|
||||
pub fn new_multilabel_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
|
||||
items: Vec<(P, Vec<String>)>,
|
||||
classes: &[S],
|
||||
) -> Result<Self, ImageLoaderError> {
|
||||
// Parse items and check valid image extension types
|
||||
let items = items
|
||||
.into_iter()
|
||||
.map(|(path, labels)| {
|
||||
// Map image path and multi-label
|
||||
let path = path.as_ref();
|
||||
let labels = AnnotationRaw::MultiLabel(labels);
|
||||
|
||||
Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;
|
||||
|
||||
Ok(ImageDatasetItemRaw::new(path, labels))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Self::with_items(items, classes)
|
||||
}
|
||||
|
||||
/// Create an image dataset with the specified items.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `items` - Raw dataset items.
|
||||
/// * `classes` - Dataset class names.
|
||||
///
|
||||
/// # Returns
|
||||
/// A new dataset instance.
|
||||
fn with_items<S: AsRef<str>>(
|
||||
items: Vec<ImageDatasetItemRaw>,
|
||||
classes: &[S],
|
||||
) -> Result<Self, ImageLoaderError> {
|
||||
// NOTE: right now we don't need to validate the supported image files since
|
||||
// the method is private. We assume it's already validated.
|
||||
let dataset = InMemDataset::new(items);
|
||||
|
||||
// Class names to index map
|
||||
let mut classes = classes.into_iter().collect::<Vec<_>>();
|
||||
classes.sort();
|
||||
let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();
|
||||
let classes_map: HashMap<_, _> = classes
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(idx, cls)| (cls, idx))
|
||||
.map(|(idx, cls)| (cls.to_string(), idx))
|
||||
.collect();
|
||||
|
||||
let mapper = PathToImageDatasetItem {
|
||||
|
@ -336,6 +433,18 @@ impl ImageFolderDataset {
|
|||
|
||||
Ok(Self { dataset })
|
||||
}
|
||||
|
||||
/// Check if extension is supported.
|
||||
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
|
||||
let extension = extension.as_ref();
|
||||
if !SUPPORTED_FILES.contains(&extension) {
|
||||
Err(ImageLoaderError::InvalidFileExtensionError(
|
||||
extension.to_string(),
|
||||
))
|
||||
} else {
|
||||
Ok(extension.to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
@ -370,6 +479,69 @@ mod tests {
|
|||
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn image_folder_dataset_with_items() {
|
||||
let root = Path::new(DATASET_ROOT);
|
||||
let items = vec![
|
||||
(root.join("orange").join("dot.jpg"), "orange".to_string()),
|
||||
(root.join("red").join("dot.jpg"), "red".to_string()),
|
||||
(root.join("red").join("dot.png"), "red".to_string()),
|
||||
];
|
||||
let dataset =
|
||||
ImageFolderDataset::new_classification_with_items(items, &["orange", "red"]).unwrap();
|
||||
|
||||
// Dataset has 3 elements
|
||||
assert_eq!(dataset.len(), 3);
|
||||
assert_eq!(dataset.get(3), None);
|
||||
|
||||
// Dataset elements should be: orange (0), red (1), red (1)
|
||||
assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0));
|
||||
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
|
||||
assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn image_folder_dataset_multilabel() {
|
||||
let root = Path::new(DATASET_ROOT);
|
||||
let items = vec![
|
||||
(
|
||||
root.join("orange").join("dot.jpg"),
|
||||
vec!["dot".to_string(), "orange".to_string()],
|
||||
),
|
||||
(
|
||||
root.join("red").join("dot.jpg"),
|
||||
vec!["dot".to_string(), "red".to_string()],
|
||||
),
|
||||
(
|
||||
root.join("red").join("dot.png"),
|
||||
vec!["dot".to_string(), "red".to_string()],
|
||||
),
|
||||
];
|
||||
let dataset = ImageFolderDataset::new_multilabel_classification_with_items(
|
||||
items,
|
||||
&["dot", "orange", "red"],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Dataset has 3 elements
|
||||
assert_eq!(dataset.len(), 3);
|
||||
assert_eq!(dataset.get(3), None);
|
||||
|
||||
// Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2)
|
||||
assert_eq!(
|
||||
dataset.get(0).unwrap().annotation,
|
||||
Annotation::MultiLabel(vec![0, 1])
|
||||
);
|
||||
assert_eq!(
|
||||
dataset.get(1).unwrap().annotation,
|
||||
Annotation::MultiLabel(vec![0, 2])
|
||||
);
|
||||
assert_eq!(
|
||||
dataset.get(2).unwrap().annotation,
|
||||
Annotation::MultiLabel(vec![0, 2])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
pub fn image_folder_dataset_invalid_extension() {
|
||||
|
@ -417,11 +589,26 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
pub fn parse_image_annotation_string() {
|
||||
pub fn parse_image_annotation_label_string() {
|
||||
let classes = HashMap::from([("0".to_string(), 0_usize), ("1".to_string(), 1_usize)]);
|
||||
let anno = AnnotationRaw::Label("0".to_string());
|
||||
assert_eq!(
|
||||
parse_image_annotation(&"0".to_string().into_bytes(), &classes),
|
||||
parse_image_annotation(&anno, &classes),
|
||||
Annotation::Label(0)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
pub fn parse_image_annotation_multilabel_string() {
|
||||
let classes = HashMap::from([
|
||||
("0".to_string(), 0_usize),
|
||||
("1".to_string(), 1_usize),
|
||||
("2".to_string(), 2_usize),
|
||||
]);
|
||||
let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]);
|
||||
assert_eq!(
|
||||
parse_image_annotation(&anno, &classes),
|
||||
Annotation::MultiLabel(vec![0, 2])
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use crate::metric::{AccuracyInput, Adaptor, LossInput};
|
||||
use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput};
|
||||
use burn_core::tensor::backend::Backend;
|
||||
use burn_core::tensor::{Int, Tensor};
|
||||
|
||||
|
@ -26,3 +26,28 @@ impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
|
|||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
||||
|
||||
/// Multi-label classification output adapted for multiple metrics.
|
||||
#[derive(new)]
|
||||
pub struct MultiLabelClassificationOutput<B: Backend> {
|
||||
/// The loss.
|
||||
pub loss: Tensor<B, 1>,
|
||||
|
||||
/// The output.
|
||||
pub output: Tensor<B, 2>,
|
||||
|
||||
/// The targets.
|
||||
pub targets: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
|
||||
fn adapt(&self) -> HammingScoreInput<B> {
|
||||
HammingScoreInput::new(self.output.clone(), self.targets.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use super::state::{FormatOptions, NumericMetricState};
|
||||
use super::{MetricEntry, MetricMetadata};
|
||||
use crate::metric::{Metric, Numeric};
|
||||
|
@ -9,7 +11,7 @@ use burn_core::tensor::{ElementConversion, Int, Tensor};
|
|||
pub struct AccuracyMetric<B: Backend> {
|
||||
state: NumericMetricState,
|
||||
pad_token: Option<usize>,
|
||||
_b: B,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
/// The [accuracy metric](AccuracyMetric) input type.
|
||||
|
|
|
@ -0,0 +1,155 @@
|
|||
use core::marker::PhantomData;
|
||||
|
||||
use super::state::{FormatOptions, NumericMetricState};
|
||||
use super::{MetricEntry, MetricMetadata};
|
||||
use crate::metric::{Metric, Numeric};
|
||||
use burn_core::tensor::{activation::sigmoid, backend::Backend, ElementConversion, Int, Tensor};
|
||||
|
||||
/// The hamming score, sometimes referred to as multi-label or label-based accuracy.
|
||||
pub struct HammingScore<B: Backend> {
|
||||
state: NumericMetricState,
|
||||
threshold: f32,
|
||||
sigmoid: bool,
|
||||
_b: PhantomData<B>,
|
||||
}
|
||||
|
||||
/// The [hamming score](HammingScore) input type.
|
||||
#[derive(new)]
|
||||
pub struct HammingScoreInput<B: Backend> {
|
||||
outputs: Tensor<B, 2>,
|
||||
targets: Tensor<B, 2, Int>,
|
||||
}
|
||||
|
||||
impl<B: Backend> HammingScore<B> {
|
||||
/// Creates the metric.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Sets the threshold.
|
||||
pub fn with_threshold(mut self, threshold: f32) -> Self {
|
||||
self.threshold = threshold;
|
||||
self
|
||||
}
|
||||
|
||||
/// Sets the sigmoid activation function usage.
|
||||
pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
|
||||
self.sigmoid = sigmoid;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Default for HammingScore<B> {
|
||||
/// Creates a new metric instance with default values.
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
state: NumericMetricState::default(),
|
||||
threshold: 0.5,
|
||||
sigmoid: false,
|
||||
_b: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Metric for HammingScore<B> {
|
||||
const NAME: &'static str = "Hamming Score";
|
||||
|
||||
type Input = HammingScoreInput<B>;
|
||||
|
||||
fn update(&mut self, input: &HammingScoreInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
|
||||
let [batch_size, _n_classes] = input.outputs.dims();
|
||||
|
||||
let targets = input.targets.clone();
|
||||
|
||||
let mut outputs = input.outputs.clone();
|
||||
|
||||
if self.sigmoid {
|
||||
outputs = sigmoid(outputs);
|
||||
}
|
||||
|
||||
let score = outputs
|
||||
.greater_elem(self.threshold)
|
||||
.equal(targets.bool())
|
||||
.float()
|
||||
.mean()
|
||||
.into_scalar()
|
||||
.elem::<f64>();
|
||||
|
||||
self.state.update(
|
||||
100.0 * score,
|
||||
batch_size,
|
||||
FormatOptions::new(Self::NAME).unit("%").precision(2),
|
||||
)
|
||||
}
|
||||
|
||||
fn clear(&mut self) {
|
||||
self.state.reset()
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Numeric for HammingScore<B> {
|
||||
fn value(&self) -> f64 {
|
||||
self.state.value()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::TestBackend;
|
||||
|
||||
#[test]
|
||||
fn test_hamming_score() {
|
||||
let device = Default::default();
|
||||
let mut metric = HammingScore::<TestBackend>::new();
|
||||
|
||||
let x = Tensor::from_data(
|
||||
[
|
||||
[0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1]
|
||||
[0.43, 0.31, 0.21, 0.63, 0.53], // [0, 0, 0, 1, 1]
|
||||
[0.44, 0.25, 0.71, 0.39, 0.73], // [0, 0, 1, 0, 1]
|
||||
[0.49, 0.37, 0.68, 0.39, 0.31], // [0, 0, 1, 0, 0]
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let y = Tensor::from_data(
|
||||
[
|
||||
[0, 1, 0, 1, 1],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 0, 1],
|
||||
[0, 0, 1, 0, 0],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
|
||||
let _entry = metric.update(
|
||||
&HammingScoreInput::new(x.clone(), y.clone()),
|
||||
&MetricMetadata::fake(),
|
||||
);
|
||||
assert_eq!(100.0, metric.value());
|
||||
|
||||
// Invert all targets: y = (1 - y)
|
||||
let y = y.neg().add_scalar(1);
|
||||
let _entry = metric.update(
|
||||
&HammingScoreInput::new(x.clone(), y), // invert targets (1 - y)
|
||||
&MetricMetadata::fake(),
|
||||
);
|
||||
assert_eq!(0.0, metric.value());
|
||||
|
||||
// Invert 5 target values -> 1 - (5/20) = 0.75
|
||||
let y = Tensor::from_data(
|
||||
[
|
||||
[0, 1, 1, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 0, 0, 0, 1],
|
||||
[0, 1, 1, 0, 0],
|
||||
],
|
||||
&device,
|
||||
);
|
||||
let _entry = metric.update(
|
||||
&HammingScoreInput::new(x, y), // invert targets (1 - y)
|
||||
&MetricMetadata::fake(),
|
||||
);
|
||||
assert_eq!(75.0, metric.value());
|
||||
}
|
||||
}
|
|
@ -9,6 +9,7 @@ mod cpu_temp;
|
|||
mod cpu_use;
|
||||
#[cfg(feature = "metrics")]
|
||||
mod cuda;
|
||||
mod hamming;
|
||||
mod learning_rate;
|
||||
mod loss;
|
||||
#[cfg(feature = "metrics")]
|
||||
|
@ -22,6 +23,7 @@ pub use cpu_temp::*;
|
|||
pub use cpu_use::*;
|
||||
#[cfg(feature = "metrics")]
|
||||
pub use cuda::*;
|
||||
pub use hamming::*;
|
||||
pub use learning_rate::*;
|
||||
pub use loss::*;
|
||||
#[cfg(feature = "metrics")]
|
||||
|
|
|
@ -17,7 +17,4 @@ burn = { path = "../../crates/burn", features = ["train", "vision", "network"] }
|
|||
|
||||
# File download
|
||||
flate2 = { workspace = true }
|
||||
indicatif = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
tar = "0.4.40"
|
||||
tokio = { workspace = true }
|
||||
|
|
|
@ -51,6 +51,15 @@ The CNN model and training recipe used in this example are fairly simple since t
|
|||
demonstrate how to load a custom image classification dataset from disk. Nonetheless, it still
|
||||
achieves 70-80% accuracy on the test set after just 30 epochs.
|
||||
|
||||
Run it with the Torch GPU backend:
|
||||
|
||||
```sh
|
||||
cargo run --example custom-image-dataset
|
||||
export TORCH_CUDA_VERSION=cu121
|
||||
cargo run --example custom-image-dataset --release --features tch-gpu
|
||||
```
|
||||
|
||||
Run it with our WGPU backend:
|
||||
|
||||
```sh
|
||||
cargo run --example custom-image-dataset --release --features wgpu
|
||||
```
|
||||
|
|
Loading…
Reference in New Issue