mirror of https://github.com/vllm-project/vllm
Fix tie_word_embeddings for Qwen2. (#3344)
This commit is contained in:
parent
429284dc37
commit
a7c871680e
|
@ -299,7 +299,11 @@ class Qwen2ForCausalLM(nn.Module):
|
|||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = Qwen2Model(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
|
||||
if not config.tie_word_embeddings:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size)
|
||||
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
|
@ -318,7 +322,11 @@ class Qwen2ForCausalLM(nn.Module):
|
|||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> Optional[SamplerOutput]:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
if self.config.tie_word_embeddings:
|
||||
lm_head_weight = self.model.embed_tokens.weight
|
||||
else:
|
||||
lm_head_weight = self.lm_head.weight
|
||||
next_tokens = self.sampler(lm_head_weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
|
@ -340,6 +348,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue