Faster mask implementation for mixformers. (#2017)

* Faster mask implementation for mixformers.

* Clippy.
This commit is contained in:
Laurent Mazare 2024-04-05 09:38:26 +02:00 committed by GitHub
parent 88f7793598
commit b869a659ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 21 deletions

View File

@ -126,20 +126,11 @@ impl Module for Embedding {
}
}
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
fn get_mask(size: usize, dtype: DType, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
.flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
.collect();
Tensor::from_slice(&mask, (size, size), device)
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?
.to_dtype(on_false.dtype())?
.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype)
}
#[derive(Debug, Clone)]
@ -252,7 +243,6 @@ struct MHA {
rotary_emb: RotaryEmbedding,
kv_cache: Option<(Tensor, Tensor)>,
head_dim: usize,
n_head: usize,
softmax_scale: f64,
span: tracing::Span,
span_rope: tracing::Span,
@ -273,7 +263,6 @@ impl MHA {
wqkv,
out_proj,
head_dim,
n_head: cfg.n_head,
kv_cache: None,
rotary_emb,
softmax_scale,
@ -321,11 +310,7 @@ impl MHA {
None => attn_weights,
Some(mask) => {
let _enter = self.span_mask.enter();
masked_fill(
&attn_weights,
&mask.broadcast_left(b_size * self.n_head)?,
f32::NEG_INFINITY,
)?
attn_weights.broadcast_add(mask)?
}
};
let attn_weights = {
@ -435,7 +420,7 @@ impl MixFormerSequentialForCausalLM {
let mask = if seq_len <= 1 {
None
} else {
Some(get_mask(seq_len, xs.device())?)
Some(get_mask(seq_len, xs.dtype(), xs.device())?)
};
for block in self.blocks.iter_mut() {
xs = block.forward(&xs, mask.as_ref())?
@ -456,7 +441,7 @@ impl MixFormerSequentialForCausalLM {
// https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56
let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
let (_b_size, seq_len, _embds) = xs.dims3()?;
let mask = Some(get_mask(seq_len, xs.device())?);
let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?);
for block in self.blocks.iter_mut() {
xs = block.forward(&xs, mask.as_ref())?
}