[Spec Decoding] Use target model max length as default for draft model (#7706)

This commit is contained in:
Nick Hill 2024-08-21 12:23:22 -04:00 committed by GitHub
parent 6925cdbeea
commit 9b73a2f498
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 10 additions and 1 deletions

View File

@ -127,6 +127,7 @@ class ModelConfig:
rope_theta: Optional[float] = None,
tokenizer_revision: Optional[str] = None,
max_model_len: Optional[int] = None,
spec_target_max_model_len: Optional[int] = None,
quantization: Optional[str] = None,
quantization_param_path: Optional[str] = None,
enforce_eager: Optional[bool] = None,
@ -210,7 +211,8 @@ class ModelConfig:
hf_config=self.hf_text_config,
max_model_len=max_model_len,
disable_sliding_window=self.disable_sliding_window,
sliding_window_len=self.get_hf_config_sliding_window())
sliding_window_len=self.get_hf_config_sliding_window(),
spec_target_max_model_len=spec_target_max_model_len)
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = self._init_multimodal_config(
@ -1134,6 +1136,7 @@ class SpeculativeConfig:
code_revision=draft_code_revision,
tokenizer_revision=target_model_config.tokenizer_revision,
max_model_len=None,
spec_target_max_model_len=target_model_config.max_model_len,
quantization=draft_quantization,
enforce_eager=target_model_config.enforce_eager,
max_seq_len_to_capture=target_model_config.
@ -1563,6 +1566,7 @@ def _get_and_verify_max_len(
max_model_len: Optional[int],
disable_sliding_window: bool,
sliding_window_len: Optional[int],
spec_target_max_model_len: Optional[int] = None,
) -> int:
"""Get and verify the model's maximum length."""
derived_max_model_len = float("inf")
@ -1605,6 +1609,11 @@ def _get_and_verify_max_len(
# If max_model_len is specified, we use it.
return max_model_len
if spec_target_max_model_len is not None:
# If this is a speculative draft model, we use the max model len
# from the target model.
return spec_target_max_model_len
default_max_len = 2048
logger.warning(
"The model's config.json does not contain any of the following "