parent
2ac302a5d1
commit
88f7793598
|
@ -198,6 +198,7 @@ struct MLP {
|
|||
fc1: Linear,
|
||||
fc2: Linear,
|
||||
act: Activation,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl MLP {
|
||||
|
@ -209,12 +210,14 @@ impl MLP {
|
|||
fc1,
|
||||
fc2,
|
||||
act: cfg.activation_function,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mlp"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for MLP {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
@ -252,6 +255,9 @@ struct MHA {
|
|||
n_head: usize,
|
||||
softmax_scale: f64,
|
||||
span: tracing::Span,
|
||||
span_rope: tracing::Span,
|
||||
span_mask: tracing::Span,
|
||||
span_softmax: tracing::Span,
|
||||
}
|
||||
|
||||
impl MHA {
|
||||
|
@ -272,6 +278,9 @@ impl MHA {
|
|||
rotary_emb,
|
||||
softmax_scale,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||
span_rope: tracing::span!(tracing::Level::TRACE, "rope"),
|
||||
span_mask: tracing::span!(tracing::Level::TRACE, "mask"),
|
||||
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -287,7 +296,10 @@ impl MHA {
|
|||
Some((prev_k, _)) => prev_k.dim(1)?,
|
||||
};
|
||||
// In the python implementation, a single tensor is returned with the third axis of size 3.
|
||||
let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
|
||||
let (q, k, v) = {
|
||||
let _enter = self.span_rope.enter();
|
||||
self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?
|
||||
};
|
||||
let (k, v) = match &self.kv_cache {
|
||||
None => (k, v),
|
||||
Some((prev_k, prev_v)) => {
|
||||
|
@ -307,13 +319,19 @@ impl MHA {
|
|||
// scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||
let attn_weights = match mask {
|
||||
None => attn_weights,
|
||||
Some(mask) => masked_fill(
|
||||
Some(mask) => {
|
||||
let _enter = self.span_mask.enter();
|
||||
masked_fill(
|
||||
&attn_weights,
|
||||
&mask.broadcast_left(b_size * self.n_head)?,
|
||||
f32::NEG_INFINITY,
|
||||
)?,
|
||||
)?
|
||||
}
|
||||
};
|
||||
let attn_weights = {
|
||||
let _enter = self.span_softmax.enter();
|
||||
candle_nn::ops::softmax_last_dim(&attn_weights)?
|
||||
};
|
||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
||||
|
||||
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||
// attn_weights: b*h,t,s, v: b*h,s,d
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
|
||||
use candle::{IndexOp, Result, Tensor, D};
|
||||
use candle_nn::{layer_norm, linear_b, Linear, Module, VarBuilder};
|
||||
use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
|
||||
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||
use candle_nn::VarBuilder;
|
||||
|
||||
pub struct Config {
|
||||
pub phi_config: PhiConfig,
|
||||
|
@ -76,6 +77,7 @@ struct Attention {
|
|||
head_dim: usize,
|
||||
qkv: Linear,
|
||||
proj: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Attention {
|
||||
|
@ -87,12 +89,14 @@ impl Attention {
|
|||
head_dim: dim / num_heads,
|
||||
qkv,
|
||||
proj,
|
||||
span: tracing::span!(tracing::Level::TRACE, "vit-attn"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Attention {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let (b, n, c) = xs.dims3()?;
|
||||
let qkv = xs
|
||||
.apply(&self.qkv)?
|
||||
|
@ -114,8 +118,9 @@ impl Module for Attention {
|
|||
struct VitBlock {
|
||||
attn: Attention,
|
||||
mlp: Mlp,
|
||||
norm1: candle_nn::LayerNorm,
|
||||
norm2: candle_nn::LayerNorm,
|
||||
norm1: LayerNorm,
|
||||
norm2: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl VitBlock {
|
||||
|
@ -129,12 +134,14 @@ impl VitBlock {
|
|||
mlp,
|
||||
norm1,
|
||||
norm2,
|
||||
span: tracing::span!(tracing::Level::TRACE, "vit-block"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VitBlock {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
|
||||
let xs = (xs + &ys)?;
|
||||
let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
|
||||
|
@ -148,7 +155,8 @@ struct VisionTransformer {
|
|||
patch_embed: LinearPatchEmbedding,
|
||||
pos_embed: Tensor,
|
||||
blocks: Vec<VitBlock>,
|
||||
norm: candle_nn::LayerNorm,
|
||||
norm: LayerNorm,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl VisionTransformer {
|
||||
|
@ -171,12 +179,14 @@ impl VisionTransformer {
|
|||
pos_embed,
|
||||
blocks,
|
||||
norm,
|
||||
span: tracing::span!(tracing::Level::TRACE, "vit"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for VisionTransformer {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
|
||||
for block in self.blocks.iter() {
|
||||
xs = xs.apply(block)?;
|
||||
|
@ -208,6 +218,7 @@ struct Mlp {
|
|||
fc1: Linear,
|
||||
act: candle_nn::Activation,
|
||||
fc2: Linear,
|
||||
span: tracing::Span,
|
||||
}
|
||||
|
||||
impl Mlp {
|
||||
|
@ -220,12 +231,18 @@ impl Mlp {
|
|||
) -> Result<Self> {
|
||||
let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
|
||||
let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
|
||||
Ok(Self { fc1, act, fc2 })
|
||||
Ok(Self {
|
||||
fc1,
|
||||
act,
|
||||
fc2,
|
||||
span: tracing::span!(tracing::Level::TRACE, "mlp"),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl Module for Mlp {
|
||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||
let _enter = self.span.enter();
|
||||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue