From 33c9b6655459bd1086574cef9ba8f2e72a8804c8 Mon Sep 17 00:00:00 2001 From: Laurent Mazare Date: Sat, 6 Apr 2024 21:25:38 +0200 Subject: [PATCH] Add the new gemma models. (#2023) * Add the new gemma models. * Revert the lightning changes. * Support for the 1.1 models. --- candle-examples/examples/gemma/main.rs | 35 ++++++++++++++++++++----- candle-transformers/src/models/gemma.rs | 1 + 2 files changed, 29 insertions(+), 7 deletions(-) diff --git a/candle-examples/examples/gemma/main.rs b/candle-examples/examples/gemma/main.rs index e1df8790..0e37f5cd 100644 --- a/candle-examples/examples/gemma/main.rs +++ b/candle-examples/examples/gemma/main.rs @@ -16,6 +16,22 @@ use candle_transformers::generation::LogitsProcessor; use hf_hub::{api::sync::Api, Repo, RepoType}; use tokenizers::Tokenizer; +#[derive(Clone, Debug, Copy, PartialEq, Eq, clap::ValueEnum)] +enum Which { + #[value(name = "2b")] + Base2B, + #[value(name = "7b")] + Base7B, + #[value(name = "2b-it")] + Instruct2B, + #[value(name = "7b-it")] + Instruct7B, + #[value(name = "1.1-2b-it")] + InstructV1_1_2B, + #[value(name = "1.1-7b-it")] + InstructV1_1_7B, +} + struct TextGeneration { model: Model, device: Device, @@ -165,6 +181,10 @@ struct Args { /// The context size to consider for the repeat penalty. #[arg(long, default_value_t = 64)] repeat_last_n: usize, + + /// The model to use. + #[arg(long, default_value = "2b")] + which: Which, } fn main() -> Result<()> { @@ -196,14 +216,15 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let api = Api::new()?; let model_id = match &args.model_id { - Some(model_id) => match model_id.as_str() { - "7b-it" => "google/gemma-7b-it".to_string(), - "7b" => "google/gemma-7b".to_string(), - "2b-it" => "google/gemma-2b-it".to_string(), - "2b" => "google/gemma-2b".to_string(), - _ => model_id.to_string(), + Some(model_id) => model_id.to_string(), + None => match args.which { + Which::InstructV1_1_2B => "google/gemma-1.1-2b-it".to_string(), + Which::InstructV1_1_7B => "google/gemma-1.1-7b-it".to_string(), + Which::Base2B => "google/gemma-2b".to_string(), + Which::Base7B => "google/gemma-7b".to_string(), + Which::Instruct2B => "google/gemma-2b-it".to_string(), + Which::Instruct7B => "google/gemma-7b-it".to_string(), }, - None => "google/gemma-2b".to_string(), }; let repo = api.repo(Repo::with_revision( model_id, diff --git a/candle-transformers/src/models/gemma.rs b/candle-transformers/src/models/gemma.rs index 282d5eb2..ab2a9582 100644 --- a/candle-transformers/src/models/gemma.rs +++ b/candle-transformers/src/models/gemma.rs @@ -11,6 +11,7 @@ fn default_max_position_embeddings() -> usize { pub struct Config { pub attention_bias: bool, pub head_dim: usize, + #[serde(alias = "hidden_activation")] pub hidden_act: candle_nn::Activation, pub hidden_size: usize, pub intermediate_size: usize,