Fix for the RWKV models. (#1955)
* Fix for the RWKV models. * More general fix + revert the rwkv hack. * Remove the old hack.
This commit is contained in:
parent
ada5d7c096
commit
b3484e7a5e
|
@ -62,8 +62,8 @@ pub enum CudaError {
|
||||||
|
|
||||||
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
#[error("matmul is only supported for contiguous tensors lstride: {lhs_stride:?} rstride: {rhs_stride:?} mnk: {mnk:?}")]
|
||||||
MatMulNonContiguous {
|
MatMulNonContiguous {
|
||||||
lhs_stride: Vec<usize>,
|
lhs_stride: Layout,
|
||||||
rhs_stride: Vec<usize>,
|
rhs_stride: Layout,
|
||||||
mnk: (usize, usize, usize),
|
mnk: (usize, usize, usize),
|
||||||
},
|
},
|
||||||
|
|
||||||
|
@ -1653,28 +1653,28 @@ fn gemm_config<T>(
|
||||||
// The a tensor has dims batching, k, n (rhs)
|
// The a tensor has dims batching, k, n (rhs)
|
||||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||||
// there is a single element.
|
// there is a single element.
|
||||||
let (lda, transa) = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
|
let (lda, transa) = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||||
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
(n as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||||
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
|
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||||
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
(k as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||||
} else {
|
} else {
|
||||||
Err(CudaError::MatMulNonContiguous {
|
Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_l.clone(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
rhs_stride: rhs_l.clone(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
// The b tensor has dims batching, m, k (lhs)
|
// The b tensor has dims batching, m, k (lhs)
|
||||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||||
// there is a single element.
|
// there is a single element.
|
||||||
let (ldb, transb) = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
|
let (ldb, transb) = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||||
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
(k as i32, cublasOperation_t::CUBLAS_OP_N)
|
||||||
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
|
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||||
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
(m as i32, cublasOperation_t::CUBLAS_OP_T)
|
||||||
} else {
|
} else {
|
||||||
Err(CudaError::MatMulNonContiguous {
|
Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_l.clone(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
rhs_stride: rhs_l.clone(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?
|
})?
|
||||||
};
|
};
|
||||||
|
@ -1698,8 +1698,8 @@ fn gemm_config<T>(
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => m * k,
|
[] => m * k,
|
||||||
_ => Err(CudaError::MatMulNonContiguous {
|
_ => Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_l.clone(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
rhs_stride: rhs_l.clone(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
|
@ -1708,8 +1708,8 @@ fn gemm_config<T>(
|
||||||
[stride] => stride,
|
[stride] => stride,
|
||||||
[] => n * k,
|
[] => n * k,
|
||||||
_ => Err(CudaError::MatMulNonContiguous {
|
_ => Err(CudaError::MatMulNonContiguous {
|
||||||
lhs_stride: lhs_stride.to_vec(),
|
lhs_stride: lhs_l.clone(),
|
||||||
rhs_stride: rhs_stride.to_vec(),
|
rhs_stride: rhs_l.clone(),
|
||||||
mnk: (m, n, k),
|
mnk: (m, n, k),
|
||||||
})?,
|
})?,
|
||||||
};
|
};
|
||||||
|
|
|
@ -1454,9 +1454,9 @@ pub fn call_gemm(
|
||||||
// lhs has shape b, m, k
|
// lhs has shape b, m, k
|
||||||
// We also allow for the case where the stride on the minor dimension is not as expected but
|
// We also allow for the case where the stride on the minor dimension is not as expected but
|
||||||
// there is a single element.
|
// there is a single element.
|
||||||
let a_trans = if lhs_m1 == 1 && (lhs_m2 == k || b * m == 1) {
|
let a_trans = if (lhs_m1 == 1 || k == 1) && (lhs_m2 == k || m == 1) {
|
||||||
false
|
false
|
||||||
} else if (lhs_m1 == m || b * k == 1) && lhs_m2 == 1 {
|
} else if (lhs_m1 == m || k == 1) && (lhs_m2 == 1 || m == 1) {
|
||||||
true
|
true
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
|
@ -1466,9 +1466,9 @@ pub fn call_gemm(
|
||||||
})?;
|
})?;
|
||||||
};
|
};
|
||||||
// rhs has shape b, k, n
|
// rhs has shape b, k, n
|
||||||
let b_trans = if rhs_m1 == 1 && (rhs_m2 == n || b * k == 1) {
|
let b_trans = if (rhs_m1 == 1 || n == 1) && (rhs_m2 == n || k == 1) {
|
||||||
false
|
false
|
||||||
} else if (rhs_m1 == k || b * n == 1) && rhs_m2 == 1 {
|
} else if (rhs_m1 == k || n == 1) && (rhs_m2 == 1 || k == 1) {
|
||||||
true
|
true
|
||||||
} else {
|
} else {
|
||||||
return Err(MetalKernelError::MatMulNonContiguous {
|
return Err(MetalKernelError::MatMulNonContiguous {
|
||||||
|
|
Loading…
Reference in New Issue