From b869a659ec678763b4ba03dc73044be2ba9ad562 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Fri, 5 Apr 2024 09:38:26 +0200 Subject: [PATCH] Faster mask implementation for mixformers. (#2017) * Faster mask implementation for mixformers. * Clippy. --- candle-transformers/src/models/mixformer.rs | 27 +++++---------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index d9676a35..700829e3 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -126,20 +126,11 @@ impl Module for Embedding { } } -fn get_mask(size: usize, device: &Device) -> Result { +fn get_mask(size: usize, dtype: DType, device: &Device) -> Result { let mask: Vec<_> = (0..size) - .flat_map(|i| (0..size).map(move |j| u8::from(j > i))) + .flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. })) .collect(); - Tensor::from_slice(&mask, (size, size), device) -} - -fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result { - let shape = mask.shape(); - let on_true = Tensor::new(on_true, on_false.device())? - .to_dtype(on_false.dtype())? - .broadcast_as(shape.dims())?; - let m = mask.where_cond(&on_true, on_false)?; - Ok(m) + Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype) } #[derive(Debug, Clone)] @@ -252,7 +243,6 @@ struct MHA { rotary_emb: RotaryEmbedding, kv_cache: Option<(Tensor, Tensor)>, head_dim: usize, - n_head: usize, softmax_scale: f64, span: tracing::Span, span_rope: tracing::Span, @@ -273,7 +263,6 @@ impl MHA { wqkv, out_proj, head_dim, - n_head: cfg.n_head, kv_cache: None, rotary_emb, softmax_scale, @@ -321,11 +310,7 @@ impl MHA { None => attn_weights, Some(mask) => { let _enter = self.span_mask.enter(); - masked_fill( - &attn_weights, - &mask.broadcast_left(b_size * self.n_head)?, - f32::NEG_INFINITY, - )? + attn_weights.broadcast_add(mask)? } }; let attn_weights = { @@ -435,7 +420,7 @@ impl MixFormerSequentialForCausalLM { let mask = if seq_len <= 1 { None } else { - Some(get_mask(seq_len, xs.device())?) + Some(get_mask(seq_len, xs.dtype(), xs.device())?) }; for block in self.blocks.iter_mut() { xs = block.forward(&xs, mask.as_ref())? @@ -456,7 +441,7 @@ impl MixFormerSequentialForCausalLM { // https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56 let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?; let (_b_size, seq_len, _embds) = xs.dims3()?; - let mask = Some(get_mask(seq_len, xs.device())?); + let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?); for block in self.blocks.iter_mut() { xs = block.forward(&xs, mask.as_ref())? }