Add the SmolLM2 models. (#2595)
* Add the SmolLM2 models. * More SmolLM2 support.
This commit is contained in:
parent
530ab96036
commit
3fba2b5fc4
|
@ -43,6 +43,18 @@ enum Which {
|
|||
Solar10_7B,
|
||||
#[value(name = "tiny-llama-1.1b-chat")]
|
||||
TinyLlama1_1BChat,
|
||||
#[value(name = "SmoLM2-1.7B")]
|
||||
SmolLM2_1B,
|
||||
#[value(name = "SmoLM2-1.7B-Instruct")]
|
||||
SmolLM2_1BInstruct,
|
||||
#[value(name = "SmoLM2-360M")]
|
||||
SmolLM2_360M,
|
||||
#[value(name = "SmoLM2-360M-Instruct")]
|
||||
SmolLM2_360MInstruct,
|
||||
#[value(name = "SmoLM2-135M")]
|
||||
SmolLM2_135M,
|
||||
#[value(name = "SmoLM2-135M-Instruct")]
|
||||
SmolLM2_135MInstruct,
|
||||
}
|
||||
|
||||
#[derive(Parser, Debug)]
|
||||
|
@ -134,19 +146,28 @@ fn main() -> Result<()> {
|
|||
};
|
||||
let (llama, tokenizer_filename, mut cache, config) = {
|
||||
let api = Api::new()?;
|
||||
let model_id = args.model_id.unwrap_or_else(|| match args.which {
|
||||
Which::V1 => "Narsil/amall-7b".to_string(),
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf".to_string(),
|
||||
Which::V3 => "meta-llama/Meta-Llama-3-8B".to_string(),
|
||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct".to_string(),
|
||||
Which::V31 => "meta-llama/Llama-3.1-8B".to_string(),
|
||||
Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct".to_string(),
|
||||
Which::V32_1b => "meta-llama/Llama-3.2-1B".to_string(),
|
||||
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct".to_string(),
|
||||
Which::V32_3b => "meta-llama/Llama-3.2-3B".to_string(),
|
||||
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct".to_string(),
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0".to_string(),
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0".to_string(),
|
||||
let model_id = args.model_id.unwrap_or_else(|| {
|
||||
let str = match args.which {
|
||||
Which::V1 => "Narsil/amall-7b",
|
||||
Which::V2 => "meta-llama/Llama-2-7b-hf",
|
||||
Which::V3 => "meta-llama/Meta-Llama-3-8B",
|
||||
Which::V3Instruct => "meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
Which::V31 => "meta-llama/Llama-3.1-8B",
|
||||
Which::V31Instruct => "meta-llama/Llama-3.1-8B-Instruct",
|
||||
Which::V32_1b => "meta-llama/Llama-3.2-1B",
|
||||
Which::V32_1bInstruct => "meta-llama/Llama-3.2-1B-Instruct",
|
||||
Which::V32_3b => "meta-llama/Llama-3.2-3B",
|
||||
Which::V32_3bInstruct => "meta-llama/Llama-3.2-3B-Instruct",
|
||||
Which::Solar10_7B => "upstage/SOLAR-10.7B-v1.0",
|
||||
Which::TinyLlama1_1BChat => "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
||||
Which::SmolLM2_135M => "HuggingFaceTB/SmolLM2-135M",
|
||||
Which::SmolLM2_135MInstruct => "HuggingFaceTB/SmolLM2-135M-Instruct",
|
||||
Which::SmolLM2_360M => "HuggingFaceTB/SmolLM2-360M",
|
||||
Which::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
Which::SmolLM2_1B => "HuggingFaceTB/SmolLM2-1.7B",
|
||||
Which::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
};
|
||||
str.to_string()
|
||||
});
|
||||
println!("loading the model weights from {model_id}");
|
||||
let revision = args.revision.unwrap_or("main".to_string());
|
||||
|
@ -169,7 +190,15 @@ fn main() -> Result<()> {
|
|||
| Which::Solar10_7B => {
|
||||
candle_examples::hub_load_safetensors(&api, "model.safetensors.index.json")?
|
||||
}
|
||||
Which::V32_1b | Which::V32_1bInstruct | Which::TinyLlama1_1BChat => {
|
||||
Which::SmolLM2_360M
|
||||
| Which::SmolLM2_360MInstruct
|
||||
| Which::SmolLM2_135M
|
||||
| Which::SmolLM2_135MInstruct
|
||||
| Which::SmolLM2_1B
|
||||
| Which::SmolLM2_1BInstruct
|
||||
| Which::V32_1b
|
||||
| Which::V32_1bInstruct
|
||||
| Which::TinyLlama1_1BChat => {
|
||||
vec![api.get("model.safetensors")?]
|
||||
}
|
||||
};
|
||||
|
|
|
@ -71,6 +71,10 @@ enum Which {
|
|||
L8b,
|
||||
#[value(name = "phi3")]
|
||||
Phi3,
|
||||
#[value(name = "SmoLM2-360M-Instruct")]
|
||||
SmolLM2_360MInstruct,
|
||||
#[value(name = "SmoLM2-1.7B-Instruct")]
|
||||
SmolLM2_1BInstruct,
|
||||
}
|
||||
|
||||
impl Which {
|
||||
|
@ -88,7 +92,9 @@ impl Which {
|
|||
| Self::Leo7b
|
||||
| Self::Leo13b
|
||||
| Self::L8b
|
||||
| Self::Phi3 => false,
|
||||
| Self::Phi3
|
||||
| Self::SmolLM2_1BInstruct
|
||||
| Self::SmolLM2_360MInstruct => false,
|
||||
// Zephyr and OpenChat are fine tuned versions of mistral and should be treated in the
|
||||
// same way. Starling is a fine tuned version of OpenChat.
|
||||
Self::OpenChat35
|
||||
|
@ -124,6 +130,8 @@ impl Which {
|
|||
| Self::OpenChat35
|
||||
| Self::Starling7bAlpha
|
||||
| Self::L8b
|
||||
| Self::SmolLM2_1BInstruct
|
||||
| Self::SmolLM2_360MInstruct
|
||||
| Self::Phi3 => false,
|
||||
Self::Zephyr7bAlpha | Self::Zephyr7bBeta => true,
|
||||
}
|
||||
|
@ -150,6 +158,8 @@ impl Which {
|
|||
| Self::Zephyr7bAlpha
|
||||
| Self::Zephyr7bBeta
|
||||
| Self::L8b
|
||||
| Self::SmolLM2_1BInstruct
|
||||
| Self::SmolLM2_360MInstruct
|
||||
| Self::Phi3 => false,
|
||||
Self::OpenChat35 | Self::Starling7bAlpha => true,
|
||||
}
|
||||
|
@ -179,6 +189,8 @@ impl Which {
|
|||
Self::Starling7bAlpha => "berkeley-nest/Starling-LM-7B-alpha",
|
||||
Self::L8b => "meta-llama/Meta-Llama-3-8B",
|
||||
Self::Phi3 => "microsoft/Phi-3-mini-4k-instruct",
|
||||
Self::SmolLM2_360MInstruct => "HuggingFaceTB/SmolLM2-360M-Instruct",
|
||||
Self::SmolLM2_1BInstruct => "HuggingFaceTB/SmolLM2-1.7B-Instruct",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -343,6 +355,14 @@ impl Args {
|
|||
"microsoft/Phi-3-mini-4k-instruct-gguf",
|
||||
"Phi-3-mini-4k-instruct-q4.gguf",
|
||||
),
|
||||
Which::SmolLM2_360MInstruct => (
|
||||
"HuggingFaceTB/SmolLM2-360M-Instruct-GGUF",
|
||||
"smollm2-360m-instruct-q8_0.gguf",
|
||||
),
|
||||
Which::SmolLM2_1BInstruct => (
|
||||
"HuggingFaceTB/SmolLM2-1.7B-Instruct-GGUF",
|
||||
"smollm2-1.7b-instruct-q4_k_m.gguf",
|
||||
),
|
||||
};
|
||||
let revision = if self.which == Which::Phi3 {
|
||||
"5eef2ce24766d31909c0b269fe90c817a8f263fb"
|
||||
|
@ -455,6 +475,8 @@ fn main() -> anyhow::Result<()> {
|
|||
| Which::Leo7b
|
||||
| Which::Leo13b
|
||||
| Which::L8b
|
||||
| Which::SmolLM2_1BInstruct
|
||||
| Which::SmolLM2_360MInstruct
|
||||
| Which::Phi3 => 1,
|
||||
Which::Mixtral
|
||||
| Which::MixtralInstruct
|
||||
|
@ -573,6 +595,7 @@ fn main() -> anyhow::Result<()> {
|
|||
}
|
||||
|
||||
let eos_token = match args.which {
|
||||
Which::SmolLM2_360MInstruct | Which::SmolLM2_1BInstruct => "<|endoftext|>",
|
||||
Which::L8b => "<|end_of_text|>",
|
||||
_ => match args.which.is_open_chat() {
|
||||
true => "<|end_of_turn|>",
|
||||
|
|
|
@ -351,13 +351,16 @@ impl ModelWeights {
|
|||
let (cos, sin) = precomput_freqs_cis(rope_dim, rope_freq_base, device)?;
|
||||
let neg_inf = Tensor::new(f32::NEG_INFINITY, device)?;
|
||||
|
||||
let tok_embeddings = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings.dequantize(device)?;
|
||||
let tok_embeddings_q = ct.tensor(reader, "token_embd.weight", device)?;
|
||||
let tok_embeddings = tok_embeddings_q.dequantize(device)?;
|
||||
let norm = RmsNorm::from_qtensor(
|
||||
ct.tensor(reader, "output_norm.weight", device)?,
|
||||
rms_norm_eps,
|
||||
)?;
|
||||
let output = ct.tensor(reader, "output.weight", device)?;
|
||||
let output = match ct.tensor(reader, "output.weight", device) {
|
||||
Ok(tensor) => tensor,
|
||||
Err(_) => tok_embeddings_q,
|
||||
};
|
||||
let mut layers = Vec::with_capacity(block_count);
|
||||
for layer_idx in 0..block_count {
|
||||
let prefix = format!("blk.{layer_idx}");
|
||||
|
|
Loading…
Reference in New Issue