mirror of https://github.com/vllm-project/vllm
[Spec Decoding] Use target model max length as default for draft model (#7706)
This commit is contained in:
parent
6925cdbeea
commit
9b73a2f498
|
@ -127,6 +127,7 @@ class ModelConfig:
|
|||
rope_theta: Optional[float] = None,
|
||||
tokenizer_revision: Optional[str] = None,
|
||||
max_model_len: Optional[int] = None,
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
quantization: Optional[str] = None,
|
||||
quantization_param_path: Optional[str] = None,
|
||||
enforce_eager: Optional[bool] = None,
|
||||
|
@ -210,7 +211,8 @@ class ModelConfig:
|
|||
hf_config=self.hf_text_config,
|
||||
max_model_len=max_model_len,
|
||||
disable_sliding_window=self.disable_sliding_window,
|
||||
sliding_window_len=self.get_hf_config_sliding_window())
|
||||
sliding_window_len=self.get_hf_config_sliding_window(),
|
||||
spec_target_max_model_len=spec_target_max_model_len)
|
||||
self.served_model_name = get_served_model_name(model,
|
||||
served_model_name)
|
||||
self.multimodal_config = self._init_multimodal_config(
|
||||
|
@ -1134,6 +1136,7 @@ class SpeculativeConfig:
|
|||
code_revision=draft_code_revision,
|
||||
tokenizer_revision=target_model_config.tokenizer_revision,
|
||||
max_model_len=None,
|
||||
spec_target_max_model_len=target_model_config.max_model_len,
|
||||
quantization=draft_quantization,
|
||||
enforce_eager=target_model_config.enforce_eager,
|
||||
max_seq_len_to_capture=target_model_config.
|
||||
|
@ -1563,6 +1566,7 @@ def _get_and_verify_max_len(
|
|||
max_model_len: Optional[int],
|
||||
disable_sliding_window: bool,
|
||||
sliding_window_len: Optional[int],
|
||||
spec_target_max_model_len: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Get and verify the model's maximum length."""
|
||||
derived_max_model_len = float("inf")
|
||||
|
@ -1605,6 +1609,11 @@ def _get_and_verify_max_len(
|
|||
# If max_model_len is specified, we use it.
|
||||
return max_model_len
|
||||
|
||||
if spec_target_max_model_len is not None:
|
||||
# If this is a speculative draft model, we use the max model len
|
||||
# from the target model.
|
||||
return spec_target_max_model_len
|
||||
|
||||
default_max_len = 2048
|
||||
logger.warning(
|
||||
"The model's config.json does not contain any of the following "
|
||||
|
|
Loading…
Reference in New Issue