Faster mask implementation for mixformers. (#2017)

* Faster mask implementation for mixformers.

* Clippy.
This commit is contained in:
Laurent Mazare 2024-04-05 09:38:26 +02:00 committed by GitHub
parent 88f7793598
commit b869a659ec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 21 deletions

View File

@ -126,20 +126,11 @@ impl Module for Embedding {
} }
} }
fn get_mask(size: usize, device: &Device) -> Result<Tensor> { fn get_mask(size: usize, dtype: DType, device: &Device) -> Result<Tensor> {
let mask: Vec<_> = (0..size) let mask: Vec<_> = (0..size)
.flat_map(|i| (0..size).map(move |j| u8::from(j > i))) .flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
.collect(); .collect();
Tensor::from_slice(&mask, (size, size), device) Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype)
}
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
let shape = mask.shape();
let on_true = Tensor::new(on_true, on_false.device())?
.to_dtype(on_false.dtype())?
.broadcast_as(shape.dims())?;
let m = mask.where_cond(&on_true, on_false)?;
Ok(m)
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
@ -252,7 +243,6 @@ struct MHA {
rotary_emb: RotaryEmbedding, rotary_emb: RotaryEmbedding,
kv_cache: Option<(Tensor, Tensor)>, kv_cache: Option<(Tensor, Tensor)>,
head_dim: usize, head_dim: usize,
n_head: usize,
softmax_scale: f64, softmax_scale: f64,
span: tracing::Span, span: tracing::Span,
span_rope: tracing::Span, span_rope: tracing::Span,
@ -273,7 +263,6 @@ impl MHA {
wqkv, wqkv,
out_proj, out_proj,
head_dim, head_dim,
n_head: cfg.n_head,
kv_cache: None, kv_cache: None,
rotary_emb, rotary_emb,
softmax_scale, softmax_scale,
@ -321,11 +310,7 @@ impl MHA {
None => attn_weights, None => attn_weights,
Some(mask) => { Some(mask) => {
let _enter = self.span_mask.enter(); let _enter = self.span_mask.enter();
masked_fill( attn_weights.broadcast_add(mask)?
&attn_weights,
&mask.broadcast_left(b_size * self.n_head)?,
f32::NEG_INFINITY,
)?
} }
}; };
let attn_weights = { let attn_weights = {
@ -435,7 +420,7 @@ impl MixFormerSequentialForCausalLM {
let mask = if seq_len <= 1 { let mask = if seq_len <= 1 {
None None
} else { } else {
Some(get_mask(seq_len, xs.device())?) Some(get_mask(seq_len, xs.dtype(), xs.device())?)
}; };
for block in self.blocks.iter_mut() { for block in self.blocks.iter_mut() {
xs = block.forward(&xs, mask.as_ref())? xs = block.forward(&xs, mask.as_ref())?
@ -456,7 +441,7 @@ impl MixFormerSequentialForCausalLM {
// https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56 // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56
let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?; let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
let (_b_size, seq_len, _embds) = xs.dims3()?; let (_b_size, seq_len, _embds) = xs.dims3()?;
let mask = Some(get_mask(seq_len, xs.device())?); let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?);
for block in self.blocks.iter_mut() { for block in self.blocks.iter_mut() {
xs = block.forward(&xs, mask.as_ref())? xs = block.forward(&xs, mask.as_ref())?
} }