diff --git a/candle-core/src/cuda_backend.rs b/candle-core/src/cuda_backend.rs index 97dc346e..23487330 100644 --- a/candle-core/src/cuda_backend.rs +++ b/candle-core/src/cuda_backend.rs @@ -62,8 +62,8 @@ pub enum CudaError { #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] MatMulNonContiguous { - lhs_stride: Vec, - rhs_stride: Vec, + lhs_stride: Layout, + rhs_stride: Layout, mnk: (usize, usize, usize), }, @@ -1653,28 +1653,28 @@ fn gemm_config( // The a tensor has dims batching, k, n (rhs) // We also allow for the case where the stride on the minor dimension is not as expected but // there is a single element. - let (lda, transa) = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) { + let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { (n as i32, cublasOperation_t::CUBLAS_OP_N) - } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 { + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { (k as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), mnk: (m, n, k), })? }; // The b tensor has dims batching, m, k (lhs) // We also allow for the case where the stride on the minor dimension is not as expected but // there is a single element. - let (ldb, transb) = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) { + let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { (k as i32, cublasOperation_t::CUBLAS_OP_N) - } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 { + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { (m as i32, cublasOperation_t::CUBLAS_OP_T) } else { Err(CudaError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), mnk: (m, n, k), })? }; @@ -1698,8 +1698,8 @@ fn gemm_config( [stride] => stride, [] => m * k, _ => Err(CudaError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), mnk: (m, n, k), })?, }; @@ -1708,8 +1708,8 @@ fn gemm_config( [stride] => stride, [] => n * k, _ => Err(CudaError::MatMulNonContiguous { - lhs_stride: lhs_stride.to_vec(), - rhs_stride: rhs_stride.to_vec(), + lhs_stride: lhs_l.clone(), + rhs_stride: rhs_l.clone(), mnk: (m, n, k), })?, }; diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 3f452331..140927e3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -1454,9 +1454,9 @@ pub fn call_gemm( // lhs has shape b, m, k // We also allow for the case where the stride on the minor dimension is not as expected but // there is a single element. - let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) { + let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) { false - } else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 { + } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) { true } else { return Err(MetalKernelError::MatMulNonContiguous { @@ -1466,9 +1466,9 @@ pub fn call_gemm( })?; }; // rhs has shape b, k, n - let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) { + let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) { false - } else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 { + } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) { true } else { return Err(MetalKernelError::MatMulNonContiguous {