Moondream tracing. (#2016)

* Moondream tracing.

* A bit more tracing.
This commit is contained in:
Laurent Mazare 2024-04-05 09:11:08 +02:00 committed by GitHub
parent 2ac302a5d1
commit 88f7793598
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 13 deletions

View File

@ -198,6 +198,7 @@ struct MLP {
fc1: Linear,
fc2: Linear,
act: Activation,
span: tracing::Span,
}
impl MLP {
@ -209,12 +210,14 @@ impl MLP {
fc1,
fc2,
act: cfg.activation_function,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
}
}
impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
}
}
@ -252,6 +255,9 @@ struct MHA {
n_head: usize,
softmax_scale: f64,
span: tracing::Span,
span_rope: tracing::Span,
span_mask: tracing::Span,
span_softmax: tracing::Span,
}
impl MHA {
@ -272,6 +278,9 @@ impl MHA {
rotary_emb,
softmax_scale,
span: tracing::span!(tracing::Level::TRACE, "mha"),
span_rope: tracing::span!(tracing::Level::TRACE, "rope"),
span_mask: tracing::span!(tracing::Level::TRACE, "mask"),
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
})
}
@ -287,7 +296,10 @@ impl MHA {
Some((prev_k, _)) => prev_k.dim(1)?,
};
// In the python implementation, a single tensor is returned with the third axis of size 3.
let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
let (q, k, v) = {
let _enter = self.span_rope.enter();
self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?
};
let (k, v) = match &self.kv_cache {
None => (k, v),
Some((prev_k, prev_v)) => {
@ -307,13 +319,19 @@ impl MHA {
// scores = scores + causal_mask.to(dtype=scores.dtype)
let attn_weights = match mask {
None => attn_weights,
Some(mask) => masked_fill(
Some(mask) => {
let _enter = self.span_mask.enter();
masked_fill(
&attn_weights,
&mask.broadcast_left(b_size * self.n_head)?,
f32::NEG_INFINITY,
)?,
)?
}
};
let attn_weights = {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn_weights)?
};
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
// attn_weights: b*h,t,s, v: b*h,s,d

View File

@ -1,6 +1,7 @@
use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
use candle::{IndexOp, Result, Tensor, D};
use candle_nn::{layer_norm, linear_b, Linear, Module, VarBuilder};
use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
pub struct Config {
pub phi_config: PhiConfig,
@ -76,6 +77,7 @@ struct Attention {
head_dim: usize,
qkv: Linear,
proj: Linear,
span: tracing::Span,
}
impl Attention {
@ -87,12 +89,14 @@ impl Attention {
head_dim: dim / num_heads,
qkv,
proj,
span: tracing::span!(tracing::Level::TRACE, "vit-attn"),
})
}
}
impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, n, c) = xs.dims3()?;
let qkv = xs
.apply(&self.qkv)?
@ -114,8 +118,9 @@ impl Module for Attention {
struct VitBlock {
attn: Attention,
mlp: Mlp,
norm1: candle_nn::LayerNorm,
norm2: candle_nn::LayerNorm,
norm1: LayerNorm,
norm2: LayerNorm,
span: tracing::Span,
}
impl VitBlock {
@ -129,12 +134,14 @@ impl VitBlock {
mlp,
norm1,
norm2,
span: tracing::span!(tracing::Level::TRACE, "vit-block"),
})
}
}
impl Module for VitBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
let xs = (xs + &ys)?;
let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
@ -148,7 +155,8 @@ struct VisionTransformer {
patch_embed: LinearPatchEmbedding,
pos_embed: Tensor,
blocks: Vec<VitBlock>,
norm: candle_nn::LayerNorm,
norm: LayerNorm,
span: tracing::Span,
}
impl VisionTransformer {
@ -171,12 +179,14 @@ impl VisionTransformer {
pos_embed,
blocks,
norm,
span: tracing::span!(tracing::Level::TRACE, "vit"),
})
}
}
impl Module for VisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
for block in self.blocks.iter() {
xs = xs.apply(block)?;
@ -208,6 +218,7 @@ struct Mlp {
fc1: Linear,
act: candle_nn::Activation,
fc2: Linear,
span: tracing::Span,
}
impl Mlp {
@ -220,12 +231,18 @@ impl Mlp {
) -> Result<Self> {
let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
Ok(Self { fc1, act, fc2 })
Ok(Self {
fc1,
act,
fc2,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
}
}
impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
}
}