More prep work for phi.

This commit is contained in:
laurent 2024-04-17 10:23:15 +02:00
parent d79041d94d
commit 3754b834f4
1 changed files with 130 additions and 49 deletions

View File

@ -1,6 +1,5 @@
use std::collections::HashMap;
use crate::quantized_nn::RmsNorm;
use candle::quantized::QTensor;
use candle::quantized::{ggml_file, gguf_file};
use candle::{DType, Device, IndexOp, Result, Tensor};
@ -140,15 +139,42 @@ impl Module for Mlp {
}
}
#[derive(Debug, Clone)]
enum Norm {
Rms(crate::quantized_nn::RmsNorm),
Layer(candle_nn::LayerNorm),
}
impl Module for Norm {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
match self {
Self::Rms(m) => m.forward(xs),
Self::Layer(m) => m.forward(xs),
}
}
}
fn rms_norm(q: QTensor, eps: f64) -> Result<Norm> {
let rms = crate::quantized_nn::RmsNorm::from_qtensor(q, eps)?;
Ok(Norm::Rms(rms))
}
fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result<Norm> {
let w = w.dequantize(&w.device())?;
let b = b.dequantize(&b.device())?;
let ln = candle_nn::LayerNorm::new(w, b, eps);
Ok(Norm::Layer(ln))
}
#[derive(Debug, Clone)]
struct LayerWeights {
attention_wq: QMatMul,
attention_wk: QMatMul,
attention_wv: QMatMul,
attention_wo: QMatMul,
attention_norm: RmsNorm,
attention_norm: Norm,
mlp: Mlp,
ffn_norm: RmsNorm,
ffn_norm: Norm,
n_head: usize,
n_kv_head: usize,
head_dim: usize,
@ -246,7 +272,7 @@ impl LayerWeights {
pub struct ModelWeights {
tok_embeddings: Embedding,
layers: Vec<LayerWeights>,
norm: RmsNorm,
norm: Norm,
output: QMatMul,
masks: HashMap<usize, Tensor>,
span: tracing::Span,
@ -272,6 +298,12 @@ fn precomput_freqs_cis(
Ok((cos, sin))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Architecture {
Llama,
Phi2,
}
#[derive(Debug, Clone)]
struct MetadataConfig {
n_expert: usize,
@ -283,6 +315,7 @@ struct MetadataConfig {
rope_dim: usize,
rms_norm_eps: f64,
rope_freq_base: f32,
architecture: Architecture,
}
impl MetadataConfig {
@ -292,35 +325,69 @@ impl MetadataConfig {
Some(v) => Ok(v),
};
// Parameter extraction from metadata.
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let architecture = match md_get("general.architecture")
.and_then(|v| v.to_string())
.map(|v| v.as_str())
{
Ok("phi2") => Architecture::Phi2,
Err(_) | Ok(_) => Architecture::Llama,
};
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
Ok(Self {
n_expert,
n_expert_used,
head_count,
head_count_kv,
block_count,
embedding_length,
rope_freq_base,
rope_dim,
rms_norm_eps,
})
let config = match architecture {
Architecture::Phi2 => {
let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("phi2.block_count")?.to_u32()? as usize;
let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize;
let rms_norm_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64;
Self {
n_expert: 1,
n_expert_used: 1,
head_count,
head_count_kv,
block_count,
embedding_length,
rope_freq_base: 10_000.,
rope_dim,
rms_norm_eps,
architecture,
}
}
Architecture::Llama => {
let n_expert = md_get("llama.expert_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let n_expert_used = md_get("llama.expert_used_count")
.and_then(|v| v.to_u32())
.unwrap_or(0) as usize;
let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize;
let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize;
let block_count = md_get("llama.block_count")?.to_u32()? as usize;
let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize;
let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize;
// Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default.
let rms_norm_eps =
md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64;
let rope_freq_base = md_get("llama.rope.freq_base")
.and_then(|m| m.to_f32())
.unwrap_or(10000f32);
Self {
n_expert,
n_expert_used,
head_count,
head_count_kv,
block_count,
embedding_length,
rope_freq_base,
rope_dim,
rms_norm_eps,
architecture,
}
}
};
Ok(config)
}
}
@ -331,7 +398,7 @@ impl ModelWeights {
let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?;
let tok_embeddings = ct.remove("tok_embeddings.weight")?;
let tok_embeddings = tok_embeddings.dequantize(&ct.device)?;
let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?;
let norm = rms_norm(ct.remove("norm.weight")?, 1e-5)?;
let output = ct.remove("output.weight")?;
let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize);
for layer_idx in 0..ct.hparams.n_layer {
@ -360,9 +427,9 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?,
attention_norm: rms_norm(attention_norm, 1e-5)?,
mlp,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?,
ffn_norm: rms_norm(ffn_norm, 1e-5)?,
n_head: ct.hparams.n_head as usize,
n_kv_head: ct.hparams.n_head as usize / gqa,
head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize,
@ -400,7 +467,7 @@ impl ModelWeights {
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
let tok_embeddings = tok_embeddings.dequantize(device)?;
let norm = RmsNorm::from_qtensor(
let norm = rms_norm(
ct.tensor(reader, "output_norm.weight", device)?,
cfg.rms_norm_eps,
)?;
@ -414,17 +481,31 @@ impl ModelWeights {
let attention_wo =
ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?;
let mlp = if cfg.n_expert <= 1 {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
Mlp::Silu(MlpSilu {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
match cfg.architecture {
Architecture::Llama => {
let feed_forward_w1 =
ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?;
let feed_forward_w2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
let feed_forward_w3 =
ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
Mlp::Silu(MlpSilu {
feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?,
feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?,
feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?,
})
}
Architecture::Phi2 => {
let fc1 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?;
let fc2 =
ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?;
Mlp::Simple(MlpSimple {
fc1: QMatMul::from_qtensor(fc1)?,
fc2: QMatMul::from_qtensor(fc2)?,
act: candle_nn::Activation::NewGelu,
})
}
}
} else {
let feed_forward_gate_inp =
ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?;
@ -459,9 +540,9 @@ impl ModelWeights {
attention_wk: QMatMul::from_qtensor(attention_wk)?,
attention_wv: QMatMul::from_qtensor(attention_wv)?,
attention_wo: QMatMul::from_qtensor(attention_wo)?,
attention_norm: RmsNorm::from_qtensor(attention_norm, cfg.rms_norm_eps)?,
attention_norm: rms_norm(attention_norm, cfg.rms_norm_eps)?,
mlp,
ffn_norm: RmsNorm::from_qtensor(ffn_norm, cfg.rms_norm_eps)?,
ffn_norm: rms_norm(ffn_norm, cfg.rms_norm_eps)?,
n_head: cfg.head_count,
n_kv_head: cfg.head_count_kv,
head_dim: cfg.embedding_length / cfg.head_count,