Faster mask implementation for mixformers. (#2017)
* Faster mask implementation for mixformers. * Clippy.
This commit is contained in:
parent
88f7793598
commit
b869a659ec
|
@ -126,20 +126,11 @@ impl Module for Embedding {
|
|||
}
|
||||
}
|
||||
|
||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
||||
fn get_mask(size: usize, dtype: DType, device: &Device) -> Result<Tensor> {
|
||||
let mask: Vec<_> = (0..size)
|
||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
||||
.flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
|
||||
.collect();
|
||||
Tensor::from_slice(&mask, (size, size), device)
|
||||
}
|
||||
|
||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
||||
let shape = mask.shape();
|
||||
let on_true = Tensor::new(on_true, on_false.device())?
|
||||
.to_dtype(on_false.dtype())?
|
||||
.broadcast_as(shape.dims())?;
|
||||
let m = mask.where_cond(&on_true, on_false)?;
|
||||
Ok(m)
|
||||
Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype)
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
@ -252,7 +243,6 @@ struct MHA {
|
|||
rotary_emb: RotaryEmbedding,
|
||||
kv_cache: Option<(Tensor, Tensor)>,
|
||||
head_dim: usize,
|
||||
n_head: usize,
|
||||
softmax_scale: f64,
|
||||
span: tracing::Span,
|
||||
span_rope: tracing::Span,
|
||||
|
@ -273,7 +263,6 @@ impl MHA {
|
|||
wqkv,
|
||||
out_proj,
|
||||
head_dim,
|
||||
n_head: cfg.n_head,
|
||||
kv_cache: None,
|
||||
rotary_emb,
|
||||
softmax_scale,
|
||||
|
@ -321,11 +310,7 @@ impl MHA {
|
|||
None => attn_weights,
|
||||
Some(mask) => {
|
||||
let _enter = self.span_mask.enter();
|
||||
masked_fill(
|
||||
&attn_weights,
|
||||
&mask.broadcast_left(b_size * self.n_head)?,
|
||||
f32::NEG_INFINITY,
|
||||
)?
|
||||
attn_weights.broadcast_add(mask)?
|
||||
}
|
||||
};
|
||||
let attn_weights = {
|
||||
|
@ -435,7 +420,7 @@ impl MixFormerSequentialForCausalLM {
|
|||
let mask = if seq_len <= 1 {
|
||||
None
|
||||
} else {
|
||||
Some(get_mask(seq_len, xs.device())?)
|
||||
Some(get_mask(seq_len, xs.dtype(), xs.device())?)
|
||||
};
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
|
@ -456,7 +441,7 @@ impl MixFormerSequentialForCausalLM {
|
|||
// https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56
|
||||
let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
|
||||
let (_b_size, seq_len, _embds) = xs.dims3()?;
|
||||
let mask = Some(get_mask(seq_len, xs.device())?);
|
||||
let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?);
|
||||
for block in self.blocks.iter_mut() {
|
||||
xs = block.forward(&xs, mask.as_ref())?
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue