mirror of https://github.com/tracel-ai/burn.git
Add `topk` tensor operation (#1497)
* Add topk and topk_with_indices * Change topk_with_indices test to guarantee order (previously equal elements)
This commit is contained in:
parent
dd699a90a2
commit
dc45cf1700
|
@ -237,6 +237,8 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
|
|||
| `tensor.sort_descending_with_indices(dim)` | `tensor.sort(dim, descending=True)` |
|
||||
| `tensor.argsort(dim)` | `tensor.argsort(dim)` |
|
||||
| `tensor.argsort_descending(dim)` | `tensor.argsort(dim, descending=True)` |
|
||||
| `tensor.topk(k, dim)` | `tensor.topk(k, dim).values` |
|
||||
| `tensor.topk_with_indices(k, dim)` | `tensor.topk(k, dim)` |
|
||||
|
||||
### Float Operations
|
||||
|
||||
|
|
|
@ -273,6 +273,14 @@ where
|
|||
Tensor::new(sort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await)
|
||||
}
|
||||
|
||||
/// Sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_descending(self, dim: usize) -> Tensor<B, D> {
|
||||
Tensor::new(sort::<B, D, Float>(self.primitive, dim, /*descending*/ true).await)
|
||||
}
|
||||
|
||||
/// Sort the elements by value in ascending order along a given dimension.
|
||||
/// Also returns the indices.
|
||||
///
|
||||
|
@ -285,6 +293,21 @@ where
|
|||
(Tensor::new(values), Tensor::new(indices))
|
||||
}
|
||||
|
||||
/// Sort the elements by value in descending order along a given dimension.
|
||||
/// Also returns the indices.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_descending_with_indices(
|
||||
self,
|
||||
dim: usize,
|
||||
) -> (Tensor<B, D>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
|
||||
let (values, indices) =
|
||||
sort_with_indices::<B, D, Float>(self.primitive, dim, /*descending*/ true).await;
|
||||
(Tensor::new(values), Tensor::new(indices))
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
|
@ -293,4 +316,36 @@ where
|
|||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(argsort::<B, D, Float>(self.primitive, dim, /*descending*/ false).await)
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(argsort::<B, D, Float>(self.primitive, dim, /*descending*/ true).await)
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn topk(self, k: usize, dim: usize) -> Tensor<B, D> {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
self.sort_descending(dim).await.select(dim, k_indices)
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
/// Also returns the indices.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn topk_with_indices(
|
||||
self,
|
||||
k: usize,
|
||||
dim: usize,
|
||||
) -> (Tensor<B, D>, Tensor<B, D, Int>) {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
let (values, indices) = self.sort_descending_with_indices(dim).await;
|
||||
(
|
||||
values.select(dim, k_indices.clone()),
|
||||
indices.select(dim, k_indices),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -70,36 +70,87 @@ where
|
|||
Tensor::new(B::int_into_float(self.primitive))
|
||||
}
|
||||
|
||||
/// Sort the elements by value along a given dimension.
|
||||
/// Sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort(self, dim: usize, descending: bool) -> Tensor<B, D, Int> {
|
||||
Tensor::new(sort::<B, D, Int>(self.primitive, dim, descending).await)
|
||||
pub async fn sort(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
Tensor::new(sort::<B, D, Int>(self.primitive, dim, /* descending */ false).await)
|
||||
}
|
||||
|
||||
/// Sort the elements by value along a given dimension.
|
||||
/// Sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_descending(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
Tensor::new(sort::<B, D, Int>(self.primitive, dim, /* descending */ true).await)
|
||||
}
|
||||
|
||||
/// Sort the elements by value in ascending order along a given dimension.
|
||||
/// Also returns the indices.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn sort_with_indices(
|
||||
self,
|
||||
dim: usize,
|
||||
descending: bool,
|
||||
) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
|
||||
pub async fn sort_with_indices(self, dim: usize) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
|
||||
let (values, indices) =
|
||||
sort_with_indices::<B, D, Int>(self.primitive, dim, descending).await;
|
||||
sort_with_indices::<B, D, Int>(self.primitive, dim, /*descending*/ false).await;
|
||||
(Tensor::new(values), Tensor::new(indices))
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements by value along a given dimension.
|
||||
/// Sort the elements by value in descending order along a given dimension.
|
||||
/// Also returns the indices.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argsort(self, dim: usize, descending: bool) -> Tensor<B, D, Int> {
|
||||
pub async fn sort_descending_with_indices(
|
||||
self,
|
||||
dim: usize,
|
||||
) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
|
||||
check!(TensorCheck::sort_dim::<D>("Sort_with_indices", dim));
|
||||
let (values, indices) =
|
||||
sort_with_indices::<B, D, Int>(self.primitive, dim, /*descending*/ true).await;
|
||||
(Tensor::new(values), Tensor::new(indices))
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements by value in ascending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argsort(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(argsort::<B, D, Int>(self.primitive, dim, descending).await)
|
||||
Tensor::new(argsort::<B, D, Int>(self.primitive, dim, /*descending*/ false).await)
|
||||
}
|
||||
|
||||
/// Returns the indices that sort the elements by value in descending order along a given dimension.
|
||||
///
|
||||
/// This sort is unstable (i.e., may reorder equal elements).
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn argsort_descending(self, dim: usize) -> Tensor<B, D, Int> {
|
||||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(argsort::<B, D, Int>(self.primitive, dim, /*descending*/ true).await)
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn topk(self, k: usize, dim: usize) -> Tensor<B, D, Int> {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
self.sort_descending(dim).await.select(dim, k_indices)
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
/// Also returns the indices.
|
||||
#[cfg(all(not(feature = "wasm-sync"), target_family = "wasm"))]
|
||||
pub async fn topk_with_indices(
|
||||
self,
|
||||
k: usize,
|
||||
dim: usize,
|
||||
) -> (Tensor<B, D, Int>, Tensor<B, D, Int>) {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
let (values, indices) = self.sort_descending_with_indices(dim).await;
|
||||
(
|
||||
values.select(dim, k_indices.clone()),
|
||||
indices.select(dim, k_indices),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -717,6 +717,25 @@ where
|
|||
check!(TensorCheck::sort_dim::<D>("Argsort", dim));
|
||||
Tensor::new(K::argsort(self.primitive, dim, /*descending*/ true))
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn topk(self, k: usize, dim: usize) -> Tensor<B, D, K> {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
self.sort_descending(dim).select(dim, k_indices)
|
||||
}
|
||||
|
||||
/// Returns the `k` largest elements of the given input tensor along a given dimension.
|
||||
/// Also returns the indices.
|
||||
#[cfg(any(feature = "wasm-sync", not(target_family = "wasm")))]
|
||||
pub fn topk_with_indices(self, k: usize, dim: usize) -> (Tensor<B, D, K>, Tensor<B, D, Int>) {
|
||||
let k_indices = Tensor::arange(0..k as i64, &self.device());
|
||||
let (values, indices) = self.sort_descending_with_indices(dim);
|
||||
(
|
||||
values.select(dim, k_indices.clone()),
|
||||
indices.select(dim, k_indices),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B, K> Tensor<B, 2, K>
|
||||
|
|
|
@ -92,6 +92,7 @@ macro_rules! testgen_all {
|
|||
burn_tensor::testgen_sign!();
|
||||
burn_tensor::testgen_tri_mask!();
|
||||
burn_tensor::testgen_sort_argsort!();
|
||||
burn_tensor::testgen_topk!();
|
||||
|
||||
// test stats
|
||||
burn_tensor::testgen_var!();
|
||||
|
|
|
@ -51,6 +51,7 @@ mod squeeze;
|
|||
mod stack;
|
||||
mod sub;
|
||||
mod tanh;
|
||||
mod topk;
|
||||
mod transpose;
|
||||
mod tri;
|
||||
mod tri_mask;
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
#[burn_tensor_testgen::testgen(topk)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_tensor::{Data, Shape, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_topk_1d() {
|
||||
// Int
|
||||
let tensor = TestTensorInt::from([1, 2, 3, 4, 5]);
|
||||
|
||||
let values = tensor.topk(3, /*dim*/ 0);
|
||||
let values_actual = values.into_data();
|
||||
|
||||
let values_expected = Data::from([5, 4, 3]);
|
||||
assert_eq!(values_expected, values_actual);
|
||||
|
||||
// Float
|
||||
let tensor = TestTensor::from([1., 2., 3., 4., 5.]);
|
||||
|
||||
let values = tensor.topk(3, /*dim*/ 0);
|
||||
let values_actual = values.into_data();
|
||||
|
||||
let values_expected = Data::from([5., 4., 3.]);
|
||||
values_expected.assert_approx_eq(&values_actual, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topk() {
|
||||
// 2D Int
|
||||
let tensor = TestTensorInt::from([[[1, 4, 7], [2, 5, 6]], [[3, 0, 9], [8, 2, 8]]]);
|
||||
|
||||
let values = tensor.topk(2, /*dim*/ 2);
|
||||
let values_actual = values.into_data();
|
||||
|
||||
let values_expected = Data::from([[[7, 4], [6, 5]], [[9, 3], [8, 8]]]);
|
||||
assert_eq!(values_expected, values_actual);
|
||||
|
||||
// 2D Float
|
||||
let tensor = TestTensor::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 8.]]]);
|
||||
|
||||
let values = tensor.topk(2, /*dim*/ 2);
|
||||
let values_actual = values.into_data();
|
||||
|
||||
let values_expected = Data::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 8.]]]);
|
||||
values_expected.assert_approx_eq(&values_actual, 5);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_topk_with_indices() {
|
||||
// 1D
|
||||
let tensor = TestTensorInt::from([1, 2, 3, 4, 5]);
|
||||
|
||||
let (values, indices) = tensor.topk_with_indices(3, /*dim*/ 0);
|
||||
let values_actual = values.into_data();
|
||||
let indices_actual = indices.into_data();
|
||||
|
||||
let values_expected = Data::from([5, 4, 3]);
|
||||
assert_eq!(values_expected, values_actual);
|
||||
|
||||
let indices_expected = Data::from([4, 3, 2]);
|
||||
assert_eq!(indices_expected, indices_actual);
|
||||
|
||||
// 2D
|
||||
let tensor = TestTensor::from([[[1., 4., 7.], [2., 5., 6.]], [[3., 0., 9.], [8., 2., 7.]]]);
|
||||
|
||||
let (values, indices) = tensor.topk_with_indices(2, /*dim*/ 2);
|
||||
let values_actual = values.into_data();
|
||||
let indices_actual = indices.into_data();
|
||||
|
||||
let values_expected = Data::from([[[7., 4.], [6., 5.]], [[9., 3.], [8., 7.]]]);
|
||||
values_expected.assert_approx_eq(&values_actual, 5);
|
||||
|
||||
let indices_expected = Data::from([[[2, 1], [2, 1]], [[2, 0], [0, 2]]]);
|
||||
assert_eq!(indices_expected, indices_actual);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue