parent
2ac302a5d1
commit
88f7793598
|
@ -198,6 +198,7 @@ struct MLP {
|
||||||
fc1: Linear,
|
fc1: Linear,
|
||||||
fc2: Linear,
|
fc2: Linear,
|
||||||
act: Activation,
|
act: Activation,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MLP {
|
impl MLP {
|
||||||
|
@ -209,12 +210,14 @@ impl MLP {
|
||||||
fc1,
|
fc1,
|
||||||
fc2,
|
fc2,
|
||||||
act: cfg.activation_function,
|
act: cfg.activation_function,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "mlp"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for MLP {
|
impl Module for MLP {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -252,6 +255,9 @@ struct MHA {
|
||||||
n_head: usize,
|
n_head: usize,
|
||||||
softmax_scale: f64,
|
softmax_scale: f64,
|
||||||
span: tracing::Span,
|
span: tracing::Span,
|
||||||
|
span_rope: tracing::Span,
|
||||||
|
span_mask: tracing::Span,
|
||||||
|
span_softmax: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl MHA {
|
impl MHA {
|
||||||
|
@ -272,6 +278,9 @@ impl MHA {
|
||||||
rotary_emb,
|
rotary_emb,
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
span: tracing::span!(tracing::Level::TRACE, "mha"),
|
||||||
|
span_rope: tracing::span!(tracing::Level::TRACE, "rope"),
|
||||||
|
span_mask: tracing::span!(tracing::Level::TRACE, "mask"),
|
||||||
|
span_softmax: tracing::span!(tracing::Level::TRACE, "softmax"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -287,7 +296,10 @@ impl MHA {
|
||||||
Some((prev_k, _)) => prev_k.dim(1)?,
|
Some((prev_k, _)) => prev_k.dim(1)?,
|
||||||
};
|
};
|
||||||
// In the python implementation, a single tensor is returned with the third axis of size 3.
|
// In the python implementation, a single tensor is returned with the third axis of size 3.
|
||||||
let (q, k, v) = self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?;
|
let (q, k, v) = {
|
||||||
|
let _enter = self.span_rope.enter();
|
||||||
|
self.rotary_emb.apply_rotary_emb_qkv(&qkv, seqlen_offset)?
|
||||||
|
};
|
||||||
let (k, v) = match &self.kv_cache {
|
let (k, v) = match &self.kv_cache {
|
||||||
None => (k, v),
|
None => (k, v),
|
||||||
Some((prev_k, prev_v)) => {
|
Some((prev_k, prev_v)) => {
|
||||||
|
@ -307,13 +319,19 @@ impl MHA {
|
||||||
// scores = scores + causal_mask.to(dtype=scores.dtype)
|
// scores = scores + causal_mask.to(dtype=scores.dtype)
|
||||||
let attn_weights = match mask {
|
let attn_weights = match mask {
|
||||||
None => attn_weights,
|
None => attn_weights,
|
||||||
Some(mask) => masked_fill(
|
Some(mask) => {
|
||||||
&attn_weights,
|
let _enter = self.span_mask.enter();
|
||||||
&mask.broadcast_left(b_size * self.n_head)?,
|
masked_fill(
|
||||||
f32::NEG_INFINITY,
|
&attn_weights,
|
||||||
)?,
|
&mask.broadcast_left(b_size * self.n_head)?,
|
||||||
|
f32::NEG_INFINITY,
|
||||||
|
)?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let attn_weights = {
|
||||||
|
let _enter = self.span_softmax.enter();
|
||||||
|
candle_nn::ops::softmax_last_dim(&attn_weights)?
|
||||||
};
|
};
|
||||||
let attn_weights = candle_nn::ops::softmax_last_dim(&attn_weights)?;
|
|
||||||
|
|
||||||
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
// output = torch.einsum('bhts,bshd->bthd', attention_drop, v)
|
||||||
// attn_weights: b*h,t,s, v: b*h,s,d
|
// attn_weights: b*h,t,s, v: b*h,s,d
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
|
use crate::models::mixformer::{Config as PhiConfig, MixFormerSequentialForCausalLM as PhiModel};
|
||||||
use candle::{IndexOp, Result, Tensor, D};
|
use crate::models::with_tracing::{layer_norm, linear_b, LayerNorm, Linear};
|
||||||
use candle_nn::{layer_norm, linear_b, Linear, Module, VarBuilder};
|
use candle::{IndexOp, Module, Result, Tensor, D};
|
||||||
|
use candle_nn::VarBuilder;
|
||||||
|
|
||||||
pub struct Config {
|
pub struct Config {
|
||||||
pub phi_config: PhiConfig,
|
pub phi_config: PhiConfig,
|
||||||
|
@ -76,6 +77,7 @@ struct Attention {
|
||||||
head_dim: usize,
|
head_dim: usize,
|
||||||
qkv: Linear,
|
qkv: Linear,
|
||||||
proj: Linear,
|
proj: Linear,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Attention {
|
impl Attention {
|
||||||
|
@ -87,12 +89,14 @@ impl Attention {
|
||||||
head_dim: dim / num_heads,
|
head_dim: dim / num_heads,
|
||||||
qkv,
|
qkv,
|
||||||
proj,
|
proj,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "vit-attn"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Attention {
|
impl Module for Attention {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let (b, n, c) = xs.dims3()?;
|
let (b, n, c) = xs.dims3()?;
|
||||||
let qkv = xs
|
let qkv = xs
|
||||||
.apply(&self.qkv)?
|
.apply(&self.qkv)?
|
||||||
|
@ -114,8 +118,9 @@ impl Module for Attention {
|
||||||
struct VitBlock {
|
struct VitBlock {
|
||||||
attn: Attention,
|
attn: Attention,
|
||||||
mlp: Mlp,
|
mlp: Mlp,
|
||||||
norm1: candle_nn::LayerNorm,
|
norm1: LayerNorm,
|
||||||
norm2: candle_nn::LayerNorm,
|
norm2: LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VitBlock {
|
impl VitBlock {
|
||||||
|
@ -129,12 +134,14 @@ impl VitBlock {
|
||||||
mlp,
|
mlp,
|
||||||
norm1,
|
norm1,
|
||||||
norm2,
|
norm2,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "vit-block"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for VitBlock {
|
impl Module for VitBlock {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
|
let ys = xs.apply(&self.norm1)?.apply(&self.attn)?;
|
||||||
let xs = (xs + &ys)?;
|
let xs = (xs + &ys)?;
|
||||||
let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
|
let ys = xs.apply(&self.norm2)?.apply(&self.mlp)?;
|
||||||
|
@ -148,7 +155,8 @@ struct VisionTransformer {
|
||||||
patch_embed: LinearPatchEmbedding,
|
patch_embed: LinearPatchEmbedding,
|
||||||
pos_embed: Tensor,
|
pos_embed: Tensor,
|
||||||
blocks: Vec<VitBlock>,
|
blocks: Vec<VitBlock>,
|
||||||
norm: candle_nn::LayerNorm,
|
norm: LayerNorm,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl VisionTransformer {
|
impl VisionTransformer {
|
||||||
|
@ -171,12 +179,14 @@ impl VisionTransformer {
|
||||||
pos_embed,
|
pos_embed,
|
||||||
blocks,
|
blocks,
|
||||||
norm,
|
norm,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "vit"),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for VisionTransformer {
|
impl Module for VisionTransformer {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
|
let mut xs = (&xs.apply(&self.patch_embed)? + &self.pos_embed)?;
|
||||||
for block in self.blocks.iter() {
|
for block in self.blocks.iter() {
|
||||||
xs = xs.apply(block)?;
|
xs = xs.apply(block)?;
|
||||||
|
@ -208,6 +218,7 @@ struct Mlp {
|
||||||
fc1: Linear,
|
fc1: Linear,
|
||||||
act: candle_nn::Activation,
|
act: candle_nn::Activation,
|
||||||
fc2: Linear,
|
fc2: Linear,
|
||||||
|
span: tracing::Span,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Mlp {
|
impl Mlp {
|
||||||
|
@ -220,12 +231,18 @@ impl Mlp {
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
|
let fc1 = linear_b(in_features, hidden_features, true, vb.pp("fc1"))?;
|
||||||
let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
|
let fc2 = linear_b(hidden_features, out_features, true, vb.pp("fc2"))?;
|
||||||
Ok(Self { fc1, act, fc2 })
|
Ok(Self {
|
||||||
|
fc1,
|
||||||
|
act,
|
||||||
|
fc2,
|
||||||
|
span: tracing::span!(tracing::Level::TRACE, "mlp"),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Module for Mlp {
|
impl Module for Mlp {
|
||||||
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
|
||||||
|
let _enter = self.span.enter();
|
||||||
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
xs.apply(&self.fc1)?.apply(&self.act)?.apply(&self.fc2)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue