diff --git a/README.md b/README.md index c55093f2..8bc324bd 100644 --- a/README.md +++ b/README.md @@ -61,7 +61,7 @@ We also provide a some command line based examples using state of the art models - [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM. - [Falcon](./candle-examples/examples/falcon/): general LLM. -- [Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b. +- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b. - [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM pre-trained on 1T tokens of English and code datasets. - [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with diff --git a/candle-examples/examples/phi/main.rs b/candle-examples/examples/phi/main.rs index eff329ff..3922b3d5 100644 --- a/candle-examples/examples/phi/main.rs +++ b/candle-examples/examples/phi/main.rs @@ -5,7 +5,7 @@ extern crate intel_mkl_src; extern crate accelerate_src; use anyhow::{Error as E, Result}; -use clap::Parser; +use clap::{Parser, ValueEnum}; use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer}; use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer; @@ -110,6 +110,14 @@ impl TextGeneration { } } +#[derive(Clone, Copy, Debug, ValueEnum)] +enum WhichModel { + #[value(name = "1")] + V1, + #[value(name = "1.5")] + V1_5, +} + #[derive(Parser, Debug)] #[command(author, version, about, long_about = None)] struct Args { @@ -140,11 +148,14 @@ struct Args { #[arg(long, short = 'n', default_value_t = 100)] sample_len: usize, - #[arg(long, default_value = "microsoft/phi-1_5")] - model_id: String, + #[arg(long)] + model_id: Option, - #[arg(long, default_value = "refs/pr/18")] - revision: String, + #[arg(long, default_value = "1.5")] + model: WhichModel, + + #[arg(long)] + revision: Option, #[arg(long)] weight_file: Option, @@ -189,18 +200,42 @@ fn main() -> Result<()> { let start = std::time::Instant::now(); let api = Api::new()?; - let repo = api.repo(Repo::with_revision( - args.model_id, - RepoType::Model, - args.revision, - )); + let model_id = match args.model_id { + Some(model_id) => model_id.to_string(), + None => { + if args.quantized { + "lmz/candle-quantized-phi".to_string() + } else { + match args.model { + WhichModel::V1 => "microsoft/phi-1".to_string(), + WhichModel::V1_5 => "microsoft/phi-1_5".to_string(), + } + } + } + }; + let revision = match args.revision { + Some(rev) => rev.to_string(), + None => { + if args.quantized { + "main".to_string() + } else { + match args.model { + WhichModel::V1 => "refs/pr/2".to_string(), + WhichModel::V1_5 => "refs/pr/18".to_string(), + } + } + } + }; + let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision)); let tokenizer_filename = repo.get("tokenizer.json")?; let filename = match args.weight_file { Some(weight_file) => std::path::PathBuf::from(weight_file), None => { if args.quantized { - api.model("lmz/candle-quantized-phi".to_string()) - .get("model-q4k.gguf")? + match args.model { + WhichModel::V1 => repo.get("model-v1-q4k.gguf")?, + WhichModel::V1_5 => repo.get("model-q4k.gguf")?, + } } else { repo.get("model.safetensors")? }