mirror of https://github.com/tracel-ai/burn.git
merge main
This commit is contained in:
commit
86dbd333c2
|
@ -40,7 +40,7 @@ impl StreamId {
|
|||
#[cfg(feature = "std")]
|
||||
fn id() -> std::thread::ThreadId {
|
||||
std::thread_local! {
|
||||
static ID: std::cell::OnceCell::<std::thread::ThreadId> = std::cell::OnceCell::new();
|
||||
static ID: std::cell::OnceCell::<std::thread::ThreadId> = const { std::cell::OnceCell::new() };
|
||||
};
|
||||
|
||||
// Getting the current thread is expensive, so we cache the value into a thread local
|
||||
|
|
|
@ -22,4 +22,6 @@ pub trait Compiler: Sync + Send + 'static + Clone + Default + core::fmt::Debug {
|
|||
fn compile(shader: gpu::ComputeShader) -> Self::Representation;
|
||||
/// The size of the given element in bytes.
|
||||
fn elem_size(elem: gpu::Elem) -> usize;
|
||||
/// The maximal size of a shared memory
|
||||
fn max_shared_memory_size() -> usize;
|
||||
}
|
||||
|
|
|
@ -119,3 +119,15 @@ impl Loop {
|
|||
parent_scope.register(Branch::Loop(op));
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(missing_docs)]
|
||||
pub struct UnrolledRangeLoop;
|
||||
|
||||
impl UnrolledRangeLoop {
|
||||
/// Registers an unrolled range loop to the given scope.
|
||||
pub fn register<F: Fn(Variable, &mut Scope)>(scope: &mut Scope, start: u32, end: u32, func: F) {
|
||||
for i in start..end {
|
||||
func(i.into(), scope);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -299,6 +299,17 @@ macro_rules! gpu {
|
|||
gpu!(unary $input, $out)
|
||||
));
|
||||
};
|
||||
// out = vec4(a, b, c, d)
|
||||
($scope:expr, $out:ident = vec4($a:ident,$b:ident,$c:ident,$d:ident)) => {
|
||||
let i = $scope.zero(Elem::UInt);
|
||||
gpu!($scope, $out[i] = $a);
|
||||
gpu!($scope, i = i + 1u32);
|
||||
gpu!($scope, $out[i] = $b);
|
||||
gpu!($scope, i = i + 1u32);
|
||||
gpu!($scope, $out[i] = $c);
|
||||
gpu!($scope, i = i + 1u32);
|
||||
gpu!($scope, $out[i] = $d);
|
||||
};
|
||||
// out = input
|
||||
($scope:expr, $out:ident = $input:ident) => {
|
||||
gpu!($scope, $out = cast($input))
|
||||
|
@ -332,10 +343,18 @@ macro_rules! gpu {
|
|||
out: $out.into(),
|
||||
});
|
||||
};
|
||||
// range(start, end).for_each(|scope| { ... })
|
||||
// range(start, end).for_each(|i, scope| { ... })
|
||||
($scope:expr, range($start:expr, $end:expr).for_each($arg:expr)) => {
|
||||
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
|
||||
};
|
||||
// range(start, end, unroll).for_each(|i, scope| { ... })
|
||||
($scope:expr, range($start:expr, $end:expr, $unroll:expr).for_each($arg:expr)) => {
|
||||
if $unroll {
|
||||
$crate::codegen::dialect::gpu::UnrolledRangeLoop::register($scope, $start.into(), $end.into(), $arg);
|
||||
} else {
|
||||
$crate::codegen::dialect::gpu::RangeLoop::register($scope, $start.into(), $end.into(), $arg);
|
||||
}
|
||||
};
|
||||
// loop(|scope| { ... })
|
||||
($scope:expr, loop($arg:expr)) => {
|
||||
$crate::codegen::dialect::gpu::Loop::register($scope, $arg);
|
||||
|
|
|
@ -20,7 +20,8 @@ pub struct Scope {
|
|||
pub depth: u8,
|
||||
pub operations: Vec<Operation>,
|
||||
locals: Vec<Variable>,
|
||||
shared: Vec<Variable>,
|
||||
shared_memories: Vec<Variable>,
|
||||
local_arrays: Vec<Variable>,
|
||||
reads_global: Vec<(Variable, ReadingStrategy, Variable)>,
|
||||
index_offset_with_output_layout_position: Vec<usize>,
|
||||
writes_global: Vec<(Variable, Variable)>,
|
||||
|
@ -48,7 +49,8 @@ impl Scope {
|
|||
depth: 0,
|
||||
operations: Vec::new(),
|
||||
locals: Vec::new(),
|
||||
shared: Vec::new(),
|
||||
local_arrays: Vec::new(),
|
||||
shared_memories: Vec::new(),
|
||||
reads_global: Vec::new(),
|
||||
index_offset_with_output_layout_position: Vec::new(),
|
||||
writes_global: Vec::new(),
|
||||
|
@ -213,7 +215,8 @@ impl Scope {
|
|||
depth: self.depth + 1,
|
||||
operations: Vec::new(),
|
||||
locals: Vec::new(),
|
||||
shared: Vec::new(),
|
||||
shared_memories: Vec::new(),
|
||||
local_arrays: Vec::new(),
|
||||
reads_global: Vec::new(),
|
||||
index_offset_with_output_layout_position: Vec::new(),
|
||||
writes_global: Vec::new(),
|
||||
|
@ -308,7 +311,11 @@ impl Scope {
|
|||
}
|
||||
|
||||
fn new_shared_index(&self) -> u16 {
|
||||
self.shared.len() as u16
|
||||
self.shared_memories.len() as u16
|
||||
}
|
||||
|
||||
fn new_local_array_index(&self) -> u16 {
|
||||
self.local_arrays.len() as u16
|
||||
}
|
||||
|
||||
fn read_input_strategy(
|
||||
|
@ -339,7 +346,16 @@ impl Scope {
|
|||
let item = item.into();
|
||||
let index = self.new_shared_index();
|
||||
let shared_memory = Variable::SharedMemory(index, item, shared_memory_size);
|
||||
self.shared.push(shared_memory);
|
||||
self.shared_memories.push(shared_memory);
|
||||
shared_memory
|
||||
}
|
||||
|
||||
/// Create a local array of the given [item type](Item).
|
||||
pub fn create_local_array<I: Into<Item>>(&mut self, item: I, array_size: u32) -> Variable {
|
||||
let item = item.into();
|
||||
let index = self.new_local_array_index();
|
||||
let local_array = Variable::LocalArray(index, item, self.depth, array_size);
|
||||
self.local_arrays.push(local_array);
|
||||
local_array
|
||||
}
|
||||
}
|
||||
|
|
|
@ -11,6 +11,7 @@ pub enum Variable {
|
|||
LocalScalar(u16, Elem, u8),
|
||||
ConstantScalar(f64, Elem),
|
||||
SharedMemory(u16, Item, u32),
|
||||
LocalArray(u16, Item, u8, u32),
|
||||
Id,
|
||||
LocalInvocationIndex,
|
||||
LocalInvocationIdX,
|
||||
|
@ -41,6 +42,7 @@ impl Variable {
|
|||
Variable::GlobalOutputArray(idx, _) => Some(*idx),
|
||||
Variable::ConstantScalar(_, _) => None,
|
||||
Variable::SharedMemory(idx, _, _) => Some(*idx),
|
||||
Variable::LocalArray(idx, _, _, _) => Some(*idx),
|
||||
Variable::Id => None,
|
||||
Variable::LocalInvocationIndex => None,
|
||||
Variable::LocalInvocationIdX => None,
|
||||
|
@ -70,6 +72,7 @@ impl Variable {
|
|||
Variable::LocalScalar(_, elem, _) => Item::Scalar(*elem),
|
||||
Variable::ConstantScalar(_, elem) => Item::Scalar(*elem),
|
||||
Variable::SharedMemory(_, item, _) => *item,
|
||||
Variable::LocalArray(_, item, _, _) => *item,
|
||||
Variable::Id => Item::Scalar(Elem::UInt),
|
||||
Variable::Rank => Item::Scalar(Elem::UInt),
|
||||
Variable::LocalInvocationIndex => Item::Scalar(Elem::UInt),
|
||||
|
|
|
@ -132,6 +132,12 @@ impl Variable {
|
|||
item.vectorize(vectorize),
|
||||
item.vectorized_size(vectorize, *size),
|
||||
),
|
||||
Variable::LocalArray(index, item, name, size) => Variable::LocalArray(
|
||||
*index,
|
||||
item.vectorize(vectorize),
|
||||
*name,
|
||||
item.vectorized_size(vectorize, *size),
|
||||
),
|
||||
Variable::ConstantScalar(_, _) => *self,
|
||||
Variable::GlobalScalar(_, _) => *self,
|
||||
Variable::Id => *self,
|
||||
|
|
|
@ -1,24 +1,98 @@
|
|||
use crate::{tensor::JitTensor, JitElement, Runtime};
|
||||
use std::cmp::{max, min};
|
||||
|
||||
use burn_tensor::Shape;
|
||||
|
||||
use crate::{compute::WorkGroup, tensor::JitTensor, Compiler, JitElement, Runtime};
|
||||
|
||||
use super::{
|
||||
init_matmul_output, matmul_autotune, matmul_mem_coalescing,
|
||||
unpadded::matmul_tiling_2d_unpadded, vec4::matmul_tiling_2d_vec4,
|
||||
init_matmul_output, matmul_autotune, matmul_simple, matmul_tiling_2d, matmul_tiling_2d_padded,
|
||||
};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
/// Tiling 2D parameters
|
||||
pub struct Tiling2dConfig {
|
||||
/// Number of invocations in x
|
||||
pub grid_x: usize,
|
||||
/// Number of invocations in y
|
||||
pub grid_y: usize,
|
||||
/// Block size along dimension of lhs
|
||||
pub block_size_m: usize,
|
||||
/// Block size along common dimension
|
||||
pub block_size_k: usize,
|
||||
/// Block size along dimension of rhs
|
||||
pub block_size_n: usize,
|
||||
/// Tile size along dimension of lhs
|
||||
pub tile_size_m: usize,
|
||||
/// Tile size along dimension of rhs
|
||||
pub tile_size_n: usize,
|
||||
}
|
||||
|
||||
impl Tiling2dConfig {
|
||||
#[allow(unused)]
|
||||
fn new<R: Runtime>(
|
||||
grid_x: usize,
|
||||
grid_y: usize,
|
||||
block_size_m: usize,
|
||||
block_size_k: usize,
|
||||
block_size_n: usize,
|
||||
tile_size_m: usize,
|
||||
tile_size_n: usize,
|
||||
) -> Self {
|
||||
assert!(grid_x == f32::ceil(block_size_m as f32 / tile_size_m as f32) as usize);
|
||||
assert!(grid_y == f32::ceil(block_size_n as f32 / tile_size_n as f32) as usize);
|
||||
assert!(
|
||||
block_size_k <= min(block_size_m, block_size_n),
|
||||
"Not enough invocations to fill shared memory"
|
||||
);
|
||||
assert!(
|
||||
block_size_k * max(block_size_m, block_size_n)
|
||||
<= <R::Compiler as Compiler>::max_shared_memory_size(),
|
||||
"Shared memory limit will be busted. "
|
||||
);
|
||||
assert!(
|
||||
block_size_m % tile_size_m == 0 && block_size_n % tile_size_n == 0,
|
||||
"Tile size must divide block size in m and n dimensions"
|
||||
);
|
||||
Self {
|
||||
grid_x,
|
||||
grid_y,
|
||||
block_size_m,
|
||||
block_size_k,
|
||||
block_size_n,
|
||||
tile_size_m,
|
||||
tile_size_n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Tiling2dConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
grid_x: 16,
|
||||
grid_y: 16,
|
||||
block_size_m: 64,
|
||||
block_size_k: 32,
|
||||
block_size_n: 64,
|
||||
tile_size_m: 4,
|
||||
tile_size_n: 4,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The strategy to be used when launching a matmul kernel.
|
||||
#[derive(Default)]
|
||||
pub enum MatmulStrategy {
|
||||
/// A simple kernel will be used with memory coalescing optimization.
|
||||
Simple {
|
||||
/// Grad size x
|
||||
/// Number of invocations in x
|
||||
grid_x: usize,
|
||||
/// Grad size y
|
||||
/// Number of invocations in y
|
||||
grid_y: usize,
|
||||
},
|
||||
/// A tiling 2d kernel will be used, with support for any matrix size without padding.
|
||||
Tiling2d,
|
||||
Tiling2d(Tiling2dConfig),
|
||||
/// A tiling 2d kernel will be used, with support for any matrix size with padding.
|
||||
Tiling2dPadded,
|
||||
Tiling2dPadded(Tiling2dConfig),
|
||||
#[cfg(feature = "autotune")]
|
||||
/// Using autotune to chose the best kernel based on runtime information.
|
||||
#[default]
|
||||
|
@ -42,17 +116,56 @@ pub fn matmul<R: Runtime, E: JitElement, const D: usize>(
|
|||
match strategy {
|
||||
MatmulStrategy::Simple { grid_x, grid_y } => {
|
||||
let out = init_matmul_output(&lhs, &rhs);
|
||||
matmul_mem_coalescing(lhs, rhs, out, grid_x, grid_y)
|
||||
matmul_simple(lhs, rhs, out, grid_x, grid_y)
|
||||
}
|
||||
MatmulStrategy::Tiling2d => {
|
||||
MatmulStrategy::Tiling2d(config) => {
|
||||
let out = init_matmul_output(&lhs, &rhs);
|
||||
matmul_tiling_2d_unpadded(lhs, rhs, out)
|
||||
matmul_tiling_2d(lhs, rhs, out, config)
|
||||
}
|
||||
MatmulStrategy::Tiling2dPadded => {
|
||||
MatmulStrategy::Tiling2dPadded(config) => {
|
||||
let out = init_matmul_output(&lhs, &rhs);
|
||||
matmul_tiling_2d_vec4(lhs, rhs, out)
|
||||
matmul_tiling_2d_padded(lhs, rhs, out, config)
|
||||
}
|
||||
#[cfg(feature = "autotune")]
|
||||
MatmulStrategy::Autotune => matmul_autotune(lhs, rhs),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn simple_launch_options<const D: usize>(
|
||||
lhs_shape: &Shape<D>,
|
||||
rhs_shape: &Shape<D>,
|
||||
output_shape: &Shape<D>,
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
) -> WorkGroup {
|
||||
let num_rows = lhs_shape.dims[D - 2];
|
||||
let num_cols = rhs_shape.dims[D - 1];
|
||||
|
||||
// set number of workgroups
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
|
||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
|
||||
let mut num_iter = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_iter *= output_shape.dims[i];
|
||||
}
|
||||
|
||||
WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
|
||||
}
|
||||
|
||||
pub(crate) fn tiling2d_launch_options<const D: usize>(
|
||||
output_shape: &Shape<D>,
|
||||
config: Tiling2dConfig,
|
||||
) -> WorkGroup {
|
||||
let num_rows = output_shape.dims[D - 2];
|
||||
let num_cols = output_shape.dims[D - 1];
|
||||
|
||||
// set number of workgroups
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / config.block_size_m as f32) as u32;
|
||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / config.block_size_n as f32) as u32;
|
||||
let mut num_iter = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_iter *= output_shape.dims[i];
|
||||
}
|
||||
|
||||
WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
|
||||
}
|
||||
|
|
|
@ -1,13 +1,22 @@
|
|||
mod base;
|
||||
mod mem_coalescing;
|
||||
mod simple;
|
||||
mod tiling2d;
|
||||
mod tiling2d_shader;
|
||||
mod tune;
|
||||
|
||||
/// Contains utilitary for matmul operation
|
||||
pub mod utils;
|
||||
|
||||
pub use base::*;
|
||||
pub use mem_coalescing::*;
|
||||
pub use tiling2d::*;
|
||||
pub use simple::*;
|
||||
pub use tune::*;
|
||||
pub use utils::*;
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
#[allow(missing_docs)]
|
||||
pub mod padding;
|
||||
|
||||
#[cfg(not(feature = "export_tests"))]
|
||||
mod padding;
|
||||
|
||||
pub use tiling2d::*;
|
||||
|
|
|
@ -6,15 +6,15 @@ use crate::{
|
|||
dialect::gpu, execute_dynamic, Compilation, CompilationInfo, CompilationSettings, Compiler,
|
||||
EagerHandle, InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
compute::WorkGroup,
|
||||
element::JitElement,
|
||||
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, WORKGROUP_DEFAULT},
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use burn_tensor::Shape;
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use super::simple_launch_options;
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulEagerKernel<R: Runtime> {
|
||||
workgroup_size_x: usize,
|
||||
|
@ -213,11 +213,11 @@ pub fn matmul_mem_coalescing_default<R: Runtime, E: JitElement, const D: usize>(
|
|||
rhs: JitTensor<R, E, D>,
|
||||
out: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
matmul_mem_coalescing::<R, E, D>(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT)
|
||||
matmul_simple::<R, E, D>(lhs, rhs, out, WORKGROUP_DEFAULT, WORKGROUP_DEFAULT)
|
||||
}
|
||||
|
||||
/// Matrix multiplication using memory coalescing algorithm with custom workgroup sizes
|
||||
pub fn matmul_mem_coalescing<R: Runtime, E: JitElement, const D: usize>(
|
||||
pub fn matmul_simple<R: Runtime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
out: JitTensor<R, E, D>,
|
||||
|
@ -228,7 +228,7 @@ pub fn matmul_mem_coalescing<R: Runtime, E: JitElement, const D: usize>(
|
|||
let lhs = into_contiguous(lhs);
|
||||
let rhs = into_contiguous(rhs);
|
||||
|
||||
let workgroup = launch_options(
|
||||
let workgroup = simple_launch_options(
|
||||
&lhs.shape,
|
||||
&rhs.shape,
|
||||
&out.shape,
|
||||
|
@ -242,9 +242,8 @@ pub fn matmul_mem_coalescing<R: Runtime, E: JitElement, const D: usize>(
|
|||
&[
|
||||
EagerHandle::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
EagerHandle::new(&out.handle, &out.strides, &out.shape.dims),
|
||||
],
|
||||
&[],
|
||||
&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)],
|
||||
None,
|
||||
kernel,
|
||||
WorkgroupLaunch::Custom(workgroup),
|
||||
|
@ -253,24 +252,3 @@ pub fn matmul_mem_coalescing<R: Runtime, E: JitElement, const D: usize>(
|
|||
|
||||
out
|
||||
}
|
||||
|
||||
fn launch_options<const D: usize>(
|
||||
lhs_shape: &Shape<D>,
|
||||
rhs_shape: &Shape<D>,
|
||||
output_shape: &Shape<D>,
|
||||
workgroup_size_x: usize,
|
||||
workgroup_size_y: usize,
|
||||
) -> WorkGroup {
|
||||
let num_rows = lhs_shape.dims[D - 2];
|
||||
let num_cols = rhs_shape.dims[D - 1];
|
||||
|
||||
// set number of workgroups
|
||||
let blocks_needed_in_x = f32::ceil(num_rows as f32 / workgroup_size_x as f32) as u32;
|
||||
let blocks_needed_in_y = f32::ceil(num_cols as f32 / workgroup_size_y as f32) as u32;
|
||||
let mut num_iter = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_iter *= output_shape.dims[i];
|
||||
}
|
||||
|
||||
WorkGroup::new(blocks_needed_in_x, blocks_needed_in_y, num_iter as u32)
|
||||
}
|
|
@ -0,0 +1,190 @@
|
|||
use burn_tensor::{Element, Shape};
|
||||
|
||||
use crate::{
|
||||
codegen::{
|
||||
dialect::gpu, Compilation, CompilationInfo, CompilationSettings, Compiler, EagerHandle,
|
||||
Execution, InputInfo, OutputInfo, WorkgroupLaunch,
|
||||
},
|
||||
element::JitElement,
|
||||
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate},
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use super::{
|
||||
padding::{crop, pad_round, PaddingOutput},
|
||||
shape_out, tiling2d_launch_options,
|
||||
tiling2d_shader::MatmulTiling2dShader,
|
||||
Tiling2dConfig,
|
||||
};
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulTiling2d<E: JitElement> {
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulTiling2dEagerKernel<R: Runtime> {
|
||||
config: Tiling2dConfig,
|
||||
bounds_check_required: bool,
|
||||
_runtime: PhantomData<R>,
|
||||
}
|
||||
|
||||
impl<R: Runtime> DynamicKernelSource for MatmulTiling2dEagerKernel<R> {
|
||||
fn source(&self) -> SourceTemplate {
|
||||
let mut scope = gpu::Scope::root();
|
||||
let lhs = gpu::Variable::GlobalInputArray(0, gpu::Elem::Float.into());
|
||||
let rhs = gpu::Variable::GlobalInputArray(1, gpu::Elem::Float.into());
|
||||
let out = gpu::Variable::GlobalOutputArray(0, gpu::Elem::Float.into());
|
||||
|
||||
scope.write_global_custom(out);
|
||||
|
||||
MatmulTiling2dShader {
|
||||
variables: gpu::BinaryOperator { lhs, rhs, out },
|
||||
config: self.config.clone(),
|
||||
bounds_check_required: self.bounds_check_required,
|
||||
unroll: true,
|
||||
}
|
||||
.expand(&mut scope);
|
||||
|
||||
let lhs = InputInfo::Array {
|
||||
item: gpu::Elem::Float.into(),
|
||||
visibility: gpu::Visibility::Read,
|
||||
};
|
||||
let rhs = InputInfo::Array {
|
||||
item: gpu::Elem::Float.into(),
|
||||
visibility: gpu::Visibility::Read,
|
||||
};
|
||||
let out = OutputInfo::Array {
|
||||
item: gpu::Elem::Float.into(),
|
||||
};
|
||||
|
||||
let info = CompilationInfo {
|
||||
inputs: vec![lhs, rhs],
|
||||
outputs: vec![out],
|
||||
scope,
|
||||
};
|
||||
|
||||
let settings = CompilationSettings::default().workgroup_size(gpu::WorkgroupSize::new(
|
||||
self.config.grid_x as u32,
|
||||
self.config.grid_y as u32,
|
||||
1,
|
||||
));
|
||||
let shader = Compilation::new(info).compile(settings);
|
||||
let shader = <R::Compiler as Compiler>::compile(shader);
|
||||
SourceTemplate::new(shader.to_string())
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
format!(
|
||||
"{:?}config={:?}boundcheck={:?}",
|
||||
core::any::TypeId::of::<Self>(),
|
||||
self.config,
|
||||
self.bounds_check_required
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2d algorithm with
|
||||
/// vec4 primitive on both lhs and rhs, with no padding needed
|
||||
pub fn matmul_tiling_2d<R: Runtime, E: JitElement + Element, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
out: JitTensor<R, E, D>,
|
||||
config: Tiling2dConfig,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let bounds_check_required = check_bound_requirement(&lhs.shape, &rhs.shape, &config);
|
||||
|
||||
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), bounds_check_required);
|
||||
let client = lhs.client.clone();
|
||||
|
||||
let lhs = match lhs.batch_swapped_with_row_col() {
|
||||
true => into_contiguous(lhs),
|
||||
false => lhs,
|
||||
};
|
||||
let rhs = match rhs.batch_swapped_with_row_col() {
|
||||
true => into_contiguous(rhs),
|
||||
false => rhs,
|
||||
};
|
||||
|
||||
Execution::start(kernel, client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
])
|
||||
.outputs(&[EagerHandle::new(&out.handle, &out.strides, &out.shape.dims)])
|
||||
.execute(WorkgroupLaunch::Custom(tiling2d_launch_options(
|
||||
&out.shape, config,
|
||||
)));
|
||||
|
||||
out
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2d algorithm with padding needed
|
||||
pub fn matmul_tiling_2d_padded<R: Runtime, E: JitElement + Element, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
out: JitTensor<R, E, D>,
|
||||
config: Tiling2dConfig,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let kernel = MatmulTiling2dEagerKernel::<R>::new(config.clone(), false);
|
||||
let client = lhs.client.clone();
|
||||
|
||||
// A tensor may need to be padded, in which case it will implicitly become contiguous
|
||||
// If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim.
|
||||
// If batches were swapped among themselves, or if the last two dims are transposed, the underlying
|
||||
// kernel handles it without needing to turn it into contiguous.
|
||||
let round_lhs = pad_round::<R, E, D>(lhs, config.block_size_m, config.block_size_k);
|
||||
let lhs = match round_lhs {
|
||||
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
|
||||
into_contiguous(tensor)
|
||||
}
|
||||
_ => round_lhs.into_tensor(),
|
||||
};
|
||||
let round_rhs = pad_round::<R, E, D>(rhs, config.block_size_k, config.block_size_n);
|
||||
let rhs = match round_rhs {
|
||||
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
|
||||
into_contiguous(tensor)
|
||||
}
|
||||
_ => round_rhs.into_tensor(),
|
||||
};
|
||||
|
||||
let rounded_output_shape = shape_out(&lhs, &rhs);
|
||||
|
||||
let num_elems = rounded_output_shape.num_elements();
|
||||
let buffer = client.empty(num_elems * core::mem::size_of::<E>());
|
||||
let rounded_output = JitTensor::new(
|
||||
rhs.client.clone(),
|
||||
rhs.device.clone(),
|
||||
rounded_output_shape.clone(),
|
||||
buffer,
|
||||
);
|
||||
|
||||
Execution::start(kernel, client)
|
||||
.inputs(&[
|
||||
EagerHandle::<R>::new(&lhs.handle, &lhs.strides, &lhs.shape.dims),
|
||||
EagerHandle::new(&rhs.handle, &rhs.strides, &rhs.shape.dims),
|
||||
])
|
||||
.outputs(&[EagerHandle::new(
|
||||
&rounded_output.handle,
|
||||
&rounded_output.strides,
|
||||
&rounded_output.shape.dims,
|
||||
)])
|
||||
.execute(WorkgroupLaunch::Custom(tiling2d_launch_options(
|
||||
&rounded_output.shape,
|
||||
config,
|
||||
)));
|
||||
|
||||
crop(rounded_output, out)
|
||||
}
|
||||
|
||||
fn check_bound_requirement<const D: usize>(
|
||||
lhs_shape: &Shape<D>,
|
||||
rhs_shape: &Shape<D>,
|
||||
config: &Tiling2dConfig,
|
||||
) -> bool {
|
||||
lhs_shape.dims[D - 2] % config.block_size_m != 0
|
||||
|| lhs_shape.dims[D - 1] % config.block_size_k != 0
|
||||
|| rhs_shape.dims[D - 1] % config.block_size_n != 0
|
||||
}
|
|
@ -1,91 +0,0 @@
|
|||
use super::padding::{crop, pad_round, PaddingOutput};
|
||||
use crate::{
|
||||
compute::{DynamicKernel, WorkGroup},
|
||||
element::JitElement,
|
||||
kernel::{build_info, into_contiguous, matmul::utils::shape_out, DynamicKernelSource},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use burn_compute::server::Handle;
|
||||
use burn_tensor::Shape;
|
||||
|
||||
pub(crate) const B_M: usize = 64;
|
||||
pub(crate) const B_N: usize = 64;
|
||||
pub(crate) const B_K: usize = 32;
|
||||
pub(crate) const WORKGROUP_SIZE: usize = 16;
|
||||
|
||||
pub(super) fn make_workgroup<const D: usize>(output_shape: &Shape<D>) -> WorkGroup {
|
||||
let num_blocks_x = f32::ceil(output_shape.dims[D - 2] as f32 / B_M as f32) as u32;
|
||||
let num_blocks_y = f32::ceil(output_shape.dims[D - 1] as f32 / B_N as f32) as u32;
|
||||
let mut num_blocks_z = 1;
|
||||
for i in 0..D - 2 {
|
||||
num_blocks_z *= output_shape.dims[i];
|
||||
}
|
||||
|
||||
WorkGroup::new(num_blocks_x, num_blocks_y, num_blocks_z as u32)
|
||||
}
|
||||
|
||||
pub(super) fn make_info_handle<R: Runtime, E: JitElement, const D: usize>(
|
||||
lhs: &JitTensor<R, E, D>,
|
||||
rhs: &JitTensor<R, E, D>,
|
||||
output: &JitTensor<R, E, D>,
|
||||
) -> Handle<R::Server> {
|
||||
let info = build_info(&[lhs, rhs, output]);
|
||||
rhs.client.create(bytemuck::cast_slice(&info))
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(super) fn matmul_tiling_2d_launch<
|
||||
R: Runtime,
|
||||
E: JitElement,
|
||||
const D: usize,
|
||||
K: DynamicKernelSource + 'static,
|
||||
>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
output: JitTensor<R, E, D>,
|
||||
kernel: K,
|
||||
) -> JitTensor<R, E, D> {
|
||||
// A tensor may need to be padded, in which case it will implicitly become contiguous
|
||||
// If not needed, it is only turned into contiguous if some batch dim has been swapped with row or col dim.
|
||||
// If batches were swapped among themselves, or if the last two dims are transposed, the underlying
|
||||
// kernel handles it without needing to turn it into contiguous.
|
||||
let round_lhs = pad_round::<R, E, D>(lhs, B_M, B_K);
|
||||
let lhs = match round_lhs {
|
||||
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
|
||||
into_contiguous(tensor)
|
||||
}
|
||||
_ => round_lhs.into_tensor(),
|
||||
};
|
||||
let round_rhs = pad_round::<R, E, D>(rhs, B_K, B_N);
|
||||
let rhs = match round_rhs {
|
||||
PaddingOutput::Unchanged(tensor) if tensor.batch_swapped_with_row_col() => {
|
||||
into_contiguous(tensor)
|
||||
}
|
||||
_ => round_rhs.into_tensor(),
|
||||
};
|
||||
|
||||
let rounded_output_shape = shape_out(&lhs, &rhs);
|
||||
|
||||
let rounded_output = empty_device(
|
||||
rhs.client.clone(),
|
||||
rhs.device.clone(),
|
||||
rounded_output_shape.clone(),
|
||||
);
|
||||
|
||||
let workgroup = make_workgroup(&rounded_output_shape);
|
||||
let info_handle = make_info_handle(&lhs, &rhs, &rounded_output);
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(DynamicKernel::new(kernel, workgroup)),
|
||||
&[
|
||||
&lhs.handle,
|
||||
&rhs.handle,
|
||||
&rounded_output.handle,
|
||||
&info_handle,
|
||||
],
|
||||
);
|
||||
|
||||
crop(rounded_output, output)
|
||||
}
|
|
@ -1,14 +0,0 @@
|
|||
mod base;
|
||||
|
||||
#[cfg(feature = "export_tests")]
|
||||
#[allow(missing_docs)]
|
||||
pub mod padding;
|
||||
|
||||
#[cfg(not(feature = "export_tests"))]
|
||||
mod padding;
|
||||
|
||||
/// WGSL vec4 primitives are used on left and right hand tensor,
|
||||
/// padding is avoided through the use of conditions in the kernel
|
||||
pub mod unpadded;
|
||||
/// WGSL vec4 primitives are used on left and right hand tensor
|
||||
pub mod vec4;
|
|
@ -1,74 +0,0 @@
|
|||
use burn_tensor::Element;
|
||||
|
||||
use crate::{
|
||||
compute::DynamicKernel,
|
||||
element::JitElement,
|
||||
kernel::{into_contiguous, DynamicKernelSource, SourceTemplate, StaticKernelSource},
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
use crate::kernel_wgsl;
|
||||
|
||||
use super::base::{make_info_handle, make_workgroup, B_K, B_M, B_N, WORKGROUP_SIZE};
|
||||
|
||||
kernel_wgsl!(
|
||||
MatmulTiling2DUnpaddedRaw,
|
||||
"../../../template/matmul/blocktiling_2d/unpadded.wgsl"
|
||||
);
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulTiling2DUnpadded<E: JitElement> {
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: JitElement> DynamicKernelSource for MatmulTiling2DUnpadded<E> {
|
||||
fn source(&self) -> SourceTemplate {
|
||||
MatmulTiling2DUnpaddedRaw::source()
|
||||
.register("b_m", B_M.to_string())
|
||||
.register("b_n", B_N.to_string())
|
||||
.register("b_k", B_K.to_string())
|
||||
.register("bm_x_bk_4", (B_M * B_K / 4).to_string())
|
||||
.register("bk_x_bn_4", (B_K * B_N / 4).to_string())
|
||||
.register("workgroup_size_x", WORKGROUP_SIZE.to_string())
|
||||
.register("workgroup_size_y", WORKGROUP_SIZE.to_string())
|
||||
.register("workgroup_size_z", "1".to_string())
|
||||
.register("elem", E::type_name())
|
||||
.register("int", "i32")
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
std::format!("{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2d algorithm with
|
||||
/// vec4 primitive on both lhs and rhs, with no padding needed
|
||||
pub fn matmul_tiling_2d_unpadded<R: Runtime, E: JitElement + Element, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
out: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let lhs = match lhs.batch_swapped_with_row_col() {
|
||||
true => into_contiguous(lhs),
|
||||
false => lhs,
|
||||
};
|
||||
let rhs = match rhs.batch_swapped_with_row_col() {
|
||||
true => into_contiguous(rhs),
|
||||
false => rhs,
|
||||
};
|
||||
|
||||
let workgroup = make_workgroup(&out.shape);
|
||||
let info_handle = make_info_handle(&lhs, &rhs, &out);
|
||||
|
||||
lhs.client.execute(
|
||||
Box::new(DynamicKernel::new(
|
||||
MatmulTiling2DUnpadded::<E>::new(),
|
||||
workgroup,
|
||||
)),
|
||||
&[&lhs.handle, &rhs.handle, &out.handle, &info_handle],
|
||||
);
|
||||
|
||||
out
|
||||
}
|
|
@ -1,49 +0,0 @@
|
|||
use super::base::{matmul_tiling_2d_launch, B_K, B_M, B_N, WORKGROUP_SIZE};
|
||||
use crate::{
|
||||
element::JitElement,
|
||||
kernel::{DynamicKernelSource, SourceTemplate, StaticKernelSource},
|
||||
tensor::JitTensor,
|
||||
};
|
||||
use crate::{kernel_wgsl, Runtime};
|
||||
use std::marker::PhantomData;
|
||||
|
||||
kernel_wgsl!(
|
||||
MatmulTiling2Dvec4Raw,
|
||||
"../../../template/matmul/blocktiling_2d/vec4.wgsl"
|
||||
);
|
||||
|
||||
#[derive(new, Debug)]
|
||||
struct MatmulTiling2Dvec4<E: JitElement> {
|
||||
_elem: PhantomData<E>,
|
||||
}
|
||||
|
||||
impl<E: JitElement> DynamicKernelSource for MatmulTiling2Dvec4<E> {
|
||||
fn source(&self) -> SourceTemplate {
|
||||
MatmulTiling2Dvec4Raw::source()
|
||||
.register("b_m", B_M.to_string())
|
||||
.register("b_n", B_N.to_string())
|
||||
.register("b_k", B_K.to_string())
|
||||
.register("bm_x_bk_4", (B_M * B_K / 4).to_string())
|
||||
.register("bk_x_bn_4", (B_K * B_N / 4).to_string())
|
||||
.register("workgroup_size_x", WORKGROUP_SIZE.to_string())
|
||||
.register("workgroup_size_y", WORKGROUP_SIZE.to_string())
|
||||
.register("workgroup_size_z", "1".to_string())
|
||||
.register("elem", E::type_name())
|
||||
.register("int", "i32")
|
||||
}
|
||||
|
||||
fn id(&self) -> String {
|
||||
std::format!("{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
/// Matrix multiplication using tiling 2d algorithm with
|
||||
/// vec4 primitive on both lhs and rhs
|
||||
pub fn matmul_tiling_2d_vec4<R: Runtime, E: JitElement, const D: usize>(
|
||||
lhs: JitTensor<R, E, D>,
|
||||
rhs: JitTensor<R, E, D>,
|
||||
out: JitTensor<R, E, D>,
|
||||
) -> JitTensor<R, E, D> {
|
||||
let kernel = MatmulTiling2Dvec4::<E>::new();
|
||||
matmul_tiling_2d_launch::<R, _, D, _>(lhs, rhs, out, kernel)
|
||||
}
|
|
@ -0,0 +1,68 @@
|
|||
use crate::gpu::{gpu, BinaryOperator, Scope, Synchronization, Variable};
|
||||
|
||||
use crate::kernel::matmul::tiling2d_shader::{
|
||||
computation_loop, gather_shader_information, load_shared_memory, write_to_output,
|
||||
};
|
||||
use crate::kernel::matmul::Tiling2dConfig;
|
||||
|
||||
pub(crate) struct MatmulTiling2dShader {
|
||||
pub variables: BinaryOperator,
|
||||
pub config: Tiling2dConfig,
|
||||
pub bounds_check_required: bool,
|
||||
pub unroll: bool,
|
||||
}
|
||||
|
||||
pub(crate) struct Tiling2dState {
|
||||
pub n_loops: Variable,
|
||||
pub k: Variable,
|
||||
pub lhs: Variable,
|
||||
pub rhs: Variable,
|
||||
pub out: Variable,
|
||||
pub offset_lhs: Variable,
|
||||
pub offset_rhs: Variable,
|
||||
pub offset_output: Variable,
|
||||
pub row: Variable,
|
||||
pub col: Variable,
|
||||
pub dim_m: Variable,
|
||||
pub dim_k: Variable,
|
||||
pub dim_n: Variable,
|
||||
pub thread_col: Variable,
|
||||
pub thread_row: Variable,
|
||||
pub shared_lhs: Variable,
|
||||
pub shared_rhs: Variable,
|
||||
pub register_m: Variable,
|
||||
pub register_n: Variable,
|
||||
pub results: Variable,
|
||||
pub lhs_stride_col: Variable,
|
||||
pub lhs_stride_row: Variable,
|
||||
pub rhs_stride_col: Variable,
|
||||
pub rhs_stride_row: Variable,
|
||||
pub out_stride_row: Variable,
|
||||
pub out_stride_col: Variable,
|
||||
}
|
||||
|
||||
impl MatmulTiling2dShader {
|
||||
pub(crate) fn expand(self, scope: &mut Scope) {
|
||||
let shader_state = gather_shader_information(scope, &self);
|
||||
|
||||
let block_size_k: Variable = self.config.block_size_k.into();
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, shader_state.n_loops).for_each(|i, scope| {
|
||||
// From 0 to K with steps block_size_k
|
||||
let k = shader_state.k;
|
||||
gpu!(scope, k = i * block_size_k);
|
||||
|
||||
load_shared_memory(scope, &self, &shader_state);
|
||||
|
||||
scope.register(Synchronization::WorkgroupBarrier);
|
||||
|
||||
computation_loop(scope, &self, &shader_state);
|
||||
|
||||
scope.register(Synchronization::WorkgroupBarrier);
|
||||
})
|
||||
);
|
||||
|
||||
write_to_output(scope, &self, &shader_state);
|
||||
}
|
||||
}
|
|
@ -0,0 +1,82 @@
|
|||
use crate::gpu::{gpu, Elem, Scope, Variable};
|
||||
|
||||
use super::{MatmulTiling2dShader, Tiling2dState};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn computation_loop(
|
||||
scope: &mut Scope,
|
||||
shader: &MatmulTiling2dShader,
|
||||
shader_state: &Tiling2dState,
|
||||
) {
|
||||
let thread_col = shader_state.thread_col;
|
||||
let thread_row = shader_state.thread_row;
|
||||
let shared_lhs = shader_state.shared_lhs;
|
||||
let shared_rhs = shader_state.shared_rhs;
|
||||
let register_m = shader_state.register_m;
|
||||
let register_n = shader_state.register_n;
|
||||
let results = shader_state.results;
|
||||
|
||||
let block_size_k: Variable = shader.config.block_size_k.into();
|
||||
let block_size_n: Variable = shader.config.block_size_n.into();
|
||||
let elem = results.item().elem();
|
||||
|
||||
let lhs_sm_position = scope.create_local(Elem::UInt);
|
||||
let rhs_sm_position = scope.create_local(Elem::UInt);
|
||||
|
||||
let registered_m = scope.create_local(elem);
|
||||
let registered_n = scope.create_local(elem);
|
||||
|
||||
let multiplied = scope.create_local(elem);
|
||||
let results_position = scope.create_local(Elem::UInt);
|
||||
let results_before = scope.create_local(elem);
|
||||
let results_after = scope.create_local(elem);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, shader.config.block_size_k as u32, shader.unroll).for_each(
|
||||
|dot_index, scope| {
|
||||
// Load a subcolumn of values from lhs
|
||||
gpu!(scope, lhs_sm_position = thread_row / 4u32);
|
||||
gpu!(scope, lhs_sm_position *= block_size_k);
|
||||
gpu!(scope, lhs_sm_position += dot_index);
|
||||
gpu!(scope, register_m = shared_lhs[lhs_sm_position]);
|
||||
|
||||
// Load a subrow of values from rhs
|
||||
gpu!(scope, rhs_sm_position = dot_index * block_size_n);
|
||||
gpu!(scope, rhs_sm_position += thread_col);
|
||||
gpu!(scope, rhs_sm_position = rhs_sm_position / 4u32);
|
||||
gpu!(scope, register_n = shared_rhs[rhs_sm_position]);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each(
|
||||
|res_idx_m, scope| {
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, shader.config.tile_size_n as u32, shader.unroll)
|
||||
.for_each(|res_idx_n, scope| {
|
||||
gpu!(scope, registered_m = register_m[res_idx_m]);
|
||||
gpu!(scope, registered_n = register_n[res_idx_n]);
|
||||
|
||||
gpu!(scope, multiplied = registered_m * registered_n);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
results_position =
|
||||
res_idx_m * shader.config.tile_size_n
|
||||
);
|
||||
gpu!(scope, results_position += res_idx_n);
|
||||
|
||||
gpu!(scope, results_before = results[results_position]);
|
||||
gpu!(scope, results_after = results_before + multiplied);
|
||||
|
||||
gpu!(scope, results[results_position] = results_after);
|
||||
})
|
||||
);
|
||||
}
|
||||
)
|
||||
);
|
||||
}
|
||||
)
|
||||
);
|
||||
}
|
|
@ -0,0 +1,278 @@
|
|||
use crate::gpu::{gpu, Elem, Scope, Variable};
|
||||
|
||||
use super::{MatmulTiling2dShader, Tiling2dState};
|
||||
|
||||
enum InputIdentifier {
|
||||
Lhs,
|
||||
Rhs,
|
||||
}
|
||||
|
||||
pub(crate) fn load_shared_memory(
|
||||
scope: &mut Scope,
|
||||
shader: &MatmulTiling2dShader,
|
||||
shader_state: &Tiling2dState,
|
||||
) {
|
||||
if shader.bounds_check_required {
|
||||
load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Lhs);
|
||||
load_shared_memory_with_bound_check(scope, shader, shader_state, InputIdentifier::Rhs);
|
||||
} else {
|
||||
load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Lhs);
|
||||
load_shared_memory_no_bound_check(scope, shader, shader_state, InputIdentifier::Rhs);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn load_shared_memory_with_bound_check(
|
||||
scope: &mut Scope,
|
||||
shader: &MatmulTiling2dShader,
|
||||
shader_state: &Tiling2dState,
|
||||
input_identifier: InputIdentifier,
|
||||
) {
|
||||
let (
|
||||
input,
|
||||
input_offset,
|
||||
shared_memory,
|
||||
thread_idx_1,
|
||||
thread_idx_2,
|
||||
stride_1,
|
||||
stride_2,
|
||||
dim,
|
||||
pos_in_dim,
|
||||
) = match input_identifier {
|
||||
InputIdentifier::Lhs => (
|
||||
shader_state.lhs,
|
||||
shader_state.offset_lhs,
|
||||
shader_state.shared_lhs,
|
||||
shader_state.thread_col,
|
||||
shader_state.thread_row,
|
||||
shader_state.lhs_stride_col,
|
||||
shader_state.lhs_stride_row,
|
||||
shader_state.dim_m,
|
||||
shader_state.row,
|
||||
),
|
||||
InputIdentifier::Rhs => (
|
||||
shader_state.rhs,
|
||||
shader_state.offset_rhs,
|
||||
shader_state.shared_rhs,
|
||||
shader_state.thread_row,
|
||||
shader_state.thread_col,
|
||||
shader_state.rhs_stride_row,
|
||||
shader_state.rhs_stride_col,
|
||||
shader_state.dim_n,
|
||||
shader_state.col,
|
||||
),
|
||||
};
|
||||
let k = shader_state.k;
|
||||
let dim_k = shader_state.dim_k;
|
||||
|
||||
// How close is the thread to the end of the matrix.
|
||||
// If < 4 then it is an edge case
|
||||
let remain = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, remain = dim - pos_in_dim);
|
||||
|
||||
let block_size_k: Variable = shader.config.block_size_k.into();
|
||||
let block_size_n: Variable = shader.config.block_size_n.into();
|
||||
let elem = input.item().elem();
|
||||
|
||||
let current = scope.create_local(Elem::UInt);
|
||||
let aligned_with_shared_memory = scope.create_local(Elem::Bool);
|
||||
let sm_position = scope.create_local(Elem::UInt);
|
||||
let within_input = scope.create_local(Elem::Bool);
|
||||
let current_with_k = scope.create_local(Elem::UInt);
|
||||
let remain_at_least_1 = scope.create_local(Elem::Bool);
|
||||
let read_condition = scope.create_local(Elem::Bool);
|
||||
let val_vec4 = scope.create_local(shared_memory.item());
|
||||
|
||||
let tmp = scope.create_local(Elem::UInt);
|
||||
let position_0 = scope.create_local(Elem::UInt);
|
||||
let position_1 = scope.create_local(Elem::UInt);
|
||||
let position_2 = scope.create_local(Elem::UInt);
|
||||
let position_3 = scope.create_local(Elem::UInt);
|
||||
let remain_n = scope.create_local(Elem::Bool);
|
||||
|
||||
let val_0 = scope.create_local(elem);
|
||||
let val_1 = scope.create_local(elem);
|
||||
let val_2 = scope.create_local(elem);
|
||||
let val_3 = scope.create_local(elem);
|
||||
let zero: Variable = 0u32.into();
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(0_u32, 4u32, shader.unroll).for_each(|j, scope| {
|
||||
gpu!(scope, current = thread_idx_1 + j);
|
||||
|
||||
gpu!(scope, aligned_with_shared_memory = current < block_size_k);
|
||||
|
||||
// To avoid overwriting following row in shared memory
|
||||
gpu!(scope, if(aligned_with_shared_memory).then(|scope|{
|
||||
|
||||
// Position in shared memory
|
||||
match input_identifier {
|
||||
InputIdentifier::Lhs => {
|
||||
gpu!(scope, sm_position = thread_idx_2 / 4u32);
|
||||
gpu!(scope, sm_position *= block_size_k);
|
||||
gpu!(scope, sm_position += current);
|
||||
},
|
||||
InputIdentifier::Rhs => {
|
||||
gpu!(scope, sm_position = current * block_size_n);
|
||||
gpu!(scope, sm_position += thread_idx_2);
|
||||
gpu!(scope, sm_position = sm_position / 4u32);
|
||||
}
|
||||
}
|
||||
|
||||
// To pad with zeros if outside lhs
|
||||
gpu!(scope, current_with_k = current + k);
|
||||
gpu!(scope, within_input = current_with_k < dim_k);
|
||||
gpu!(scope, remain_at_least_1 = remain >= 1u32);
|
||||
gpu!(scope, read_condition = within_input && remain_at_least_1);
|
||||
|
||||
gpu!(scope, if(read_condition).then(|scope| {
|
||||
gpu!(scope, position_0 = k + current);
|
||||
gpu!(scope, position_0 *= stride_1);
|
||||
gpu!(scope, tmp = thread_idx_2 * stride_2);
|
||||
gpu!(scope, position_0 += tmp);
|
||||
gpu!(scope, position_0 += input_offset);
|
||||
gpu!(scope, position_1 = position_0 + stride_2);
|
||||
gpu!(scope, position_2 = position_1 + stride_2);
|
||||
gpu!(scope, position_3 = position_2 + stride_2);
|
||||
|
||||
gpu!(scope, remain_n = remain >= 4u32);
|
||||
gpu!(scope, if(remain_n).then(|scope|{
|
||||
gpu!(scope, val_0 = input[position_0]);
|
||||
gpu!(scope, val_1 = input[position_1]);
|
||||
gpu!(scope, val_2 = input[position_2]);
|
||||
gpu!(scope, val_3 = input[position_3]);
|
||||
|
||||
}).else(|scope|{
|
||||
gpu!(scope, remain_n = remain == 3u32);
|
||||
gpu!(scope, if(remain_n).then(|scope|{
|
||||
gpu!(scope, val_0 = input[position_0]);
|
||||
gpu!(scope, val_1 = input[position_1]);
|
||||
gpu!(scope, val_2 = input[position_2]);
|
||||
gpu!(scope, val_3 = zero);
|
||||
|
||||
}).else(|scope|{
|
||||
gpu!(scope, remain_n = remain == 2u32);
|
||||
gpu!(scope, if(remain_n).then(|scope|{
|
||||
gpu!(scope, val_0 = input[position_0]);
|
||||
gpu!(scope, val_1 = input[position_1]);
|
||||
gpu!(scope, val_2 = zero);
|
||||
gpu!(scope, val_3 = zero);
|
||||
|
||||
}).else(|scope|{
|
||||
gpu!(scope, remain_n = remain == 1u32);
|
||||
gpu!(scope, if(remain_n).then(|scope|{
|
||||
gpu!(scope, val_0 = input[position_0]);
|
||||
gpu!(scope, val_1 = zero);
|
||||
gpu!(scope, val_2 = zero);
|
||||
gpu!(scope, val_3 = zero);
|
||||
}));
|
||||
}));
|
||||
}));
|
||||
}));
|
||||
|
||||
gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3));
|
||||
gpu!(scope, shared_memory[sm_position] = val_vec4);
|
||||
|
||||
}).else(|scope|{
|
||||
gpu!(scope, val_0 = zero);
|
||||
gpu!(scope, val_vec4 = vec4(val_0, val_0, val_0, val_0));
|
||||
gpu!(scope, shared_memory[sm_position] = val_vec4);
|
||||
}));
|
||||
}));
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn load_shared_memory_no_bound_check(
|
||||
scope: &mut Scope,
|
||||
shader: &MatmulTiling2dShader,
|
||||
shader_state: &Tiling2dState,
|
||||
input_identifier: InputIdentifier,
|
||||
) {
|
||||
let (input, input_offset, shared_memory, thread_idx_1, thread_idx_2, stride_1, stride_2) =
|
||||
match input_identifier {
|
||||
InputIdentifier::Lhs => (
|
||||
shader_state.lhs,
|
||||
shader_state.offset_lhs,
|
||||
shader_state.shared_lhs,
|
||||
shader_state.thread_col,
|
||||
shader_state.thread_row,
|
||||
shader_state.lhs_stride_col,
|
||||
shader_state.lhs_stride_row,
|
||||
),
|
||||
InputIdentifier::Rhs => (
|
||||
shader_state.rhs,
|
||||
shader_state.offset_rhs,
|
||||
shader_state.shared_rhs,
|
||||
shader_state.thread_row,
|
||||
shader_state.thread_col,
|
||||
shader_state.rhs_stride_row,
|
||||
shader_state.rhs_stride_col,
|
||||
),
|
||||
};
|
||||
let k = shader_state.k;
|
||||
|
||||
let block_size_k: Variable = shader.config.block_size_k.into();
|
||||
let block_size_n: Variable = shader.config.block_size_n.into();
|
||||
let elem = input.item().elem();
|
||||
|
||||
let current = scope.create_local(Elem::UInt);
|
||||
let aligned_with_shared_memory = scope.create_local(Elem::Bool);
|
||||
let sm_position = scope.create_local(Elem::UInt);
|
||||
|
||||
let tmp = scope.create_local(Elem::UInt);
|
||||
let position_0 = scope.create_local(Elem::UInt);
|
||||
let position_1 = scope.create_local(Elem::UInt);
|
||||
let position_2 = scope.create_local(Elem::UInt);
|
||||
let position_3 = scope.create_local(Elem::UInt);
|
||||
let val_0 = scope.create_local(elem);
|
||||
let val_1 = scope.create_local(elem);
|
||||
let val_2 = scope.create_local(elem);
|
||||
let val_3 = scope.create_local(elem);
|
||||
let val_vec4 = scope.create_local(shared_memory.item());
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(0_u32, 4u32, shader.unroll).for_each(|j, scope| {
|
||||
gpu!(scope, current = thread_idx_1 + j);
|
||||
|
||||
gpu!(scope, aligned_with_shared_memory = current < block_size_k);
|
||||
|
||||
// To avoid overwriting following row in shared memory
|
||||
gpu!(scope, if(aligned_with_shared_memory).then(|scope|{
|
||||
|
||||
match input_identifier {
|
||||
InputIdentifier::Lhs => {
|
||||
gpu!(scope, sm_position = thread_idx_2 / 4u32);
|
||||
gpu!(scope, sm_position *= block_size_k);
|
||||
gpu!(scope, sm_position += current);
|
||||
},
|
||||
InputIdentifier::Rhs => {
|
||||
gpu!(scope, sm_position = current * block_size_n);
|
||||
gpu!(scope, sm_position += thread_idx_2);
|
||||
gpu!(scope, sm_position = sm_position / 4u32);
|
||||
}
|
||||
}
|
||||
|
||||
gpu!(scope, position_0 = k + current);
|
||||
gpu!(scope, position_0 *= stride_1);
|
||||
gpu!(scope, tmp = thread_idx_2 * stride_2);
|
||||
gpu!(scope, position_0 += tmp);
|
||||
gpu!(scope, position_0 += input_offset);
|
||||
gpu!(scope, position_1 = position_0 + stride_2);
|
||||
gpu!(scope, position_2 = position_1 + stride_2);
|
||||
gpu!(scope, position_3 = position_2 + stride_2);
|
||||
|
||||
gpu!(scope, val_0 = input[position_0]);
|
||||
gpu!(scope, val_1 = input[position_1]);
|
||||
gpu!(scope, val_2 = input[position_2]);
|
||||
gpu!(scope, val_3 = input[position_3]);
|
||||
|
||||
gpu!(scope, val_vec4 = vec4(val_0, val_1, val_2, val_3));
|
||||
gpu!(scope, shared_memory[sm_position] = val_vec4);
|
||||
}));
|
||||
})
|
||||
);
|
||||
}
|
|
@ -0,0 +1,11 @@
|
|||
mod base;
|
||||
mod computation;
|
||||
mod load_shared_memory;
|
||||
mod shader_information;
|
||||
mod write_output;
|
||||
|
||||
pub(crate) use base::*;
|
||||
pub(crate) use computation::*;
|
||||
pub(crate) use load_shared_memory::*;
|
||||
pub(crate) use shader_information::*;
|
||||
pub(crate) use write_output::*;
|
|
@ -0,0 +1,180 @@
|
|||
use crate::gpu::{gpu, Elem, Item, Scope, Variable};
|
||||
|
||||
use super::{MatmulTiling2dShader, Tiling2dState};
|
||||
|
||||
pub(crate) fn gather_shader_information(
|
||||
scope: &mut Scope,
|
||||
shader: &MatmulTiling2dShader,
|
||||
) -> Tiling2dState {
|
||||
// Inputs
|
||||
let lhs = shader.variables.lhs;
|
||||
let rhs = shader.variables.rhs;
|
||||
let out = shader.variables.out;
|
||||
|
||||
// Config variables
|
||||
let block_size_m: Variable = shader.config.block_size_m.into();
|
||||
let block_size_k: Variable = shader.config.block_size_k.into();
|
||||
let block_size_n: Variable = shader.config.block_size_n.into();
|
||||
let tile_size_m: Variable = shader.config.tile_size_m.into();
|
||||
let tile_size_n: Variable = shader.config.tile_size_n.into();
|
||||
let n_threads_per_row: Variable =
|
||||
(((shader.config.block_size_n - 1) / shader.config.tile_size_n) + 1).into();
|
||||
let results_size = (shader.config.tile_size_m * shader.config.tile_size_n) as u32;
|
||||
|
||||
// Shader info
|
||||
let local_idx = Variable::LocalInvocationIndex;
|
||||
let batch = Variable::GlobalInvocationIdZ;
|
||||
|
||||
// Shapes
|
||||
let rank = Variable::Rank;
|
||||
let last_dim = scope.create_local(Elem::UInt);
|
||||
let second_to_last_dim = scope.create_local(Elem::UInt);
|
||||
let dim_m = scope.create_local(Elem::UInt);
|
||||
let dim_k = scope.create_local(Elem::UInt);
|
||||
let dim_n = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, last_dim = rank - 1u32);
|
||||
gpu!(scope, second_to_last_dim = rank - 2u32);
|
||||
gpu!(scope, dim_m = shape(lhs, second_to_last_dim));
|
||||
gpu!(scope, dim_k = shape(lhs, last_dim));
|
||||
gpu!(scope, dim_n = shape(rhs, last_dim));
|
||||
|
||||
// Strides
|
||||
let lhs_stride_row = scope.create_local(Elem::UInt);
|
||||
let lhs_stride_col = scope.create_local(Elem::UInt);
|
||||
let rhs_stride_row = scope.create_local(Elem::UInt);
|
||||
let rhs_stride_col = scope.create_local(Elem::UInt);
|
||||
let out_stride_row = scope.create_local(Elem::UInt);
|
||||
let out_stride_col = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, lhs_stride_row = stride(lhs, second_to_last_dim));
|
||||
gpu!(scope, lhs_stride_col = stride(lhs, last_dim));
|
||||
gpu!(scope, rhs_stride_row = stride(rhs, second_to_last_dim));
|
||||
gpu!(scope, rhs_stride_col = stride(rhs, last_dim));
|
||||
gpu!(scope, out_stride_row = stride(out, second_to_last_dim));
|
||||
gpu!(scope, out_stride_col = stride(out, last_dim));
|
||||
|
||||
// Workgroup offset
|
||||
let skip_row = scope.create_local(Elem::UInt);
|
||||
let skip_col = scope.create_local(Elem::UInt);
|
||||
let workgroup_id_x = Variable::WorkgroupIdX;
|
||||
let workgroup_id_y = Variable::WorkgroupIdY;
|
||||
gpu!(scope, skip_row = workgroup_id_x);
|
||||
gpu!(scope, skip_row *= block_size_m);
|
||||
gpu!(scope, skip_col = workgroup_id_y);
|
||||
gpu!(scope, skip_col *= block_size_n);
|
||||
|
||||
// Position of the first element of the thread, relative to the block
|
||||
let thread_row = scope.create_local(Elem::UInt);
|
||||
let thread_col = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, thread_row = local_idx / n_threads_per_row);
|
||||
gpu!(scope, thread_row *= tile_size_m);
|
||||
gpu!(scope, thread_col = local_idx % n_threads_per_row);
|
||||
gpu!(scope, thread_col *= tile_size_n);
|
||||
|
||||
// Position of the first element of the thread, in absolute (in one batch)
|
||||
let row = scope.create_local(Elem::UInt);
|
||||
let col = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, row = skip_row + thread_row);
|
||||
gpu!(scope, col = skip_col + thread_col);
|
||||
|
||||
// Calculate offset.
|
||||
let offset_lhs = scope.create_local(Elem::UInt);
|
||||
let offset_rhs = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, offset_lhs = skip_row * lhs_stride_row);
|
||||
gpu!(scope, offset_rhs = skip_col * rhs_stride_col);
|
||||
|
||||
// Batch offset for the output.
|
||||
let offset_output = scope.create_local(Elem::UInt);
|
||||
let batch_dims = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, offset_output = dim_m * dim_n);
|
||||
gpu!(scope, offset_output = offset_output * batch);
|
||||
|
||||
// Batch offset for the lhs & rhs matrices.
|
||||
let stride_lhs = scope.create_local(Elem::UInt);
|
||||
let stride_rhs = scope.create_local(Elem::UInt);
|
||||
let stride_output = scope.create_local(Elem::UInt);
|
||||
let shape_lhs = scope.create_local(Elem::UInt);
|
||||
let shape_rhs = scope.create_local(Elem::UInt);
|
||||
let tmp = scope.create_local(Elem::UInt);
|
||||
let tmp_lhs = scope.create_local(Elem::UInt);
|
||||
let tmp_rhs = scope.create_local(Elem::UInt);
|
||||
gpu!(scope, batch_dims = rank - 2u32);
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, batch_dims).for_each(|b, scope| {
|
||||
gpu!(scope, stride_lhs = stride(lhs, b));
|
||||
gpu!(scope, stride_rhs = stride(rhs, b));
|
||||
gpu!(scope, stride_output = stride(out, b));
|
||||
gpu!(scope, shape_lhs = shape(lhs, b));
|
||||
gpu!(scope, shape_rhs = shape(rhs, b));
|
||||
|
||||
gpu!(scope, tmp = offset_output / stride_output);
|
||||
gpu!(scope, tmp_lhs = tmp % shape_lhs);
|
||||
gpu!(scope, tmp_lhs = tmp_lhs * stride_lhs);
|
||||
gpu!(scope, offset_lhs += tmp_lhs);
|
||||
|
||||
gpu!(scope, tmp_rhs = tmp % shape_rhs);
|
||||
gpu!(scope, tmp_rhs = tmp_rhs * stride_rhs);
|
||||
gpu!(scope, offset_rhs += tmp_rhs);
|
||||
})
|
||||
);
|
||||
|
||||
let elem = lhs.item().elem();
|
||||
|
||||
// Registers used in the compute pass
|
||||
let results = scope.create_local_array(elem, results_size);
|
||||
let register_m = scope.create_local(Item::Vec4(elem));
|
||||
let register_n = scope.create_local(Item::Vec4(elem));
|
||||
let shared_lhs = scope.create_shared(
|
||||
Item::Vec4(elem),
|
||||
shader.config.block_size_m as u32 * shader.config.block_size_k as u32 / 4u32,
|
||||
);
|
||||
let shared_rhs = scope.create_shared(
|
||||
Item::Vec4(elem),
|
||||
shader.config.block_size_k as u32 * shader.config.block_size_n as u32 / 4u32,
|
||||
);
|
||||
|
||||
// Calculate exact number of loop iterations
|
||||
let n_loops = scope.create_local(Elem::UInt);
|
||||
let k = scope.create_local(Elem::UInt);
|
||||
if shader.bounds_check_required {
|
||||
let dim_k_float = scope.create_local(elem);
|
||||
let block_size_k_float = scope.create_local(elem);
|
||||
let n_loops_float = scope.create_local(elem);
|
||||
gpu!(scope, dim_k_float = dim_k);
|
||||
gpu!(scope, block_size_k_float = block_size_k);
|
||||
gpu!(scope, n_loops_float = dim_k_float / block_size_k_float);
|
||||
gpu!(scope, n_loops_float = ceil(n_loops_float));
|
||||
gpu!(scope, n_loops = n_loops_float);
|
||||
} else {
|
||||
gpu!(scope, n_loops = dim_k / block_size_k);
|
||||
}
|
||||
|
||||
Tiling2dState {
|
||||
n_loops,
|
||||
k,
|
||||
lhs,
|
||||
rhs,
|
||||
out,
|
||||
offset_lhs,
|
||||
offset_rhs,
|
||||
offset_output,
|
||||
row,
|
||||
col,
|
||||
dim_m,
|
||||
dim_k,
|
||||
dim_n,
|
||||
thread_col,
|
||||
thread_row,
|
||||
shared_lhs,
|
||||
shared_rhs,
|
||||
register_m,
|
||||
register_n,
|
||||
results,
|
||||
lhs_stride_col,
|
||||
lhs_stride_row,
|
||||
rhs_stride_col,
|
||||
rhs_stride_row,
|
||||
out_stride_row,
|
||||
out_stride_col,
|
||||
}
|
||||
}
|
|
@ -0,0 +1,121 @@
|
|||
use crate::gpu::{gpu, Elem, Scope, Variable};
|
||||
|
||||
use super::{MatmulTiling2dShader, Tiling2dState};
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn write_to_output(
|
||||
scope: &mut Scope,
|
||||
shader: &MatmulTiling2dShader,
|
||||
shader_state: &Tiling2dState,
|
||||
) {
|
||||
let row = shader_state.row;
|
||||
let col = shader_state.col;
|
||||
|
||||
let row_index = scope.create_local(Elem::UInt);
|
||||
let col_index = scope.create_local(Elem::UInt);
|
||||
|
||||
if shader.bounds_check_required {
|
||||
let dim_m = shader_state.dim_m;
|
||||
let dim_n = shader_state.dim_n;
|
||||
|
||||
let within_output = scope.create_local(Elem::Bool);
|
||||
let within_output_tmp = scope.create_local(Elem::Bool);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each(
|
||||
|res_idx_m, scope| {
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each(
|
||||
|res_idx_n, scope| {
|
||||
gpu!(scope, row_index = row + res_idx_m);
|
||||
gpu!(scope, col_index = col + res_idx_n);
|
||||
|
||||
gpu!(scope, within_output = row_index < dim_m);
|
||||
gpu!(scope, within_output_tmp = col_index < dim_n);
|
||||
gpu!(scope, within_output = within_output && within_output_tmp);
|
||||
|
||||
gpu!(scope, if(within_output).then(|scope|{
|
||||
write_inner(
|
||||
scope,
|
||||
shader,
|
||||
shader_state,
|
||||
res_idx_m,
|
||||
res_idx_n,
|
||||
row_index,
|
||||
col_index,
|
||||
);
|
||||
}));
|
||||
}
|
||||
)
|
||||
);
|
||||
}
|
||||
)
|
||||
);
|
||||
} else {
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, shader.config.tile_size_m as u32, shader.unroll).for_each(
|
||||
|res_idx_m, scope| {
|
||||
gpu!(
|
||||
scope,
|
||||
range(0u32, shader.config.tile_size_n as u32, shader.unroll).for_each(
|
||||
|res_idx_n, scope| {
|
||||
gpu!(scope, row_index = row + res_idx_m);
|
||||
gpu!(scope, col_index = col + res_idx_n);
|
||||
|
||||
write_inner(
|
||||
scope,
|
||||
shader,
|
||||
shader_state,
|
||||
res_idx_m,
|
||||
res_idx_n,
|
||||
row_index,
|
||||
col_index,
|
||||
)
|
||||
}
|
||||
)
|
||||
);
|
||||
}
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn write_inner(
|
||||
scope: &mut Scope,
|
||||
shader: &MatmulTiling2dShader,
|
||||
shader_state: &Tiling2dState,
|
||||
res_idx_m: Variable,
|
||||
res_idx_n: Variable,
|
||||
row_index: Variable,
|
||||
col_index: Variable,
|
||||
) {
|
||||
let offset_output = shader_state.offset_output;
|
||||
let out = shader_state.out;
|
||||
let out_stride_row = shader_state.out_stride_row;
|
||||
let out_stride_col = shader_state.out_stride_col;
|
||||
let results = shader_state.results;
|
||||
|
||||
let elem = results.item().elem();
|
||||
let results_position = scope.create_local(Elem::UInt);
|
||||
let result = scope.create_local(elem);
|
||||
let output_position = scope.create_local(Elem::UInt);
|
||||
|
||||
gpu!(
|
||||
scope,
|
||||
results_position = res_idx_m * shader.config.tile_size_n
|
||||
);
|
||||
gpu!(scope, results_position += res_idx_n);
|
||||
|
||||
gpu!(scope, result = results[results_position]);
|
||||
|
||||
gpu!(scope, row_index *= out_stride_row);
|
||||
gpu!(scope, col_index *= out_stride_col);
|
||||
gpu!(scope, output_position = row_index + col_index);
|
||||
gpu!(scope, output_position += offset_output);
|
||||
|
||||
gpu!(scope, out[output_position] = result);
|
||||
}
|
|
@ -4,7 +4,10 @@ use burn_tensor::{Element, ElementConversion};
|
|||
use crate::{
|
||||
compute::JitAutotuneKey,
|
||||
element::JitElement,
|
||||
kernel::{matmul::utils::init_matmul_output, prng::random_like_uniform},
|
||||
kernel::{
|
||||
matmul::{utils::init_matmul_output, Tiling2dConfig},
|
||||
prng::random_like_uniform,
|
||||
},
|
||||
ops::numeric::empty_device,
|
||||
tensor::JitTensor,
|
||||
Runtime,
|
||||
|
@ -50,22 +53,14 @@ impl<R: Runtime, E: JitElement + Element, const D: usize> AutotuneOperationSet<J
|
|||
);
|
||||
|
||||
vec![
|
||||
Box::new(MemoryCoalescingMatmulDefault::new(
|
||||
Box::new(SimpleMatmul::new(lhs.clone(), rhs.clone(), out.clone())),
|
||||
Box::new(SimpleMatmul16x16::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(MemoryCoalescingMatmulW16x16::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(Vec4TilingMatmulDefault::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
)),
|
||||
Box::new(Vec4TilingMatmulUnpaddedDefault::new(
|
||||
Box::new(Tiling2dMatmul::new(lhs.clone(), rhs.clone(), out.clone())),
|
||||
Box::new(Tiling2dMatmulPadded::new(
|
||||
lhs.clone(),
|
||||
rhs.clone(),
|
||||
out.clone(),
|
||||
|
@ -75,16 +70,10 @@ impl<R: Runtime, E: JitElement + Element, const D: usize> AutotuneOperationSet<J
|
|||
|
||||
fn fastest(self: Box<Self>, fastest_index: usize) -> Box<dyn AutotuneOperation> {
|
||||
match fastest_index {
|
||||
0 => Box::new(MemoryCoalescingMatmulDefault::new(
|
||||
self.lhs, self.rhs, self.out,
|
||||
)),
|
||||
1 => Box::new(MemoryCoalescingMatmulW16x16::new(
|
||||
self.lhs, self.rhs, self.out,
|
||||
)),
|
||||
2 => Box::new(Vec4TilingMatmulDefault::new(self.lhs, self.rhs, self.out)),
|
||||
3 => Box::new(Vec4TilingMatmulUnpaddedDefault::new(
|
||||
self.lhs, self.rhs, self.out,
|
||||
)),
|
||||
0 => Box::new(SimpleMatmul::new(self.lhs, self.rhs, self.out)),
|
||||
1 => Box::new(SimpleMatmul16x16::new(self.lhs, self.rhs, self.out)),
|
||||
2 => Box::new(Tiling2dMatmul::new(self.lhs, self.rhs, self.out)),
|
||||
3 => Box::new(Tiling2dMatmulPadded::new(self.lhs, self.rhs, self.out)),
|
||||
_ => panic!("Fastest index is out of bound"),
|
||||
}
|
||||
}
|
||||
|
@ -134,23 +123,21 @@ macro_rules! matmul_tune_ops {
|
|||
|
||||
// Potentially better for small matrices.
|
||||
matmul_tune_ops!(
|
||||
MemoryCoalescingMatmulDefault,
|
||||
SimpleMatmul,
|
||||
crate::kernel::matmul::matmul_mem_coalescing_default
|
||||
);
|
||||
|
||||
// Potentially better for small matrices.
|
||||
matmul_tune_ops!(MemoryCoalescingMatmulW16x16, |lhs, rhs, out| {
|
||||
crate::kernel::matmul::matmul_mem_coalescing(lhs, rhs, out, 16, 16)
|
||||
matmul_tune_ops!(SimpleMatmul16x16, |lhs, rhs, out| {
|
||||
crate::kernel::matmul::matmul_simple(lhs, rhs, out, 16, 16)
|
||||
});
|
||||
|
||||
// Probably the fastest when fixed sizes.
|
||||
matmul_tune_ops!(
|
||||
Vec4TilingMatmulDefault,
|
||||
crate::kernel::matmul::vec4::matmul_tiling_2d_vec4
|
||||
);
|
||||
matmul_tune_ops!(Tiling2dMatmulPadded, |lhs, rhs, out| {
|
||||
crate::kernel::matmul::matmul_tiling_2d_padded(lhs, rhs, out, Tiling2dConfig::default())
|
||||
});
|
||||
|
||||
// Probably the fastest otherwise.
|
||||
matmul_tune_ops!(
|
||||
Vec4TilingMatmulUnpaddedDefault,
|
||||
crate::kernel::matmul::unpadded::matmul_tiling_2d_unpadded
|
||||
);
|
||||
// Probably the fastest in the general case
|
||||
matmul_tune_ops!(Tiling2dMatmul, |lhs, rhs, out| {
|
||||
crate::kernel::matmul::matmul_tiling_2d(lhs, rhs, out, Tiling2dConfig::default())
|
||||
});
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#[burn_tensor_testgen::testgen(matmul)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use burn_jit::kernel::matmul::{matmul, MatmulStrategy};
|
||||
use burn_jit::kernel::matmul::{matmul, MatmulStrategy, Tiling2dConfig};
|
||||
use burn_tensor::{Shape, Tensor};
|
||||
|
||||
mod simple {
|
||||
|
@ -174,7 +174,7 @@ mod tests {
|
|||
let shape_lhs = [3, 2, 4, 4];
|
||||
let shape_rhs = [3, 2, 4, 4];
|
||||
same_as_reference_swapped_dims(
|
||||
MatmulStrategy::Tiling2dPadded,
|
||||
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
|
||||
swap,
|
||||
swap,
|
||||
shape_lhs,
|
||||
|
@ -189,7 +189,7 @@ mod tests {
|
|||
let shape_lhs = [3, 2, 4, 4];
|
||||
let shape_rhs = [3, 2, 4, 4];
|
||||
same_as_reference_swapped_dims(
|
||||
MatmulStrategy::Tiling2dPadded,
|
||||
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
|
||||
swap_lhs,
|
||||
swap_rhs,
|
||||
shape_lhs,
|
||||
|
@ -204,7 +204,7 @@ mod tests {
|
|||
let shape_lhs = [4, 4, 4, 4];
|
||||
let shape_rhs = [4, 4, 4, 4];
|
||||
same_as_reference_swapped_dims(
|
||||
MatmulStrategy::Tiling2dPadded,
|
||||
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
|
||||
swap_lhs,
|
||||
swap_rhs,
|
||||
shape_lhs,
|
||||
|
@ -212,10 +212,56 @@ mod tests {
|
|||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stable_test() {
|
||||
let ref_tensor_device = Default::default();
|
||||
let x = ReferenceTensor::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device);
|
||||
let y =
|
||||
ReferenceTensor::from_floats([[0., 1.], [2., 3.], [4., 5.]], &ref_tensor_device);
|
||||
|
||||
let test_tensor_device = Default::default();
|
||||
let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device);
|
||||
let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device);
|
||||
|
||||
let z_reference = x.matmul(y);
|
||||
let z = Tensor::<TestBackend, 2>::from_primitive(matmul(
|
||||
x_jit.into_primitive(),
|
||||
y_jit.into_primitive(),
|
||||
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
|
||||
));
|
||||
|
||||
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn stable_test_2() {
|
||||
let ref_tensor_device = Default::default();
|
||||
let x =
|
||||
ReferenceTensor::from_floats([[0., 1.], [2., 3.], [4., 5.]], &ref_tensor_device);
|
||||
let y = ReferenceTensor::from_floats([[0., 1., 2.], [3., 4., 5.]], &ref_tensor_device);
|
||||
|
||||
let test_tensor_device = Default::default();
|
||||
let x_jit = TestTensor::from_data(x.to_data(), &test_tensor_device);
|
||||
let y_jit = TestTensor::from_data(y.to_data(), &test_tensor_device);
|
||||
|
||||
let z_reference = x.matmul(y);
|
||||
let z = Tensor::<TestBackend, 2>::from_primitive(matmul(
|
||||
x_jit.into_primitive(),
|
||||
y_jit.into_primitive(),
|
||||
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
|
||||
));
|
||||
|
||||
z_reference.into_data().assert_approx_eq(&z.into_data(), 3);
|
||||
}
|
||||
|
||||
fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) {
|
||||
let shape_lhs = [batch_1, batch_2, m, k];
|
||||
let shape_rhs = [batch_1, batch_2, k, n];
|
||||
same_as_reference(MatmulStrategy::Tiling2dPadded, shape_lhs, shape_rhs);
|
||||
same_as_reference(
|
||||
MatmulStrategy::Tiling2dPadded(Tiling2dConfig::default()),
|
||||
shape_lhs,
|
||||
shape_rhs,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -308,7 +354,7 @@ mod tests {
|
|||
let shape_lhs = [3, 2, 4, 4];
|
||||
let shape_rhs = [3, 2, 4, 4];
|
||||
same_as_reference_swapped_dims(
|
||||
MatmulStrategy::Tiling2d,
|
||||
MatmulStrategy::Tiling2d(Tiling2dConfig::default()),
|
||||
swap,
|
||||
swap,
|
||||
shape_lhs,
|
||||
|
@ -323,7 +369,7 @@ mod tests {
|
|||
let shape_lhs = [3, 2, 4, 4];
|
||||
let shape_rhs = [3, 2, 4, 4];
|
||||
same_as_reference_swapped_dims(
|
||||
MatmulStrategy::Tiling2d,
|
||||
MatmulStrategy::Tiling2d((Tiling2dConfig::default())),
|
||||
swap_lhs,
|
||||
swap_rhs,
|
||||
shape_lhs,
|
||||
|
@ -338,7 +384,7 @@ mod tests {
|
|||
let shape_lhs = [4, 4, 4, 4];
|
||||
let shape_rhs = [4, 4, 4, 4];
|
||||
same_as_reference_swapped_dims(
|
||||
MatmulStrategy::Tiling2d,
|
||||
MatmulStrategy::Tiling2d(Tiling2dConfig::default()),
|
||||
swap_lhs,
|
||||
swap_rhs,
|
||||
shape_lhs,
|
||||
|
@ -349,7 +395,11 @@ mod tests {
|
|||
fn test_with_params(m: usize, k: usize, n: usize, batch_1: usize, batch_2: usize) {
|
||||
let shape_lhs = [batch_1, batch_2, m, k];
|
||||
let shape_rhs = [batch_1, batch_2, k, n];
|
||||
same_as_reference(MatmulStrategy::Tiling2d, shape_lhs, shape_rhs);
|
||||
same_as_reference(
|
||||
MatmulStrategy::Tiling2d(Tiling2dConfig::default()),
|
||||
shape_lhs,
|
||||
shape_rhs,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -18,6 +18,7 @@ pub enum Variable {
|
|||
scope_depth: u8,
|
||||
},
|
||||
SharedMemory(u16, Item, u32),
|
||||
LocalArray(u16, Item, u8, u32),
|
||||
Id,
|
||||
LocalInvocationIndex,
|
||||
LocalInvocationIdX,
|
||||
|
@ -79,6 +80,7 @@ impl Variable {
|
|||
Variable::GlobalInputArray(_, _) => false,
|
||||
Variable::GlobalOutputArray(_, _) => false,
|
||||
Variable::SharedMemory(_, _, _) => false,
|
||||
Variable::LocalArray(_, _, _, _) => false,
|
||||
Variable::Local {
|
||||
index: _,
|
||||
item: _,
|
||||
|
@ -110,6 +112,7 @@ impl Variable {
|
|||
Self::GlobalInputArray(_, e) => *e,
|
||||
Self::GlobalOutputArray(_, e) => *e,
|
||||
Self::SharedMemory(_, e, _) => *e,
|
||||
Self::LocalArray(_, e, _, _) => *e,
|
||||
Self::Local {
|
||||
index: _,
|
||||
item,
|
||||
|
@ -222,6 +225,9 @@ impl Display for Variable {
|
|||
Variable::SharedMemory(number, _, _) => {
|
||||
f.write_fmt(format_args!("shared_memory_{number}"))
|
||||
}
|
||||
Variable::LocalArray(number, _, scope_depth, _) => {
|
||||
f.write_fmt(format_args!("a_{scope_depth}_{number}"))
|
||||
}
|
||||
Variable::Id => f.write_str("id"),
|
||||
Variable::LocalInvocationIndex => f.write_str("local_idx"),
|
||||
Variable::LocalInvocationIdX => f.write_str("local_invocation_id.x"),
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
use super::LocalArray;
|
||||
use super::{shader::ComputeShader, Item, SharedMemory};
|
||||
use crate::compiler::wgsl;
|
||||
use crate::{FloatElement, IntElement};
|
||||
|
@ -19,6 +20,7 @@ pub struct WgslCompiler<F: FloatElement, I: IntElement> {
|
|||
shape: bool,
|
||||
num_workgroups: bool,
|
||||
shared_memories: Vec<SharedMemory>,
|
||||
local_arrays: Vec<LocalArray>,
|
||||
_float: PhantomData<F>,
|
||||
_int: PhantomData<I>,
|
||||
}
|
||||
|
@ -44,6 +46,7 @@ impl<F: FloatElement, I: IntElement> Default for WgslCompiler<F, I> {
|
|||
shape: false,
|
||||
num_workgroups: false,
|
||||
shared_memories: Vec::default(),
|
||||
local_arrays: Vec::default(),
|
||||
_float: PhantomData,
|
||||
_int: PhantomData,
|
||||
}
|
||||
|
@ -64,6 +67,10 @@ impl<F: FloatElement, I: IntElement> burn_jit::Compiler for WgslCompiler<F, I> {
|
|||
fn elem_size(elem: gpu::Elem) -> usize {
|
||||
Self::compile_elem(elem).size()
|
||||
}
|
||||
|
||||
fn max_shared_memory_size() -> usize {
|
||||
8192
|
||||
}
|
||||
}
|
||||
|
||||
impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
||||
|
@ -98,6 +105,7 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
|||
.map(|(name, binding)| (name, Self::compile_binding(binding)))
|
||||
.collect(),
|
||||
shared_memories: self.shared_memories.clone(),
|
||||
local_arrays: self.local_arrays.clone(),
|
||||
workgroup_size: value.workgroup_size,
|
||||
global_invocation_id: self.global_invocation_id || self.id,
|
||||
local_invocation_index: self.local_invocation_index,
|
||||
|
@ -159,6 +167,14 @@ impl<F: FloatElement, I: IntElement> WgslCompiler<F, I> {
|
|||
}
|
||||
wgsl::Variable::SharedMemory(index, item, size)
|
||||
}
|
||||
gpu::Variable::LocalArray(index, item, scope_depth, size) => {
|
||||
let item = Self::compile_item(item);
|
||||
if !self.local_arrays.iter().any(|s| s.index == index) {
|
||||
self.local_arrays
|
||||
.push(LocalArray::new(index, item, scope_depth, size));
|
||||
}
|
||||
wgsl::Variable::LocalArray(index, item, scope_depth, size)
|
||||
}
|
||||
gpu::Variable::Id => {
|
||||
self.id = true;
|
||||
wgsl::Variable::Id
|
||||
|
|
|
@ -5,7 +5,6 @@ use std::fmt::Display;
|
|||
#[derive(Debug, PartialEq, Eq, Clone, Copy)]
|
||||
pub enum Location {
|
||||
Storage,
|
||||
#[allow(dead_code)]
|
||||
Workgroup,
|
||||
}
|
||||
|
||||
|
@ -42,12 +41,32 @@ impl SharedMemory {
|
|||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq, Eq, Clone)]
|
||||
pub struct LocalArray {
|
||||
pub index: u16,
|
||||
item: Item,
|
||||
name: u8,
|
||||
size: u32,
|
||||
}
|
||||
|
||||
impl LocalArray {
|
||||
pub fn new(index: u16, item: Item, name: u8, size: u32) -> Self {
|
||||
Self {
|
||||
index,
|
||||
item,
|
||||
name,
|
||||
size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ComputeShader {
|
||||
pub inputs: Vec<Binding>,
|
||||
pub outputs: Vec<Binding>,
|
||||
pub named: Vec<(String, Binding)>,
|
||||
pub shared_memories: Vec<SharedMemory>,
|
||||
pub local_arrays: Vec<LocalArray>,
|
||||
pub workgroup_size: WorkgroupSize,
|
||||
pub global_invocation_id: bool,
|
||||
pub local_invocation_index: bool,
|
||||
|
@ -72,10 +91,10 @@ impl Display for ComputeShader {
|
|||
)?;
|
||||
}
|
||||
|
||||
for shared_memory in self.shared_memories.iter() {
|
||||
for array in self.shared_memories.iter() {
|
||||
f.write_fmt(format_args!(
|
||||
"var<{}> shared_memory_{}: array<{}, {}>;\n\n",
|
||||
shared_memory.location, shared_memory.index, shared_memory.item, shared_memory.size
|
||||
array.location, array.index, array.item, array.size
|
||||
))?;
|
||||
}
|
||||
|
||||
|
@ -115,12 +134,22 @@ fn main(
|
|||
f.write_str(" @builtin(workgroup_id) workgroup_id: vec3<u32>,\n")?;
|
||||
}
|
||||
|
||||
f.write_fmt(format_args!(
|
||||
") {{
|
||||
{}
|
||||
}}",
|
||||
self.body
|
||||
))?;
|
||||
// Open body
|
||||
f.write_fmt(format_args!(") {{"))?;
|
||||
|
||||
// Local arrays
|
||||
for array in self.local_arrays.iter() {
|
||||
f.write_fmt(format_args!(
|
||||
"var a_{}_{}: array<{}, {}>;\n\n",
|
||||
array.name, array.index, array.item, array.size
|
||||
))?;
|
||||
}
|
||||
|
||||
// Body
|
||||
f.write_fmt(format_args!("{}", self.body))?;
|
||||
|
||||
// Close body
|
||||
f.write_fmt(format_args!("}}"))?;
|
||||
|
||||
for extension in self.extensions.iter() {
|
||||
f.write_fmt(format_args!("{extension}\n\n"))?;
|
||||
|
|
|
@ -1,7 +1,13 @@
|
|||
use std::process::Command;
|
||||
use std::{path::Path, process::Command};
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
const MEMBER_PATH_PREFIX: &str = if cfg!(target_os = "windows") {
|
||||
"path+file:///"
|
||||
} else {
|
||||
"path+file://"
|
||||
};
|
||||
|
||||
pub(crate) enum WorkspaceMemberType {
|
||||
Crate,
|
||||
Example,
|
||||
|
@ -26,43 +32,27 @@ pub(crate) fn get_workspaces(w_type: WorkspaceMemberType) -> Vec<WorkspaceMember
|
|||
.arg("metadata")
|
||||
.output()
|
||||
.expect("Failed to execute command");
|
||||
|
||||
// Parse the JSON output
|
||||
let metadata: Value = serde_json::from_slice(&output.stdout).expect("Failed to parse JSON");
|
||||
|
||||
// Extract workspaces from the metadata, excluding examples/ and xtask
|
||||
let workspaces = metadata["workspace_members"]
|
||||
.as_array()
|
||||
.expect("Expected an array of workspace members")
|
||||
.iter()
|
||||
.filter_map(|member| {
|
||||
let parts: Vec<_> = member.as_str()?.split_whitespace().collect();
|
||||
let (workspace_name, workspace_path) =
|
||||
(parts.first()?.to_owned(), parts.last()?.to_owned());
|
||||
|
||||
let prefix = if cfg!(target_os = "windows") {
|
||||
"(path+file:///"
|
||||
let member_str = member.as_str()?;
|
||||
let has_whitespace = member_str.chars().any(|c| c.is_whitespace());
|
||||
let (name, path) = if has_whitespace {
|
||||
parse_workspace_member0(member_str)?
|
||||
} else {
|
||||
"(path+file://"
|
||||
parse_workspace_member1(member_str)?
|
||||
};
|
||||
let workspace_path = workspace_path.replace(prefix, "").replace(')', "");
|
||||
|
||||
match w_type {
|
||||
WorkspaceMemberType::Crate
|
||||
if workspace_name != "xtask" && !workspace_path.contains("examples/") =>
|
||||
{
|
||||
Some(WorkspaceMember::new(
|
||||
workspace_name.to_string(),
|
||||
workspace_path.to_string(),
|
||||
))
|
||||
WorkspaceMemberType::Crate if name != "xtask" && !path.contains("examples/") => {
|
||||
Some(WorkspaceMember::new(name.to_string(), path.to_string()))
|
||||
}
|
||||
WorkspaceMemberType::Example
|
||||
if workspace_name != "xtask" && workspace_path.contains("examples/") =>
|
||||
{
|
||||
Some(WorkspaceMember::new(
|
||||
workspace_name.to_string(),
|
||||
workspace_path.to_string(),
|
||||
))
|
||||
WorkspaceMemberType::Example if name != "xtask" && path.contains("examples/") => {
|
||||
Some(WorkspaceMember::new(name.to_string(), path.to_string()))
|
||||
}
|
||||
_ => None,
|
||||
}
|
||||
|
@ -71,3 +61,30 @@ pub(crate) fn get_workspaces(w_type: WorkspaceMemberType) -> Vec<WorkspaceMember
|
|||
|
||||
workspaces
|
||||
}
|
||||
|
||||
/// Legacy cargo metadata format for member specs (rust < 1.77)
|
||||
/// Example:
|
||||
/// "backend-comparison 0.13.0 (path+file:///Users/username/burn/backend-comparison)"
|
||||
fn parse_workspace_member0(specs: &str) -> Option<(String, String)> {
|
||||
let parts: Vec<_> = specs.split_whitespace().collect();
|
||||
let (name, path) = (parts.first()?.to_owned(), parts.last()?.to_owned());
|
||||
// skip the first character because it is a '('
|
||||
let path = path
|
||||
.chars()
|
||||
.skip(1)
|
||||
.collect::<String>()
|
||||
.replace(MEMBER_PATH_PREFIX, "")
|
||||
.replace(')', "");
|
||||
Some((name.to_string(), path.to_string()))
|
||||
}
|
||||
|
||||
/// Cargo metadata format for member specs (rust >= 1.77)
|
||||
/// Example:
|
||||
/// "path+file:///Users/username/burn/backend-comparison#0.13.0"
|
||||
fn parse_workspace_member1(specs: &str) -> Option<(String, String)> {
|
||||
let no_prefix = specs.replace(MEMBER_PATH_PREFIX, "").replace(')', "");
|
||||
let path = Path::new(no_prefix.split_once('#')?.0);
|
||||
let name = path.file_name()?.to_str()?;
|
||||
let path = path.to_str()?;
|
||||
Some((name.to_string(), path.to_string()))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue