feat: argmax + argmin (#412)

This commit is contained in:
Nathaniel Simard 2023-06-20 10:03:00 -04:00 committed by GitHub
parent 323261b594
commit 4d40bde7b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 203 additions and 66 deletions

View File

@ -380,31 +380,46 @@ where
NdArrayTensor::new(output_array.into_shared())
}
pub fn argmax<const D: usize>(
tensor: NdArrayTensor<E, D>,
mut tensor: NdArrayTensor<E, D>,
dim: usize,
) -> NdArrayTensor<i64, D> {
arg(tensor, dim, cmp_min)
if dim == D - 1 {
return arg(tensor, cmp_min);
}
tensor.array.swap_axes(dim, D - 1);
let mut tensor = arg(tensor, cmp_min);
tensor.array.swap_axes(dim, D - 1);
tensor
}
pub fn argmin<const D: usize>(
tensor: NdArrayTensor<E, D>,
mut tensor: NdArrayTensor<E, D>,
dim: usize,
) -> NdArrayTensor<i64, D> {
arg(tensor, dim, cmp_max)
if dim == D - 1 {
return arg(tensor, cmp_max);
}
tensor.array.swap_axes(dim, D - 1);
let mut tensor = arg(tensor, cmp_max);
tensor.array.swap_axes(dim, D - 1);
tensor
}
}
fn arg<E: NdArrayElement, F, const D: usize>(
tensor: NdArrayTensor<E, D>,
dim: usize,
cmp: F,
) -> NdArrayTensor<i64, D>
where
F: Fn(&f64, &f64) -> Ordering,
{
let mut shape = tensor.shape();
let batch_size = shape.dims[dim];
let mut end = shape.dims[dim];
let batch_size = shape.dims[D - 1];
let mut end = shape.dims[D - 1];
let mut values = tensor.array.into_iter().collect::<Vec<_>>();
let mut start = 0;
@ -430,7 +445,7 @@ where
start += batch_size;
end += batch_size;
}
shape.dims[dim] = 1;
shape.dims[D - 1] = 1;
NdArrayTensor::from_data(Data::new(output, shape))
}

View File

@ -4,13 +4,46 @@ mod tests {
use burn_tensor::{Data, Tensor};
#[test]
fn test_argmax_2d() {
let data = Data::from([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]);
fn test_argmax_2d_dim0() {
let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.argmax(0);
let data_expected = Data::from([[0, 0, 1]]);
assert_eq!(data_expected, data_actual.to_data());
}
#[test]
fn test_argmin_2d_dim0() {
let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.argmin(0);
let data_expected = Data::from([[0, 1, 0]]);
assert_eq!(data_expected, data_actual.to_data());
}
#[test]
fn test_argmax_2d_dim1() {
let data = Data::from([[10.0, 11.0, 2.0], [3.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.argmax(1);
let data_expected = Data::from([[2], [2]]);
let data_expected = Data::from([[1], [2]]);
assert_eq!(data_expected, data_actual.to_data());
}
#[test]
fn test_argmin_2d_dim1() {
let data = Data::from([[10.0, 11.0, 2.0], [30.0, 4.0, 5.0]]);
let tensor = Tensor::<TestBackend, 2>::from_data(data);
let data_actual = tensor.argmin(1);
let data_expected = Data::from([[2], [1]]);
assert_eq!(data_expected, data_actual.to_data());
}
}

View File

@ -4,18 +4,20 @@ use burn_tensor::Shape;
kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl");
kernel_wgsl!(ReductionDimRaw, "../template/reduction/reduce_dim.wgsl");
kernel_wgsl!(ReductionArgsRaw, "../template/reduction/args.wgsl");
struct SumDimRaw;
pub struct ArgsMax;
pub struct ArgsMin;
pub struct SumDim;
pub struct MeanDim;
impl StaticKernel for SumDimRaw {
impl StaticKernel for SumDim {
fn source_template() -> SourceTemplate {
ReductionDimRaw::source_template().register("assign", "output[global_id.x] = sum;")
}
}
struct MeanDimRaw;
impl StaticKernel for MeanDimRaw {
impl StaticKernel for MeanDim {
fn source_template() -> SourceTemplate {
ReductionDimRaw::source_template()
.add_template(
@ -27,6 +29,22 @@ impl StaticKernel for MeanDimRaw {
}
}
impl StaticKernel for ArgsMax {
fn source_template() -> SourceTemplate {
ReductionArgsRaw::source_template()
.register("cmp", ">")
.register("initial", (-32767).to_string())
}
}
impl StaticKernel for ArgsMin {
fn source_template() -> SourceTemplate {
ReductionArgsRaw::source_template()
.register("cmp", "<")
.register("initial", 32767.to_string())
}
}
pub fn reduction_sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
const WORKGROUP: usize = 256;
@ -57,21 +75,7 @@ pub fn reduction_sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) ->
}
}
pub fn reduction_sum_dim<E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_dim::<SumDimRaw, E, D>(input, dim)
}
pub fn reduction_mean_dim<E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_dim::<MeanDimRaw, E, D>(input, dim)
}
fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
pub fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
@ -103,3 +107,36 @@ fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
output
}
pub fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D: usize>(
input: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
let mut shape_out = input.shape.clone();
shape_out.dims[dim] = 1;
let buffer = input
.context
.create_buffer(shape_out.num_elements() * core::mem::size_of::<I>());
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
let kernel = input
.context
.compile_static::<KernelSettings<K, E, I, 256, 1, 1>>();
let mut info = build_info(&[&input, &output]);
info.push(dim as u32);
let info_buffers = input
.context
.create_buffer_with_data(bytemuck::cast_slice(&info));
input.context.execute(
WorkGroup::new(
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
1,
1,
),
kernel,
&[&input.buffer, &output.buffer, &info_buffers],
);
WgpuTensor::new(output.context, output.shape, output.buffer)
}

View File

@ -52,6 +52,7 @@ mod tests {
burn_tensor::testgen_transpose!();
burn_tensor::testgen_index!();
burn_tensor::testgen_aggregation!();
burn_tensor::testgen_arg!();
burn_tensor::testgen_map_comparison!();
type TestADBackend = burn_autodiff::ADBackendDecorator<TestBackend>;

View File

@ -1,7 +1,7 @@
use std::ops::Range;
use super::numeric::NumericOps;
use super::{BaseOps, BoolTensor, Device, FloatElem, FloatTensor};
use super::{BaseOps, BoolTensor, Device, FloatElem, FloatTensor, IntTensor};
use crate::kernel::{matmul, unary, unary_inplace, unary_scalar, unary_scalar_inplace};
use crate::{
element::{FloatElement, IntElement},
@ -427,17 +427,11 @@ where
todo!()
}
fn argmax<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
fn argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::argmax(tensor, dim)
}
fn argmin<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::TensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
fn argmin<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::argmin(tensor, dim)
}
}

View File

@ -273,17 +273,11 @@ where
NumericOps::<G>::mean_dim(tensor, dim)
}
fn int_argmax<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
fn int_argmax<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::argmax(tensor, dim)
}
fn int_argmin<const D: usize>(
_tensor: <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D>,
_dim: usize,
) -> <WGPUBackend<G, F, I> as Backend>::IntTensorPrimitive<D> {
todo!()
fn int_argmin<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
NumericOps::<G>::argmin(tensor, dim)
}
}

View File

@ -1,10 +1,6 @@
use std::marker::PhantomData;
use burn_tensor::{Element, ElementConversion, Shape};
use crate::kernel::{
binary_elemwise, binary_elemwise_inplace, reduction_mean_dim, reduction_sum, reduction_sum_dim,
unary_scalar, unary_scalar_inplace,
binary_elemwise, binary_elemwise_inplace, reduction_args_dim, reduction_dim, reduction_sum,
unary_scalar, unary_scalar_inplace, ArgsMax, ArgsMin, MeanDim, SumDim,
};
use crate::pool::get_context;
use crate::{
@ -12,6 +8,8 @@ use crate::{
unary_scalar, unary_scalar_inplace,
};
use crate::{GraphicsApi, WgpuDevice};
use burn_tensor::{Element, ElementConversion, Shape};
use std::marker::PhantomData;
pub struct NumericOps<G: GraphicsApi> {
_g: PhantomData<G>,
@ -170,20 +168,34 @@ impl<G: GraphicsApi> NumericOps<G> {
tensor: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_sum_dim(tensor, dim)
reduction_dim::<SumDim, E, D>(tensor, dim)
}
pub fn mean<E: WgpuElement + Element, const D: usize>(
tensor: WgpuTensor<E, D>,
) -> WgpuTensor<E, 1> {
let num_elems = tensor.shape.num_elements();
Self::div_scalar(reduction_sum(tensor), (num_elems as f32).elem())
Self::div_scalar(Self::sum(tensor), (num_elems as f32).elem())
}
pub fn mean_dim<E: WgpuElement + Element, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<E, D> {
reduction_mean_dim(tensor, dim)
reduction_dim::<MeanDim, E, D>(tensor, dim)
}
pub fn argmax<E: WgpuElement + Element, I: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
reduction_args_dim::<ArgsMax, E, I, D>(tensor, dim)
}
pub fn argmin<E: WgpuElement + Element, I: WgpuElement, const D: usize>(
tensor: WgpuTensor<E, D>,
dim: usize,
) -> WgpuTensor<I, D> {
reduction_args_dim::<ArgsMin, E, I, D>(tensor, dim)
}
}

View File

@ -0,0 +1,56 @@
@group(0)
@binding(0)
var<storage, read> input: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ int }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32>;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
) {
let dim: u32 = info[0];
let dim_reduce = info[4u * dim + 1u];
var index_offset: u32 = 0u;
var stride_dim: u32 = 0u;
var shape_dim: u32 = 0u;
for (var i: u32 = 1u; i <= dim; i++) {
let stride_input = info[i];
let stride_output = info[i + dim];
let shape_output = info[i + 3u * dim];
let num_block = global_id.x / stride_output % shape_output;
if i - 1u != dim_reduce {
index_offset += num_block * stride_input;
} else {
let shape_input = info[i + 2u * dim];
index_offset += num_block;
stride_dim = stride_input;
shape_dim = shape_input;
}
}
var current_value = {{ elem }}({{ initial }});
var index = {{ int }}(0);
for (var i = 0u; i < shape_dim; i++) {
let index_input = i * stride_dim;
let value = input[index_input + index_offset];
if (value {{ cmp }} current_value) {
current_value = value;
index = {{ int }}(i);
}
}
output[global_id.x] = index;
}

View File

@ -10,15 +10,10 @@ var<storage, read_write> output: array<{{ elem }}>;
@binding(2)
var<storage, read> info: array<u32>;
var<workgroup> data: array<{{ elem }}, {{ workgroup_size_x }}>;
@compute
@workgroup_size({{ workgroup_size_x }}, 1, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let dim: u32 = info[0];
let dim_reduce = info[4u * dim + 1u];