Bugfixes for marian-mt. (#1219)

* Bugfixes for marian-mt.

* Apply the final decoding head.

* More fixes.
This commit is contained in:
Laurent Mazare 2023-10-30 12:44:19 +01:00 committed by GitHub
parent 5fc66bd4ba
commit 969960847a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 13 deletions

View File

@ -36,8 +36,6 @@ struct Args {
text: String,
}
const SEP_TOKEN_ID: u32 = 102;
pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
@ -62,7 +60,7 @@ pub fn main() -> anyhow::Result<()> {
model.encoder().forward(&tokens, 0)?
};
let mut token_ids = vec![30522u32];
let mut token_ids = vec![config.decoder_start_token_id];
for index in 0..1000 {
// TODO: Add a kv cache.
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
@ -72,7 +70,8 @@ pub fn main() -> anyhow::Result<()> {
let logits = logits.squeeze(0)?;
let logits = logits.get(logits.dim(0)? - 1)?;
let token = logits_processor.sample(&logits)?;
if token == SEP_TOKEN_ID {
println!("{token}");
if token == config.eos_token_id || token == config.forced_eos_token_id {
break;
}
token_ids.push(token);

View File

@ -18,11 +18,11 @@ pub struct Config {
pub is_encoder_decoder: bool,
pub activation_function: candle_nn::Activation,
pub d_model: usize,
pub decoder_start_token_id: usize,
pub decoder_start_token_id: u32,
pub scale_embedding: bool,
pub pad_token_id: usize,
pub eos_token_id: usize,
pub forced_eos_token_id: usize,
pub pad_token_id: u32,
pub eos_token_id: u32,
pub forced_eos_token_id: u32,
pub share_encoder_decoder_embeddings: bool,
}
@ -224,7 +224,8 @@ impl DecoderLayer {
let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?;
let encoder_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
let encoder_attn_layer_norm =
layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
@ -249,7 +250,7 @@ impl DecoderLayer {
Some(encoder_xs) => {
let residual = &xs;
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?;
(residual + xs)?.apply(&self.self_attn_layer_norm)?
(residual + xs)?.apply(&self.encoder_attn_layer_norm)?
}
};
let residual = &xs;
@ -257,7 +258,8 @@ impl DecoderLayer {
.apply(&self.fc1)?
.apply(&self.activation_fn)?
.apply(&self.fc2)?;
(xs + residual)?.apply(&self.final_layer_norm)
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
Ok(xs)
}
}
@ -356,7 +358,7 @@ impl Decoder {
.unsqueeze(0)?;
let mut xs = xs.broadcast_add(&embed_pos)?;
for layer in self.layers.iter() {
xs = layer.forward(&xs, encoder_xs)?
xs = layer.forward(&xs, encoder_xs)?;
}
Ok(xs)
}
@ -385,6 +387,7 @@ impl Model {
#[derive(Debug, Clone)]
pub struct MTModel {
model: Model,
lm_head: Linear,
final_logits_bias: Tensor,
}
@ -393,8 +396,10 @@ impl MTModel {
let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size);
let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?;
let model = Model::new(cfg, vb.pp("model"))?;
let lm_head = Linear::from_weights(model.shared.embeddings().clone(), None);
Ok(Self {
model,
lm_head,
final_logits_bias,
})
}
@ -408,6 +413,10 @@ impl MTModel {
}
pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
self.model.decoder.forward(xs, Some(encoder_xs), 0)
self.model
.decoder
.forward(xs, Some(encoder_xs), 0)?
.apply(&self.lm_head)?
.broadcast_add(&self.final_logits_bias)
}
}