From 3754b834f42cc49360aed7a1095d2ff8f732858b Mon Sep 17 00:00:00 2001 From: laurent Date: Wed, 17 Apr 2024 10:23:15 +0200 Subject: [PATCH] More prep work for phi. --- .../src/models/quantized_llama.rs | 179 +++++++++++++----- 1 file changed, 130 insertions(+), 49 deletions(-) diff --git a/candle-transformers/src/models/quantized_llama.rs b/candle-transformers/src/models/quantized_llama.rs index bac9e7e7..d16300c6 100644 --- a/candle-transformers/src/models/quantized_llama.rs +++ b/candle-transformers/src/models/quantized_llama.rs @@ -1,6 +1,5 @@ use std::collections::HashMap; -use crate::quantized_nn::RmsNorm; use candle::quantized::QTensor; use candle::quantized::{ggml_file, gguf_file}; use candle::{DType, Device, IndexOp, Result, Tensor}; @@ -140,15 +139,42 @@ impl Module for Mlp { } } +#[derive(Debug, Clone)] +enum Norm { + Rms(crate::quantized_nn::RmsNorm), + Layer(candle_nn::LayerNorm), +} + +impl Module for Norm { + fn forward(&self, xs: &Tensor) -> Result { + match self { + Self::Rms(m) => m.forward(xs), + Self::Layer(m) => m.forward(xs), + } + } +} + +fn rms_norm(q: QTensor, eps: f64) -> Result { + let rms = crate::quantized_nn::RmsNorm::from_qtensor(q, eps)?; + Ok(Norm::Rms(rms)) +} + +fn layer_norm(w: QTensor, b: QTensor, eps: f64) -> Result { + let w = w.dequantize(&w.device())?; + let b = b.dequantize(&b.device())?; + let ln = candle_nn::LayerNorm::new(w, b, eps); + Ok(Norm::Layer(ln)) +} + #[derive(Debug, Clone)] struct LayerWeights { attention_wq: QMatMul, attention_wk: QMatMul, attention_wv: QMatMul, attention_wo: QMatMul, - attention_norm: RmsNorm, + attention_norm: Norm, mlp: Mlp, - ffn_norm: RmsNorm, + ffn_norm: Norm, n_head: usize, n_kv_head: usize, head_dim: usize, @@ -246,7 +272,7 @@ impl LayerWeights { pub struct ModelWeights { tok_embeddings: Embedding, layers: Vec, - norm: RmsNorm, + norm: Norm, output: QMatMul, masks: HashMap, span: tracing::Span, @@ -272,6 +298,12 @@ fn precomput_freqs_cis( Ok((cos, sin)) } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Architecture { + Llama, + Phi2, +} + #[derive(Debug, Clone)] struct MetadataConfig { n_expert: usize, @@ -283,6 +315,7 @@ struct MetadataConfig { rope_dim: usize, rms_norm_eps: f64, rope_freq_base: f32, + architecture: Architecture, } impl MetadataConfig { @@ -292,35 +325,69 @@ impl MetadataConfig { Some(v) => Ok(v), }; - // Parameter extraction from metadata. - let n_expert = md_get("llama.expert_count") - .and_then(|v| v.to_u32()) - .unwrap_or(0) as usize; - let n_expert_used = md_get("llama.expert_used_count") - .and_then(|v| v.to_u32()) - .unwrap_or(0) as usize; - let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; - let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; - let block_count = md_get("llama.block_count")?.to_u32()? as usize; - let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; - let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; - // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. - let rms_norm_eps = md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + let architecture = match md_get("general.architecture") + .and_then(|v| v.to_string()) + .map(|v| v.as_str()) + { + Ok("phi2") => Architecture::Phi2, + Err(_) | Ok(_) => Architecture::Llama, + }; - let rope_freq_base = md_get("llama.rope.freq_base") - .and_then(|m| m.to_f32()) - .unwrap_or(10000f32); - Ok(Self { - n_expert, - n_expert_used, - head_count, - head_count_kv, - block_count, - embedding_length, - rope_freq_base, - rope_dim, - rms_norm_eps, - }) + let config = match architecture { + Architecture::Phi2 => { + let head_count = md_get("phi2.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("phi2.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("phi2.block_count")?.to_u32()? as usize; + let embedding_length = md_get("phi2.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("phi2.rope.dimension_count")?.to_u32()? as usize; + let rms_norm_eps = md_get("phi2.attention.layer_norm_epsilon")?.to_f32()? as f64; + Self { + n_expert: 1, + n_expert_used: 1, + head_count, + head_count_kv, + block_count, + embedding_length, + rope_freq_base: 10_000., + rope_dim, + rms_norm_eps, + architecture, + } + } + Architecture::Llama => { + let n_expert = md_get("llama.expert_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let n_expert_used = md_get("llama.expert_used_count") + .and_then(|v| v.to_u32()) + .unwrap_or(0) as usize; + let head_count = md_get("llama.attention.head_count")?.to_u32()? as usize; + let head_count_kv = md_get("llama.attention.head_count_kv")?.to_u32()? as usize; + let block_count = md_get("llama.block_count")?.to_u32()? as usize; + let embedding_length = md_get("llama.embedding_length")?.to_u32()? as usize; + let rope_dim = md_get("llama.rope.dimension_count")?.to_u32()? as usize; + // Strangely this value is generally 1e-6 in GGUF file but used to be 1e-5 by default. + let rms_norm_eps = + md_get("llama.attention.layer_norm_rms_epsilon")?.to_f32()? as f64; + + let rope_freq_base = md_get("llama.rope.freq_base") + .and_then(|m| m.to_f32()) + .unwrap_or(10000f32); + Self { + n_expert, + n_expert_used, + head_count, + head_count_kv, + block_count, + embedding_length, + rope_freq_base, + rope_dim, + rms_norm_eps, + architecture, + } + } + }; + Ok(config) } } @@ -331,7 +398,7 @@ impl ModelWeights { let neg_inf = Tensor::new(f32::NEG_INFINITY, &ct.device)?; let tok_embeddings = ct.remove("tok_embeddings.weight")?; let tok_embeddings = tok_embeddings.dequantize(&ct.device)?; - let norm = RmsNorm::from_qtensor(ct.remove("norm.weight")?, 1e-5)?; + let norm = rms_norm(ct.remove("norm.weight")?, 1e-5)?; let output = ct.remove("output.weight")?; let mut layers = Vec::with_capacity(ct.hparams.n_layer as usize); for layer_idx in 0..ct.hparams.n_layer { @@ -360,9 +427,9 @@ impl ModelWeights { attention_wk: QMatMul::from_qtensor(attention_wk)?, attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, - attention_norm: RmsNorm::from_qtensor(attention_norm, 1e-5)?, + attention_norm: rms_norm(attention_norm, 1e-5)?, mlp, - ffn_norm: RmsNorm::from_qtensor(ffn_norm, 1e-5)?, + ffn_norm: rms_norm(ffn_norm, 1e-5)?, n_head: ct.hparams.n_head as usize, n_kv_head: ct.hparams.n_head as usize / gqa, head_dim: (ct.hparams.n_embd / ct.hparams.n_head) as usize, @@ -400,7 +467,7 @@ impl ModelWeights { let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?; let tok_embeddings = tok_embeddings.dequantize(device)?; - let norm = RmsNorm::from_qtensor( + let norm = rms_norm( ct.tensor(reader, "output_norm.weight", device)?, cfg.rms_norm_eps, )?; @@ -414,17 +481,31 @@ impl ModelWeights { let attention_wo = ct.tensor(reader, &format!("{prefix}.attn_output.weight"), device)?; let mlp = if cfg.n_expert <= 1 { - let feed_forward_w1 = - ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; - let feed_forward_w2 = - ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; - let feed_forward_w3 = - ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; - Mlp::Silu(MlpSilu { - feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, - feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, - feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, - }) + match cfg.architecture { + Architecture::Llama => { + let feed_forward_w1 = + ct.tensor(reader, &format!("{prefix}.ffn_gate.weight"), device)?; + let feed_forward_w2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + let feed_forward_w3 = + ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + Mlp::Silu(MlpSilu { + feed_forward_w1: QMatMul::from_qtensor(feed_forward_w1)?, + feed_forward_w2: QMatMul::from_qtensor(feed_forward_w2)?, + feed_forward_w3: QMatMul::from_qtensor(feed_forward_w3)?, + }) + } + Architecture::Phi2 => { + let fc1 = ct.tensor(reader, &format!("{prefix}.ffn_up.weight"), device)?; + let fc2 = + ct.tensor(reader, &format!("{prefix}.ffn_down.weight"), device)?; + Mlp::Simple(MlpSimple { + fc1: QMatMul::from_qtensor(fc1)?, + fc2: QMatMul::from_qtensor(fc2)?, + act: candle_nn::Activation::NewGelu, + }) + } + } } else { let feed_forward_gate_inp = ct.tensor(reader, &format!("{prefix}.ffn_gate_inp.weight"), device)?; @@ -459,9 +540,9 @@ impl ModelWeights { attention_wk: QMatMul::from_qtensor(attention_wk)?, attention_wv: QMatMul::from_qtensor(attention_wv)?, attention_wo: QMatMul::from_qtensor(attention_wo)?, - attention_norm: RmsNorm::from_qtensor(attention_norm, cfg.rms_norm_eps)?, + attention_norm: rms_norm(attention_norm, cfg.rms_norm_eps)?, mlp, - ffn_norm: RmsNorm::from_qtensor(ffn_norm, cfg.rms_norm_eps)?, + ffn_norm: rms_norm(ffn_norm, cfg.rms_norm_eps)?, n_head: cfg.head_count, n_kv_head: cfg.head_count_kv, head_dim: cfg.embedding_length / cfg.head_count,