delete some wgsl

This commit is contained in:
louisfd 2024-03-22 08:38:16 -04:00
parent 86dbd333c2
commit ad7d41fdf1
6 changed files with 0 additions and 1055 deletions

View File

@ -1,100 +0,0 @@
// use crate::{
// compute::StaticKernel,
// element::JitElement,
// kernel::{elemwise_workgroup, KernelSettings, WORKGROUP_DEFAULT},
// kernel_wgsl,
// ops::numeric::empty_device,
// tensor::JitTensor,
// Runtime,
// };
// use burn_compute::server::Handle;
// use burn_tensor::Shape;
// kernel_wgsl!(
// AdaptiveAvgPool2d,
// "../../template/pool/adaptive_avg_pool2d.wgsl"
// );
// kernel_wgsl!(
// AdaptiveAvgPool2dBackward,
// "../../template/pool/adaptive_avg_pool2d_backward.wgsl"
// );
// pub(crate) fn adaptive_avg_pool2d<R: Runtime, E: JitElement>(
// x: JitTensor<R, E, 4>,
// output_size: [usize; 2],
// ) -> JitTensor<R, E, 4> {
// let [batch_size, channels, _, _] = x.shape.dims;
// let output_shape = Shape::new([batch_size, channels, output_size[0], output_size[1]]);
// let output = empty_device(x.client.clone(), x.device.clone(), output_shape);
// let kernel = StaticKernel::<
// KernelSettings<AdaptiveAvgPool2d, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
// >::new(elemwise_workgroup(
// output.shape.num_elements(),
// WORKGROUP_DEFAULT,
// ));
// let info_handle = build_info(&x, &output);
// x.client
// .execute(Box::new(kernel), &[&x.handle, &output.handle, &info_handle]);
// output
// }
// pub(crate) fn adaptive_avg_pool2d_backward<R: Runtime, E: JitElement>(
// x: JitTensor<R, E, 4>,
// out_grad: JitTensor<R, E, 4>,
// ) -> JitTensor<R, E, 4> {
// let output_shape = x.shape.clone();
// let num_elems = output_shape.num_elements();
// let output_buffer = x.client.empty(num_elems * core::mem::size_of::<E>());
// let output = JitTensor::new(
// x.client.clone(),
// x.device.clone(),
// output_shape,
// output_buffer,
// );
// let kernel = StaticKernel::<
// KernelSettings<AdaptiveAvgPool2dBackward, E, i32, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT, 1>,
// >::new(elemwise_workgroup(
// output.shape.num_elements(),
// WORKGROUP_DEFAULT,
// ));
// let info_handle = build_info(&x, &out_grad);
// x.client.execute(
// Box::new(kernel),
// &[&out_grad.handle, &output.handle, &info_handle],
// );
// output
// }
// fn build_info<R: Runtime, E: JitElement>(
// x: &JitTensor<R, E, 4>,
// output: &JitTensor<R, E, 4>,
// ) -> Handle<R::Server> {
// let mut info: [u32; 16] = [0; 16];
// info[0] = x.strides[0] as u32;
// info[1] = x.strides[1] as u32;
// info[2] = x.strides[2] as u32;
// info[3] = x.strides[3] as u32;
// info[4] = x.shape.dims[0] as u32;
// info[5] = x.shape.dims[1] as u32;
// info[6] = x.shape.dims[2] as u32;
// info[7] = x.shape.dims[3] as u32;
// info[8] = output.strides[0] as u32;
// info[9] = output.strides[1] as u32;
// info[10] = output.strides[2] as u32;
// info[11] = output.strides[3] as u32;
// info[12] = output.shape.dims[0] as u32;
// info[13] = output.shape.dims[1] as u32;
// info[14] = output.shape.dims[2] as u32;
// info[15] = output.shape.dims[3] as u32;
// output.client.create(bytemuck::cast_slice(&info))
// }

View File

@ -1,403 +0,0 @@
// use burn_tensor::{ops::conv::calculate_pool_output_size, ElementConversion, Shape};
// use std::marker::PhantomData;
// use crate::{
// codegen::{
// dialect::gpu::{gpu, Elem, Item, Scope, Variable, Visibility},
// execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
// InputInfo, OutputInfo, WorkgroupLaunch,
// },
// element::JitElement,
// kernel::{DynamicKernelSource, SourceTemplate},
// ops::numeric::empty_device,
// tensor::JitTensor,
// Runtime, RuntimeInt,
// };
// #[derive(new)]
// struct MaxPool2dEagerKernel<R: Runtime, E: JitElement> {
// kernel_size: [usize; 2],
// _runtime: PhantomData<R>,
// _elem: PhantomData<E>,
// }
// #[derive(new)]
// struct MaxPool2dWithIndicesEagerKernel<R: Runtime, E: JitElement> {
// kernel_size: [usize; 2],
// _runtime: PhantomData<R>,
// _elem: PhantomData<E>,
// }
// struct MaxPool2dComputeShader<E: JitElement> {
// x: Variable,
// output: Variable,
// kernel_size: [usize; 2],
// indices: Option<Variable>,
// _elem: PhantomData<E>,
// }
// impl<E: JitElement> MaxPool2dComputeShader<E> {
// fn expand(self, scope: &mut Scope) {
// let x = self.x;
// let output = self.output;
// let id = Variable::Id;
// let input_stride_0 = scope.create_local(Elem::UInt);
// let input_stride_1 = scope.create_local(Elem::UInt);
// let input_stride_2 = scope.create_local(Elem::UInt);
// let input_stride_3 = scope.create_local(Elem::UInt);
// let input_shape_2 = scope.create_local(Elem::UInt);
// let input_shape_3 = scope.create_local(Elem::UInt);
// let output_stride_0 = scope.create_local(Elem::UInt);
// let output_stride_1 = scope.create_local(Elem::UInt);
// let output_stride_2 = scope.create_local(Elem::UInt);
// let output_stride_3 = scope.create_local(Elem::UInt);
// let output_shape_0 = scope.create_local(Elem::UInt);
// let output_shape_1 = scope.create_local(Elem::UInt);
// let output_shape_2 = scope.create_local(Elem::UInt);
// let output_shape_3 = scope.create_local(Elem::UInt);
// gpu!(scope, input_stride_0 = stride(x, 0u32));
// gpu!(scope, input_stride_1 = stride(x, 1u32));
// gpu!(scope, input_stride_2 = stride(x, 2u32));
// gpu!(scope, input_stride_3 = stride(x, 3u32));
// gpu!(scope, input_shape_2 = shape(x, 2u32));
// gpu!(scope, input_shape_3 = shape(x, 3u32));
// gpu!(scope, output_stride_0 = stride(output, 0u32));
// gpu!(scope, output_stride_1 = stride(output, 1u32));
// gpu!(scope, output_stride_2 = stride(output, 2u32));
// gpu!(scope, output_stride_3 = stride(output, 3u32));
// gpu!(scope, output_shape_0 = shape(output, 0u32));
// gpu!(scope, output_shape_1 = shape(output, 1u32));
// gpu!(scope, output_shape_2 = shape(output, 2u32));
// gpu!(scope, output_shape_3 = shape(output, 3u32));
// let pool_stride_0 = Variable::GlobalScalar(0, Elem::UInt);
// let pool_stride_1 = Variable::GlobalScalar(1, Elem::UInt);
// let dilation_0 = Variable::GlobalScalar(2, Elem::UInt);
// let dilation_1 = Variable::GlobalScalar(3, Elem::UInt);
// let padding_0 = Variable::GlobalScalar(4, Elem::UInt);
// let padding_1 = Variable::GlobalScalar(5, Elem::UInt);
// let [kernel_size_0, kernel_size_1] = self.kernel_size;
// let b = scope.create_local(Elem::UInt);
// let c = scope.create_local(Elem::UInt);
// let oh = scope.create_local(Elem::UInt);
// let ow = scope.create_local(Elem::UInt);
// gpu!(scope, b = id / output_stride_0);
// gpu!(scope, b = b % output_shape_0);
// gpu!(scope, c = id / output_stride_1);
// gpu!(scope, c = c % output_shape_1);
// gpu!(scope, oh = id / output_stride_2);
// gpu!(scope, oh = oh % output_shape_2);
// gpu!(scope, ow = id / output_stride_3);
// gpu!(scope, ow = ow % output_shape_3);
// let tmp = scope.create_local(Elem::UInt);
// let ih = scope.create_local(Elem::UInt);
// let iw = scope.create_local(Elem::UInt);
// let ih_pad = scope.create_local(Elem::UInt);
// let iw_pad = scope.create_local(Elem::UInt);
// let result = scope.create_local(x.item());
// let cond = scope.create_local(Elem::Bool);
// let cond_tmp = scope.create_local(Elem::Bool);
// let index_input = scope.create_local(Elem::UInt);
// let index_input_1 = scope.create_local(Elem::UInt);
// let index_input_2 = scope.create_local(Elem::UInt);
// let index_input_3 = scope.create_local(Elem::UInt);
// let index_input_4 = scope.create_local(Elem::UInt);
// let is_max = scope.create_local(Elem::Bool);
// let max_index = self.indices.map(|_| scope.create_local(Elem::UInt));
// let max_val = scope.create_local(x.item());
// let max_initial =
// Variable::ConstantScalar(E::minimum_value().to_f64().unwrap(), x.item().elem());
// gpu!(scope, max_val = max_initial);
// (0..kernel_size_0).for_each(|kh| {
// gpu!(scope, ih = oh * pool_stride_0);
// gpu!(scope, tmp = kh * dilation_0);
// gpu!(scope, ih += tmp);
// // Up
// gpu!(scope, cond = ih < padding_0);
// // Down
// gpu!(scope, tmp = input_shape_2 + padding_0);
// gpu!(scope, cond_tmp = ih >= tmp);
// gpu!(scope, cond = cond || cond_tmp);
// gpu!(scope, cond = !cond);
// gpu!(scope, if (cond).then(|scope| {
// (0..kernel_size_1).for_each(|kw| {
// gpu!(scope, iw = ow * pool_stride_1);
// gpu!(scope, tmp = kw * dilation_1);
// gpu!(scope, iw = iw + tmp);
// // Left
// gpu!(scope, cond = iw < padding_1);
// // Right
// gpu!(scope, tmp = input_shape_3 + padding_1);
// gpu!(scope, cond_tmp = iw >= tmp);
// gpu!(scope, cond = cond || cond_tmp);
// gpu!(scope, cond = !cond);
// gpu!(scope, if (cond).then(|scope| {
// gpu!(scope, ih_pad = ih - padding_0);
// gpu!(scope, iw_pad = iw - padding_1);
// gpu!(scope, index_input_1 = b * input_stride_0);
// gpu!(scope, index_input_2 = c * input_stride_1);
// gpu!(scope, index_input_3 = ih_pad * input_stride_2);
// gpu!(scope, index_input_4 = iw_pad * input_stride_3);
// gpu!(scope, index_input = index_input_1);
// gpu!(scope, index_input += index_input_2);
// gpu!(scope, index_input += index_input_3);
// gpu!(scope, index_input += index_input_4);
// gpu!(scope, result = x[index_input]);
// gpu!(scope, is_max = result > max_val);
// gpu!(scope, if(is_max).then(|scope|{
// gpu!(scope, max_val = result);
// if let Some(max_index) = max_index {
// gpu!(scope, max_index = ih_pad * input_shape_2);
// gpu!(scope, max_index += iw_pad);
// }
// }));
// }));
// });
// }));
// });
// gpu!(scope, output[id] = max_val);
// if let Some(indices) = self.indices {
// let max_index = max_index.unwrap();
// gpu!(scope, indices[id] = max_index);
// }
// }
// }
// impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dEagerKernel<R, E> {
// fn source(&self) -> crate::kernel::SourceTemplate {
// let mut scope = Scope::root();
// let item = E::gpu_elem().into();
// let x = Variable::GlobalInputArray(0, item);
// let output = Variable::GlobalOutputArray(0, item);
// scope.write_global_custom(output);
// MaxPool2dComputeShader {
// x,
// output,
// kernel_size: self.kernel_size,
// indices: None,
// _elem: PhantomData::<E>,
// }
// .expand(&mut scope);
// let input = InputInfo::Array {
// item,
// visibility: Visibility::Read,
// };
// let scalars = InputInfo::Scalar {
// elem: Elem::UInt,
// size: 6,
// };
// let output = OutputInfo::Array { item };
// let info = CompilationInfo {
// inputs: vec![input, scalars],
// outputs: vec![output],
// scope,
// };
// let settings = CompilationSettings::default();
// let shader = Compilation::new(info).compile(settings);
// let shader = <R::Compiler as Compiler>::compile(shader);
// SourceTemplate::new(shader.to_string())
// }
// fn id(&self) -> String {
// format!(
// "{:?}k={:?}",
// core::any::TypeId::of::<Self>(),
// self.kernel_size,
// )
// }
// }
// impl<R: Runtime, E: JitElement> DynamicKernelSource for MaxPool2dWithIndicesEagerKernel<R, E> {
// fn source(&self) -> crate::kernel::SourceTemplate {
// let mut scope = Scope::root();
// let item = E::gpu_elem().into();
// let x = Variable::GlobalInputArray(0, item);
// let output = Variable::GlobalOutputArray(0, item);
// let indices = Variable::GlobalOutputArray(1, Item::Scalar(Elem::Int));
// scope.write_global_custom(output);
// MaxPool2dComputeShader {
// x,
// output,
// kernel_size: self.kernel_size,
// indices: Some(indices),
// _elem: PhantomData::<E>,
// }
// .expand(&mut scope);
// let input = InputInfo::Array {
// item,
// visibility: Visibility::Read,
// };
// let scalars = InputInfo::Scalar {
// elem: Elem::UInt,
// size: 6,
// };
// let output = OutputInfo::Array { item };
// let indices = OutputInfo::Array {
// item: Item::Scalar(Elem::Int),
// };
// let info = CompilationInfo {
// inputs: vec![input, scalars],
// outputs: vec![output, indices],
// scope,
// };
// let settings = CompilationSettings::default();
// let shader = Compilation::new(info).compile(settings);
// let shader = <R::Compiler as Compiler>::compile(shader);
// SourceTemplate::new(shader.to_string())
// }
// fn id(&self) -> String {
// format!(
// "{:?}k={:?}",
// core::any::TypeId::of::<Self>(),
// self.kernel_size,
// )
// }
// }
pub(crate) fn max_pool2d_with_indices<R: Runtime, E: JitElement, I: JitElement>(
x: JitTensor<R, E, 4>,
kernel_size: [usize; 2],
stride: [usize; 2],
padding: [usize; 2],
dilation: [usize; 2],
) -> (JitTensor<R, E, 4>, JitTensor<R, I, 4>) {
let [batch_size, channels, _, _] = x.shape.dims;
let size_0 = calculate_pool_output_size(
kernel_size[0],
stride[0],
padding[0],
dilation[0],
x.shape.dims[2],
);
let size_1 = calculate_pool_output_size(
kernel_size[1],
stride[1],
padding[1],
dilation[1],
x.shape.dims[3],
);
let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
let output = empty_device(x.client.clone(), x.device.clone(), shape_out.clone());
let indices = empty_device(x.client.clone(), x.device.clone(), shape_out);
let kernel = MaxPool2dWithIndicesEagerKernel::new(kernel_size);
execute_dynamic::<R, MaxPool2dWithIndicesEagerKernel<R, E>, I>(
&[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
&[
EagerHandle::new(&output.handle, &output.strides, &output.shape.dims),
EagerHandle::new(&indices.handle, &indices.strides, &indices.shape.dims),
],
Some(&[
(stride[0] as i32).elem(),
(stride[1] as i32).elem(),
(dilation[0] as i32).elem(),
(dilation[1] as i32).elem(),
(padding[0] as i32).elem(),
(padding[1] as i32).elem(),
]),
kernel,
WorkgroupLaunch::Output { pos: 0 },
x.client,
);
(output, indices)
}
// pub(crate) fn max_pool2d<R: Runtime, E: JitElement>(
// x: JitTensor<R, E, 4>,
// kernel_size: [usize; 2],
// stride: [usize; 2],
// padding: [usize; 2],
// dilation: [usize; 2],
// ) -> JitTensor<R, E, 4> {
// let [batch_size, channels, _, _] = x.shape.dims;
// let size_0 = calculate_pool_output_size(
// kernel_size[0],
// stride[0],
// padding[0],
// dilation[0],
// x.shape.dims[2],
// );
// let size_1 = calculate_pool_output_size(
// kernel_size[1],
// stride[1],
// padding[1],
// dilation[1],
// x.shape.dims[3],
// );
// let shape_out = Shape::new([batch_size, channels, size_0, size_1]);
// let output = empty_device(x.client.clone(), x.device.clone(), shape_out);
// let kernel = MaxPool2dEagerKernel::new(kernel_size);
// execute_dynamic::<R, MaxPool2dEagerKernel<R, E>, RuntimeInt<R>>(
// &[EagerHandle::new(&x.handle, &x.strides, &x.shape.dims)],
// &[EagerHandle::new(
// &output.handle,
// &output.strides,
// &output.shape.dims,
// )],
// Some(&[
// (stride[0] as i32).elem(),
// (stride[1] as i32).elem(),
// (dilation[0] as i32).elem(),
// (dilation[1] as i32).elem(),
// (padding[0] as i32).elem(),
// (padding[1] as i32).elem(),
// ]),
// kernel,
// WorkgroupLaunch::Output { pos: 0 },
// x.client,
// );
// output
// }

View File

@ -1,243 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const B_M = {{b_m}}u;
const B_N = {{b_n}}u;
const B_K = {{b_k}}u;
const B_M_X_B_K_4 = {{bm_x_bk_4}}u;
const B_K_X_B_N_4 = {{bk_x_bn_4}}u;
const T_M = 4u;
const T_N = 4u;
const T_M_X_T_N = 16u;
var<workgroup> shared_lhs: array<vec4<{{ elem }}>, B_M_X_B_K_4>;
var<workgroup> shared_rhs: array<vec4<{{ elem }}>, B_K_X_B_N_4>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let skip_row = workgroup_id.x * B_M;
let skip_col = workgroup_id.y * B_N;
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
// Position of the first element of the thread, relative to the block
let thread_row = (local_idx / n_thread_per_row) * T_M;
let thread_col = (local_idx % n_thread_per_row) * T_N;
// Position of the first element of the thread, in absolute (in one batch)
let row = skip_row + thread_row;
let col = skip_col + thread_col;
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Row / col strides
let lhs_stride_row = info[dim - 1u];
let lhs_stride_col = info[dim];
let rhs_stride_row = info[2u * dim - 1u];
let rhs_stride_col = info[2u * dim];
let out_stride_row = info [3u * dim - 1u];
let out_stride_col = info [3u * dim];
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = skip_row * lhs_stride_row;
var offset_rhs: u32 = skip_col * rhs_stride_col;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
// Registers used in the compute pass
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: vec4<{{ elem }}>;
var register_N: vec4<{{ elem }}>;
// How close is the thread to the end of the matrix.
// If < 4 then it is an edge case
let remain_row_lhs = n_rows - row;
let remain_col_rhs = n_cols - col;
for (var k = 0u; k < K; k += B_K) {
// LHS LOAD PASS
// For the 4 vec4 columns of this thread
for (var j = 0u; j < 4u; j++) {
// The precise
let current_col = thread_col + j;
// Position of the column vec4 in shared memory
let lhs_sm_position = (thread_row/4u) * B_K + current_col;
// To avoid overwriting following row in share memory
if current_col < B_K {
// To pad with zeros if outside lhs
if current_col + k < K && remain_row_lhs >= 1u {
let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row;
let lhs_position1 = lhs_position0 + lhs_stride_row;
let lhs_position2 = lhs_position1 + lhs_stride_row;
let lhs_position3 = lhs_position2 + lhs_stride_row;
if remain_row_lhs >= 4u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
lhs[lhs_position2],
lhs[lhs_position3],
);
} else if remain_row_lhs == 3u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
lhs[lhs_position2],
0.
);
} else if remain_row_lhs == 2u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
0.,
0.
);
} else if remain_row_lhs == 1u {
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
0.,
0.,
0.
);
}
} else {
shared_lhs[lhs_sm_position] = vec4(0.,0.,0.,0.);
}
}
}
// RHS LOAD PASS
for (var i = 0u; i < 4u; i++) {
let current_row = thread_row + i;
let rhs_sm_position = (current_row * B_N + thread_col) / 4u;
if current_row < B_K {
if current_row + k < K && remain_col_rhs >= 1u {
let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col;
let rhs_position1 = rhs_position0 + rhs_stride_col;
let rhs_position2 = rhs_position1 + rhs_stride_col;
let rhs_position3 = rhs_position2 + rhs_stride_col;
if remain_col_rhs >= 4u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
rhs[rhs_position2],
rhs[rhs_position3],
);
} else if remain_col_rhs == 3u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
rhs[rhs_position2],
0.
);
} else if remain_col_rhs == 2u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
0.,
0.
);
} else if remain_col_rhs == 1u {
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
0.,
0.,
0.
);
}
} else {
shared_rhs[rhs_sm_position] = vec4(0.,0.,0.,0.);
}
}
}
workgroupBarrier();
// COMPUTE PASS
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
let lhs_sm_position = (thread_row/4u) * B_K + dot_index;
register_M = shared_lhs[lhs_sm_position];
// Load a subrow of values from rhs
let rhs_sm_position = (dot_index * B_N + thread_col) / 4u;
register_N = shared_rhs[rhs_sm_position];
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
workgroupBarrier();
}
// OUTPUT PASS
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
let row_index = row + res_idx_M;
let col_index = col + res_idx_N;
if row_index < n_rows && col_index < n_cols {
let result_position = res_idx_M * T_N + res_idx_N;
let output_position = offset_output + row_index * out_stride_row + col_index * out_stride_col;
output[output_position] = results[result_position];
}
}
}
}

View File

@ -1,165 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const B_M = {{b_m}}u;
const B_N = {{b_n}}u;
const B_K = {{b_k}}u;
const B_M_X_B_K_4 = {{bm_x_bk_4}}u;
const B_K_X_B_N_4 = {{bk_x_bn_4}}u;
const T_M = 4u;
const T_N = 4u;
const T_M_X_T_N = 16u;
var<workgroup> shared_lhs: array<vec4<{{ elem }}>, B_M_X_B_K_4>;
var<workgroup> shared_rhs: array<vec4<{{ elem }}>, B_K_X_B_N_4>;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
let skip_row = workgroup_id.x * B_M;
let skip_col = workgroup_id.y * B_N;
let n_thread_per_row = ((B_N - 1u) / T_N) + 1u;
let thread_row = (local_idx / n_thread_per_row) * T_M;
let thread_col = (local_idx % n_thread_per_row) * T_N;
let row = skip_row + thread_row;
let col = skip_col + thread_col;
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Row / col strides
let lhs_stride_row = info[dim - 1u];
let lhs_stride_col = info[dim];
let rhs_stride_row = info[2u * dim - 1u];
let rhs_stride_col = info[2u * dim];
let out_stride_row = info [3u * dim - 1u];
let out_stride_col = info [3u * dim];
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = skip_row * lhs_stride_row;
var offset_rhs: u32 = skip_col * rhs_stride_col;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
var results: array<{{ elem }}, T_M_X_T_N>;
var register_M: vec4<{{ elem }}>;
var register_N: vec4<{{ elem }}>;
for (var k = 0u; k < K; k += B_K) {
// Load data into shared memories
// Each thread is responsible of loading T_M x T_N values from both lhs and rhs
for (var j = 0u; j < 4u; j++) {
let current_col = thread_col + j;
if current_col < B_K { // so that threads who work on between B_K and B_N store nothing
let lhs_sm_position = (thread_row/4u) * B_K + current_col;
let lhs_position0 = offset_lhs + (k + current_col) * lhs_stride_col + thread_row * lhs_stride_row;
let lhs_position1 = lhs_position0 + lhs_stride_row;
let lhs_position2 = lhs_position1 + lhs_stride_row;
let lhs_position3 = lhs_position2 + lhs_stride_row;
shared_lhs[lhs_sm_position] = vec4(
lhs[lhs_position0],
lhs[lhs_position1],
lhs[lhs_position2],
lhs[lhs_position3],
);
}
}
for (var i = 0u; i < 4u; i++) {
let current_row = thread_row + i;
if current_row < B_K { // so that threads who work on between B_K and B_N store nothing
let rhs_sm_position = (current_row * B_N + thread_col) / 4u;
let rhs_position0 = offset_rhs + (k + current_row) * rhs_stride_row + thread_col * rhs_stride_col;
let rhs_position1 = rhs_position0 + rhs_stride_col;
let rhs_position2 = rhs_position1 + rhs_stride_col;
let rhs_position3 = rhs_position2 + rhs_stride_col;
shared_rhs[rhs_sm_position] = vec4(
rhs[rhs_position0],
rhs[rhs_position1],
rhs[rhs_position2],
rhs[rhs_position3],
);
}
}
workgroupBarrier();
// Compute intermediate results
// Results are cumulated in results array and updated at each block
// Outer loop indicates which subcolumns/subrows to read from shared memories
for (var dot_index = 0u; dot_index < B_K; dot_index++) {
// Load a subcolumn of values from lhs
let lhs_sm_position = (thread_row/4u) * B_K + dot_index;
register_M = shared_lhs[lhs_sm_position];
// Load a subrow of values from rhs
let rhs_sm_position = (dot_index * B_N + thread_col) / 4u;
register_N = shared_rhs[rhs_sm_position];
// Multiply subcolumn and subrow and store results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
results[res_idx_M * T_N + res_idx_N] += register_M[res_idx_M] * register_N[res_idx_N];
}
}
}
workgroupBarrier();
}
// Write output matrix
// Each thread is responsible of writing T_M x T_N results
for (var res_idx_M = 0u; res_idx_M < T_M; res_idx_M++) {
for (var res_idx_N = 0u; res_idx_N < T_N; res_idx_N++) {
let result_position = res_idx_M * T_N + res_idx_N;
let output_position = offset_output + (row + res_idx_M) * out_stride_row + (col + res_idx_N) * out_stride_col;
output[output_position] = results[result_position];
}
}
}

View File

@ -1,70 +0,0 @@
@group(0)
@binding(0)
var<storage, read> lhs: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read> rhs: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(3)
var<storage, read> info: array<u32>;
const BLOCK_SIZE = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(local_invocation_index) local_idx: u32,
@builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
// Indices
let row = workgroup_id.x * BLOCK_SIZE + (local_idx / BLOCK_SIZE);
let col = workgroup_id.y * BLOCK_SIZE + (local_idx % BLOCK_SIZE);
let batch = global_id.z;
// Basic information
let dim = info[0];
let n_rows = info[6u * dim - 1u];
let n_cols = info[6u * dim];
let K = info[5u * dim - 1u];
// Returns if outside the output dimension
if row >= n_rows || col >= n_cols {
return;
}
// Calculate the corresponding offsets with support for broadcasting.
let offset_output = batch * n_rows * n_cols;
var offset_lhs: u32 = 0u;
var offset_rhs: u32 = 0u;
let batch_dims = dim - 2u;
for (var b: u32 = 1u; b <= batch_dims; b++) {
let stride_lhs = info[b];
let stride_rhs = info[b + dim];
let stride_output = info[b + 2u * dim];
let shape_lhs = info[b + 3u * dim];
let shape_rhs = info[b + 4u * dim];
offset_lhs += offset_output / stride_output % shape_lhs * stride_lhs;
offset_rhs += offset_output / stride_output % shape_rhs * stride_rhs;
}
// Basic matmul implementation
var sum = 0.0;
for (var k: u32 = 0u; k < K; k++) {
let lhs_index = row * K + k;
let rhs_index = k * n_cols + col;
sum += lhs[offset_lhs + lhs_index] * rhs[offset_rhs + rhs_index];
}
let output_index = row * n_cols + col;
output[offset_output + output_index] = sum;
}

View File

@ -1,74 +0,0 @@
@group(0)
@binding(0)
var<storage, read> x: array<{{ elem }}>;
@group(0)
@binding(1)
var<storage, read_write> output: array<{{ elem }}>;
@group(0)
@binding(2)
var<storage, read> info: array<u32, 16>;
const WORKGROUP_SIZE_X = {{ workgroup_size_x }}u;
@compute
@workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, 1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>,
@builtin(num_workgroups) num_workgroups: vec3<u32>,
) {
let id = global_id.y * (num_workgroups.x * WORKGROUP_SIZE_X) + global_id.x;
let input_stride_0 = info[0];
let input_stride_1 = info[1];
let input_stride_2 = info[2];
let input_stride_3 = info[3];
let input_shape_0 = info[4];
let input_shape_1 = info[5];
let input_shape_2 = info[6];
let input_shape_3 = info[7];
let output_stride_0 = info[8];
let output_stride_1 = info[9];
let output_stride_2 = info[10];
let output_stride_3 = info[11];
let output_shape_0 = info[12];
let output_shape_1 = info[13];
let output_shape_2 = info[14];
let output_shape_3 = info[15];
let b = id / output_stride_0 % output_shape_0;
let c = id / output_stride_1 % output_shape_1;
let oh = id / output_stride_2 % output_shape_2;
let ow = id / output_stride_3 % output_shape_3;
let ih_start = start_index(oh, output_shape_2, input_shape_2);
let ih_end = end_index(oh, output_shape_2, input_shape_2);
let iw_start = start_index(ow, output_shape_3, input_shape_3);
let iw_end = end_index(ow, output_shape_3, input_shape_3);
var sum = 0.0;
for (var ih = ih_start; ih < ih_end; ih++) {
for (var iw = iw_start; iw < iw_end; iw++) {
let index_input = b * input_stride_0 + c * input_stride_1 + ih * input_stride_2 + iw * input_stride_3;
sum += x[index_input];
}
}
let count = {{ elem }}((ih_end - ih_start) * (iw_end - iw_start));
output[id] = sum / count;
}
fn start_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
return u32(floor((f32(output_size_index) * f32(input_size)) / f32(output_size)));
}
fn end_index(output_size_index: u32, output_size: u32, input_size: u32) -> u32 {
let index = u32(ceil((f32(output_size_index + 1u) * f32(input_size)) / f32(output_size)));
return min(index, input_size);
}