diff --git a/crates/burn-dataset/src/vision/image_folder.rs b/crates/burn-dataset/src/vision/image_folder.rs index e98f81b95..f850714f3 100644 --- a/crates/burn-dataset/src/vision/image_folder.rs +++ b/crates/burn-dataset/src/vision/image_folder.rs @@ -57,11 +57,13 @@ impl TryFrom for f32 { } } -/// Image target for different tasks. +/// Annotation type for different tasks. #[derive(Debug, Clone, PartialEq)] pub enum Annotation { /// Image-level label. Label(usize), + /// Multiple image-level labels. + MultiLabel(Vec), /// Object bounding boxes. BoundingBoxes(Vec), /// Segmentation mask. @@ -97,14 +99,30 @@ pub struct ImageDatasetItem { pub annotation: Annotation, } +/// Raw annotation types. +#[derive(Deserialize, Serialize, Debug, Clone)] +enum AnnotationRaw { + Label(String), + MultiLabel(Vec), + // TODO: bounding boxes and segmentation mask +} + #[derive(Deserialize, Serialize, Debug, Clone)] struct ImageDatasetItemRaw { /// Image path. - pub image_path: PathBuf, + image_path: PathBuf, /// Image annotation. - /// The annotation bytes can represent a string (category name) or path to annotation file. - pub annotation: Vec, + annotation: AnnotationRaw, +} + +impl ImageDatasetItemRaw { + fn new>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw { + ImageDatasetItemRaw { + image_path: image_path.as_ref().to_path_buf(), + annotation, + } + } } struct PathToImageDatasetItem { @@ -112,15 +130,25 @@ struct PathToImageDatasetItem { } /// Parse the image annotation to the corresponding type. -fn parse_image_annotation(annotation: &[u8], classes: &HashMap) -> Annotation { +fn parse_image_annotation( + annotation: &AnnotationRaw, + classes: &HashMap, +) -> Annotation { // TODO: add support for other annotations // - [ ] Object bounding boxes // - [ ] Segmentation mask // For now, only image classification labels are supported. // Map class string to label id - let name = std::str::from_utf8(annotation).unwrap(); - Annotation::Label(*classes.get(name).unwrap()) + match annotation { + AnnotationRaw::Label(name) => Annotation::Label(*classes.get(name).unwrap()), + AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel( + names + .iter() + .map(|name| *classes.get(name).unwrap()) + .collect(), + ), + } } impl Mapper for PathToImageDatasetItem { @@ -212,7 +240,7 @@ pub enum ImageLoaderError { type ImageDatasetMapper = MapperDataset, PathToImageDatasetItem, ImageDatasetItemRaw>; -/// A generic dataset to load classification images from disk. +/// A generic dataset to load images from disk. pub struct ImageFolderDataset { dataset: ImageDatasetMapper, } @@ -259,18 +287,6 @@ impl ImageFolderDataset { P: AsRef, S: AsRef, { - /// Check if extension is supported. - fn check_extension>(extension: &S) -> Result { - let extension = extension.as_ref(); - if !SUPPORTED_FILES.contains(&extension) { - Err(ImageLoaderError::InvalidFileExtensionError( - extension.to_string(), - )) - } else { - Ok(extension.to_string()) - } - } - // Glob all images with extensions let walker = globwalk::GlobWalkerBuilder::from_patterns( root.as_ref(), @@ -278,7 +294,7 @@ impl ImageFolderDataset { "*.{{{}}}", // "*.{ext1,ext2,ext3} extensions .iter() - .map(check_extension) + .map(Self::check_extension) .collect::, _>>()? .join(",") )], @@ -312,21 +328,102 @@ impl ImageFolderDataset { classes.insert(label.clone()); - items.push(ImageDatasetItemRaw { - image_path: image_path.to_path_buf(), - annotation: label.into_bytes(), - }) + items.push(ImageDatasetItemRaw::new( + image_path, + AnnotationRaw::Label(label), + )) } + // Sort class names + let mut classes = classes.into_iter().collect::>(); + classes.sort(); + + Self::with_items(items, &classes) + } + + /// Create an image classification dataset with the specified items. + /// + /// # Arguments + /// + /// * `items` - List of dataset items, each item represented by a tuple `(image path, label)`. + /// * `classes` - Dataset class names. + /// + /// # Returns + /// A new dataset instance. + pub fn new_classification_with_items, S: AsRef>( + items: Vec<(P, String)>, + classes: &[S], + ) -> Result { + // Parse items and check valid image extension types + let items = items + .into_iter() + .map(|(path, label)| { + // Map image path and label + let path = path.as_ref(); + let label = AnnotationRaw::Label(label); + + Self::check_extension(&path.extension().unwrap().to_str().unwrap())?; + + Ok(ImageDatasetItemRaw::new(path, label)) + }) + .collect::, _>>()?; + + Self::with_items(items, classes) + } + + /// Create a multi-label image classification dataset with the specified items. + /// + /// # Arguments + /// + /// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`. + /// * `classes` - Dataset class names. + /// + /// # Returns + /// A new dataset instance. + pub fn new_multilabel_classification_with_items, S: AsRef>( + items: Vec<(P, Vec)>, + classes: &[S], + ) -> Result { + // Parse items and check valid image extension types + let items = items + .into_iter() + .map(|(path, labels)| { + // Map image path and multi-label + let path = path.as_ref(); + let labels = AnnotationRaw::MultiLabel(labels); + + Self::check_extension(&path.extension().unwrap().to_str().unwrap())?; + + Ok(ImageDatasetItemRaw::new(path, labels)) + }) + .collect::, _>>()?; + + Self::with_items(items, classes) + } + + /// Create an image dataset with the specified items. + /// + /// # Arguments + /// + /// * `items` - Raw dataset items. + /// * `classes` - Dataset class names. + /// + /// # Returns + /// A new dataset instance. + fn with_items>( + items: Vec, + classes: &[S], + ) -> Result { + // NOTE: right now we don't need to validate the supported image files since + // the method is private. We assume it's already validated. let dataset = InMemDataset::new(items); // Class names to index map - let mut classes = classes.into_iter().collect::>(); - classes.sort(); + let classes = classes.iter().map(|c| c.as_ref()).collect::>(); let classes_map: HashMap<_, _> = classes .into_iter() .enumerate() - .map(|(idx, cls)| (cls, idx)) + .map(|(idx, cls)| (cls.to_string(), idx)) .collect(); let mapper = PathToImageDatasetItem { @@ -336,6 +433,18 @@ impl ImageFolderDataset { Ok(Self { dataset }) } + + /// Check if extension is supported. + fn check_extension>(extension: &S) -> Result { + let extension = extension.as_ref(); + if !SUPPORTED_FILES.contains(&extension) { + Err(ImageLoaderError::InvalidFileExtensionError( + extension.to_string(), + )) + } else { + Ok(extension.to_string()) + } + } } #[cfg(test)] @@ -370,6 +479,69 @@ mod tests { assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1)); } + #[test] + pub fn image_folder_dataset_with_items() { + let root = Path::new(DATASET_ROOT); + let items = vec![ + (root.join("orange").join("dot.jpg"), "orange".to_string()), + (root.join("red").join("dot.jpg"), "red".to_string()), + (root.join("red").join("dot.png"), "red".to_string()), + ]; + let dataset = + ImageFolderDataset::new_classification_with_items(items, &["orange", "red"]).unwrap(); + + // Dataset has 3 elements + assert_eq!(dataset.len(), 3); + assert_eq!(dataset.get(3), None); + + // Dataset elements should be: orange (0), red (1), red (1) + assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0)); + assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1)); + assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1)); + } + + #[test] + pub fn image_folder_dataset_multilabel() { + let root = Path::new(DATASET_ROOT); + let items = vec![ + ( + root.join("orange").join("dot.jpg"), + vec!["dot".to_string(), "orange".to_string()], + ), + ( + root.join("red").join("dot.jpg"), + vec!["dot".to_string(), "red".to_string()], + ), + ( + root.join("red").join("dot.png"), + vec!["dot".to_string(), "red".to_string()], + ), + ]; + let dataset = ImageFolderDataset::new_multilabel_classification_with_items( + items, + &["dot", "orange", "red"], + ) + .unwrap(); + + // Dataset has 3 elements + assert_eq!(dataset.len(), 3); + assert_eq!(dataset.get(3), None); + + // Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2) + assert_eq!( + dataset.get(0).unwrap().annotation, + Annotation::MultiLabel(vec![0, 1]) + ); + assert_eq!( + dataset.get(1).unwrap().annotation, + Annotation::MultiLabel(vec![0, 2]) + ); + assert_eq!( + dataset.get(2).unwrap().annotation, + Annotation::MultiLabel(vec![0, 2]) + ); + } + #[test] #[should_panic] pub fn image_folder_dataset_invalid_extension() { @@ -417,11 +589,26 @@ mod tests { } #[test] - pub fn parse_image_annotation_string() { + pub fn parse_image_annotation_label_string() { let classes = HashMap::from([("0".to_string(), 0_usize), ("1".to_string(), 1_usize)]); + let anno = AnnotationRaw::Label("0".to_string()); assert_eq!( - parse_image_annotation(&"0".to_string().into_bytes(), &classes), + parse_image_annotation(&anno, &classes), Annotation::Label(0) ); } + + #[test] + pub fn parse_image_annotation_multilabel_string() { + let classes = HashMap::from([ + ("0".to_string(), 0_usize), + ("1".to_string(), 1_usize), + ("2".to_string(), 2_usize), + ]); + let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]); + assert_eq!( + parse_image_annotation(&anno, &classes), + Annotation::MultiLabel(vec![0, 2]) + ); + } } diff --git a/crates/burn-train/src/learner/classification.rs b/crates/burn-train/src/learner/classification.rs index f6b415fa2..ee86a0575 100644 --- a/crates/burn-train/src/learner/classification.rs +++ b/crates/burn-train/src/learner/classification.rs @@ -1,4 +1,4 @@ -use crate::metric::{AccuracyInput, Adaptor, LossInput}; +use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput}; use burn_core::tensor::backend::Backend; use burn_core::tensor::{Int, Tensor}; @@ -26,3 +26,28 @@ impl Adaptor> for ClassificationOutput { LossInput::new(self.loss.clone()) } } + +/// Multi-label classification output adapted for multiple metrics. +#[derive(new)] +pub struct MultiLabelClassificationOutput { + /// The loss. + pub loss: Tensor, + + /// The output. + pub output: Tensor, + + /// The targets. + pub targets: Tensor, +} + +impl Adaptor> for MultiLabelClassificationOutput { + fn adapt(&self) -> HammingScoreInput { + HammingScoreInput::new(self.output.clone(), self.targets.clone()) + } +} + +impl Adaptor> for MultiLabelClassificationOutput { + fn adapt(&self) -> LossInput { + LossInput::new(self.loss.clone()) + } +} diff --git a/crates/burn-train/src/metric/acc.rs b/crates/burn-train/src/metric/acc.rs index efdbf88cf..ad24fe077 100644 --- a/crates/burn-train/src/metric/acc.rs +++ b/crates/burn-train/src/metric/acc.rs @@ -1,3 +1,5 @@ +use core::marker::PhantomData; + use super::state::{FormatOptions, NumericMetricState}; use super::{MetricEntry, MetricMetadata}; use crate::metric::{Metric, Numeric}; @@ -9,7 +11,7 @@ use burn_core::tensor::{ElementConversion, Int, Tensor}; pub struct AccuracyMetric { state: NumericMetricState, pad_token: Option, - _b: B, + _b: PhantomData, } /// The [accuracy metric](AccuracyMetric) input type. diff --git a/crates/burn-train/src/metric/hamming.rs b/crates/burn-train/src/metric/hamming.rs new file mode 100644 index 000000000..6291833de --- /dev/null +++ b/crates/burn-train/src/metric/hamming.rs @@ -0,0 +1,155 @@ +use core::marker::PhantomData; + +use super::state::{FormatOptions, NumericMetricState}; +use super::{MetricEntry, MetricMetadata}; +use crate::metric::{Metric, Numeric}; +use burn_core::tensor::{activation::sigmoid, backend::Backend, ElementConversion, Int, Tensor}; + +/// The hamming score, sometimes referred to as multi-label or label-based accuracy. +pub struct HammingScore { + state: NumericMetricState, + threshold: f32, + sigmoid: bool, + _b: PhantomData, +} + +/// The [hamming score](HammingScore) input type. +#[derive(new)] +pub struct HammingScoreInput { + outputs: Tensor, + targets: Tensor, +} + +impl HammingScore { + /// Creates the metric. + pub fn new() -> Self { + Self::default() + } + + /// Sets the threshold. + pub fn with_threshold(mut self, threshold: f32) -> Self { + self.threshold = threshold; + self + } + + /// Sets the sigmoid activation function usage. + pub fn with_sigmoid(mut self, sigmoid: bool) -> Self { + self.sigmoid = sigmoid; + self + } +} + +impl Default for HammingScore { + /// Creates a new metric instance with default values. + fn default() -> Self { + Self { + state: NumericMetricState::default(), + threshold: 0.5, + sigmoid: false, + _b: PhantomData, + } + } +} + +impl Metric for HammingScore { + const NAME: &'static str = "Hamming Score"; + + type Input = HammingScoreInput; + + fn update(&mut self, input: &HammingScoreInput, _metadata: &MetricMetadata) -> MetricEntry { + let [batch_size, _n_classes] = input.outputs.dims(); + + let targets = input.targets.clone(); + + let mut outputs = input.outputs.clone(); + + if self.sigmoid { + outputs = sigmoid(outputs); + } + + let score = outputs + .greater_elem(self.threshold) + .equal(targets.bool()) + .float() + .mean() + .into_scalar() + .elem::(); + + self.state.update( + 100.0 * score, + batch_size, + FormatOptions::new(Self::NAME).unit("%").precision(2), + ) + } + + fn clear(&mut self) { + self.state.reset() + } +} + +impl Numeric for HammingScore { + fn value(&self) -> f64 { + self.state.value() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::TestBackend; + + #[test] + fn test_hamming_score() { + let device = Default::default(); + let mut metric = HammingScore::::new(); + + let x = Tensor::from_data( + [ + [0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1] + [0.43, 0.31, 0.21, 0.63, 0.53], // [0, 0, 0, 1, 1] + [0.44, 0.25, 0.71, 0.39, 0.73], // [0, 0, 1, 0, 1] + [0.49, 0.37, 0.68, 0.39, 0.31], // [0, 0, 1, 0, 0] + ], + &device, + ); + let y = Tensor::from_data( + [ + [0, 1, 0, 1, 1], + [0, 0, 0, 1, 1], + [0, 0, 1, 0, 1], + [0, 0, 1, 0, 0], + ], + &device, + ); + + let _entry = metric.update( + &HammingScoreInput::new(x.clone(), y.clone()), + &MetricMetadata::fake(), + ); + assert_eq!(100.0, metric.value()); + + // Invert all targets: y = (1 - y) + let y = y.neg().add_scalar(1); + let _entry = metric.update( + &HammingScoreInput::new(x.clone(), y), // invert targets (1 - y) + &MetricMetadata::fake(), + ); + assert_eq!(0.0, metric.value()); + + // Invert 5 target values -> 1 - (5/20) = 0.75 + let y = Tensor::from_data( + [ + [0, 1, 1, 0, 1], + [0, 0, 0, 0, 1], + [0, 0, 0, 0, 1], + [0, 1, 1, 0, 0], + ], + &device, + ); + let _entry = metric.update( + &HammingScoreInput::new(x, y), // invert targets (1 - y) + &MetricMetadata::fake(), + ); + assert_eq!(75.0, metric.value()); + } +} diff --git a/crates/burn-train/src/metric/mod.rs b/crates/burn-train/src/metric/mod.rs index 37ad5af73..7d443da06 100644 --- a/crates/burn-train/src/metric/mod.rs +++ b/crates/burn-train/src/metric/mod.rs @@ -9,6 +9,7 @@ mod cpu_temp; mod cpu_use; #[cfg(feature = "metrics")] mod cuda; +mod hamming; mod learning_rate; mod loss; #[cfg(feature = "metrics")] @@ -22,6 +23,7 @@ pub use cpu_temp::*; pub use cpu_use::*; #[cfg(feature = "metrics")] pub use cuda::*; +pub use hamming::*; pub use learning_rate::*; pub use loss::*; #[cfg(feature = "metrics")] diff --git a/examples/custom-image-dataset/Cargo.toml b/examples/custom-image-dataset/Cargo.toml index 13e27c165..eabd5e43d 100644 --- a/examples/custom-image-dataset/Cargo.toml +++ b/examples/custom-image-dataset/Cargo.toml @@ -17,7 +17,4 @@ burn = { path = "../../crates/burn", features = ["train", "vision", "network"] } # File download flate2 = { workspace = true } -indicatif = { workspace = true } -reqwest = { workspace = true } tar = "0.4.40" -tokio = { workspace = true } diff --git a/examples/custom-image-dataset/README.md b/examples/custom-image-dataset/README.md index d9eee42b2..ac49957e4 100644 --- a/examples/custom-image-dataset/README.md +++ b/examples/custom-image-dataset/README.md @@ -51,6 +51,15 @@ The CNN model and training recipe used in this example are fairly simple since t demonstrate how to load a custom image classification dataset from disk. Nonetheless, it still achieves 70-80% accuracy on the test set after just 30 epochs. +Run it with the Torch GPU backend: + ```sh -cargo run --example custom-image-dataset +export TORCH_CUDA_VERSION=cu121 +cargo run --example custom-image-dataset --release --features tch-gpu +``` + +Run it with our WGPU backend: + +```sh +cargo run --example custom-image-dataset --release --features wgpu ```