Faster mask implementation for mixformers. (#2017)
* Faster mask implementation for mixformers. * Clippy.
This commit is contained in:
parent
88f7793598
commit
b869a659ec
|
@ -126,20 +126,11 @@ impl Module for Embedding {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_mask(size: usize, device: &Device) -> Result<Tensor> {
|
fn get_mask(size: usize, dtype: DType, device: &Device) -> Result<Tensor> {
|
||||||
let mask: Vec<_> = (0..size)
|
let mask: Vec<_> = (0..size)
|
||||||
.flat_map(|i| (0..size).map(move |j| u8::from(j > i)))
|
.flat_map(|i| (0..size).map(move |j| if j > i { f32::NEG_INFINITY } else { 0. }))
|
||||||
.collect();
|
.collect();
|
||||||
Tensor::from_slice(&mask, (size, size), device)
|
Tensor::from_slice(&mask, (size, size), device)?.to_dtype(dtype)
|
||||||
}
|
|
||||||
|
|
||||||
fn masked_fill(on_false: &Tensor, mask: &Tensor, on_true: f32) -> Result<Tensor> {
|
|
||||||
let shape = mask.shape();
|
|
||||||
let on_true = Tensor::new(on_true, on_false.device())?
|
|
||||||
.to_dtype(on_false.dtype())?
|
|
||||||
.broadcast_as(shape.dims())?;
|
|
||||||
let m = mask.where_cond(&on_true, on_false)?;
|
|
||||||
Ok(m)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
|
@ -252,7 +243,6 @@ struct MHA {
|
||||||
rotary_emb: RotaryEmbedding,
|
rotary_emb: RotaryEmbedding,
|
||||||
kv_cache: Option<(Tensor, Tensor)>,
|
kv_cache: Option<(Tensor, Tensor)>,
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
n_head: usize,
|
|
||||||
softmax_scale: f64,
|
softmax_scale: f64,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
span_rope: tracing::Span,
|
span_rope: tracing::Span,
|
||||||
|
@ -273,7 +263,6 @@ impl MHA {
|
||||||
wqkv,
|
wqkv,
|
||||||
out_proj,
|
out_proj,
|
||||||
head_dim,
|
head_dim,
|
||||||
n_head: cfg.n_head,
|
|
||||||
kv_cache: None,
|
kv_cache: None,
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
|
@ -321,11 +310,7 @@ impl MHA {
|
||||||
None => attn_weights,
|
None => attn_weights,
|
||||||
Some(mask) => {
|
Some(mask) => {
|
||||||
let _enter = self.span_mask.enter();
|
let _enter = self.span_mask.enter();
|
||||||
masked_fill(
|
attn_weights.broadcast_add(mask)?
|
||||||
&attn_weights,
|
|
||||||
&mask.broadcast_left(b_size * self.n_head)?,
|
|
||||||
f32::NEG_INFINITY,
|
|
||||||
)?
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
let attn_weights = {
|
let attn_weights = {
|
||||||
|
@ -435,7 +420,7 @@ impl MixFormerSequentialForCausalLM {
|
||||||
let mask = if seq_len <= 1 {
|
let mask = if seq_len <= 1 {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(get_mask(seq_len, xs.device())?)
|
Some(get_mask(seq_len, xs.dtype(), xs.device())?)
|
||||||
};
|
};
|
||||||
for block in self.blocks.iter_mut() {
|
for block in self.blocks.iter_mut() {
|
||||||
xs = block.forward(&xs, mask.as_ref())?
|
xs = block.forward(&xs, mask.as_ref())?
|
||||||
|
@ -456,7 +441,7 @@ impl MixFormerSequentialForCausalLM {
|
||||||
// https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56
|
// https://github.com/vikhyat/moondream/blob/a9d788a20d1543fb1479edc54106e88cff7759d3/moondream/moondream.py#L43-L56
|
||||||
let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
|
let mut xs = Tensor::cat(&[bos_token, img_embeds.clone(), xs], 1)?;
|
||||||
let (_b_size, seq_len, _embds) = xs.dims3()?;
|
let (_b_size, seq_len, _embds) = xs.dims3()?;
|
||||||
let mask = Some(get_mask(seq_len, xs.device())?);
|
let mask = Some(get_mask(seq_len, xs.dtype(), xs.device())?);
|
||||||
for block in self.blocks.iter_mut() {
|
for block in self.blocks.iter_mut() {
|
||||||
xs = block.forward(&xs, mask.as_ref())?
|
xs = block.forward(&xs, mask.as_ref())?
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue