Fix for the RWKV models. (#1955)

* Fix for the RWKV models.

* More general fix + revert the rwkv hack.

* Remove the old hack.
This commit is contained in:
Laurent Mazare 2024-03-28 10:17:38 +01:00 committed by GitHub
parent ada5d7c096
commit b3484e7a5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 18 additions and 18 deletions

View File

@ -62,8 +62,8 @@ pub enum CudaError {
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")] #[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
MatMulNonContiguous { MatMulNonContiguous {
lhs_stride: Vec<usize>, lhs_stride: Layout,
rhs_stride: Vec<usize>, rhs_stride: Layout,
mnk: (usize, usize, usize), mnk: (usize, usize, usize),
}, },
@ -1653,28 +1653,28 @@ fn gemm_config<T>(
// The a tensor has dims batching, k, n (rhs) // The a tensor has dims batching, k, n (rhs)
// We also allow for the case where the stride on the minor dimension is not as expected but // We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element. // there is a single element.
let (lda, transa) = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) { let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
(n as i32, cublasOperation_t::CUBLAS_OP_N) (n as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 { } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
(k as i32, cublasOperation_t::CUBLAS_OP_T) (k as i32, cublasOperation_t::CUBLAS_OP_T)
} else { } else {
Err(CudaError::MatMulNonContiguous { Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(), lhs_stride: lhs_l.clone(),
rhs_stride: rhs_stride.to_vec(), rhs_stride: rhs_l.clone(),
mnk: (m, n, k), mnk: (m, n, k),
})? })?
}; };
// The b tensor has dims batching, m, k (lhs) // The b tensor has dims batching, m, k (lhs)
// We also allow for the case where the stride on the minor dimension is not as expected but // We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element. // there is a single element.
let (ldb, transb) = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) { let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
(k as i32, cublasOperation_t::CUBLAS_OP_N) (k as i32, cublasOperation_t::CUBLAS_OP_N)
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 { } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
(m as i32, cublasOperation_t::CUBLAS_OP_T) (m as i32, cublasOperation_t::CUBLAS_OP_T)
} else { } else {
Err(CudaError::MatMulNonContiguous { Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(), lhs_stride: lhs_l.clone(),
rhs_stride: rhs_stride.to_vec(), rhs_stride: rhs_l.clone(),
mnk: (m, n, k), mnk: (m, n, k),
})? })?
}; };
@ -1698,8 +1698,8 @@ fn gemm_config<T>(
[stride] => stride, [stride] => stride,
[] => m * k, [] => m * k,
_ => Err(CudaError::MatMulNonContiguous { _ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(), lhs_stride: lhs_l.clone(),
rhs_stride: rhs_stride.to_vec(), rhs_stride: rhs_l.clone(),
mnk: (m, n, k), mnk: (m, n, k),
})?, })?,
}; };
@ -1708,8 +1708,8 @@ fn gemm_config<T>(
[stride] => stride, [stride] => stride,
[] => n * k, [] => n * k,
_ => Err(CudaError::MatMulNonContiguous { _ => Err(CudaError::MatMulNonContiguous {
lhs_stride: lhs_stride.to_vec(), lhs_stride: lhs_l.clone(),
rhs_stride: rhs_stride.to_vec(), rhs_stride: rhs_l.clone(),
mnk: (m, n, k), mnk: (m, n, k),
})?, })?,
}; };

View File

@ -1454,9 +1454,9 @@ pub fn call_gemm(
// lhs has shape b, m, k // lhs has shape b, m, k
// We also allow for the case where the stride on the minor dimension is not as expected but // We also allow for the case where the stride on the minor dimension is not as expected but
// there is a single element. // there is a single element.
let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) { let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
false false
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 { } else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
true true
} else { } else {
return Err(MetalKernelError::MatMulNonContiguous { return Err(MetalKernelError::MatMulNonContiguous {
@ -1466,9 +1466,9 @@ pub fn call_gemm(
})?; })?;
}; };
// rhs has shape b, k, n // rhs has shape b, k, n
let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) { let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
false false
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 { } else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
true true
} else { } else {
return Err(MetalKernelError::MatMulNonContiguous { return Err(MetalKernelError::MatMulNonContiguous {