mirror of https://github.com/tracel-ai/burn.git
refactor: wgpu reductions (#471)
This commit is contained in:
parent
d78f25f922
commit
04ad14a32a
|
@ -634,7 +634,10 @@ pub trait IntTensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// The mean of all elements in the tensor.
|
||||
fn int_mean<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1>;
|
||||
fn int_mean<const D: usize>(tensor: B::IntTensorPrimitive<D>) -> B::IntTensorPrimitive<1> {
|
||||
let num_elems = B::int_shape(&tensor).num_elements();
|
||||
B::int_div_scalar(B::int_sum(tensor), (num_elems as i64).elem())
|
||||
}
|
||||
|
||||
/// Computes the mean of all elements in the tensor along a dimension.
|
||||
///
|
||||
|
|
|
@ -716,7 +716,10 @@ pub trait TensorOps<B: Backend> {
|
|||
/// # Returns
|
||||
///
|
||||
/// A scalar tensor with the mean of all elements in `tensor`.
|
||||
fn mean<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1>;
|
||||
fn mean<const D: usize>(tensor: B::TensorPrimitive<D>) -> B::TensorPrimitive<1> {
|
||||
let num_elems = B::shape(&tensor).num_elements();
|
||||
B::div_scalar(B::sum(tensor), (num_elems as i64).elem())
|
||||
}
|
||||
|
||||
/// Mean of all elements in a tensor along a dimension.
|
||||
///
|
||||
|
|
|
@ -47,6 +47,12 @@ pub struct WorkGroup {
|
|||
pub z: u32,
|
||||
}
|
||||
|
||||
impl WorkGroup {
|
||||
pub fn num_invocations(&self) -> usize {
|
||||
(self.x * self.y * self.z) as usize
|
||||
}
|
||||
}
|
||||
|
||||
impl Context {
|
||||
/// Create a new context where computing tasks will be executed on the given
|
||||
/// [device](WgpuDevice).
|
||||
|
|
|
@ -99,6 +99,10 @@ impl<
|
|||
.register("workgroup_size_x", WORKGROUP_X_SIZE.to_string())
|
||||
.register("workgroup_size_y", WORKGROUP_Y_SIZE.to_string())
|
||||
.register("workgroup_size_z", WORKGROUP_Z_SIZE.to_string())
|
||||
.register(
|
||||
"workgroup_size",
|
||||
(WORKGROUP_X_SIZE * WORKGROUP_Y_SIZE * WORKGROUP_Z_SIZE).to_string(),
|
||||
)
|
||||
.register("elem", E::type_name())
|
||||
.register("int", I::type_name())
|
||||
}
|
||||
|
@ -123,6 +127,10 @@ impl<K: StaticKernel, E: WgpuElement, I: WgpuElement> DynamicKernel
|
|||
.register("workgroup_size_x", self.workgroup_x_size.to_string())
|
||||
.register("workgroup_size_y", self.workgroup_y_size.to_string())
|
||||
.register("workgroup_size_z", self.workgroup_z_size.to_string())
|
||||
.register(
|
||||
"workgroup_size",
|
||||
(self.workgroup_x_size * self.workgroup_y_size * self.workgroup_z_size).to_string(),
|
||||
)
|
||||
.register("elem", E::type_name())
|
||||
.register("int", I::type_name())
|
||||
}
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
use super::{build_info, KernelSettings, SourceTemplate, StaticKernel};
|
||||
use crate::{context::WorkGroup, element::WgpuElement, kernel_wgsl, tensor::WgpuTensor};
|
||||
use crate::{element::WgpuElement, kernel::elemwise_workgroup, kernel_wgsl, tensor::WgpuTensor};
|
||||
use burn_tensor::Shape;
|
||||
|
||||
kernel_wgsl!(RecursiveSumRaw, "../template/reduction/recursive_sum.wgsl");
|
||||
|
@ -13,7 +13,7 @@ pub struct MeanDim;
|
|||
|
||||
impl StaticKernel for SumDim {
|
||||
fn source_template() -> SourceTemplate {
|
||||
ReductionDimRaw::source_template().register("assign", "output[global_id.x] = sum;")
|
||||
ReductionDimRaw::source_template().register("assign", "output[id] = sum;")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,7 +25,7 @@ impl StaticKernel for MeanDim {
|
|||
return sum / {{ elem }}(dim);
|
||||
}",
|
||||
)
|
||||
.register("assign", "output[global_id.x] = mean_dim(sum, shape_dim);")
|
||||
.register("assign", "output[id] = mean_dim(sum, shape_dim);")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -45,50 +45,70 @@ impl StaticKernel for ArgsMin {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn reduction_sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
|
||||
const WORKGROUP: usize = 256;
|
||||
/// Sum all elements in the input buffer.
|
||||
pub fn sum<E: WgpuElement, const D: usize>(input: WgpuTensor<E, D>) -> WgpuTensor<E, 1> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut input_buffer = input.buffer;
|
||||
let mut num_invocations =
|
||||
f32::ceil(input.shape.num_elements() as f32 / WORKGROUP as f32) as usize;
|
||||
let mut workgroup = elemwise_workgroup(input.shape.num_elements(), WORKGROUP);
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, 1, 1>>();
|
||||
.compile_static::<KernelSettings<RecursiveSumRaw, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
loop {
|
||||
let num_invocations = workgroup.num_invocations();
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(core::mem::size_of::<E>() * num_invocations);
|
||||
let workgroup = WorkGroup::new(num_invocations as u32, 1, 1);
|
||||
|
||||
input
|
||||
.context
|
||||
.execute(workgroup, kernel.clone(), &[&input_buffer, &buffer]);
|
||||
.execute(workgroup.clone(), kernel.clone(), &[&input_buffer, &buffer]);
|
||||
|
||||
if num_invocations == 1 {
|
||||
if num_invocations <= 1 {
|
||||
return WgpuTensor::new(input.context, Shape::new([1]), buffer);
|
||||
}
|
||||
|
||||
input_buffer = buffer;
|
||||
num_invocations = f32::ceil(num_invocations as f32 / WORKGROUP as f32) as usize;
|
||||
workgroup = elemwise_workgroup(num_invocations, WORKGROUP);
|
||||
}
|
||||
}
|
||||
|
||||
pub fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
/// Execute the sum dim kernel.
|
||||
pub fn sum_dim<E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
reduction_dim::<SumDim, E, D>(input, dim)
|
||||
}
|
||||
|
||||
/// Execute the mean dim kernel.
|
||||
pub fn mean_dim<E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
reduction_dim::<MeanDim, E, D>(input, dim)
|
||||
}
|
||||
|
||||
fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut shape_out = input.shape.clone();
|
||||
shape_out.dims[dim] = 1;
|
||||
let num_elems = shape_out.num_elements();
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(shape_out.num_elements() * core::mem::size_of::<E>());
|
||||
.create_buffer(num_elems * core::mem::size_of::<E>());
|
||||
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, i32, 256, 1, 1>>();
|
||||
.compile_static::<KernelSettings<K, E, i32, WORKGROUP, WORKGROUP, 1>>();
|
||||
|
||||
let mut info = build_info(&[&input, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_buffers = input
|
||||
|
@ -96,11 +116,7 @@ pub fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
|
|||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
input.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&input.buffer, &output.buffer, &info_buffers],
|
||||
);
|
||||
|
@ -108,20 +124,39 @@ pub fn reduction_dim<K: StaticKernel, E: WgpuElement, const D: usize>(
|
|||
output
|
||||
}
|
||||
|
||||
pub fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
/// Execute the argmax kernel.
|
||||
pub fn argmax<E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
reduction_args_dim::<ArgsMax, E, I, D>(input, dim)
|
||||
}
|
||||
|
||||
/// Execute the argmin kernel.
|
||||
pub fn argmin<E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
reduction_args_dim::<ArgsMin, E, I, D>(input, dim)
|
||||
}
|
||||
|
||||
fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const D: usize>(
|
||||
input: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
const WORKGROUP: usize = 32;
|
||||
|
||||
let mut shape_out = input.shape.clone();
|
||||
shape_out.dims[dim] = 1;
|
||||
let num_elems = shape_out.num_elements();
|
||||
let buffer = input
|
||||
.context
|
||||
.create_buffer(shape_out.num_elements() * core::mem::size_of::<I>());
|
||||
.create_buffer(num_elems * core::mem::size_of::<I>());
|
||||
let output = WgpuTensor::new(input.context.clone(), shape_out, buffer);
|
||||
|
||||
let kernel = input
|
||||
.context
|
||||
.compile_static::<KernelSettings<K, E, I, 256, 1, 1>>();
|
||||
.compile_static::<KernelSettings<K, E, I, WORKGROUP, WORKGROUP, 1>>();
|
||||
let mut info = build_info(&[&input, &output]);
|
||||
info.push(dim as u32);
|
||||
let info_buffers = input
|
||||
|
@ -129,14 +164,53 @@ pub fn reduction_args_dim<K: StaticKernel, E: WgpuElement, I: WgpuElement, const
|
|||
.create_buffer_with_data(bytemuck::cast_slice(&info));
|
||||
|
||||
input.context.execute(
|
||||
WorkGroup::new(
|
||||
f32::ceil(output.shape.num_elements() as f32 / 256_f32) as u32,
|
||||
1,
|
||||
1,
|
||||
),
|
||||
elemwise_workgroup(num_elems, WORKGROUP),
|
||||
kernel,
|
||||
&[&input.buffer, &output.buffer, &info_buffers],
|
||||
);
|
||||
|
||||
WgpuTensor::new(output.context, output.shape, output.buffer)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::tests::{ReferenceBackend, TestBackend};
|
||||
use burn_tensor::{Distribution, Int, Tensor};
|
||||
|
||||
#[test]
|
||||
fn reduction_sum_should_work_with_multiple_invocations() {
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 256], Distribution::Default);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let val = Tensor::<TestBackend, 1>::from_primitive(sum(tensor.into_primitive()));
|
||||
let val_ref = tensor_ref.sum();
|
||||
|
||||
val_ref.into_data().assert_approx_eq(&val.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_sum_dim_should_work_with_multiple_invocations() {
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let val = Tensor::<TestBackend, 2>::from_primitive(reduction_dim::<SumDim, f32, 2>(
|
||||
tensor.into_primitive(),
|
||||
1,
|
||||
));
|
||||
let val_ref = tensor_ref.sum_dim(1);
|
||||
|
||||
val_ref.into_data().assert_approx_eq(&val.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reduction_args_dim_should_work_with_multiple_invocations() {
|
||||
let tensor = Tensor::<TestBackend, 2>::random([6, 1024], Distribution::Default);
|
||||
let tensor_ref = Tensor::<ReferenceBackend, 2>::from_data(tensor.to_data());
|
||||
|
||||
let val = Tensor::<TestBackend, 2, Int>::from_primitive(argmax(tensor.into_primitive(), 1));
|
||||
let val_ref = tensor_ref.argmax(1);
|
||||
|
||||
assert_eq!(val_ref.into_data().convert(), val.into_data());
|
||||
}
|
||||
}
|
||||
|
|
|
@ -295,19 +295,15 @@ where
|
|||
}
|
||||
|
||||
fn sum<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
|
||||
NumericOps::<G>::sum(tensor)
|
||||
kernel::sum(tensor)
|
||||
}
|
||||
|
||||
fn sum_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
|
||||
NumericOps::<G>::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn mean<const D: usize>(tensor: FloatTensor<Self, D>) -> FloatTensor<Self, 1> {
|
||||
NumericOps::<G>::mean(tensor)
|
||||
kernel::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn mean_dim<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> FloatTensor<Self, D> {
|
||||
NumericOps::<G>::mean_dim(tensor, dim)
|
||||
kernel::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn to_full_precision<const D: usize>(
|
||||
|
@ -427,10 +423,10 @@ where
|
|||
}
|
||||
|
||||
fn argmax<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::argmax(tensor, dim)
|
||||
kernel::argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn argmin<const D: usize>(tensor: FloatTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::argmin(tensor, dim)
|
||||
kernel::argmin(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use super::{numeric::NumericOps, BaseOps, BoolTensor, Device, IntElem, IntTensor};
|
||||
use crate::{
|
||||
element::{FloatElement, IntElement},
|
||||
kernel, GraphicsApi, WgpuBackend,
|
||||
|
@ -5,8 +6,6 @@ use crate::{
|
|||
use burn_tensor::{ops::IntTensorOps, Data, Shape};
|
||||
use std::ops::Range;
|
||||
|
||||
use super::{numeric::NumericOps, BaseOps, BoolTensor, Device, IntElem, IntTensor};
|
||||
|
||||
impl<G, F, I> IntTensorOps<WgpuBackend<G, F, I>> for WgpuBackend<G, F, I>
|
||||
where
|
||||
G: GraphicsApi + 'static,
|
||||
|
@ -254,25 +253,22 @@ where
|
|||
}
|
||||
|
||||
fn int_sum<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
NumericOps::<G>::sum(tensor)
|
||||
kernel::sum(tensor)
|
||||
}
|
||||
|
||||
fn int_sum_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::sum_dim(tensor, dim)
|
||||
}
|
||||
fn int_mean<const D: usize>(tensor: IntTensor<Self, D>) -> IntTensor<Self, 1> {
|
||||
NumericOps::<G>::mean(tensor)
|
||||
kernel::sum_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_mean_dim<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::mean_dim(tensor, dim)
|
||||
kernel::mean_dim(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_argmax<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::argmax(tensor, dim)
|
||||
kernel::argmax(tensor, dim)
|
||||
}
|
||||
|
||||
fn int_argmin<const D: usize>(tensor: IntTensor<Self, D>, dim: usize) -> IntTensor<Self, D> {
|
||||
NumericOps::<G>::argmin(tensor, dim)
|
||||
kernel::argmin(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
use crate::kernel::{
|
||||
binary_elemwise_default, binary_elemwise_inplace_default, reduction_args_dim, reduction_dim,
|
||||
reduction_sum, unary_scalar_default, unary_scalar_inplace_default, ArgsMax, ArgsMin, MeanDim,
|
||||
SumDim,
|
||||
binary_elemwise_default, binary_elemwise_inplace_default, unary_scalar_default,
|
||||
unary_scalar_inplace_default,
|
||||
};
|
||||
use crate::pool::get_context;
|
||||
use crate::{
|
||||
|
@ -158,45 +157,4 @@ impl<G: GraphicsApi> NumericOps<G> {
|
|||
|
||||
unary_scalar_default::<DivScalar, E, D>(lhs, rhs)
|
||||
}
|
||||
|
||||
pub fn sum<E: WgpuElement + Element, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, 1> {
|
||||
reduction_sum(tensor)
|
||||
}
|
||||
|
||||
pub fn sum_dim<E: WgpuElement + Element, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
reduction_dim::<SumDim, E, D>(tensor, dim)
|
||||
}
|
||||
|
||||
pub fn mean<E: WgpuElement + Element, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
) -> WgpuTensor<E, 1> {
|
||||
let num_elems = tensor.shape.num_elements();
|
||||
Self::div_scalar(Self::sum(tensor), (num_elems as f32).elem())
|
||||
}
|
||||
|
||||
pub fn mean_dim<E: WgpuElement + Element, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<E, D> {
|
||||
reduction_dim::<MeanDim, E, D>(tensor, dim)
|
||||
}
|
||||
|
||||
pub fn argmax<E: WgpuElement + Element, I: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
reduction_args_dim::<ArgsMax, E, I, D>(tensor, dim)
|
||||
}
|
||||
|
||||
pub fn argmin<E: WgpuElement + Element, I: WgpuElement, const D: usize>(
|
||||
tensor: WgpuTensor<E, D>,
|
||||
dim: usize,
|
||||
) -> WgpuTensor<I, D> {
|
||||
reduction_args_dim::<ArgsMin, E, I, D>(tensor, dim)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,11 +10,15 @@ var<storage, read_write> output: array<{{ int }}>;
|
|||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim: u32 = info[0];
|
||||
let dim_reduce = info[4u * dim + 1u];
|
||||
var index_offset: u32 = 0u;
|
||||
|
@ -26,7 +30,7 @@ fn main(
|
|||
let stride_output = info[i + dim];
|
||||
let shape_output = info[i + 3u * dim];
|
||||
|
||||
let num_block = global_id.x / stride_output % shape_output;
|
||||
let num_block = id / stride_output % shape_output;
|
||||
|
||||
if i - 1u != dim_reduce {
|
||||
index_offset += num_block * stride_input;
|
||||
|
@ -52,5 +56,5 @@ fn main(
|
|||
}
|
||||
}
|
||||
|
||||
output[global_id.x] = index;
|
||||
output[id] = index;
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
const BLOCK_SIZE = {{ workgroup_size_x }}u;
|
||||
const WORKGROUP_SIZE = {{ workgroup_size }}u;
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@group(0)
|
||||
@binding(0)
|
||||
|
@ -8,26 +9,30 @@ var<storage, read> input: array<{{ elem }}>;
|
|||
@binding(1)
|
||||
var<storage, read_write> output: array<{{ elem }}>;
|
||||
|
||||
var<workgroup> data: array<{{ elem }}, BLOCK_SIZE>;
|
||||
var<workgroup> data: array<{{ elem }}, WORKGROUP_SIZE>;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(local_invocation_id) local_id: vec3<u32>,
|
||||
@builtin(workgroup_id) workgroup_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
data[local_id.x] = input[global_id.x];
|
||||
let id_global = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let id_local = local_id.y * WORKGROUP_SIZE_X + local_id.x;
|
||||
|
||||
data[id_local] = input[id_global];
|
||||
|
||||
workgroupBarrier();
|
||||
|
||||
if local_id.x == 0u {
|
||||
if id_local == 0u {
|
||||
var sum = {{ elem }}(0);
|
||||
for (var i: u32 = 0u; i < BLOCK_SIZE; i++) {
|
||||
for (var i: u32 = 0u; i < WORKGROUP_SIZE; i++) {
|
||||
sum += data[i];
|
||||
}
|
||||
|
||||
output[workgroup_id.x] = sum;
|
||||
let id_output = workgroup_id.y * num_workgroups.x + workgroup_id.x;
|
||||
output[id_output] = sum;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -10,11 +10,15 @@ var<storage, read_write> output: array<{{ elem }}>;
|
|||
@binding(2)
|
||||
var<storage, read> info: array<u32>;
|
||||
|
||||
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
|
||||
|
||||
@compute
|
||||
@workgroup_size({{ workgroup_size_x }}, 1, 1)
|
||||
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
|
||||
fn main(
|
||||
@builtin(global_invocation_id) global_id: vec3<u32>,
|
||||
@builtin(num_workgroups) num_workgroups: vec3<u32>,
|
||||
) {
|
||||
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
|
||||
let dim: u32 = info[0];
|
||||
let dim_reduce = info[4u * dim + 1u];
|
||||
var index_offset: u32 = 0u;
|
||||
|
@ -26,7 +30,7 @@ fn main(
|
|||
let stride_output = info[i + dim];
|
||||
let shape_output = info[i + 3u * dim];
|
||||
|
||||
let num_block = global_id.x / stride_output % shape_output;
|
||||
let num_block = id / stride_output % shape_output;
|
||||
|
||||
if i - 1u != dim_reduce {
|
||||
index_offset += num_block * stride_input;
|
||||
|
|
Loading…
Reference in New Issue