fix: qwen2 lm_head loading #2443 (#2445)

Co-authored-by: Yi Xu <xuyi@me.com>
This commit is contained in:
ilookee 2024-08-23 22:50:02 +08:00 committed by GitHub
parent ccdbe87639
commit fdc2622686
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 1 additions and 1 deletions

View File

@ -361,7 +361,7 @@ pub struct ModelForCausalLM {
impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let base_model = Model::new(cfg, vb.clone())?;
let lm_head = if vb.contains_tensor("lm_head") {
let lm_head = if vb.contains_tensor("lm_head.weight") {
linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?
} else {
Linear::from_weights(base_model.embed_tokens.embeddings().clone(), None)