Handle the batch dimension in quantized MMV on metal. (#2022)

This commit is contained in:
Laurent Mazare 2024-04-06 20:02:24 +02:00 committed by GitHub
parent e662431acf
commit 9fd52b3b71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 1 deletions

View File

@ -149,8 +149,11 @@ impl QMetalStorage {
let (n, k) = self_shape.dims2()?;
let mut dst_shape = src_shape.dims().to_vec();
// We always use a single batch dimension and stack all the tensors in the batch on the
// second dimension as the implementation in candle-metal-kernels doesn't handle batch
// properly.
let (b, m) = match dst_shape.len() {
3 => (dst_shape[0], dst_shape[1]),
3 => (1, dst_shape[0] * dst_shape[1]),
2 => (1, dst_shape[0]),
n => crate::bail!("Invalid rank {n} for quantized matmul metal"),
};