Add multi-label classification dataset and metric (#1572)

* Add multilabel classification dataset

- Add MultiLabel annotation support
- Refactor de/serialize annotation with AnnotationRaw
- Add ImageFolderDataset::with_items methods

* Fix custom-image-classification example deps

* Add image_folder_dataset_multilabel test

* Do not change class names order when provided

* Add hamming score and multi-label classification output

* Add new_classification_with_items test

* Fix clippy suggestions

* Implement default trait for hamming score

* Remove de/serialization and use AnnotationRaw as type

* Fix clippy

* Fix metric backend phantom data
This commit is contained in:
Guillaume Lagrange 2024-04-05 13:16:46 -04:00 committed by GitHub
parent f5159b6d22
commit f3e0aa6689
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 413 additions and 36 deletions

View File

@ -57,11 +57,13 @@ impl TryFrom<PixelDepth> for f32 {
}
}
/// Image target for different tasks.
/// Annotation type for different tasks.
#[derive(Debug, Clone, PartialEq)]
pub enum Annotation {
/// Image-level label.
Label(usize),
/// Multiple image-level labels.
MultiLabel(Vec<usize>),
/// Object bounding boxes.
BoundingBoxes(Vec<BoundingBox>),
/// Segmentation mask.
@ -97,14 +99,30 @@ pub struct ImageDatasetItem {
pub annotation: Annotation,
}
/// Raw annotation types.
#[derive(Deserialize, Serialize, Debug, Clone)]
enum AnnotationRaw {
Label(String),
MultiLabel(Vec<String>),
// TODO: bounding boxes and segmentation mask
}
#[derive(Deserialize, Serialize, Debug, Clone)]
struct ImageDatasetItemRaw {
/// Image path.
pub image_path: PathBuf,
image_path: PathBuf,
/// Image annotation.
/// The annotation bytes can represent a string (category name) or path to annotation file.
pub annotation: Vec<u8>,
annotation: AnnotationRaw,
}
impl ImageDatasetItemRaw {
fn new<P: AsRef<Path>>(image_path: P, annotation: AnnotationRaw) -> ImageDatasetItemRaw {
ImageDatasetItemRaw {
image_path: image_path.as_ref().to_path_buf(),
annotation,
}
}
}
struct PathToImageDatasetItem {
@ -112,15 +130,25 @@ struct PathToImageDatasetItem {
}
/// Parse the image annotation to the corresponding type.
fn parse_image_annotation(annotation: &[u8], classes: &HashMap<String, usize>) -> Annotation {
fn parse_image_annotation(
annotation: &AnnotationRaw,
classes: &HashMap<String, usize>,
) -> Annotation {
// TODO: add support for other annotations
// - [ ] Object bounding boxes
// - [ ] Segmentation mask
// For now, only image classification labels are supported.
// Map class string to label id
let name = std::str::from_utf8(annotation).unwrap();
Annotation::Label(*classes.get(name).unwrap())
match annotation {
AnnotationRaw::Label(name) => Annotation::Label(*classes.get(name).unwrap()),
AnnotationRaw::MultiLabel(names) => Annotation::MultiLabel(
names
.iter()
.map(|name| *classes.get(name).unwrap())
.collect(),
),
}
}
impl Mapper<ImageDatasetItemRaw, ImageDatasetItem> for PathToImageDatasetItem {
@ -212,7 +240,7 @@ pub enum ImageLoaderError {
type ImageDatasetMapper =
MapperDataset<InMemDataset<ImageDatasetItemRaw>, PathToImageDatasetItem, ImageDatasetItemRaw>;
/// A generic dataset to load classification images from disk.
/// A generic dataset to load images from disk.
pub struct ImageFolderDataset {
dataset: ImageDatasetMapper,
}
@ -259,18 +287,6 @@ impl ImageFolderDataset {
P: AsRef<Path>,
S: AsRef<str>,
{
/// Check if extension is supported.
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(ImageLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}
// Glob all images with extensions
let walker = globwalk::GlobWalkerBuilder::from_patterns(
root.as_ref(),
@ -278,7 +294,7 @@ impl ImageFolderDataset {
"*.{{{}}}", // "*.{ext1,ext2,ext3}
extensions
.iter()
.map(check_extension)
.map(Self::check_extension)
.collect::<Result<Vec<_>, _>>()?
.join(",")
)],
@ -312,21 +328,102 @@ impl ImageFolderDataset {
classes.insert(label.clone());
items.push(ImageDatasetItemRaw {
image_path: image_path.to_path_buf(),
annotation: label.into_bytes(),
})
items.push(ImageDatasetItemRaw::new(
image_path,
AnnotationRaw::Label(label),
))
}
// Sort class names
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();
Self::with_items(items, &classes)
}
/// Create an image classification dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, label)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, String)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(path, label)| {
// Map image path and label
let path = path.as_ref();
let label = AnnotationRaw::Label(label);
Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;
Ok(ImageDatasetItemRaw::new(path, label))
})
.collect::<Result<Vec<_>, _>>()?;
Self::with_items(items, classes)
}
/// Create a multi-label image classification dataset with the specified items.
///
/// # Arguments
///
/// * `items` - List of dataset items, each item represented by a tuple `(image path, labels)`.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
pub fn new_multilabel_classification_with_items<P: AsRef<Path>, S: AsRef<str>>(
items: Vec<(P, Vec<String>)>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// Parse items and check valid image extension types
let items = items
.into_iter()
.map(|(path, labels)| {
// Map image path and multi-label
let path = path.as_ref();
let labels = AnnotationRaw::MultiLabel(labels);
Self::check_extension(&path.extension().unwrap().to_str().unwrap())?;
Ok(ImageDatasetItemRaw::new(path, labels))
})
.collect::<Result<Vec<_>, _>>()?;
Self::with_items(items, classes)
}
/// Create an image dataset with the specified items.
///
/// # Arguments
///
/// * `items` - Raw dataset items.
/// * `classes` - Dataset class names.
///
/// # Returns
/// A new dataset instance.
fn with_items<S: AsRef<str>>(
items: Vec<ImageDatasetItemRaw>,
classes: &[S],
) -> Result<Self, ImageLoaderError> {
// NOTE: right now we don't need to validate the supported image files since
// the method is private. We assume it's already validated.
let dataset = InMemDataset::new(items);
// Class names to index map
let mut classes = classes.into_iter().collect::<Vec<_>>();
classes.sort();
let classes = classes.iter().map(|c| c.as_ref()).collect::<Vec<_>>();
let classes_map: HashMap<_, _> = classes
.into_iter()
.enumerate()
.map(|(idx, cls)| (cls, idx))
.map(|(idx, cls)| (cls.to_string(), idx))
.collect();
let mapper = PathToImageDatasetItem {
@ -336,6 +433,18 @@ impl ImageFolderDataset {
Ok(Self { dataset })
}
/// Check if extension is supported.
fn check_extension<S: AsRef<str>>(extension: &S) -> Result<String, ImageLoaderError> {
let extension = extension.as_ref();
if !SUPPORTED_FILES.contains(&extension) {
Err(ImageLoaderError::InvalidFileExtensionError(
extension.to_string(),
))
} else {
Ok(extension.to_string())
}
}
}
#[cfg(test)]
@ -370,6 +479,69 @@ mod tests {
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
}
#[test]
pub fn image_folder_dataset_with_items() {
let root = Path::new(DATASET_ROOT);
let items = vec![
(root.join("orange").join("dot.jpg"), "orange".to_string()),
(root.join("red").join("dot.jpg"), "red".to_string()),
(root.join("red").join("dot.png"), "red".to_string()),
];
let dataset =
ImageFolderDataset::new_classification_with_items(items, &["orange", "red"]).unwrap();
// Dataset has 3 elements
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);
// Dataset elements should be: orange (0), red (1), red (1)
assert_eq!(dataset.get(0).unwrap().annotation, Annotation::Label(0));
assert_eq!(dataset.get(1).unwrap().annotation, Annotation::Label(1));
assert_eq!(dataset.get(2).unwrap().annotation, Annotation::Label(1));
}
#[test]
pub fn image_folder_dataset_multilabel() {
let root = Path::new(DATASET_ROOT);
let items = vec![
(
root.join("orange").join("dot.jpg"),
vec!["dot".to_string(), "orange".to_string()],
),
(
root.join("red").join("dot.jpg"),
vec!["dot".to_string(), "red".to_string()],
),
(
root.join("red").join("dot.png"),
vec!["dot".to_string(), "red".to_string()],
),
];
let dataset = ImageFolderDataset::new_multilabel_classification_with_items(
items,
&["dot", "orange", "red"],
)
.unwrap();
// Dataset has 3 elements
assert_eq!(dataset.len(), 3);
assert_eq!(dataset.get(3), None);
// Dataset elements should be: [dot, orange] (0, 1), [dot, red] (0, 2), [dot, red] (0, 2)
assert_eq!(
dataset.get(0).unwrap().annotation,
Annotation::MultiLabel(vec![0, 1])
);
assert_eq!(
dataset.get(1).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
assert_eq!(
dataset.get(2).unwrap().annotation,
Annotation::MultiLabel(vec![0, 2])
);
}
#[test]
#[should_panic]
pub fn image_folder_dataset_invalid_extension() {
@ -417,11 +589,26 @@ mod tests {
}
#[test]
pub fn parse_image_annotation_string() {
pub fn parse_image_annotation_label_string() {
let classes = HashMap::from([("0".to_string(), 0_usize), ("1".to_string(), 1_usize)]);
let anno = AnnotationRaw::Label("0".to_string());
assert_eq!(
parse_image_annotation(&"0".to_string().into_bytes(), &classes),
parse_image_annotation(&anno, &classes),
Annotation::Label(0)
);
}
#[test]
pub fn parse_image_annotation_multilabel_string() {
let classes = HashMap::from([
("0".to_string(), 0_usize),
("1".to_string(), 1_usize),
("2".to_string(), 2_usize),
]);
let anno = AnnotationRaw::MultiLabel(vec!["0".to_string(), "2".to_string()]);
assert_eq!(
parse_image_annotation(&anno, &classes),
Annotation::MultiLabel(vec![0, 2])
);
}
}

View File

@ -1,4 +1,4 @@
use crate::metric::{AccuracyInput, Adaptor, LossInput};
use crate::metric::{AccuracyInput, Adaptor, HammingScoreInput, LossInput};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{Int, Tensor};
@ -26,3 +26,28 @@ impl<B: Backend> Adaptor<LossInput<B>> for ClassificationOutput<B> {
LossInput::new(self.loss.clone())
}
}
/// Multi-label classification output adapted for multiple metrics.
#[derive(new)]
pub struct MultiLabelClassificationOutput<B: Backend> {
/// The loss.
pub loss: Tensor<B, 1>,
/// The output.
pub output: Tensor<B, 2>,
/// The targets.
pub targets: Tensor<B, 2, Int>,
}
impl<B: Backend> Adaptor<HammingScoreInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> HammingScoreInput<B> {
HammingScoreInput::new(self.output.clone(), self.targets.clone())
}
}
impl<B: Backend> Adaptor<LossInput<B>> for MultiLabelClassificationOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}

View File

@ -1,3 +1,5 @@
use core::marker::PhantomData;
use super::state::{FormatOptions, NumericMetricState};
use super::{MetricEntry, MetricMetadata};
use crate::metric::{Metric, Numeric};
@ -9,7 +11,7 @@ use burn_core::tensor::{ElementConversion, Int, Tensor};
pub struct AccuracyMetric<B: Backend> {
state: NumericMetricState,
pad_token: Option<usize>,
_b: B,
_b: PhantomData<B>,
}
/// The [accuracy metric](AccuracyMetric) input type.

View File

@ -0,0 +1,155 @@
use core::marker::PhantomData;
use super::state::{FormatOptions, NumericMetricState};
use super::{MetricEntry, MetricMetadata};
use crate::metric::{Metric, Numeric};
use burn_core::tensor::{activation::sigmoid, backend::Backend, ElementConversion, Int, Tensor};
/// The hamming score, sometimes referred to as multi-label or label-based accuracy.
pub struct HammingScore<B: Backend> {
state: NumericMetricState,
threshold: f32,
sigmoid: bool,
_b: PhantomData<B>,
}
/// The [hamming score](HammingScore) input type.
#[derive(new)]
pub struct HammingScoreInput<B: Backend> {
outputs: Tensor<B, 2>,
targets: Tensor<B, 2, Int>,
}
impl<B: Backend> HammingScore<B> {
/// Creates the metric.
pub fn new() -> Self {
Self::default()
}
/// Sets the threshold.
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.threshold = threshold;
self
}
/// Sets the sigmoid activation function usage.
pub fn with_sigmoid(mut self, sigmoid: bool) -> Self {
self.sigmoid = sigmoid;
self
}
}
impl<B: Backend> Default for HammingScore<B> {
/// Creates a new metric instance with default values.
fn default() -> Self {
Self {
state: NumericMetricState::default(),
threshold: 0.5,
sigmoid: false,
_b: PhantomData,
}
}
}
impl<B: Backend> Metric for HammingScore<B> {
const NAME: &'static str = "Hamming Score";
type Input = HammingScoreInput<B>;
fn update(&mut self, input: &HammingScoreInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
let [batch_size, _n_classes] = input.outputs.dims();
let targets = input.targets.clone();
let mut outputs = input.outputs.clone();
if self.sigmoid {
outputs = sigmoid(outputs);
}
let score = outputs
.greater_elem(self.threshold)
.equal(targets.bool())
.float()
.mean()
.into_scalar()
.elem::<f64>();
self.state.update(
100.0 * score,
batch_size,
FormatOptions::new(Self::NAME).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
}
impl<B: Backend> Numeric for HammingScore<B> {
fn value(&self) -> f64 {
self.state.value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_hamming_score() {
let device = Default::default();
let mut metric = HammingScore::<TestBackend>::new();
let x = Tensor::from_data(
[
[0.32, 0.52, 0.38, 0.68, 0.61], // with x > 0.5: [0, 1, 0, 1, 1]
[0.43, 0.31, 0.21, 0.63, 0.53], // [0, 0, 0, 1, 1]
[0.44, 0.25, 0.71, 0.39, 0.73], // [0, 0, 1, 0, 1]
[0.49, 0.37, 0.68, 0.39, 0.31], // [0, 0, 1, 0, 0]
],
&device,
);
let y = Tensor::from_data(
[
[0, 1, 0, 1, 1],
[0, 0, 0, 1, 1],
[0, 0, 1, 0, 1],
[0, 0, 1, 0, 0],
],
&device,
);
let _entry = metric.update(
&HammingScoreInput::new(x.clone(), y.clone()),
&MetricMetadata::fake(),
);
assert_eq!(100.0, metric.value());
// Invert all targets: y = (1 - y)
let y = y.neg().add_scalar(1);
let _entry = metric.update(
&HammingScoreInput::new(x.clone(), y), // invert targets (1 - y)
&MetricMetadata::fake(),
);
assert_eq!(0.0, metric.value());
// Invert 5 target values -> 1 - (5/20) = 0.75
let y = Tensor::from_data(
[
[0, 1, 1, 0, 1],
[0, 0, 0, 0, 1],
[0, 0, 0, 0, 1],
[0, 1, 1, 0, 0],
],
&device,
);
let _entry = metric.update(
&HammingScoreInput::new(x, y), // invert targets (1 - y)
&MetricMetadata::fake(),
);
assert_eq!(75.0, metric.value());
}
}

View File

@ -9,6 +9,7 @@ mod cpu_temp;
mod cpu_use;
#[cfg(feature = "metrics")]
mod cuda;
mod hamming;
mod learning_rate;
mod loss;
#[cfg(feature = "metrics")]
@ -22,6 +23,7 @@ pub use cpu_temp::*;
pub use cpu_use::*;
#[cfg(feature = "metrics")]
pub use cuda::*;
pub use hamming::*;
pub use learning_rate::*;
pub use loss::*;
#[cfg(feature = "metrics")]

View File

@ -17,7 +17,4 @@ burn = { path = "../../crates/burn", features = ["train", "vision", "network"] }
# File download
flate2 = { workspace = true }
indicatif = { workspace = true }
reqwest = { workspace = true }
tar = "0.4.40"
tokio = { workspace = true }

View File

@ -51,6 +51,15 @@ The CNN model and training recipe used in this example are fairly simple since t
demonstrate how to load a custom image classification dataset from disk. Nonetheless, it still
achieves 70-80% accuracy on the test set after just 30 epochs.
Run it with the Torch GPU backend:
```sh
cargo run --example custom-image-dataset
export TORCH_CUDA_VERSION=cu121
cargo run --example custom-image-dataset --release --features tch-gpu
```
Run it with our WGPU backend:
```sh
cargo run --example custom-image-dataset --release --features wgpu
```