diff --git a/candle-transformers/src/models/mixformer.rs b/candle-transformers/src/models/mixformer.rs index e9451f0e..d9676a35 100644 --- a/candle-transformers/src/models/mixformer.rs +++ b/candle-transformers/src/models/mixformer.rs @@ -198,6 +198,7 @@ struct MLP { fc1: Linear, fc2: Linear, act: Activation, + span: tracing::Span, } impl MLP { @@ -209,12 +210,14 @@ impl MLP { fc1, fc2, act: cfg.activation_function, + span: tracing::span!(tracing::Level::TRACE, "mlp"), }) } } impl Module for MLP { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) } } @@ -252,6 +255,9 @@ struct MHA { n_head: usize, softmax_scale: f64, span: tracing::Span, + span_rope: tracing::Span, + span_mask: tracing::Span, + span_softmax: tracing::Span, } impl MHA { @@ -272,6 +278,9 @@ impl MHA { rotary_emb, softmax_scale, span: tracing::span!(tracing::Level::TRACE, "mha"), + span_rope: tracing::span!(tracing::Level::TRACE, "rope"), + span_mask: tracing::span!(tracing::Level::TRACE, "mask"), + span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"), }) } @@ -287,7 +296,10 @@ impl MHA { Some((prev_k, _)) => prev_k.dim(1)?, }; // In the python implementation, a single tensor is returned with the third axis of size 3. - let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?; + let (q, k, v) = { + let _enter = self.span_rope.enter(); + self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)? + }; let (k, v) = match &self.kv_cache { None => (k, v), Some((prev_k, prev_v)) => { @@ -307,13 +319,19 @@ impl MHA { // scores = scores + causal_mask.to(dtype=scores.dtype) let attn_weights = match mask { None => attn_weights, - Some(mask) => masked_fill( - &attn_weights, - &mask.broadcast_left(b_size * self.n_head)?, - f32::NEG_INFINITY, - )?, + Some(mask) => { + let _enter = self.span_mask.enter(); + masked_fill( + &attn_weights, + &mask.broadcast_left(b_size * self.n_head)?, + f32::NEG_INFINITY, + )? + } + }; + let attn_weights = { + let _enter = self.span_softmax.enter(); + candle_nn::ops::softmax_last_dim(&attn_weights)? }; - let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?; // output = torch.einsum('bhts,bshd->bthd', attention_drop, v) // attn_weights: b*h,t,s, v: b*h,s,d diff --git a/candle-transformers/src/models/moondream.rs b/candle-transformers/src/models/moondream.rs index 717f3bb4..7ad8c921 100644 --- a/candle-transformers/src/models/moondream.rs +++ b/candle-transformers/src/models/moondream.rs @@ -1,6 +1,7 @@ use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel}; -use candle::{IndexOp, Result, Tensor, D}; -use candle_nn::{layer_norm, linear_b, Linear, Module, VarBuilder}; +use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear}; +use candle::{IndexOp, Module, Result, Tensor, D}; +use candle_nn::VarBuilder; pub struct Config { pub phi_config: PhiConfig, @@ -76,6 +77,7 @@ struct Attention { head_dim: usize, qkv: Linear, proj: Linear, + span: tracing::Span, } impl Attention { @@ -87,12 +89,14 @@ impl Attention { head_dim: dim / num_heads, qkv, proj, + span: tracing::span!(tracing::Level::TRACE, "vit-attn"), }) } } impl Module for Attention { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let (b, n, c) = xs.dims3()?; let qkv = xs .apply(&self.qkv)? @@ -114,8 +118,9 @@ impl Module for Attention { struct VitBlock { attn: Attention, mlp: Mlp, - norm1: candle_nn::LayerNorm, - norm2: candle_nn::LayerNorm, + norm1: LayerNorm, + norm2: LayerNorm, + span: tracing::Span, } impl VitBlock { @@ -129,12 +134,14 @@ impl VitBlock { mlp, norm1, norm2, + span: tracing::span!(tracing::Level::TRACE, "vit-block"), }) } } impl Module for VitBlock { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let ys = xs.apply(&self.norm1)?.apply(&self.attn)?; let xs = (xs + &ys)?; let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?; @@ -148,7 +155,8 @@ struct VisionTransformer { patch_embed: LinearPatchEmbedding, pos_embed: Tensor, blocks: Vec, - norm: candle_nn::LayerNorm, + norm: LayerNorm, + span: tracing::Span, } impl VisionTransformer { @@ -171,12 +179,14 @@ impl VisionTransformer { pos_embed, blocks, norm, + span: tracing::span!(tracing::Level::TRACE, "vit"), }) } } impl Module for VisionTransformer { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?; for block in self.blocks.iter() { xs = xs.apply(block)?; @@ -208,6 +218,7 @@ struct Mlp { fc1: Linear, act: candle_nn::Activation, fc2: Linear, + span: tracing::Span, } impl Mlp { @@ -220,12 +231,18 @@ impl Mlp { ) -> Result { let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?; let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?; - Ok(Self { fc1, act, fc2 }) + Ok(Self { + fc1, + act, + fc2, + span: tracing::span!(tracing::Level::TRACE, "mlp"), + }) } } impl Module for Mlp { fn forward(&self, xs: &Tensor) -> Result { + let _enter = self.span.enter(); xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2) } }