Change for the encoder-only ProstT5 model (#2045)

* This change avoids crashes when running T5 models with F16 tensors on CPU.

* This enables running ProstT5's (https://huggingface.co/Rostlab/ProstT5) encoder-only mode in Candle. This ProstT5 mode stores it's embed_tokens weights within the encoder, as its decoding stage was replaced with a CNN.  This alone is not sufficient to run ProstT5 within Candle examples. We will develop a ProstT5 runner outside candle for now, but would be willing to upstream it to candle-examples at a later point.
This commit is contained in:
Victor-Mihaila 2024-04-13 11:06:24 +02:00 committed by GitHub
parent e6d412b156
commit 79e3bec789
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 3 additions and 1 deletions

View File

@ -709,8 +709,10 @@ impl T5EncoderModel {
pub fn load(vb: VarBuilder, cfg: &Config) -> Result<Self> {
let shared_vb = if vb.contains_tensor("shared.weight") {
vb.pp("shared")
} else {
} else if vb.contains_tensor("decoder.embed_tokens") {
vb.pp("decoder").pp("embed_tokens")
} else {
vb.pp("encoder").pp("embed_tokens")
};
let shared = Embedding::new(cfg.vocab_size, cfg.d_model, shared_vb)?;
let shared = Arc::new(shared);