Fix tie_word_embeddings for Qwen2. (#3344)

This commit is contained in:
Yang Fan 2024-03-16 00:36:53 +08:00 committed by GitHub
parent 429284dc37
commit a7c871680e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 12 additions and 2 deletions

View File

@ -299,7 +299,11 @@ class Qwen2ForCausalLM(nn.Module):
self.config = config
self.linear_method = linear_method
self.model = Qwen2Model(config, linear_method)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
if not config.tie_word_embeddings:
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size)
self.sampler = Sampler(config.vocab_size)
def forward(
@ -318,7 +322,11 @@ class Qwen2ForCausalLM(nn.Module):
hidden_states: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> Optional[SamplerOutput]:
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
next_tokens = self.sampler(lm_head_weight, hidden_states,
sampling_metadata)
return next_tokens
@ -340,6 +348,8 @@ class Qwen2ForCausalLM(nn.Module):
model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name:
continue
if self.config.tie_word_embeddings and "lm_head.weight" in name:
continue
for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name:
continue