Add top-k accuracy (#2097)

* make contacts deterministic across Worlds

* add top k acc

* update book

---------

Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
This commit is contained in:
Periwink 2024-08-06 13:01:28 -04:00 committed by GitHub
parent a53f459f20
commit ade664d4d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 158 additions and 0 deletions

View File

@ -6,6 +6,7 @@ throughout the training process. We currently offer a restricted range of metric
| Metric | Description | | Metric | Description |
| ---------------- | ------------------------------------------------------- | | ---------------- | ------------------------------------------------------- |
| Accuracy | Calculate the accuracy in percentage | | Accuracy | Calculate the accuracy in percentage |
| TopKAccuracy | Calculate the top-k accuracy in percentage |
| Loss | Output the loss used for the backward pass | | Loss | Output the loss used for the backward pass |
| CPU Temperature | Fetch the temperature of CPUs | | CPU Temperature | Fetch the temperature of CPUs |
| CPU Usage | Fetch the CPU utilization | | CPU Usage | Fetch the CPU utilization |

View File

@ -15,6 +15,9 @@ mod loss;
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
mod memory_use; mod memory_use;
#[cfg(feature = "metrics")]
mod top_k_acc;
pub use acc::*; pub use acc::*;
pub use base::*; pub use base::*;
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
@ -28,6 +31,8 @@ pub use learning_rate::*;
pub use loss::*; pub use loss::*;
#[cfg(feature = "metrics")] #[cfg(feature = "metrics")]
pub use memory_use::*; pub use memory_use::*;
#[cfg(feature = "metrics")]
pub use top_k_acc::*;
pub(crate) mod processor; pub(crate) mod processor;
/// Module responsible to save and exposes data collected during training. /// Module responsible to save and exposes data collected during training.

View File

@ -0,0 +1,152 @@
use core::marker::PhantomData;
use super::state::{FormatOptions, NumericMetricState};
use super::{MetricEntry, MetricMetadata};
use crate::metric::{Metric, Numeric};
use burn_core::tensor::backend::Backend;
use burn_core::tensor::{ElementConversion, Int, Tensor};
/// The Top-K accuracy metric.
///
/// For K=1, this is equivalent to the [accuracy metric](`super::acc::AccuracyMetric`).
#[derive(Default)]
pub struct TopKAccuracyMetric<B: Backend> {
k: usize,
state: NumericMetricState,
/// If specified, targets equal to this value will be considered padding and will not count
/// towards the metric
pad_token: Option<usize>,
_b: PhantomData<B>,
}
/// The [top-k accuracy metric](TopKAccuracyMetric) input type.
#[derive(new)]
pub struct TopKAccuracyInput<B: Backend> {
/// The outputs (batch_size, num_classes)
outputs: Tensor<B, 2>,
/// The labels (batch_size)
targets: Tensor<B, 1, Int>,
}
impl<B: Backend> TopKAccuracyMetric<B> {
/// Creates the metric.
pub fn new(k: usize) -> Self {
Self {
k,
..Default::default()
}
}
/// Sets the pad token.
pub fn with_pad_token(mut self, index: usize) -> Self {
self.pad_token = Some(index);
self
}
}
impl<B: Backend> Metric for TopKAccuracyMetric<B> {
const NAME: &'static str = "Top-K Accuracy";
type Input = TopKAccuracyInput<B>;
fn update(&mut self, input: &TopKAccuracyInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
let [batch_size, _n_classes] = input.outputs.dims();
let targets = input.targets.clone().to_device(&B::Device::default());
let outputs = input
.outputs
.clone()
.argsort_descending(1)
.narrow(1, 0, self.k)
.to_device(&B::Device::default())
.reshape([batch_size, self.k]);
let (targets, num_pad) = match self.pad_token {
Some(pad_token) => {
// we ignore the samples where the target is equal to the pad token
let mask = targets.clone().equal_elem(pad_token as i64);
let num_pad = mask.clone().int().sum().into_scalar().elem::<f64>();
(targets.clone().mask_fill(mask, -1_i64), num_pad)
}
None => (targets.clone(), 0_f64),
};
let accuracy = targets
.reshape([batch_size, 1])
.repeat_dim(1, self.k)
.equal(outputs)
.int()
.sum()
.into_scalar()
.elem::<f64>()
/ (batch_size as f64 - num_pad);
self.state.update(
100.0 * accuracy,
batch_size,
FormatOptions::new(Self::NAME).unit("%").precision(2),
)
}
fn clear(&mut self) {
self.state.reset()
}
}
impl<B: Backend> Numeric for TopKAccuracyMetric<B> {
fn value(&self) -> f64 {
self.state.value()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::TestBackend;
#[test]
fn test_accuracy_without_padding() {
let device = Default::default();
let mut metric = TopKAccuracyMetric::<TestBackend>::new(2);
let input = TopKAccuracyInput::new(
Tensor::from_data(
[
[0.0, 0.2, 0.8], // 2, 1
[1.0, 2.0, 0.5], // 1, 0
[0.4, 0.1, 0.2], // 0, 2
[0.6, 0.7, 0.2], // 1, 0
],
&device,
),
Tensor::from_data([2, 2, 1, 1], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(50.0, metric.value());
}
#[test]
fn test_accuracy_with_padding() {
let device = Default::default();
let mut metric = TopKAccuracyMetric::<TestBackend>::new(2).with_pad_token(3);
let input = TopKAccuracyInput::new(
Tensor::from_data(
[
[0.0, 0.2, 0.8, 0.0], // 2, 1
[1.0, 2.0, 0.5, 0.0], // 1, 0
[0.4, 0.1, 0.2, 0.0], // 0, 2
[0.6, 0.7, 0.2, 0.0], // 1, 0
[0.0, 0.1, 0.2, 5.0], // Predicted padding should not count
[0.0, 0.1, 0.2, 0.0], // Error on padding should not count
[0.6, 0.0, 0.2, 0.0], // Error on padding should not count
],
&device,
),
Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
);
let _entry = metric.update(&input, &MetricMetadata::fake());
assert_eq!(50.0, metric.value());
}
}