mirror of https://github.com/tracel-ai/burn.git
Add top-k accuracy (#2097)
* make contacts deterministic across Worlds * add top k acc * update book --------- Co-authored-by: Charles Bournhonesque <cbournhonesque@snapchat.com>
This commit is contained in:
parent
a53f459f20
commit
ade664d4d8
|
@ -6,6 +6,7 @@ throughout the training process. We currently offer a restricted range of metric
|
||||||
| Metric | Description |
|
| Metric | Description |
|
||||||
| ---------------- | ------------------------------------------------------- |
|
| ---------------- | ------------------------------------------------------- |
|
||||||
| Accuracy | Calculate the accuracy in percentage |
|
| Accuracy | Calculate the accuracy in percentage |
|
||||||
|
| TopKAccuracy | Calculate the top-k accuracy in percentage |
|
||||||
| Loss | Output the loss used for the backward pass |
|
| Loss | Output the loss used for the backward pass |
|
||||||
| CPU Temperature | Fetch the temperature of CPUs |
|
| CPU Temperature | Fetch the temperature of CPUs |
|
||||||
| CPU Usage | Fetch the CPU utilization |
|
| CPU Usage | Fetch the CPU utilization |
|
||||||
|
|
|
@ -15,6 +15,9 @@ mod loss;
|
||||||
#[cfg(feature = "metrics")]
|
#[cfg(feature = "metrics")]
|
||||||
mod memory_use;
|
mod memory_use;
|
||||||
|
|
||||||
|
#[cfg(feature = "metrics")]
|
||||||
|
mod top_k_acc;
|
||||||
|
|
||||||
pub use acc::*;
|
pub use acc::*;
|
||||||
pub use base::*;
|
pub use base::*;
|
||||||
#[cfg(feature = "metrics")]
|
#[cfg(feature = "metrics")]
|
||||||
|
@ -28,6 +31,8 @@ pub use learning_rate::*;
|
||||||
pub use loss::*;
|
pub use loss::*;
|
||||||
#[cfg(feature = "metrics")]
|
#[cfg(feature = "metrics")]
|
||||||
pub use memory_use::*;
|
pub use memory_use::*;
|
||||||
|
#[cfg(feature = "metrics")]
|
||||||
|
pub use top_k_acc::*;
|
||||||
|
|
||||||
pub(crate) mod processor;
|
pub(crate) mod processor;
|
||||||
/// Module responsible to save and exposes data collected during training.
|
/// Module responsible to save and exposes data collected during training.
|
||||||
|
|
|
@ -0,0 +1,152 @@
|
||||||
|
use core::marker::PhantomData;
|
||||||
|
|
||||||
|
use super::state::{FormatOptions, NumericMetricState};
|
||||||
|
use super::{MetricEntry, MetricMetadata};
|
||||||
|
use crate::metric::{Metric, Numeric};
|
||||||
|
use burn_core::tensor::backend::Backend;
|
||||||
|
use burn_core::tensor::{ElementConversion, Int, Tensor};
|
||||||
|
|
||||||
|
/// The Top-K accuracy metric.
|
||||||
|
///
|
||||||
|
/// For K=1, this is equivalent to the [accuracy metric](`super::acc::AccuracyMetric`).
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct TopKAccuracyMetric<B: Backend> {
|
||||||
|
k: usize,
|
||||||
|
state: NumericMetricState,
|
||||||
|
/// If specified, targets equal to this value will be considered padding and will not count
|
||||||
|
/// towards the metric
|
||||||
|
pad_token: Option<usize>,
|
||||||
|
_b: PhantomData<B>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The [top-k accuracy metric](TopKAccuracyMetric) input type.
|
||||||
|
#[derive(new)]
|
||||||
|
pub struct TopKAccuracyInput<B: Backend> {
|
||||||
|
/// The outputs (batch_size, num_classes)
|
||||||
|
outputs: Tensor<B, 2>,
|
||||||
|
/// The labels (batch_size)
|
||||||
|
targets: Tensor<B, 1, Int>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> TopKAccuracyMetric<B> {
|
||||||
|
/// Creates the metric.
|
||||||
|
pub fn new(k: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
k,
|
||||||
|
..Default::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sets the pad token.
|
||||||
|
pub fn with_pad_token(mut self, index: usize) -> Self {
|
||||||
|
self.pad_token = Some(index);
|
||||||
|
self
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Metric for TopKAccuracyMetric<B> {
|
||||||
|
const NAME: &'static str = "Top-K Accuracy";
|
||||||
|
|
||||||
|
type Input = TopKAccuracyInput<B>;
|
||||||
|
|
||||||
|
fn update(&mut self, input: &TopKAccuracyInput<B>, _metadata: &MetricMetadata) -> MetricEntry {
|
||||||
|
let [batch_size, _n_classes] = input.outputs.dims();
|
||||||
|
|
||||||
|
let targets = input.targets.clone().to_device(&B::Device::default());
|
||||||
|
|
||||||
|
let outputs = input
|
||||||
|
.outputs
|
||||||
|
.clone()
|
||||||
|
.argsort_descending(1)
|
||||||
|
.narrow(1, 0, self.k)
|
||||||
|
.to_device(&B::Device::default())
|
||||||
|
.reshape([batch_size, self.k]);
|
||||||
|
|
||||||
|
let (targets, num_pad) = match self.pad_token {
|
||||||
|
Some(pad_token) => {
|
||||||
|
// we ignore the samples where the target is equal to the pad token
|
||||||
|
let mask = targets.clone().equal_elem(pad_token as i64);
|
||||||
|
let num_pad = mask.clone().int().sum().into_scalar().elem::<f64>();
|
||||||
|
(targets.clone().mask_fill(mask, -1_i64), num_pad)
|
||||||
|
}
|
||||||
|
None => (targets.clone(), 0_f64),
|
||||||
|
};
|
||||||
|
|
||||||
|
let accuracy = targets
|
||||||
|
.reshape([batch_size, 1])
|
||||||
|
.repeat_dim(1, self.k)
|
||||||
|
.equal(outputs)
|
||||||
|
.int()
|
||||||
|
.sum()
|
||||||
|
.into_scalar()
|
||||||
|
.elem::<f64>()
|
||||||
|
/ (batch_size as f64 - num_pad);
|
||||||
|
|
||||||
|
self.state.update(
|
||||||
|
100.0 * accuracy,
|
||||||
|
batch_size,
|
||||||
|
FormatOptions::new(Self::NAME).unit("%").precision(2),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn clear(&mut self) {
|
||||||
|
self.state.reset()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<B: Backend> Numeric for TopKAccuracyMetric<B> {
|
||||||
|
fn value(&self) -> f64 {
|
||||||
|
self.state.value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::TestBackend;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accuracy_without_padding() {
|
||||||
|
let device = Default::default();
|
||||||
|
let mut metric = TopKAccuracyMetric::<TestBackend>::new(2);
|
||||||
|
let input = TopKAccuracyInput::new(
|
||||||
|
Tensor::from_data(
|
||||||
|
[
|
||||||
|
[0.0, 0.2, 0.8], // 2, 1
|
||||||
|
[1.0, 2.0, 0.5], // 1, 0
|
||||||
|
[0.4, 0.1, 0.2], // 0, 2
|
||||||
|
[0.6, 0.7, 0.2], // 1, 0
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
),
|
||||||
|
Tensor::from_data([2, 2, 1, 1], &device),
|
||||||
|
);
|
||||||
|
|
||||||
|
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||||
|
assert_eq!(50.0, metric.value());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_accuracy_with_padding() {
|
||||||
|
let device = Default::default();
|
||||||
|
let mut metric = TopKAccuracyMetric::<TestBackend>::new(2).with_pad_token(3);
|
||||||
|
let input = TopKAccuracyInput::new(
|
||||||
|
Tensor::from_data(
|
||||||
|
[
|
||||||
|
[0.0, 0.2, 0.8, 0.0], // 2, 1
|
||||||
|
[1.0, 2.0, 0.5, 0.0], // 1, 0
|
||||||
|
[0.4, 0.1, 0.2, 0.0], // 0, 2
|
||||||
|
[0.6, 0.7, 0.2, 0.0], // 1, 0
|
||||||
|
[0.0, 0.1, 0.2, 5.0], // Predicted padding should not count
|
||||||
|
[0.0, 0.1, 0.2, 0.0], // Error on padding should not count
|
||||||
|
[0.6, 0.0, 0.2, 0.0], // Error on padding should not count
|
||||||
|
],
|
||||||
|
&device,
|
||||||
|
),
|
||||||
|
Tensor::from_data([2, 2, 1, 1, 3, 3, 3], &device),
|
||||||
|
);
|
||||||
|
|
||||||
|
let _entry = metric.update(&input, &MetricMetadata::fake());
|
||||||
|
assert_eq!(50.0, metric.value());
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue