Bugfixes for marian-mt. (#1219)
* Bugfixes for marian-mt. * Apply the final decoding head. * More fixes.
This commit is contained in:
parent
5fc66bd4ba
commit
969960847a
|
@ -36,8 +36,6 @@ struct Args {
|
|||
text: String,
|
||||
}
|
||||
|
||||
const SEP_TOKEN_ID: u32 = 102;
|
||||
|
||||
pub fn main() -> anyhow::Result<()> {
|
||||
let args = Args::parse();
|
||||
|
||||
|
@ -62,7 +60,7 @@ pub fn main() -> anyhow::Result<()> {
|
|||
model.encoder().forward(&tokens, 0)?
|
||||
};
|
||||
|
||||
let mut token_ids = vec![30522u32];
|
||||
let mut token_ids = vec![config.decoder_start_token_id];
|
||||
for index in 0..1000 {
|
||||
// TODO: Add a kv cache.
|
||||
let context_size = if index >= 1000 { 1 } else { token_ids.len() };
|
||||
|
@ -72,7 +70,8 @@ pub fn main() -> anyhow::Result<()> {
|
|||
let logits = logits.squeeze(0)?;
|
||||
let logits = logits.get(logits.dim(0)? - 1)?;
|
||||
let token = logits_processor.sample(&logits)?;
|
||||
if token == SEP_TOKEN_ID {
|
||||
println!("{token}");
|
||||
if token == config.eos_token_id || token == config.forced_eos_token_id {
|
||||
break;
|
||||
}
|
||||
token_ids.push(token);
|
||||
|
|
|
@ -18,11 +18,11 @@ pub struct Config {
|
|||
pub is_encoder_decoder: bool,
|
||||
pub activation_function: candle_nn::Activation,
|
||||
pub d_model: usize,
|
||||
pub decoder_start_token_id: usize,
|
||||
pub decoder_start_token_id: u32,
|
||||
pub scale_embedding: bool,
|
||||
pub pad_token_id: usize,
|
||||
pub eos_token_id: usize,
|
||||
pub forced_eos_token_id: usize,
|
||||
pub pad_token_id: u32,
|
||||
pub eos_token_id: u32,
|
||||
pub forced_eos_token_id: u32,
|
||||
pub share_encoder_decoder_embeddings: bool,
|
||||
}
|
||||
|
||||
|
@ -224,7 +224,8 @@ impl DecoderLayer {
|
|||
let self_attn = Attention::new(cfg, true, vb.pp("self_attn"))?;
|
||||
let self_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn = Attention::new(cfg, true, vb.pp("encoder_attn"))?;
|
||||
let encoder_attn_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("self_attn_layer_norm"))?;
|
||||
let encoder_attn_layer_norm =
|
||||
layer_norm(cfg.d_model, 1e-5, vb.pp("encoder_attn_layer_norm"))?;
|
||||
let fc1 = linear(cfg.d_model, cfg.decoder_ffn_dim, vb.pp("fc1"))?;
|
||||
let fc2 = linear(cfg.decoder_ffn_dim, cfg.d_model, vb.pp("fc2"))?;
|
||||
let final_layer_norm = layer_norm(cfg.d_model, 1e-5, vb.pp("final_layer_norm"))?;
|
||||
|
@ -249,7 +250,7 @@ impl DecoderLayer {
|
|||
Some(encoder_xs) => {
|
||||
let residual = &xs;
|
||||
let xs = self.encoder_attn.forward(&xs, Some(encoder_xs))?;
|
||||
(residual + xs)?.apply(&self.self_attn_layer_norm)?
|
||||
(residual + xs)?.apply(&self.encoder_attn_layer_norm)?
|
||||
}
|
||||
};
|
||||
let residual = &xs;
|
||||
|
@ -257,7 +258,8 @@ impl DecoderLayer {
|
|||
.apply(&self.fc1)?
|
||||
.apply(&self.activation_fn)?
|
||||
.apply(&self.fc2)?;
|
||||
(xs + residual)?.apply(&self.final_layer_norm)
|
||||
let xs = (xs + residual)?.apply(&self.final_layer_norm)?;
|
||||
Ok(xs)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -356,7 +358,7 @@ impl Decoder {
|
|||
.unsqueeze(0)?;
|
||||
let mut xs = xs.broadcast_add(&embed_pos)?;
|
||||
for layer in self.layers.iter() {
|
||||
xs = layer.forward(&xs, encoder_xs)?
|
||||
xs = layer.forward(&xs, encoder_xs)?;
|
||||
}
|
||||
Ok(xs)
|
||||
}
|
||||
|
@ -385,6 +387,7 @@ impl Model {
|
|||
#[derive(Debug, Clone)]
|
||||
pub struct MTModel {
|
||||
model: Model,
|
||||
lm_head: Linear,
|
||||
final_logits_bias: Tensor,
|
||||
}
|
||||
|
||||
|
@ -393,8 +396,10 @@ impl MTModel {
|
|||
let target_vocab_size = cfg.decoder_vocab_size.unwrap_or(cfg.vocab_size);
|
||||
let final_logits_bias = vb.get((1, target_vocab_size), "final_logits_bias")?;
|
||||
let model = Model::new(cfg, vb.pp("model"))?;
|
||||
let lm_head = Linear::from_weights(model.shared.embeddings().clone(), None);
|
||||
Ok(Self {
|
||||
model,
|
||||
lm_head,
|
||||
final_logits_bias,
|
||||
})
|
||||
}
|
||||
|
@ -408,6 +413,10 @@ impl MTModel {
|
|||
}
|
||||
|
||||
pub fn decode(&self, xs: &Tensor, encoder_xs: &Tensor) -> Result<Tensor> {
|
||||
self.model.decoder.forward(xs, Some(encoder_xs), 0)
|
||||
self.model
|
||||
.decoder
|
||||
.forward(xs, Some(encoder_xs), 0)?
|
||||
.apply(&self.lm_head)?
|
||||
.broadcast_add(&self.final_logits_bias)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue