More accelerate optimizations (#427)
* Add more tracing to the whisper example. * Support accelerate in more examples. * Use accelerate for pointwise functions. * Use accelerate for binary operations too. * Bugfix for binary operation: use the rhs before the lhs.
This commit is contained in:
parent
60cd1551ca
commit
9aca398a4f
|
@ -1,3 +1,6 @@
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use std::str::FromStr;
|
use std::str::FromStr;
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
use libc::{c_char, c_double, c_float, c_int};
|
use libc::{c_char, c_double, c_float, c_int, c_long, c_ulong};
|
||||||
|
|
||||||
mod ffi {
|
mod ffi {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
@ -39,6 +39,90 @@ mod ffi {
|
||||||
c: *mut c_double,
|
c: *mut c_double,
|
||||||
ldc: *const c_int,
|
ldc: *const c_int,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
pub fn vvexpf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||||
|
pub fn vvexp(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||||
|
pub fn vvsqrtf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||||
|
pub fn vvsqrt(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||||
|
pub fn vvsinf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||||
|
pub fn vvsin(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||||
|
pub fn vvcosf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||||
|
pub fn vvcos(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||||
|
pub fn vvlogf(dst: *mut c_float, src: *const c_float, len: *const c_int);
|
||||||
|
pub fn vvlog(dst: *mut c_double, src: *const c_double, len: *const c_int);
|
||||||
|
|
||||||
|
pub fn vDSP_vaddD(
|
||||||
|
_: *const c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: *const c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: *mut c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: c_ulong,
|
||||||
|
);
|
||||||
|
pub fn vDSP_vadd(
|
||||||
|
_: *const c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: *const c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: *mut c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: c_ulong,
|
||||||
|
);
|
||||||
|
pub fn vDSP_vsubD(
|
||||||
|
_: *const c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: *const c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: *mut c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: c_ulong,
|
||||||
|
);
|
||||||
|
pub fn vDSP_vsub(
|
||||||
|
_: *const c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: *const c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: *mut c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: c_ulong,
|
||||||
|
);
|
||||||
|
pub fn vDSP_vmulD(
|
||||||
|
_: *const c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: *const c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: *mut c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: c_ulong,
|
||||||
|
);
|
||||||
|
pub fn vDSP_vmul(
|
||||||
|
_: *const c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: *const c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: *mut c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: c_ulong,
|
||||||
|
);
|
||||||
|
pub fn vDSP_vdivD(
|
||||||
|
_: *const c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: *const c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: *mut c_double,
|
||||||
|
_: c_long,
|
||||||
|
_: c_ulong,
|
||||||
|
);
|
||||||
|
pub fn vDSP_vdiv(
|
||||||
|
_: *const c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: *const c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: *mut c_float,
|
||||||
|
_: c_long,
|
||||||
|
_: c_ulong,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -109,3 +193,158 @@ pub unsafe fn dgemm(
|
||||||
&ldc,
|
&ldc,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_exp(a: &[f32], y: &mut [f32]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvexpf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_exp(a: &[f64], y: &mut [f64]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvexp(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_sqrt(a: &[f32], y: &mut [f32]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvsqrtf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_sqrt(a: &[f64], y: &mut [f64]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvsqrt(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_sin(a: &[f32], y: &mut [f32]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvsinf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_sin(a: &[f64], y: &mut [f64]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvsin(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_cos(a: &[f32], y: &mut [f32]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvcosf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_cos(a: &[f64], y: &mut [f64]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvcos(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_ln(a: &[f32], y: &mut [f32]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvlogf(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_ln(a: &[f64], y: &mut [f64]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
unsafe { ffi::vvlog(y.as_mut_ptr(), a.as_ptr(), &(a_len as i32)) }
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vs_sqr(a: &[f32], y: &mut [f32]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[inline]
|
||||||
|
pub fn vd_sqr(a: &[f64], y: &mut [f64]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len {
|
||||||
|
panic!("a and y have different lengths {a_len} <> {y_len}")
|
||||||
|
}
|
||||||
|
y.iter_mut().zip(a.iter()).for_each(|(y, a)| *y = *a * *a)
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! binary_op {
|
||||||
|
($fn_name:ident, $ty:ty, $accelerate_name:ident) => {
|
||||||
|
#[inline]
|
||||||
|
pub fn $fn_name(a: &[$ty], b: &[$ty], y: &mut [$ty]) {
|
||||||
|
let a_len = a.len();
|
||||||
|
let b_len = b.len();
|
||||||
|
let y_len = y.len();
|
||||||
|
if a_len != y_len || b_len != y_len {
|
||||||
|
panic!(
|
||||||
|
"{} a,b,y len mismatch {a_len} {b_len} {y_len}",
|
||||||
|
stringify!($fn_name)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
unsafe {
|
||||||
|
// Weird quirk of accelerate, the rhs comes before the lhs.
|
||||||
|
ffi::$accelerate_name(
|
||||||
|
b.as_ptr(),
|
||||||
|
1,
|
||||||
|
a.as_ptr(),
|
||||||
|
1,
|
||||||
|
y.as_mut_ptr(),
|
||||||
|
1,
|
||||||
|
a_len as u64,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
binary_op!(vs_add, f32, vDSP_vadd);
|
||||||
|
binary_op!(vd_add, f64, vDSP_vaddD);
|
||||||
|
binary_op!(vs_sub, f32, vDSP_vsub);
|
||||||
|
binary_op!(vd_sub, f64, vDSP_vsubD);
|
||||||
|
binary_op!(vs_mul, f32, vDSP_vmul);
|
||||||
|
binary_op!(vd_mul, f64, vDSP_vmulD);
|
||||||
|
binary_op!(vs_div, f32, vDSP_vdiv);
|
||||||
|
binary_op!(vd_div, f64, vDSP_vdivD);
|
||||||
|
|
|
@ -338,6 +338,21 @@ macro_rules! bin_op {
|
||||||
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
|
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
|
||||||
crate::mkl::$f64_vec(xs1, xs2, ys)
|
crate::mkl::$f64_vec(xs1, xs2, ys)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F32_VEC: bool = true;
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F64_VEC: bool = true;
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32_vec(xs1: &[f32], xs2: &[f32], ys: &mut [f32]) {
|
||||||
|
crate::accelerate::$f32_vec(xs1, xs2, ys)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64_vec(xs1: &[f64], xs2: &[f64], ys: &mut [f64]) {
|
||||||
|
crate::accelerate::$f64_vec(xs1, xs2, ys)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -424,6 +439,21 @@ macro_rules! unary_op {
|
||||||
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||||
crate::mkl::$f64_vec(xs, ys)
|
crate::mkl::$f64_vec(xs, ys)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F32_VEC: bool = true;
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
const F64_VEC: bool = true;
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f32_vec(xs: &[f32], ys: &mut [f32]) {
|
||||||
|
crate::accelerate::$f32_vec(xs, ys)
|
||||||
|
}
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
#[inline(always)]
|
||||||
|
fn f64_vec(xs: &[f64], ys: &mut [f64]) {
|
||||||
|
crate::accelerate::$f64_vec(xs, ys)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
|
@ -85,8 +85,14 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||||
let y = (x.log()? + 1.)?;
|
let y = (x.log()? + 1.)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(y.to_vec1::<f32>()?, [2.0986123, 1.0, 2.3862944, -0.89712]);
|
assert_eq!(
|
||||||
assert_eq!(grad_x.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[2.0986, 1.0, 2.3863, -0.8971]
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(grad_x, 4)?,
|
||||||
|
[0.3333, 1.0, 0.25, 6.6667]
|
||||||
|
);
|
||||||
let y = x.exp()?;
|
let y = x.exp()?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
|
@ -141,7 +147,7 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(y.to_vec1::<f32>()?, [3.0, 1.0, 4.0, 0.15]);
|
assert_eq!(y.to_vec1::<f32>()?, [3.0, 1.0, 4.0, 0.15]);
|
||||||
assert_eq!(grad_x.to_vec1::<f32>()?, [1.0, 1.0, 1.0, 1.0]);
|
assert_eq!(test_utils::to_vec1_round(grad_x, 4)?, [1.0, 1.0, 1.0, 1.0]);
|
||||||
let y = x.neg()?;
|
let y = x.neg()?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
|
@ -155,7 +161,10 @@ fn unary_grad(device: &Device) -> Result<()> {
|
||||||
let y = Tensor::new(1f32, device)?.broadcast_div(x)?;
|
let y = Tensor::new(1f32, device)?.broadcast_div(x)?;
|
||||||
let grads = y.backward()?;
|
let grads = y.backward()?;
|
||||||
let grad_x = grads.get(x).context("no grad for x")?;
|
let grad_x = grads.get(x).context("no grad for x")?;
|
||||||
assert_eq!(y.to_vec1::<f32>()?, [0.33333334, 1.0, 0.25, 6.6666665]);
|
assert_eq!(
|
||||||
|
test_utils::to_vec1_round(&y, 4)?,
|
||||||
|
[0.3333, 1.0, 0.25, 6.6667]
|
||||||
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
grad_x.to_vec1::<f32>()?,
|
grad_x.to_vec1::<f32>()?,
|
||||||
[-0.11111111, -1.0, -0.0625, -44.444443],
|
[-0.11111111, -1.0, -0.0625, -44.444443],
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
#![allow(dead_code)]
|
#![allow(dead_code)]
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use candle_core::{Result, Tensor};
|
use candle_core::{Result, Tensor};
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
|
|
|
@ -1,5 +1,8 @@
|
||||||
// TODO: Add an offline mode.
|
// TODO: Add an offline mode.
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
// https://github.com/openai/whisper/blob/main/whisper/model.py/rgs
|
// https://github.com/openai/whisper/blob/main/whisper/model.py/rgs
|
||||||
// TODO:
|
// TODO:
|
||||||
// - kv-cache support?
|
|
||||||
// - Batch size greater than 1.
|
// - Batch size greater than 1.
|
||||||
// - More token filters (SuppressBlanks, ApplyTimestampRules).
|
// - More token filters (SuppressBlanks, ApplyTimestampRules).
|
||||||
|
|
||||||
|
#[cfg(feature = "accelerate")]
|
||||||
|
extern crate accelerate_src;
|
||||||
|
|
||||||
#[cfg(feature = "mkl")]
|
#[cfg(feature = "mkl")]
|
||||||
extern crate intel_mkl_src;
|
extern crate intel_mkl_src;
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
use candle::{Device, Result, Tensor};
|
use candle::{Device, IndexOp, Result, Tensor};
|
||||||
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
|
use candle_nn::{ops::softmax, Conv1d, Conv1dConfig, Embedding, LayerNorm, VarBuilder};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
@ -105,12 +105,16 @@ struct MultiHeadAttention {
|
||||||
out: Linear,
|
out: Linear,
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
|
softmax_span: tracing::Span,
|
||||||
|
matmul_span: tracing::Span,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MultiHeadAttention {
|
impl MultiHeadAttention {
|
||||||
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
fn load(n_state: usize, n_head: usize, vb: VarBuilder) -> Result<Self> {
|
||||||
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
|
let span = tracing::span!(tracing::Level::TRACE, "multi-head-attn");
|
||||||
|
let softmax_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-softmax");
|
||||||
|
let matmul_span = tracing::span!(tracing::Level::TRACE, "multi-head-attn-matmul");
|
||||||
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
let query = linear(n_state, n_state, vb.pp("q_proj"))?;
|
||||||
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
let value = linear(n_state, n_state, vb.pp("v_proj"))?;
|
||||||
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
let key = linear_no_bias(n_state, n_state, vb.pp("k_proj"))?;
|
||||||
|
@ -122,6 +126,8 @@ impl MultiHeadAttention {
|
||||||
out,
|
out,
|
||||||
n_head,
|
n_head,
|
||||||
span,
|
span,
|
||||||
|
softmax_span,
|
||||||
|
matmul_span,
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -178,13 +184,24 @@ impl MultiHeadAttention {
|
||||||
let q = (self.reshape_head(q)? * scale)?;
|
let q = (self.reshape_head(q)? * scale)?;
|
||||||
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
let k = (self.reshape_head(k)?.transpose(2, 3)? * scale)?;
|
||||||
let v = self.reshape_head(v)?.contiguous()?;
|
let v = self.reshape_head(v)?.contiguous()?;
|
||||||
let mut qk = q.matmul(&k)?;
|
let mut qk = {
|
||||||
|
let _enter = self.matmul_span.enter();
|
||||||
|
q.matmul(&k)?
|
||||||
|
};
|
||||||
if let Some(mask) = mask {
|
if let Some(mask) = mask {
|
||||||
let mask = mask.narrow(0, 0, n_ctx)?.narrow(1, 0, n_ctx)?;
|
let mask = mask.i((0..n_ctx, 0..n_ctx))?;
|
||||||
qk = qk.broadcast_add(&mask)?
|
qk = qk.broadcast_add(&mask)?
|
||||||
}
|
}
|
||||||
let w = softmax(&qk, candle::D::Minus1)?;
|
let w = {
|
||||||
let wv = w.matmul(&v)?.transpose(1, 2)?.flatten_from(2)?;
|
let _enter = self.softmax_span.enter();
|
||||||
|
softmax(&qk, candle::D::Minus1)?
|
||||||
|
};
|
||||||
|
let wv = {
|
||||||
|
let _enter = self.matmul_span.enter();
|
||||||
|
w.matmul(&v)?
|
||||||
|
}
|
||||||
|
.transpose(1, 2)?
|
||||||
|
.flatten_from(2)?;
|
||||||
Ok(wv)
|
Ok(wv)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue