Add support for phi-1.0 (#1093)
* Add support for phi-1.0 * Update the readme.
This commit is contained in:
parent
29c7f2565d
commit
8921d5027c
|
@ -61,7 +61,7 @@ We also provide a some command line based examples using state of the art models
|
||||||
|
|
||||||
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
- [LLaMA and LLaMA-v2](./candle-examples/examples/llama/): general LLM.
|
||||||
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
- [Falcon](./candle-examples/examples/falcon/): general LLM.
|
||||||
- [Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
- [Phi-v1 and Phi-v1.5](./candle-examples/examples/phi/): a 1.3b general LLM with performance on par with LLaMA-v2 7b.
|
||||||
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
- [StableLM-3B-4E1T](./candle-examples/examples/stable-lm/): a 3b general LLM
|
||||||
pre-trained on 1T tokens of English and code datasets.
|
pre-trained on 1T tokens of English and code datasets.
|
||||||
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
- [Mistral7b-v0.1](./candle-examples/examples/mistral/): a 7b general LLM with
|
||||||
|
|
|
@ -5,7 +5,7 @@ extern crate intel_mkl_src;
|
||||||
extern crate accelerate_src;
|
extern crate accelerate_src;
|
||||||
|
|
||||||
use anyhow::{Error as E, Result};
|
use anyhow::{Error as E, Result};
|
||||||
use clap::Parser;
|
use clap::{Parser, ValueEnum};
|
||||||
|
|
||||||
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
use candle_transformers::models::mixformer::{Config, MixFormerSequentialForCausalLM as MixFormer};
|
||||||
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
use candle_transformers::models::quantized_mixformer::MixFormerSequentialForCausalLM as QMixFormer;
|
||||||
|
@ -110,6 +110,14 @@ impl TextGeneration {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug, ValueEnum)]
|
||||||
|
enum WhichModel {
|
||||||
|
#[value(name = "1")]
|
||||||
|
V1,
|
||||||
|
#[value(name = "1.5")]
|
||||||
|
V1_5,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Parser, Debug)]
|
#[derive(Parser, Debug)]
|
||||||
#[command(author, version, about, long_about = None)]
|
#[command(author, version, about, long_about = None)]
|
||||||
struct Args {
|
struct Args {
|
||||||
|
@ -140,11 +148,14 @@ struct Args {
|
||||||
#[arg(long, short = 'n', default_value_t = 100)]
|
#[arg(long, short = 'n', default_value_t = 100)]
|
||||||
sample_len: usize,
|
sample_len: usize,
|
||||||
|
|
||||||
#[arg(long, default_value = "microsoft/phi-1_5")]
|
#[arg(long)]
|
||||||
model_id: String,
|
model_id: Option<String>,
|
||||||
|
|
||||||
#[arg(long, default_value = "refs/pr/18")]
|
#[arg(long, default_value = "1.5")]
|
||||||
revision: String,
|
model: WhichModel,
|
||||||
|
|
||||||
|
#[arg(long)]
|
||||||
|
revision: Option<String>,
|
||||||
|
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
weight_file: Option<String>,
|
weight_file: Option<String>,
|
||||||
|
@ -189,18 +200,42 @@ fn main() -> Result<()> {
|
||||||
|
|
||||||
let start = std::time::Instant::now();
|
let start = std::time::Instant::now();
|
||||||
let api = Api::new()?;
|
let api = Api::new()?;
|
||||||
let repo = api.repo(Repo::with_revision(
|
let model_id = match args.model_id {
|
||||||
args.model_id,
|
Some(model_id) => model_id.to_string(),
|
||||||
RepoType::Model,
|
None => {
|
||||||
args.revision,
|
if args.quantized {
|
||||||
));
|
"lmz/candle-quantized-phi".to_string()
|
||||||
|
} else {
|
||||||
|
match args.model {
|
||||||
|
WhichModel::V1 => "microsoft/phi-1".to_string(),
|
||||||
|
WhichModel::V1_5 => "microsoft/phi-1_5".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let revision = match args.revision {
|
||||||
|
Some(rev) => rev.to_string(),
|
||||||
|
None => {
|
||||||
|
if args.quantized {
|
||||||
|
"main".to_string()
|
||||||
|
} else {
|
||||||
|
match args.model {
|
||||||
|
WhichModel::V1 => "refs/pr/2".to_string(),
|
||||||
|
WhichModel::V1_5 => "refs/pr/18".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let repo = api.repo(Repo::with_revision(model_id, RepoType::Model, revision));
|
||||||
let tokenizer_filename = repo.get("tokenizer.json")?;
|
let tokenizer_filename = repo.get("tokenizer.json")?;
|
||||||
let filename = match args.weight_file {
|
let filename = match args.weight_file {
|
||||||
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
Some(weight_file) => std::path::PathBuf::from(weight_file),
|
||||||
None => {
|
None => {
|
||||||
if args.quantized {
|
if args.quantized {
|
||||||
api.model("lmz/candle-quantized-phi".to_string())
|
match args.model {
|
||||||
.get("model-q4k.gguf")?
|
WhichModel::V1 => repo.get("model-v1-q4k.gguf")?,
|
||||||
|
WhichModel::V1_5 => repo.get("model-q4k.gguf")?,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
repo.get("model.safetensors")?
|
repo.get("model.safetensors")?
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue