Moondream tracing. (#2016)

* Moondream tracing.

* A bit more tracing.
This commit is contained in:
Laurent Mazare 2024-04-05 09:11:08 +02:00 committed by GitHub
parent 2ac302a5d1
commit 88f7793598
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 48 additions and 13 deletions

View File

@ -198,6 +198,7 @@ struct MLP {
fc1: Linear, fc1: Linear,
fc2: Linear, fc2: Linear,
act: Activation, act: Activation,
span: tracing::Span,
} }
impl MLP { impl MLP {
@ -209,12 +210,14 @@ impl MLP {
fc1, fc1,
fc2, fc2,
act: cfg.activation_function, act: cfg.activation_function,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
}) })
} }
} }
impl Module for MLP { impl Module for MLP {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
} }
} }
@ -252,6 +255,9 @@ struct MHA {
n_head: usize, n_head: usize,
softmax_scale: f64, softmax_scale: f64,
span: tracing::Span, span: tracing::Span,
span_rope: tracing::Span,
span_mask: tracing::Span,
span_softmax: tracing::Span,
} }
impl MHA { impl MHA {
@ -272,6 +278,9 @@ impl MHA {
rotary_emb, rotary_emb,
softmax_scale, softmax_scale,
span: tracing::span!(tracing::Level::TRACE, "mha"), span: tracing::span!(tracing::Level::TRACE, "mha"),
span_rope: tracing::span!(tracing::Level::TRACE, "rope"),
span_mask: tracing::span!(tracing::Level::TRACE, "mask"),
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
}) })
} }
@ -287,7 +296,10 @@ impl MHA {
Some((prev_k, _)) => prev_k.dim(1)?, Some((prev_k, _)) => prev_k.dim(1)?,
}; };
// In the python implementation, a single tensor is returned with the third axis of size 3. // In the python implementation, a single tensor is returned with the third axis of size 3.
let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?; let (q, k, v) = {
let _enter = self.span_rope.enter();
self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?
};
let (k, v) = match &self.kv_cache { let (k, v) = match &self.kv_cache {
None => (k, v), None => (k, v),
Some((prev_k, prev_v)) => { Some((prev_k, prev_v)) => {
@ -307,13 +319,19 @@ impl MHA {
// scores = scores + causal_mask.to(dtype=scores.dtype) // scores = scores + causal_mask.to(dtype=scores.dtype)
let attn_weights = match mask { let attn_weights = match mask {
None => attn_weights, None => attn_weights,
Some(mask) => masked_fill( Some(mask) => {
&attn_weights, let _enter = self.span_mask.enter();
&mask.broadcast_left(b_size * self.n_head)?, masked_fill(
f32::NEG_INFINITY, &attn_weights,
)?, &mask.broadcast_left(b_size * self.n_head)?,
f32::NEG_INFINITY,
)?
}
};
let attn_weights = {
let _enter = self.span_softmax.enter();
candle_nn::ops::softmax_last_dim(&attn_weights)?
}; };
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v) // output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
// attn_weights: b*h,t,s, v: b*h,s,d // attn_weights: b*h,t,s, v: b*h,s,d

View File

@ -1,6 +1,7 @@
use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel}; use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
use candle::{IndexOp, Result, Tensor, D}; use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
use candle_nn::{layer_norm, linear_b, Linear, Module, VarBuilder}; use candle::{IndexOp, Module, Result, Tensor, D};
use candle_nn::VarBuilder;
pub struct Config { pub struct Config {
pub phi_config: PhiConfig, pub phi_config: PhiConfig,
@ -76,6 +77,7 @@ struct Attention {
head_dim: usize, head_dim: usize,
qkv: Linear, qkv: Linear,
proj: Linear, proj: Linear,
span: tracing::Span,
} }
impl Attention { impl Attention {
@ -87,12 +89,14 @@ impl Attention {
head_dim: dim / num_heads, head_dim: dim / num_heads,
qkv, qkv,
proj, proj,
span: tracing::span!(tracing::Level::TRACE, "vit-attn"),
}) })
} }
} }
impl Module for Attention { impl Module for Attention {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let (b, n, c) = xs.dims3()?; let (b, n, c) = xs.dims3()?;
let qkv = xs let qkv = xs
.apply(&self.qkv)? .apply(&self.qkv)?
@ -114,8 +118,9 @@ impl Module for Attention {
struct VitBlock { struct VitBlock {
attn: Attention, attn: Attention,
mlp: Mlp, mlp: Mlp,
norm1: candle_nn::LayerNorm, norm1: LayerNorm,
norm2: candle_nn::LayerNorm, norm2: LayerNorm,
span: tracing::Span,
} }
impl VitBlock { impl VitBlock {
@ -129,12 +134,14 @@ impl VitBlock {
mlp, mlp,
norm1, norm1,
norm2, norm2,
span: tracing::span!(tracing::Level::TRACE, "vit-block"),
}) })
} }
} }
impl Module for VitBlock { impl Module for VitBlock {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let ys = xs.apply(&self.norm1)?.apply(&self.attn)?; let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
let xs = (xs + &ys)?; let xs = (xs + &ys)?;
let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?; let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
@ -148,7 +155,8 @@ struct VisionTransformer {
patch_embed: LinearPatchEmbedding, patch_embed: LinearPatchEmbedding,
pos_embed: Tensor, pos_embed: Tensor,
blocks: Vec<VitBlock>, blocks: Vec<VitBlock>,
norm: candle_nn::LayerNorm, norm: LayerNorm,
span: tracing::Span,
} }
impl VisionTransformer { impl VisionTransformer {
@ -171,12 +179,14 @@ impl VisionTransformer {
pos_embed, pos_embed,
blocks, blocks,
norm, norm,
span: tracing::span!(tracing::Level::TRACE, "vit"),
}) })
} }
} }
impl Module for VisionTransformer { impl Module for VisionTransformer {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?; let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
for block in self.blocks.iter() { for block in self.blocks.iter() {
xs = xs.apply(block)?; xs = xs.apply(block)?;
@ -208,6 +218,7 @@ struct Mlp {
fc1: Linear, fc1: Linear,
act: candle_nn::Activation, act: candle_nn::Activation,
fc2: Linear, fc2: Linear,
span: tracing::Span,
} }
impl Mlp { impl Mlp {
@ -220,12 +231,18 @@ impl Mlp {
) -> Result<Self> { ) -> Result<Self> {
let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?; let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?; let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
Ok(Self { fc1, act, fc2 }) Ok(Self {
fc1,
act,
fc2,
span: tracing::span!(tracing::Level::TRACE, "mlp"),
})
} }
} }
impl Module for Mlp { impl Module for Mlp {
fn forward(&self, xs: &Tensor) -> Result<Tensor> { fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let _enter = self.span.enter();
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
} }
} }