mirror of https://github.com/tracel-ai/burn.git
feat: argmax + argmin (#412)
This commit is contained in:
parent
323261b594
commit
4d40bde7b9
|
@ -380,31 +380,46 @@ where
|
|||
NdArrayTensor::new(output_array.into_shared())
|
||||
}
|
||||
pub fn argmax<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
arg(tensor, dim, cmp_min)
|
||||
if dim == D - 1 {
|
||||
return arg(tensor, cmp_min);
|
||||
}
|
||||
|
||||
tensor.array.swap_axes(dim, D - 1);
|
||||
let mut tensor = arg(tensor, cmp_min);
|
||||
tensor.array.swap_axes(dim, D - 1);
|
||||
|
||||
tensor
|
||||
}
|
||||
|
||||
pub fn argmin<const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
mut tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> NdArrayTensor<i64, D> {
|
||||
arg(tensor, dim, cmp_max)
|
||||
if dim == D - 1 {
|
||||
return arg(tensor, cmp_max);
|
||||
}
|
||||
|
||||
tensor.array.swap_axes(dim, D - 1);
|
||||
let mut tensor = arg(tensor, cmp_max);
|
||||
tensor.array.swap_axes(dim, D - 1);
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
fn arg<E: NdArrayElement, F, const D: usize>(
|
||||
tensor: NdArrayTensor<E, D>,
|
||||
dim: usize,
|
||||
cmp: F,
|
||||
) -> NdArrayTensor<i64, D>
|
||||
where
|
||||
F: Fn(&f64, &f64) -> Ordering,
|
||||
{
|
||||
let mut shape = tensor.shape();
|
||||
let batch_size = shape.dims[dim];
|
||||
let mut end = shape.dims[dim];
|
||||
let batch_size = shape.dims[D - 1];
|
||||
let mut end = shape.dims[D - 1];
|
||||
|
||||
let mut values = tensor.array.into_iter().collect::<Vec<_>>();
|
||||
let mut start = 0;
|
||||
|
@ -430,7 +445,7 @@ where
|
|||
start += batch_size;
|
||||
end += batch_size;
|
||||
}
|
||||
shape.dims[dim] = 1;
|
||||
shape.dims[D - 1] = 1;
|
||||
NdArrayTensor::from_data(Data::new(output, shape))
|
||||
}
|
||||
|
||||
|
|
|
@ -4,13 +4,46 @@ mod tests {
|
|||
use burn_tensor::{Data, Tensor};
|
||||
|
||||
#[test]
|
||||
fn test_argmax_2d() {
|
||||
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
fn test_argmax_2d_dim0() {
|
||||
let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.argmax(0);
|
||||
|
||||
let data_expected = Data::from([[0, 0, 1]]);
|
||||
assert_eq!(data_expected, data_actual.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argmin_2d_dim0() {
|
||||
let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.argmin(0);
|
||||
|
||||
let data_expected = Data::from([[0, 1, 0]]);
|
||||
assert_eq!(data_expected, data_actual.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argmax_2d_dim1() {
|
||||
let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.argmax(1);
|
||||
|
||||
let data_expected = Data::from([[2], [2]]);
|
||||
let data_expected = Data::from([[1], [2]]);
|
||||
assert_eq!(data_expected, data_actual.to_data());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_argmin_2d_dim1() {
|
||||
let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);
|
||||
let tensor = Tensor::<TestBackend, 2>::from_data(data);
|
||||
|
||||
let data_actual = tensor.argmin(1);
|
||||
|
||||
let data_expected = Data::from([[2], [1]]);
|
||||
assert_eq!(data_expected, data_actual.to_data());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,18 +4,20 @@ use burn_tensor::Shape;
|
|||
|
||||
kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl");
|
||||
kernel_wgsl!(ReductionDimRaw, "../template/reduction/reduce_dim.wgsl");
|
||||
kernel_wgsl!(ReductionArgsRaw, "../template/reduction/args.wgsl");
|
||||
|
||||
struct SumDimRaw;
|
||||
pub struct ArgsMax;
|
||||
pub struct ArgsMin;
|
||||
pub struct SumDim;
|
||||
pub struct MeanDim;
|
||||
|
||||
impl StaticKernel for SumDimRaw {
|
||||
impl StaticKernel for SumDim {
|
||||
fn source_template() -> SourceTemplate {
|
||||
ReductionDimRaw::source_template().register("assign", "output[global_id.x] = sum;")
|
||||
}
|
||||
}
|
||||
|
||||
struct MeanDimRaw;
|
||||
|
||||
impl StaticKernel for MeanDimRaw {
|
||||
impl StaticKernel for MeanDim {
|
||||
fn source_template() -> SourceTemplate {
|
||||
ReductionDimRaw::source_template()
|
||||
.add_template(
|
||||
|
@ -27,6 +29,22 @@ impl StaticKernel for MeanDimRaw {
|
|||
}
|
||||
}
|
||||
|
||||
impl StaticKernel for ArgsMax {
|
||||
fn source_template() -> SourceTemplate {
|
||||
ReductionArgsRaw::source_template()
|
||||
.register("cmp", ">")
|
||||
.register("initial", (-32767).to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl StaticKernel for ArgsMin {
|
||||
fn source_template() -> SourceTemplate {
|
||||
ReductionArgsRaw::source_template()
|
||||
.register("cmp", "<")
|
||||
.register("initial", 32767.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reduction_sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
|
||||
const WORKGROUP: usize = 256;
|
||||
|
||||
|
@ -57,21 +75,7 @@ pub fn reduction_sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) ->
|
|||
}
|
||||
}
|
||||
|
||||
pub fn reduction_sum_dim<E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
reduction_dim::<SumDimRaw, E, D>(input, dim)
|
||||
}
|
||||
|
||||
pub fn reduction_mean_dim<E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
reduction_dim::<MeanDimRaw, E, D>(input, dim)
|
||||
}
|
||||
|
||||
fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
pub fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
|
@ -103,3 +107,36 @@ fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
|
|||
|
||||
output
|
||||
}
|
||||
|
||||
pub fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
let mut shape_out = input.shape.clone();
|
||||
shape_out.dims[dim] = 1;
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(shape_out.num_elements() * core::mem::size_of::<I>());
|
||||
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, I, 256, 1, 1>>();
|
||||
let mut info = build_info(&[&input, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_buffers = input
|
||||
.context
|
||||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
input.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
kernel,
|
||||
&[&input.buffer, &output.buffer, &info_buffers],
|
||||
);
|
||||
|
||||
WgpuTensor::new(output.context, output.shape, output.buffer)
|
||||
}
|
||||
|
|
|
@ -52,6 +52,7 @@ mod tests {
|
|||
burn_tensor::testgen_transpose!();
|
||||
burn_tensor::testgen_index!();
|
||||
burn_tensor::testgen_aggregation!();
|
||||
burn_tensor::testgen_arg!();
|
||||
burn_tensor::testgen_map_comparison!();
|
||||
|
||||
type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use super::numeric::NumericOps;
|
||||
use super::{BaseOps, BoolTensor, Device, FloatElem, FloatTensor};
|
||||
use super::{BaseOps, BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
|
||||
use crate::kernel::{matmul, unary, unary_inplace, unary_scalar, unary_scalar_inplace};
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
|
@ -427,17 +427,11 @@ where
|
|||
todo!()
|
||||
}
|
||||
|
||||
fn argmax<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
fn argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn argmin<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
fn argmin<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::argmin(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -273,17 +273,11 @@ where
|
|||
NumericOps::<G>::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_argmax<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
fn int_argmax<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_argmin<const D: usize>(
|
||||
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
|
||||
_dim: usize,
|
||||
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
|
||||
todo!()
|
||||
fn int_argmin<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::argmin(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
use std::marker::PhantomData;
|
||||
|
||||
use burn_tensor::{Element, ElementConversion, Shape};
|
||||
|
||||
use crate::kernel::{
|
||||
binary_elemwise, binary_elemwise_inplace, reduction_mean_dim, reduction_sum, reduction_sum_dim,
|
||||
unary_scalar, unary_scalar_inplace,
|
||||
binary_elemwise, binary_elemwise_inplace, reduction_args_dim, reduction_dim, reduction_sum,
|
||||
unary_scalar, unary_scalar_inplace, ArgsMax, ArgsMin, MeanDim, SumDim,
|
||||
};
|
||||
use crate::pool::get_context;
|
||||
use crate::{
|
||||
|
@ -12,6 +8,8 @@ use crate::{
|
|||
unary_scalar, unary_scalar_inplace,
|
||||
};
|
||||
use crate::{GraphicsApi, WgpuDevice};
|
||||
use burn_tensor::{Element, ElementConversion, Shape};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
pub struct NumericOps<G: GraphicsApi> {
|
||||
_g: PhantomData<G>,
|
||||
|
@ -170,20 +168,34 @@ impl<G: GraphicsApi> NumericOps<G> {
|
|||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
reduction_sum_dim(tensor, dim)
|
||||
reduction_dim::<SumDim, E, D>(tensor, dim)
|
||||
}
|
||||
|
||||
pub fn mean<E: WgpuElement + Element, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, 1> {
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
Self::div_scalar(reduction_sum(tensor), (num_elems as f32).elem())
|
||||
Self::div_scalar(Self::sum(tensor), (num_elems as f32).elem())
|
||||
}
|
||||
|
||||
pub fn mean_dim<E: WgpuElement + Element, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
reduction_mean_dim(tensor, dim)
|
||||
reduction_dim::<MeanDim, E, D>(tensor, dim)
|
||||
}
|
||||
|
||||
pub fn argmax<E: WgpuElement + Element, I: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
reduction_args_dim::<ArgsMax, E, I, D>(tensor, dim)
|
||||
}
|
||||
|
||||
pub fn argmin<E: WgpuElement + Element, I: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
reduction_args_dim::<ArgsMin, E, I, D>(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,56 @@
|
|||
@group(0)
|
||||
@binding(0)
|
||||
var<storage, read> input: array<{{ elem }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ int }}>;
|
||||
|
||||
@group(0)
|
||||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
) {
|
||||
let dim: u32 = info[0];
|
||||
let dim_reduce = info[4u * dim + 1u];
|
||||
var index_offset: u32 = 0u;
|
||||
var stride_dim: u32 = 0u;
|
||||
var shape_dim: u32 = 0u;
|
||||
|
||||
for (var i: u32 = 1u; i <= dim; i++) {
|
||||
let stride_input = info[i];
|
||||
let stride_output = info[i + dim];
|
||||
let shape_output = info[i + 3u * dim];
|
||||
|
||||
let num_block = global_id.x / stride_output % shape_output;
|
||||
|
||||
if i - 1u != dim_reduce {
|
||||
index_offset += num_block * stride_input;
|
||||
} else {
|
||||
let shape_input = info[i + 2u * dim];
|
||||
index_offset += num_block;
|
||||
stride_dim = stride_input;
|
||||
shape_dim = shape_input;
|
||||
}
|
||||
}
|
||||
|
||||
var current_value = {{ elem }}({{ initial }});
|
||||
var index = {{ int }}(0);
|
||||
|
||||
for (var i = 0u; i < shape_dim; i++) {
|
||||
let index_input = i * stride_dim;
|
||||
let value = input[index_input + index_offset];
|
||||
|
||||
if (value {{ cmp }} current_value) {
|
||||
current_value = value;
|
||||
index = {{ int }}(i);
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
output[global_id.x] = index;
|
||||
}
|
|
@ -10,15 +10,10 @@ var<storage, read_write> output: array<{{ elem }}>;
|
|||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
var<workgroup> data: array<{{ elem }}, {{ workgroup_size_x }}>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let dim: u32 = info[0];
|
||||
let dim_reduce = info[4u * dim + 1u];
|
||||
|
|
Loading…
Reference in New Issue