HuggingFace: Clearer error message when a model is not supported
This commit is contained in:
parent
f6d2c59bca
commit
303e601b87
|
@ -217,14 +217,39 @@ impl NewEmbedderError {
|
|||
}
|
||||
|
||||
pub fn deserialize_config(
|
||||
model_name: String,
|
||||
config: String,
|
||||
config_filename: PathBuf,
|
||||
inner: serde_json::Error,
|
||||
) -> NewEmbedderError {
|
||||
let deserialize_config = DeserializeConfig { config, filename: config_filename, inner };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
|
||||
fault: FaultSource::Runtime,
|
||||
match serde_json::from_str(&config) {
|
||||
Ok(value) => {
|
||||
let value: serde_json::Value = value;
|
||||
let architectures = match value.get("architectures") {
|
||||
Some(serde_json::Value::Array(architectures)) => architectures
|
||||
.iter()
|
||||
.filter_map(|value| match value {
|
||||
serde_json::Value::String(s) => Some(s.to_owned()),
|
||||
_ => None,
|
||||
})
|
||||
.collect(),
|
||||
_ => vec![],
|
||||
};
|
||||
|
||||
let unsupported_model = UnsupportedModel { model_name, inner, architectures };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::UnsupportedModel(unsupported_model),
|
||||
fault: FaultSource::User,
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
let deserialize_config =
|
||||
DeserializeConfig { model_name, filename: config_filename, inner: error };
|
||||
Self {
|
||||
kind: NewEmbedderErrorKind::DeserializeConfig(deserialize_config),
|
||||
fault: FaultSource::Runtime,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -252,7 +277,7 @@ impl NewEmbedderError {
|
|||
}
|
||||
|
||||
pub fn safetensor_weight(inner: candle_core::Error) -> Self {
|
||||
Self { kind: NewEmbedderErrorKind::PytorchWeight(inner), fault: FaultSource::Runtime }
|
||||
Self { kind: NewEmbedderErrorKind::SafetensorWeight(inner), fault: FaultSource::Runtime }
|
||||
}
|
||||
|
||||
pub fn load_model(inner: candle_core::Error) -> Self {
|
||||
|
@ -275,13 +300,26 @@ pub struct OpenConfig {
|
|||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not deserialize config at {filename}: {inner}. Config follows:\n{config}")]
|
||||
#[error("for model '{model_name}', could not deserialize config at {filename} as JSON: {inner}")]
|
||||
pub struct DeserializeConfig {
|
||||
pub config: String,
|
||||
pub model_name: String,
|
||||
pub filename: PathBuf,
|
||||
pub inner: serde_json::Error,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("model `{model_name}` appears to be unsupported{}\n - inner error: {inner}",
|
||||
if architectures.is_empty() {
|
||||
"\n - Note: only models with architecture \"BertModel\" are supported.".to_string()
|
||||
} else {
|
||||
format!("\n - Note: model has declared architectures `{architectures:?}`, only models with architecture `\"BertModel\"` are supported.")
|
||||
})]
|
||||
pub struct UnsupportedModel {
|
||||
pub model_name: String,
|
||||
pub inner: serde_json::Error,
|
||||
pub architectures: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
#[error("could not open tokenizer at {filename}: {inner}")]
|
||||
pub struct OpenTokenizer {
|
||||
|
@ -298,6 +336,8 @@ pub enum NewEmbedderErrorKind {
|
|||
#[error(transparent)]
|
||||
DeserializeConfig(DeserializeConfig),
|
||||
#[error(transparent)]
|
||||
UnsupportedModel(UnsupportedModel),
|
||||
#[error(transparent)]
|
||||
OpenTokenizer(OpenTokenizer),
|
||||
#[error("could not build weights from Pytorch weights: {0}")]
|
||||
PytorchWeight(candle_core::Error),
|
||||
|
|
|
@ -103,7 +103,12 @@ impl Embedder {
|
|||
let config = std::fs::read_to_string(&config_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_config(config_filename.clone(), inner))?;
|
||||
let config: Config = serde_json::from_str(&config).map_err(|inner| {
|
||||
NewEmbedderError::deserialize_config(config, config_filename, inner)
|
||||
NewEmbedderError::deserialize_config(
|
||||
options.model.clone(),
|
||||
config,
|
||||
config_filename,
|
||||
inner,
|
||||
)
|
||||
})?;
|
||||
let mut tokenizer = Tokenizer::from_file(&tokenizer_filename)
|
||||
.map_err(|inner| NewEmbedderError::open_tokenizer(tokenizer_filename, inner))?;
|
||||
|
|
Loading…
Reference in New Issue