This commit is contained in:
Nicolas Patry 2023-12-13 16:09:20 +01:00
parent ce33d6ad2a
commit 1f23cea90c
5 changed files with 307 additions and 123 deletions

View File

@ -796,101 +796,37 @@ impl BackendStorage for MetalStorage {
) -> Result<Self> {
// Create descriptors
let (type_id, size) = match self.dtype {
DType::F32 => (
metal::mps::MPS_FLOATBIT_ENCODING | 32,
core::mem::size_of::<f32>() as NSUInteger,
),
DType::F16 => (
metal::mps::MPS_FLOATBIT_ENCODING | 16,
core::mem::size_of::<f16>() as NSUInteger,
),
dtype => todo!("Dtype for matmul {dtype:?} is not supported"),
let buffer = self.device.new_buffer(b * m * n, self.dtype);
let name = match self.dtype {
DType::F32 => "sgemm",
DType::F16 => "hgemm",
dtype => {
return Err(MetalError::Message(format!("matmul doesn't support {dtype:?}")).into())
}
};
let lhs_stride = lhs_l.stride();
let rhs_stride = rhs_l.stride();
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
// The a tensor has dims batching, k, n (rhs)
let transpose_left = if lhs_m1 == 1 && lhs_m2 == k {
false
} else if lhs_m1 == m && lhs_m2 == 1 {
true
} else {
Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
let transpose_right = if rhs_m1 == 1 && rhs_m2 == n {
false
} else if rhs_m1 == k && rhs_m2 == 1 {
true
} else {
Err(MetalError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
mnk: (m, n, k),
})?
};
let b = b as NSUInteger;
let m = m as NSUInteger;
let n = n as NSUInteger;
let k = k as NSUInteger;
let left_matrix = self.matrix(
(b, m, k),
transpose_left,
size,
lhs_l.start_offset() as NSUInteger * size,
type_id,
)?;
let right_matrix = rhs.matrix(
(b, k, n),
transpose_right,
size,
rhs_l.start_offset() as NSUInteger * size,
type_id,
)?;
let (result_matrix, out_buffer) =
self.device
.new_matrix((b, m, n), size, type_id, self.dtype)?;
let command_buffer = self.device.command_buffer();
let alpha = 1.0f64;
let beta = 0.0f64;
// Create kernel
let matrix_multiplication = MatrixMultiplication::init(
&self.device,
transpose_left,
transpose_right,
m,
n,
k,
alpha,
beta,
)
.ok_or_else(|| {
MetalError::from("Failed to create matrix multiplication kernel".to_string())
})?;
// Encode kernel to command buffer
matrix_multiplication.encode_to_command_buffer(
&command_buffer,
&left_matrix,
&right_matrix,
&result_matrix,
);
command_buffer.set_label("matmul");
candle_metal_kernels::call_gemm(
&self.device.device,
&command_buffer,
&self.device.kernels,
name,
(b, m, n, k),
&lhs_l.stride(),
lhs_l.start_offset(),
&self.buffer,
&rhs_l.stride(),
rhs_l.start_offset(),
&rhs.buffer,
&buffer,
)
.map_err(MetalError::from)?;
// Create kernel
drop(command_buffer);
self.device.commit();
Ok(Self::new(out_buffer, self.device.clone(), self.dtype()))
Ok(Self::new(buffer, self.device.clone(), self.dtype()))
}
fn copy_strided_src(&self, dst: &mut Self, dst_offset: usize, src_l: &Layout) -> Result<()> {

View File

@ -183,7 +183,7 @@ impl<T> From<std::sync::PoisonError<T>> for MetalKernelError {
#[derive(Debug, PartialEq)]
pub enum Value {
U32(u32),
USize(usize),
Bool(bool),
F32(f32),
U16(u16),
@ -193,7 +193,7 @@ impl std::hash::Hash for Value {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
match self {
Value::F32(v) => v.to_bits().hash(state),
Value::U32(v) => v.hash(state),
Value::USize(v) => v.hash(state),
Value::U16(v) => v.hash(state),
Value::Bool(v) => v.hash(state),
}
@ -203,7 +203,7 @@ impl std::hash::Hash for Value {
impl Value {
fn data_type(&self) -> MTLDataType {
match self {
Value::U32(_) => MTLDataType::UInt,
Value::USize(_) => MTLDataType::UInt,
Value::F32(_) => MTLDataType::Float,
Value::U16(_) => MTLDataType::UShort,
Value::Bool(_) => MTLDataType::Bool,
@ -227,9 +227,9 @@ impl ConstantValues {
for (index, value) in &self.0 {
let ty = value.data_type();
match value {
Value::U32(v) => {
Value::USize(v) => {
f.set_constant_value_at_index(
v as *const u32 as *const c_void,
v as *const usize as *const c_void,
ty,
*index as u64,
);
@ -824,11 +824,39 @@ pub fn call_gemm(
rhs_buffer: &Buffer,
output: &Buffer,
) -> Result<(), MetalKernelError> {
let a_trans = false;
let b_trans = false;
assert!(rhs_stride.len() >= 2);
assert!(lhs_stride.len() >= 2);
let rhs_m1 = rhs_stride[rhs_stride.len() - 1];
let rhs_m2 = rhs_stride[rhs_stride.len() - 2];
let lhs_m1 = lhs_stride[lhs_stride.len() - 1];
let lhs_m2 = lhs_stride[lhs_stride.len() - 2];
let a_trans = if lhs_m1 == 1 && lhs_m2 == k {
false
} else if lhs_m1 == m && lhs_m2 == 1 {
true
} else {
todo!();
// Err(MetalError::MatMulNonContiguous {
// lhs_stride: lhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(),
// mnk: (m, n, k),
// })?
};
let b_trans = if rhs_m1 == 1 && rhs_m2 == n {
false
} else if rhs_m1 == k && rhs_m2 == 1 {
true
} else {
todo!();
// Err(MetalError::MatMulNonContiguous {
// lhs_stride: lhs_stride.to_vec(),
// rhs_stride: rhs_stride.to_vec(),
// mnk: (m, n, k),
// })?
};
let d_trans = false;
let alpha = 1.0;
let beta = 0.0;
let alpha = 1.0f32;
let beta = 0.0f32;
let batched = b > 1;
let fused_activation = false;
let fused_bias = false;
@ -838,9 +866,9 @@ pub fn call_gemm(
let m_splits = 2;
let n_splits = 2;
let constants = Some(ConstantValues::new(vec![
(0, Value::U32(m as u32)),
(1, Value::U32(n as u32)),
(2, Value::U32(k as u32)),
(0, Value::USize(m)),
(1, Value::USize(n)),
(2, Value::USize(k)),
(10, Value::Bool(a_trans)),
(11, Value::Bool(b_trans)),
(13, Value::Bool(d_trans)),
@ -861,7 +889,7 @@ pub fn call_gemm(
(211, Value::U16(n_splits)),
(50_001, Value::Bool(fused_bias)),
]));
println!("Constants {constants:?}");
// println!("Constants {constants:?}");
let pipeline = kernels.load_pipeline_with_constants(device, Source::Mfa, name, constants)?;
let m_group = m_simd * m_splits;
let n_group = n_simd * n_splits;
@ -895,35 +923,34 @@ pub fn call_gemm(
let encoder = command_buffer.new_compute_command_encoder();
encoder.set_compute_pipeline_state(&pipeline);
println!("Threadgroup {block_bytes}");
encoder.set_threadgroup_memory_length(block_bytes.into(), 0);
// println!("Threadgroup {block_bytes}");
encoder.set_threadgroup_memory_length(0, block_bytes.into());
encoder.set_buffer(0, Some(lhs_buffer), lhs_offset as NSUInteger);
encoder.set_buffer(1, Some(rhs_buffer), rhs_offset as NSUInteger);
encoder.set_buffer(2, Some(output), 0);
// TODO Tensor D
let grid_z = b;
let byte_stride_a: usize = *lhs_stride.get(lhs_stride.len() - 3).unwrap_or(&0) * bytes as usize;
let byte_stride_b = *rhs_stride.get(rhs_stride.len() - 3).unwrap_or(&0) * bytes as usize;
let byte_stride_c = m * n * bytes as usize;
// TODO byte_stride_d
let byte_stride_d = 0;
if batched {
let byte_stride_a: usize = lhs_stride[lhs_stride.len() - 3] * bytes as usize;
let byte_stride_b: usize = rhs_stride[rhs_stride.len() - 3] * bytes as usize;
let byte_stride_c = m * n * bytes as usize;
// TODO byte_stride_d
let byte_stride_d = 0;
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
for i in 0..b {
buffer.push((i * byte_stride_a) as u64);
buffer.push((i * byte_stride_b) as u64);
buffer.push((i * byte_stride_c) as u64);
buffer.push((i * byte_stride_d) as u64);
let mut buffer: Vec<u64> = Vec::with_capacity(b * 4);
for i in 0..b {
buffer.push((i * byte_stride_a) as u64);
buffer.push((i * byte_stride_b) as u64);
buffer.push((i * byte_stride_c) as u64);
buffer.push((i * byte_stride_d) as u64);
}
encoder.set_bytes(
10,
buffer.len() as NSUInteger * core::mem::size_of::<u64>(),
buffer.as_ptr() as *const NSUInteger as *const c_void,
);
}
println!("A {:?}", lhs_buffer.read_to_vec::<f32>(12));
println!("B {:?}", rhs_buffer.read_to_vec::<f32>(24));
println!("buffer {:?}", buffer);
encoder.set_bytes(
10,
buffer.len() as NSUInteger,
buffer.as_ptr() as *const NSUInteger as *const c_void,
);
let grid_size = MTLSize {
width: divide(n, n_group.into()),
@ -935,7 +962,7 @@ pub fn call_gemm(
height: 1,
depth: 1,
};
println!("grid size {grid_size:?} group size {group_size:?}");
// println!("grid size {grid_size:?} group size {group_size:?}");
encoder.dispatch_thread_groups(grid_size, group_size);
encoder.end_encoding();

View File

@ -0,0 +1,211 @@
import Metal
import MetalPerformanceShadersGraph
let type = MTLDataType.float;
let dataType = type;
var B = 2;
var M = 2;
var N = 4;
var K = 3;
var A_trans = false;
var B_trans = false;
var D_trans = false;
var alpha = Float(1.0);
var beta = Float(0.0);
var batched = B > 1;
var fused_activation = false;
var fused_bias = false;
let constants = MTLFunctionConstantValues()
constants.setConstantValue(&M, type: .uint, index: 0)
constants.setConstantValue(&N, type: .uint, index: 1)
constants.setConstantValue(&K, type: .uint, index: 2)
constants.setConstantValue(&A_trans, type: .bool, index: 10)
constants.setConstantValue(&B_trans, type: .bool, index: 11)
constants.setConstantValue(&D_trans, type: .bool, index: 13)
constants.setConstantValue(&alpha, type: .float, index: 20)
constants.setConstantValue(&beta, type: .float, index: 21)
constants.setConstantValue(&batched, type: .bool, index: 100)
constants.setConstantValue(&fused_activation, type: .bool, index: 101)
constants.setConstantValue(&fused_bias, type: .bool, index: 50001)
var M_simd = UInt16(16)
var N_simd = UInt16(16)
var K_simd = UInt16(32)
var M_splits = UInt16(2)
var N_splits = UInt16(2)
constants.setConstantValue(&M_simd, type: .ushort, index: 200)
constants.setConstantValue(&N_simd, type: .ushort, index: 201)
constants.setConstantValue(&K_simd, type: .ushort, index: 202)
constants.setConstantValue(&M_splits, type: .ushort, index: 210)
constants.setConstantValue(&N_splits, type: .ushort, index: 211)
let M_group = M_simd * M_splits
let N_group = N_simd * N_splits
// Satisfy Metal API validation.
#if DEBUG
do {
var garbage: SIMD4<UInt64> = .zero
constants.setConstantValue(&garbage, type: .bool, index: 102)
constants.setConstantValue(&garbage, type: .bool, index: 103)
constants.setConstantValue(&garbage, type: .bool, index: 113)
constants.setConstantValue(&garbage, type: .bool, index: 50000)
}
#endif
print(constants)
let device = MTLCopyAllDevices().first!
device.shouldMaximizeConcurrentCompilation = true
var libraryURL = URL.init(string: "/Users/nicolas/src/candle/candle-metal-kernels/")!;
libraryURL.append(component: "src")
libraryURL.append(component: "libMetalFlashAttention.metallib")
let library = try! device.makeLibrary(URL: libraryURL)
var name: String
switch dataType {
case .half: name = "hgemm"
case .float: name = "sgemm"
default: fatalError()
}
let function = try! library.makeFunction(
name: name, constantValues: constants)
let A_block_length = M_group * K_simd
let B_block_length = K_simd * N_group
var blockElements = A_block_length + B_block_length;
if (M % 8 != 0) && (N % 8 != 0) {
let C_block_length = M_group * N_group;
blockElements = max(C_block_length, blockElements)
}
if fused_bias {
if D_trans {
blockElements = max(blockElements, M_group)
} else {
blockElements = max(blockElements, N_group)
}
}
// let blockBytes = blockElements * UInt16(dataType.size)
let elementSize = 4
let blockBytes = blockElements * UInt16(elementSize)
func ceilDivide(target: Int, granularity: UInt16) -> Int {
(target + Int(granularity) - 1) / Int(granularity)
}
var gridSize = MTLSize(
width: ceilDivide(target: N, granularity: N_group),
height: ceilDivide(target: M, granularity: M_group),
depth: 1)
let groupSize = MTLSize(
width: Int(32 * M_splits * N_splits),
height: 1,
depth: 1)
let commandQueue = device.makeCommandQueue()!
let commandBuffer = commandQueue.makeCommandBuffer()!
let encoder = commandBuffer.makeComputeCommandEncoder(dispatchType: MTLDispatchType.serial)!
let pipeline = try device.makeComputePipelineState(function: function)
let threadgroupMemoryLength = blockBytes;
print(threadgroupMemoryLength)
encoder.setComputePipelineState(pipeline)
encoder.setThreadgroupMemoryLength(Int(threadgroupMemoryLength), index: 0)
let rowsA = M;
let columnsA = K;
let rowsB = K;
let columnsB = N;
let rowsC = M;
let columnsC = N;
var arrayA = [Float](repeating: 0, count: B * rowsA * columnsA)
var arrayB = [Float](repeating: 0, count: B * rowsB * columnsB)
var arrayC = [Float](repeating: 0, count: B * rowsC * columnsC)
for i in 0..<arrayA.count {
arrayA[i] = Float(i)
}
for i in 0..<arrayB.count {
arrayB[i] = Float(i)
}
let bufferA = device.makeBuffer(bytes: arrayA, length: B * rowsA * columnsA * MemoryLayout<Float>.stride, options: [])
let bufferB = device.makeBuffer(bytes: arrayB, length: B * rowsB * columnsB * MemoryLayout<Float>.stride, options: [])
let bufferC = device.makeBuffer(length: B * rowsC * columnsC * MemoryLayout<Float>.stride, options: [])
print(arrayA)
print(arrayB)
encoder.setBuffer(bufferA, offset: 0, index: 0)
encoder.setBuffer(bufferB, offset: 0, index: 1)
encoder.setBuffer(bufferC, offset: 0, index: 2)
var gridZ: Int = B
if batched{
func byteStride(shape: [Int]) -> Int {
let rank = shape.count
var output = elementSize * shape[rank - 2] * shape[rank - 1]
if shape.dropLast(2).reduce(1, *) == 1 {
output = 0
}
return output
}
let byteStrideA = M*K*elementSize
let byteStrideB = N*K*elementSize
let byteStrideC = M*N*elementSize
let byteStrideD = 0
// if let shapeD = tensors.d?.shape {
// let rank = shapeD.count
// byteStrideD = elementSize * shapeD[rank - 1]
// if shapeD.dropLast(1).reduce(1, *) == 1 {
// byteStrideD = 0
// }
// }
withUnsafeTemporaryAllocation(
of: SIMD4<UInt64>.self, capacity: gridZ
) { buffer in
for i in 0..<buffer.count {
buffer[i] = SIMD4(
UInt64(truncatingIfNeeded: i * byteStrideA),
UInt64(truncatingIfNeeded: i * byteStrideB),
UInt64(truncatingIfNeeded: i * byteStrideC),
UInt64(truncatingIfNeeded: i * byteStrideD))
}
let bufferLength = buffer.count * MemoryLayout<SIMD4<UInt64>>.stride
assert(MemoryLayout<SIMD4<UInt64>>.stride == 8 * 4)
encoder.setBytes(buffer.baseAddress!, length: bufferLength, index: 10)
print("BATCHED")
print(buffer)
}
}
gridSize.depth = gridZ
print(gridSize, groupSize)
encoder.dispatchThreadgroups(
gridSize, threadsPerThreadgroup: groupSize
)
encoder.endEncoding()
commandBuffer.commit()
commandBuffer.waitUntilCompleted()
var contents = bufferC!.contents();
var count = B * rowsA * columnsB;
var typedPointer = contents.bindMemory(to: Float.self, capacity: count)
var bufferedPointer = UnsafeBufferPointer(start: typedPointer, count: count)
print(Array(bufferedPointer))

View File

@ -774,6 +774,16 @@ fn run_gemm<T: Clone>(
#[test]
fn gemm() {
let (b, m, n, k) = (1, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();
let rhs_stride = vec![n * k, n, 1];
let rhs: Vec<f32> = (0..b * n * k).map(|f| f as f32).collect();
let results = run_gemm((b, m, n, k), &lhs, lhs_stride, &rhs, rhs_stride);
assert_eq!(
approx(results, 4),
vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0]
);
let (b, m, n, k) = (2, 2, 4, 3);
let lhs_stride = vec![m * k, k, 1];
let lhs: Vec<f32> = (0..b * m * k).map(|f| f as f32).collect();