Fix for the RWKV models. (#1955)

* Fix for the RWKV models.

* More general fix + revert the rwkv hack.

* Remove the old hack.
This commit is contained in:
Laurent Mazare 2024-03-28 10:17:38 +01:00 committed by GitHub
parent ada5d7c096
commit b3484e7a5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 18 deletions

View File

@ -62,8 +62,8 @@ pub enum CudaError {
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous {
lhs_stride: Vec<usize>,
rhs_stride: Vec<usize>,
lhs_stride: Layout,
rhs_stride: Layout,
mnk: (usize, usize, usize),
},
@ -1653,28 +1653,28 @@ fn gemm_config<T>(
// The a tensor has dims batching, k, n (rhs)
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (lda, transa) = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
(k as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
lhs_stride: lhs_l.clone(),
rhs_stride: rhs_l.clone(),
mnk: (m, n, k),
})?
};
// The b tensor has dims batching, m, k (lhs)
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let (ldb, transb) = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
(m as i32, cublasOperation_t::CUBLAS_OP_T)
} else {
Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
lhs_stride: lhs_l.clone(),
rhs_stride: rhs_l.clone(),
mnk: (m, n, k),
})?
};
@ -1698,8 +1698,8 @@ fn gemm_config<T>(
[stride] => stride,
[] => m * k,
_ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
lhs_stride: lhs_l.clone(),
rhs_stride: rhs_l.clone(),
mnk: (m, n, k),
})?,
};
@ -1708,8 +1708,8 @@ fn gemm_config<T>(
[stride] => stride,
[] => n * k,
_ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(),
rhs_stride: rhs_stride.to_vec(),
lhs_stride: lhs_l.clone(),
rhs_stride: rhs_l.clone(),
mnk: (m, n, k),
})?,
};

View File

@ -1454,9 +1454,9 @@ pub fn call_gemm(
// lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element.
let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
false
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {
@ -1466,9 +1466,9 @@ pub fn call_gemm(
})?;
};
// rhs has shape b, k, n
let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
false
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
true
} else {
return Err(MetalKernelError::MatMulNonContiguous {