mirror of https://github.com/vllm-project/vllm
[Speculative Decoding] EAGLE Implementation with Top-1 proposer (#6830)
This commit is contained in:
parent
b3856bef7d
commit
a3fce56b88
|
@ -0,0 +1,268 @@
|
||||||
|
"""This docstring details important information on the testing methodology.
|
||||||
|
|
||||||
|
Most of the tests rely on "greedy equality", where we expect the output of
|
||||||
|
speculative decoding on a sequence to exactly match the output of normal non-
|
||||||
|
speculative decoding.
|
||||||
|
|
||||||
|
Since speculative decoding with rejection sampling guarantees that the output
|
||||||
|
distribution matches the target model's output distribution (up to hardware
|
||||||
|
numerics, see https://arxiv.org/pdf/2302.01318.pdf), we can expect greedy
|
||||||
|
equality.
|
||||||
|
|
||||||
|
However, we still need to verify below scenario could be passed:
|
||||||
|
* Batch size 1 greedy equality
|
||||||
|
* Batch size >1 greedy equality
|
||||||
|
* Test greedy equality under preemption
|
||||||
|
* Test greedy equality under various number of speculative tokens.
|
||||||
|
|
||||||
|
With those tests, we can say at least, EAGLE would not break the
|
||||||
|
correctess for the target model outputs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from .conftest import run_greedy_equality_correctness_test
|
||||||
|
|
||||||
|
# main model
|
||||||
|
MAIN_MODEL = "JackFram/llama-68m"
|
||||||
|
|
||||||
|
# speculative model
|
||||||
|
SPEC_MODEL = "abhigoyal/vllm-eagle-llama-68m-random"
|
||||||
|
|
||||||
|
# max. number of speculative tokens: this corresponds to
|
||||||
|
# num_heads in the config.json of the speculator model.
|
||||||
|
MAX_SPEC_TOKENS = 4
|
||||||
|
|
||||||
|
# precision
|
||||||
|
PRECISION = "float32"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_eagle_e2e_greedy_correctness(baseline_llm_generator,
|
||||||
|
test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality with different batch size."""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"enforce_eager": False,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_eagle_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality with cuda graph enabled and different
|
||||||
|
batch sizes."""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"block_size": 8,
|
||||||
|
# 2 for small prompt, 256//8 for generated.
|
||||||
|
"num_gpu_blocks_override": 2 + 256 // 8,
|
||||||
|
"max_model_len": (2 + 256 // 8) * 8,
|
||||||
|
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use small output len for fast test.
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_eagle_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
|
generation.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"test_llm_kwargs",
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": k,
|
||||||
|
}
|
||||||
|
# Try a range of num. speculative tokens
|
||||||
|
for k in range(1, 1 + MAX_SPEC_TOKENS)
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [2])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_eagle_different_k(baseline_llm_generator, test_llm_generator,
|
||||||
|
batch_size: int, output_len: int):
|
||||||
|
"""Verify that eagle speculative decoding produces exact equality
|
||||||
|
to without spec decode with different values of num_speculative_tokens.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
# Skip cuda graph recording for fast test.
|
||||||
|
"enforce_eager": True,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
"speculative_disable_by_batch_size": 4
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 5])
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"output_len",
|
||||||
|
[
|
||||||
|
# Use smaller output len for fast test.
|
||||||
|
32,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_eagle_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||||
|
batch_size: int, output_len: int):
|
||||||
|
"""Verify that eagle speculative decoding produces exact equality
|
||||||
|
to without spec decode when speculation is disabled for large
|
||||||
|
batch sizes.
|
||||||
|
"""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import pytest
|
||||||
|
pytest.main([__file__])
|
|
@ -70,8 +70,9 @@ PRECISION = "float32"
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [1, 32])
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
def test_medusa_e2e_greedy_correctness(baseline_llm_generator,
|
||||||
batch_size: int, output_len: int):
|
test_llm_generator, batch_size: int,
|
||||||
|
output_len: int):
|
||||||
"""Verify greedy equality with different batch size."""
|
"""Verify greedy equality with different batch size."""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
test_llm_generator,
|
test_llm_generator,
|
||||||
|
@ -80,6 +81,49 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
||||||
force_output_len=True)
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"common_llm_kwargs",
|
||||||
|
[{
|
||||||
|
"enforce_eager": False,
|
||||||
|
|
||||||
|
# Required for spec decode.
|
||||||
|
"use_v2_block_manager": True,
|
||||||
|
|
||||||
|
# Print spec metrics.
|
||||||
|
"disable_log_stats": False,
|
||||||
|
|
||||||
|
# Precision
|
||||||
|
"dtype": PRECISION,
|
||||||
|
|
||||||
|
# Main model
|
||||||
|
"model": MAIN_MODEL,
|
||||||
|
}])
|
||||||
|
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("baseline_llm_kwargs", [{}])
|
||||||
|
@pytest.mark.parametrize("test_llm_kwargs", [
|
||||||
|
{
|
||||||
|
"speculative_model": SPEC_MODEL,
|
||||||
|
"num_speculative_tokens": MAX_SPEC_TOKENS,
|
||||||
|
},
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("output_len", [
|
||||||
|
128,
|
||||||
|
])
|
||||||
|
@pytest.mark.parametrize("batch_size", [1, 32])
|
||||||
|
@pytest.mark.parametrize("seed", [1])
|
||||||
|
def test_medusa_e2e_greedy_correctness_cuda_graph(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size: int,
|
||||||
|
output_len: int):
|
||||||
|
"""Verify greedy equality with cuda graph enabled and different
|
||||||
|
batch sizes."""
|
||||||
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
test_llm_generator,
|
||||||
|
batch_size,
|
||||||
|
max_output_len=output_len,
|
||||||
|
force_output_len=True)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"common_llm_kwargs",
|
"common_llm_kwargs",
|
||||||
[{
|
[{
|
||||||
|
@ -116,10 +160,10 @@ def test_mlp_e2e_greedy_correctness(baseline_llm_generator, test_llm_generator,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("batch_size", [4])
|
@pytest.mark.parametrize("batch_size", [4])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
def test_medusa_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||||
test_llm_generator,
|
test_llm_generator,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
output_len: int):
|
output_len: int):
|
||||||
"""Verify greedy equality, even when some sequences are preempted mid-
|
"""Verify greedy equality, even when some sequences are preempted mid-
|
||||||
generation.
|
generation.
|
||||||
"""
|
"""
|
||||||
|
@ -165,9 +209,9 @@ def test_mlp_e2e_greedy_correctness_with_preemption(baseline_llm_generator,
|
||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
def test_medusa_different_k(baseline_llm_generator, test_llm_generator,
|
||||||
batch_size: int, output_len: int):
|
batch_size: int, output_len: int):
|
||||||
"""Verify that mlp speculative decoding produces exact equality
|
"""Verify that medusa speculative decoding produces exact equality
|
||||||
to without spec decode with different values of num_speculative_tokens.
|
to without spec decode with different values of num_speculative_tokens.
|
||||||
"""
|
"""
|
||||||
run_greedy_equality_correctness_test(baseline_llm_generator,
|
run_greedy_equality_correctness_test(baseline_llm_generator,
|
||||||
|
@ -208,9 +252,9 @@ def test_mlp_different_k(baseline_llm_generator, test_llm_generator,
|
||||||
32,
|
32,
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("seed", [1])
|
@pytest.mark.parametrize("seed", [1])
|
||||||
def test_mlp_disable_queue(baseline_llm_generator, test_llm_generator,
|
def test_medusa_disable_queue(baseline_llm_generator, test_llm_generator,
|
||||||
batch_size: int, output_len: int):
|
batch_size: int, output_len: int):
|
||||||
"""Verify that mlp speculative decoding produces exact equality
|
"""Verify that medusa speculative decoding produces exact equality
|
||||||
to without spec decode when speculation is disabled for large
|
to without spec decode when speculation is disabled for large
|
||||||
batch sizes.
|
batch sizes.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -6,7 +6,8 @@ import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.model_executor.utils import set_random_seed
|
from vllm.model_executor.utils import set_random_seed
|
||||||
from vllm.sequence import ExecuteModelRequest, Logprob, SamplerOutput
|
from vllm.sequence import (ExecuteModelRequest, HiddenStates, Logprob,
|
||||||
|
SamplerOutput, get_all_seq_ids)
|
||||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
from vllm.spec_decode.multi_step_worker import MultiStepWorker
|
||||||
from vllm.spec_decode.top1_proposer import Top1Proposer
|
from vllm.spec_decode.top1_proposer import Top1Proposer
|
||||||
|
@ -690,3 +691,36 @@ def test_use_draft_model_runner_advance_step():
|
||||||
worker.execute_model(execute_model_req=execute_model_req)
|
worker.execute_model(execute_model_req=execute_model_req)
|
||||||
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
|
call_args_list = worker.model_runner._gpu_advance_step.call_args_list
|
||||||
assert len(call_args_list) == 1
|
assert len(call_args_list) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_expand_execute_model_request_sync_with_expand_hidden_states():
|
||||||
|
"""
|
||||||
|
In this test we verify that the logic for expanding the
|
||||||
|
seq_group_metadata_list remains in sync with the expansion logic of
|
||||||
|
the HiddenStates in _expand_execute_model_request.
|
||||||
|
"""
|
||||||
|
k = 5
|
||||||
|
batch_size = 16
|
||||||
|
seq_with_bonus_token_in_last_step = [1, 3, 8, 10, 13, 15]
|
||||||
|
|
||||||
|
seq_group_metadata_list, _, _ = create_batch(batch_size, k)
|
||||||
|
|
||||||
|
execute_model_request = ExecuteModelRequest(
|
||||||
|
seq_group_metadata_list,
|
||||||
|
previous_hidden_states=HiddenStates(
|
||||||
|
torch.arange(batch_size), seq_group_metadata_list,
|
||||||
|
torch.arange(batch_size, 2 * batch_size)))
|
||||||
|
|
||||||
|
expanded_execute_model_request, orig_seq_group_ids = MultiStepWorker.\
|
||||||
|
_expand_execute_model_request(execute_model_request,
|
||||||
|
seq_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
|
all_seq_ids = torch.tensor(
|
||||||
|
get_all_seq_ids(
|
||||||
|
expanded_execute_model_request.seq_group_metadata_list))
|
||||||
|
ref_expanded_hidden_states = all_seq_ids + batch_size
|
||||||
|
ref_expanded_hidden_states[orig_seq_group_ids] -= batch_size
|
||||||
|
|
||||||
|
assert (ref_expanded_hidden_states == expanded_execute_model_request.
|
||||||
|
previous_hidden_states.hidden_states).all().item()
|
||||||
|
|
|
@ -60,6 +60,7 @@ _GENERATION_MODELS = {
|
||||||
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
"XverseForCausalLM": ("xverse", "XverseForCausalLM"),
|
||||||
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
"Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"),
|
||||||
"MedusaModel": ("medusa", "Medusa"),
|
"MedusaModel": ("medusa", "Medusa"),
|
||||||
|
"EAGLEModel": ("eagle", "EAGLE"),
|
||||||
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
"MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"),
|
||||||
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
"JambaForCausalLM": ("jamba", "JambaForCausalLM"),
|
||||||
}
|
}
|
||||||
|
|
|
@ -0,0 +1,161 @@
|
||||||
|
from typing import Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||||
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||||
|
DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead)
|
||||||
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
||||||
|
from vllm.model_executor.models import ModelRegistry
|
||||||
|
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||||
|
from vllm.sequence import IntermediateTensors, SamplerOutput
|
||||||
|
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||||
|
|
||||||
|
|
||||||
|
class EAGLE(nn.Module):
|
||||||
|
"""This class implements the EAGLE draft model from the paper: https://arxiv.org/pdf/2401.15077
|
||||||
|
Reference implementation: https://github.com/SafeAILab/EAGLE
|
||||||
|
|
||||||
|
Differences from reference implementation:
|
||||||
|
1. In reference, LlamaDecoderLayer implementation doesn't have
|
||||||
|
input_layernorm for 1st decoder layer (https://github.com/SafeAILab/EAGLE/blob/7d065d084443fbfd386f88839efd7193c12be869/eagle/model/cnets.py#L427)
|
||||||
|
but we do as HF implementation also does.
|
||||||
|
2. We allow any decoder layer to be used in EAGLE whereas in reference
|
||||||
|
decoder layer is fixed to be LlamaDecoderLayer.
|
||||||
|
3. We have an optional token_map which reduces draft vocab to most
|
||||||
|
frequently used tokens to give some additional speed-up by reducing
|
||||||
|
sampling overhead. This is disabled unless the checkpoint file has
|
||||||
|
explicit token_map tensor and config has an optional attribute
|
||||||
|
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
||||||
|
the top-k most frequent tokens in target dataset and add that as a tensor
|
||||||
|
in the draft checkpoint (using key token_map). Also, the draft config
|
||||||
|
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||||
|
|
||||||
|
def __init__(self, config: EAGLEConfig, *args, **kwargs) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
architectures = getattr(self.config.model, "architectures", [])
|
||||||
|
model_cls, _ = ModelRegistry.resolve_model_cls(architectures)
|
||||||
|
|
||||||
|
self.model = model_cls(self.config.model, *args, **kwargs)
|
||||||
|
self.fc = nn.Linear(config.model.hidden_size * 2,
|
||||||
|
config.model.hidden_size,
|
||||||
|
bias=False)
|
||||||
|
|
||||||
|
self.orig_vocab_size = config.vocab_size
|
||||||
|
self.truncated_vocab_size = config.truncated_vocab_size
|
||||||
|
self.unpadded_vocab_size = self.truncated_vocab_size
|
||||||
|
|
||||||
|
self.lm_head = ParallelLMHead(
|
||||||
|
self.unpadded_vocab_size,
|
||||||
|
config.hidden_size,
|
||||||
|
org_num_embeddings=self.truncated_vocab_size,
|
||||||
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
||||||
|
)
|
||||||
|
|
||||||
|
logit_scale = getattr(config, "logit_scale", 1.0)
|
||||||
|
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
|
||||||
|
self.truncated_vocab_size,
|
||||||
|
logit_scale)
|
||||||
|
|
||||||
|
# Token map is a idx to token mapping to reduce the vocab size for
|
||||||
|
# the draft model. Using smaller vocab size for draft, containing
|
||||||
|
# only most frequent tokens reduces the speculation overhead. This
|
||||||
|
# doesn't affect the acceptance rate much and thus gives more speed
|
||||||
|
# -up. By default, this is disabled and is only used if the EAGLE
|
||||||
|
# checkpoint file has token_map tensor.
|
||||||
|
self.token_map = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def sampler(self):
|
||||||
|
return self.model.sampler
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.Tensor,
|
||||||
|
positions: torch.Tensor,
|
||||||
|
kv_caches: List[torch.Tensor],
|
||||||
|
attn_metadata: AttentionMetadata,
|
||||||
|
previous_hidden_states: torch.Tensor,
|
||||||
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
|
||||||
|
tok_embeds = self.model.model.embed_tokens(input_ids)
|
||||||
|
inputs_embeds = self.fc(
|
||||||
|
torch.cat([tok_embeds, previous_hidden_states], dim=-1))
|
||||||
|
|
||||||
|
inputs_embeds[positions == 0] = 0 # masking inputs at position=0
|
||||||
|
|
||||||
|
hidden_states = self.model.model(
|
||||||
|
input_ids=None,
|
||||||
|
inputs_embeds=inputs_embeds,
|
||||||
|
positions=positions,
|
||||||
|
kv_caches=kv_caches,
|
||||||
|
attn_metadata=attn_metadata,
|
||||||
|
intermediate_tensors=intermediate_tensors)
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def compute_logits(self, hidden_states: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata) -> torch.Tensor:
|
||||||
|
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||||
|
sampling_metadata)
|
||||||
|
|
||||||
|
if self.token_map is not None:
|
||||||
|
_logits = logits
|
||||||
|
logits = -torch.inf * torch.ones(
|
||||||
|
size=(*_logits.shape[:-1], self.orig_vocab_size),
|
||||||
|
device=_logits.device,
|
||||||
|
dtype=_logits.dtype)
|
||||||
|
|
||||||
|
logits[..., self.token_map] = _logits
|
||||||
|
|
||||||
|
return logits
|
||||||
|
|
||||||
|
def sample(
|
||||||
|
self,
|
||||||
|
logits: torch.Tensor,
|
||||||
|
sampling_metadata: SamplingMetadata,
|
||||||
|
) -> Optional[SamplerOutput]:
|
||||||
|
next_tokens = self.sampler(logits, sampling_metadata)
|
||||||
|
return next_tokens
|
||||||
|
|
||||||
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||||
|
# This implementation is incompitable with https://huggingface.co/yuhuili/EAGLE-LLaMA3-Instruct-8B
|
||||||
|
# due to missing lm_head weights and its config being that of a
|
||||||
|
# Llama model. Here's a compatible version with the same weights:
|
||||||
|
# https://huggingface.co/abhigoyal/EAGLE-LLaMA3-Instruct-8B-vllm
|
||||||
|
# Also, here's an example script for converting trained EAGLE
|
||||||
|
# checkpoint to vLLM compatible version: https://gist.github.com/abhigoyal1997/1e7a4109ccb7704fbc67f625e86b2d6d
|
||||||
|
model_weights = {}
|
||||||
|
for name, loaded_weight in weights:
|
||||||
|
if name == "token_map":
|
||||||
|
if self.config.truncated_vocab_size < self.config.vocab_size:
|
||||||
|
self.token_map = nn.Parameter(loaded_weight,
|
||||||
|
requires_grad=False)
|
||||||
|
elif name.startswith("fc."):
|
||||||
|
weight_loader = getattr(self.fc.weight, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(self.fc.weight, loaded_weight)
|
||||||
|
elif name.startswith("model.lm_head.") or name.startswith(
|
||||||
|
"model.model."):
|
||||||
|
model_weights[name.split("model.", 1)[-1]] = loaded_weight
|
||||||
|
elif name.startswith("lm_head.") or name.startswith("model."):
|
||||||
|
model_weights[name] = loaded_weight
|
||||||
|
else:
|
||||||
|
model_weights[f"model.{name}"] = loaded_weight
|
||||||
|
|
||||||
|
lm_head_weight = model_weights.pop("lm_head.weight")
|
||||||
|
|
||||||
|
if self.token_map is not None and\
|
||||||
|
lm_head_weight.shape[0] > self.token_map.shape[0]:
|
||||||
|
|
||||||
|
lm_head_weight = lm_head_weight[self.token_map]
|
||||||
|
|
||||||
|
weight_loader = getattr(self.lm_head.weight, "weight_loader",
|
||||||
|
default_weight_loader)
|
||||||
|
weight_loader(self.lm_head.weight, lm_head_weight)
|
||||||
|
|
||||||
|
self.model.load_weights(model_weights.items())
|
|
@ -30,6 +30,19 @@ class ResidualBlock(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class Medusa(nn.Module):
|
class Medusa(nn.Module):
|
||||||
|
"""This class implements the Medusa draft model from the paper: https://arxiv.org/abs/2401.10774
|
||||||
|
Reference implementation: https://github.com/FasterDecoding/Medusa
|
||||||
|
|
||||||
|
Differences from reference implementation:
|
||||||
|
1. Currently this only supports generating proposals from top-1 tokens.
|
||||||
|
2. We have an optional token_map which reduces draft vocab to most
|
||||||
|
frequently used tokens to give some additional speed-up by reducing
|
||||||
|
sampling overhead. This is disabled unless the checkpoint file has
|
||||||
|
explicit token_map tensor and config has an optional attribute
|
||||||
|
truncated_vocab_size < vocab_size. To use this technique, one has to find
|
||||||
|
the top-k most frequent tokens in target dataset and add that as a tensor
|
||||||
|
in the draft checkpoint (using key token_map). Also, the draft config
|
||||||
|
needs to have truncated_vocab_size (=k) as an attribute."""
|
||||||
|
|
||||||
def __init__(self, config: MedusaConfig, **_) -> None:
|
def __init__(self, config: MedusaConfig, **_) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -57,6 +70,12 @@ class Medusa(nn.Module):
|
||||||
self.truncated_vocab_size,
|
self.truncated_vocab_size,
|
||||||
logit_scale)
|
logit_scale)
|
||||||
|
|
||||||
|
# Token map is a idx to token mapping to reduce the vocab size for
|
||||||
|
# the draft model. Using smaller vocab size for draft, containing
|
||||||
|
# only most frequent tokens reduces the speculation overhead. This
|
||||||
|
# doesn't affect the acceptance rate much and thus gives more speed
|
||||||
|
# -up. By default, this is disabled and is only used if the EAGLE
|
||||||
|
# checkpoint file has token_map tensor.
|
||||||
self.token_map = None
|
self.token_map = None
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
|
def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]:
|
||||||
|
|
|
@ -1092,6 +1092,10 @@ class SamplerOutput(
|
||||||
# Optional last hidden states from the model.
|
# Optional last hidden states from the model.
|
||||||
hidden_states: Optional[torch.Tensor] = None
|
hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
|
# Optional prefill hidden states from the model
|
||||||
|
# (used for models like EAGLE).
|
||||||
|
prefill_hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
# Time taken in the forward pass for this across all workers
|
# Time taken in the forward pass for this across all workers
|
||||||
model_forward_time: Optional[float] = None
|
model_forward_time: Optional[float] = None
|
||||||
|
|
||||||
|
@ -1176,40 +1180,87 @@ class HiddenStates(msgspec.Struct, array_like=True,
|
||||||
omit_defaults=True): # type: ignore[call-arg]
|
omit_defaults=True): # type: ignore[call-arg]
|
||||||
"""Hidden states corresponding to in-progress sequences.
|
"""Hidden states corresponding to in-progress sequences.
|
||||||
Used in speculative decoding to pass hidden states from
|
Used in speculative decoding to pass hidden states from
|
||||||
the target model to the proposer model in the subsequent step.
|
the target model to the proposer model.
|
||||||
|
|
||||||
seq_ids are the sequence ids of each entry of the batch
|
seq_ids are the sequence ids of each entry of the batch
|
||||||
dimension of the hidden_states tensor"""
|
dimension of the hidden_states tensor"""
|
||||||
|
# Scorer hidden states. For prefill step, it is used for hidden states of
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata]
|
# all tokens, whereas for decode step, it use used for last accepted tokens.
|
||||||
hidden_states: torch.Tensor
|
hidden_states: torch.Tensor
|
||||||
|
# The sequence group metadata list. Only needed for decode step.
|
||||||
|
seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None
|
||||||
|
# Scorer hidden states of the 2nd last token proposed by the proposer (
|
||||||
|
# irrespective of whether it was accepted or not). Only used for cases when
|
||||||
|
# last proposed token is accepted (i.e., in case of bonus tokens). For the
|
||||||
|
# case of no bonus tokens, these are ignored.
|
||||||
|
second_last_token_hidden_states: Optional[torch.Tensor] = None
|
||||||
|
|
||||||
_seq_ids: List[int] = msgspec.field(default_factory=list)
|
_seq_ids: List[int] = msgspec.field(default_factory=list)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
|
if self.seq_group_metadata_list is not None:
|
||||||
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
|
assert len(self.seq_group_metadata_list) == len(self.hidden_states)
|
||||||
|
self._seq_ids = get_all_seq_ids(self.seq_group_metadata_list)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def seq_ids(self) -> List[int]:
|
def seq_ids(self) -> List[int]:
|
||||||
return self._seq_ids
|
return self._seq_ids
|
||||||
|
|
||||||
def update(self, seq_group_metadata_list: List[SequenceGroupMetadata],
|
def update(self,
|
||||||
hidden_states: torch.Tensor) -> None:
|
hidden_states: torch.Tensor,
|
||||||
"""Update hidden states from target model invocation."""
|
seq_group_metadata_list: List[SequenceGroupMetadata],
|
||||||
|
second_last_token_hidden_states: Optional[torch.Tensor] = None):
|
||||||
|
"""Update hidden states from target model invocation. Only used for
|
||||||
|
decode steps"""
|
||||||
assert len(seq_group_metadata_list) == len(hidden_states)
|
assert len(seq_group_metadata_list) == len(hidden_states)
|
||||||
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
self._seq_ids.extend(get_all_seq_ids(seq_group_metadata_list))
|
||||||
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
self.hidden_states = torch.cat([self.hidden_states, hidden_states])
|
||||||
|
|
||||||
|
if self.second_last_token_hidden_states is not None:
|
||||||
|
# Adding dummy hidden_states to this to maintain same shape
|
||||||
|
self.second_last_token_hidden_states = torch.cat([
|
||||||
|
self.second_last_token_hidden_states,
|
||||||
|
torch.zeros_like(hidden_states)
|
||||||
|
if second_last_token_hidden_states is None else
|
||||||
|
second_last_token_hidden_states
|
||||||
|
])
|
||||||
|
|
||||||
def prune(self,
|
def prune(self,
|
||||||
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
seq_group_metadata_list: List[SequenceGroupMetadata]) -> None:
|
||||||
"""Prune to provided list of sequence ids."""
|
"""Prune to provided list of sequence ids. Only used for decode steps.
|
||||||
|
"""
|
||||||
|
# Currently this prunes all seq_ids not present in
|
||||||
|
# seq_group_metadata_list which might cause problems where a sequence
|
||||||
|
# may be "paused" then "resumed" later. This should only prune sequences
|
||||||
|
# which are confirmed to be aborted.
|
||||||
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
seq_ids = get_all_seq_ids(seq_group_metadata_list)
|
||||||
if seq_ids != self._seq_ids:
|
if seq_ids != self._seq_ids:
|
||||||
# Batch contents changed - prune removed sequences.
|
# Batch contents changed - prune removed sequences.
|
||||||
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
|
index = [self._seq_ids.index(seq_id) for seq_id in seq_ids]
|
||||||
self.hidden_states = self.hidden_states[index]
|
self.hidden_states = self.hidden_states[index]
|
||||||
|
if self.second_last_token_hidden_states is not None:
|
||||||
|
self.second_last_token_hidden_states = self\
|
||||||
|
.second_last_token_hidden_states[index]
|
||||||
self._seq_ids = seq_ids
|
self._seq_ids = seq_ids
|
||||||
|
|
||||||
|
def expand_with_bonus_tokens(
|
||||||
|
self, seq_with_bonus_token_in_last_step: set) -> None:
|
||||||
|
"""Expand hidden states for sequences with bonus tokens. This is in
|
||||||
|
alignment with `MultiStepWorker._expand_execute_model_request`."""
|
||||||
|
if self.second_last_token_hidden_states is None \
|
||||||
|
or not seq_with_bonus_token_in_last_step:
|
||||||
|
return
|
||||||
|
|
||||||
|
index = []
|
||||||
|
for seq_id in self._seq_ids:
|
||||||
|
i = self._seq_ids.index(seq_id)
|
||||||
|
if seq_id in seq_with_bonus_token_in_last_step:
|
||||||
|
index.append(i + len(self._seq_ids))
|
||||||
|
index.append(i)
|
||||||
|
|
||||||
|
self.hidden_states = torch.cat(
|
||||||
|
[self.hidden_states, self.second_last_token_hidden_states])[index]
|
||||||
|
|
||||||
|
|
||||||
class ExecuteModelRequest(
|
class ExecuteModelRequest(
|
||||||
msgspec.Struct,
|
msgspec.Struct,
|
||||||
|
|
|
@ -203,6 +203,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||||
self,
|
self,
|
||||||
model_input: ModelInputForGPUWithSamplingMetadata,
|
model_input: ModelInputForGPUWithSamplingMetadata,
|
||||||
kv_caches: List[torch.Tensor],
|
kv_caches: List[torch.Tensor],
|
||||||
|
previous_hidden_states: Optional[torch.Tensor] = None,
|
||||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||||
num_steps: int = 1,
|
num_steps: int = 1,
|
||||||
) -> Optional[List[SamplerOutput]]:
|
) -> Optional[List[SamplerOutput]]:
|
||||||
|
@ -280,13 +281,30 @@ class TP1DraftModelRunner(ModelRunner):
|
||||||
graph_batch_size = model_input.input_tokens.shape[0]
|
graph_batch_size = model_input.input_tokens.shape[0]
|
||||||
model_executable = (self.graph_runners[model_input.virtual_engine]
|
model_executable = (self.graph_runners[model_input.virtual_engine]
|
||||||
[graph_batch_size])
|
[graph_batch_size])
|
||||||
|
|
||||||
|
if previous_hidden_states is not None:
|
||||||
|
hidden_states = torch.cat([
|
||||||
|
previous_hidden_states,
|
||||||
|
torch.empty([
|
||||||
|
graph_batch_size - previous_hidden_states.shape[0],
|
||||||
|
*previous_hidden_states.shape[1:]
|
||||||
|
],
|
||||||
|
dtype=previous_hidden_states.dtype,
|
||||||
|
device=previous_hidden_states.device)
|
||||||
|
])
|
||||||
|
else:
|
||||||
|
hidden_states = None
|
||||||
else:
|
else:
|
||||||
model_executable = self.model
|
model_executable = self.model
|
||||||
|
hidden_states = previous_hidden_states
|
||||||
|
|
||||||
outputs: List[SamplerOutput] = []
|
outputs: List[SamplerOutput] = []
|
||||||
for step in range(num_steps):
|
for step in range(num_steps):
|
||||||
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
multi_modal_kwargs = model_input.multi_modal_kwargs or {}
|
||||||
|
|
||||||
|
kwargs = {"previous_hidden_states": hidden_states} \
|
||||||
|
if previous_hidden_states is not None else {}
|
||||||
|
|
||||||
# Run model
|
# Run model
|
||||||
hidden_states = model_executable(
|
hidden_states = model_executable(
|
||||||
input_ids=model_input.input_tokens,
|
input_ids=model_input.input_tokens,
|
||||||
|
@ -296,6 +314,7 @@ class TP1DraftModelRunner(ModelRunner):
|
||||||
intermediate_tensors=intermediate_tensors,
|
intermediate_tensors=intermediate_tensors,
|
||||||
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
**MultiModalInputs.as_kwargs(multi_modal_kwargs,
|
||||||
device=self.device),
|
device=self.device),
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute the logits.
|
# Compute the logits.
|
||||||
|
|
|
@ -4,8 +4,8 @@ from typing import Dict, List, Set, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from vllm.sequence import (ExecuteModelRequest, SamplerOutput, SequenceData,
|
from vllm.sequence import (ExecuteModelRequest, HiddenStates, SamplerOutput,
|
||||||
SequenceGroupMetadata)
|
SequenceData, SequenceGroupMetadata)
|
||||||
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
from vllm.spec_decode.draft_model_runner import TP1DraftModelRunner
|
||||||
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
from vllm.spec_decode.interfaces import (SpeculativeProposals,
|
||||||
SpeculativeProposer)
|
SpeculativeProposer)
|
||||||
|
@ -157,6 +157,12 @@ class MultiStepWorker(Worker, ProposerWorkerBase):
|
||||||
|
|
||||||
updated_execute_model_req.seq_group_metadata_list =\
|
updated_execute_model_req.seq_group_metadata_list =\
|
||||||
updated_seq_group_metadata_list
|
updated_seq_group_metadata_list
|
||||||
|
|
||||||
|
if isinstance(updated_execute_model_req.previous_hidden_states,
|
||||||
|
HiddenStates):
|
||||||
|
updated_execute_model_req.previous_hidden_states\
|
||||||
|
.expand_with_bonus_tokens(seq_with_bonus_token_in_last_step)
|
||||||
|
|
||||||
return updated_execute_model_req, indices_of_original_sequence_groups
|
return updated_execute_model_req, indices_of_original_sequence_groups
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -147,6 +147,11 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
draft_worker_kwargs[
|
draft_worker_kwargs[
|
||||||
"model_runner_cls"] = TP1DraftModelRunner
|
"model_runner_cls"] = TP1DraftModelRunner
|
||||||
else:
|
else:
|
||||||
|
if draft_worker_kwargs[
|
||||||
|
"model_config"].hf_config.model_type == "eagle":
|
||||||
|
raise NotImplementedError(
|
||||||
|
"EAGLE does not support TP > 1 yet")
|
||||||
|
|
||||||
allow_zero_draft_token_step = False
|
allow_zero_draft_token_step = False
|
||||||
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
proposer_worker = MultiStepWorker(**draft_worker_kwargs)
|
||||||
|
|
||||||
|
@ -355,14 +360,34 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
execute_model_req)
|
execute_model_req)
|
||||||
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
num_lookahead_slots = execute_model_req.num_lookahead_slots
|
||||||
|
|
||||||
|
# Speculative decoding is disabled in the following cases:
|
||||||
|
# 1. Prefill phase: Speculative decoding is not
|
||||||
|
# used during the prefill phase.
|
||||||
|
# 2. Auto-disable enabled: The running queue size exceeds
|
||||||
|
# the specified threshold.
|
||||||
|
# 3. No request: There are no requests in the batch.
|
||||||
|
# In any of these cases, the proposer and scorer workers
|
||||||
|
# are called normally.
|
||||||
|
no_spec = num_lookahead_slots == 0 or len(
|
||||||
|
execute_model_req.seq_group_metadata_list
|
||||||
|
) == 0 or disable_all_speculation
|
||||||
|
|
||||||
# Broadcast how many lookahead slots are scheduled for this step, and
|
# Broadcast how many lookahead slots are scheduled for this step, and
|
||||||
# whether all speculation is disabled, to all non-driver workers.
|
# whether all speculation is disabled, to all non-driver workers.
|
||||||
|
|
||||||
# This is required as if the number of draft model runs changes
|
# This is required as if the number of draft model runs changes
|
||||||
# dynamically, the non-driver workers won't know unless we perform a
|
# dynamically, the non-driver workers won't know unless we perform a
|
||||||
# communication to inform them.
|
# communication to inform them.
|
||||||
|
|
||||||
|
# no_spec is used to signal non-driver worker about prefill vs decode
|
||||||
|
# stage. This is needed to ensure that order of execution of proposer
|
||||||
|
# and scorer is same in both driver and non-driver workers (i.e.,
|
||||||
|
# scorer -> proposer for prefill and proposer -> scorer in decode). This
|
||||||
|
# order is needed to support models like EAGLE that take scorer states
|
||||||
|
# as inputs.
|
||||||
broadcast_dict = dict(
|
broadcast_dict = dict(
|
||||||
num_lookahead_slots=num_lookahead_slots,
|
num_lookahead_slots=num_lookahead_slots,
|
||||||
|
no_spec=no_spec,
|
||||||
disable_all_speculation=disable_all_speculation,
|
disable_all_speculation=disable_all_speculation,
|
||||||
)
|
)
|
||||||
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
broadcast_tensor_dict(broadcast_dict, src=self._driver_rank)
|
||||||
|
@ -373,17 +398,7 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
self._maybe_disable_speculative_tokens(
|
self._maybe_disable_speculative_tokens(
|
||||||
disable_all_speculation, execute_model_req.seq_group_metadata_list)
|
disable_all_speculation, execute_model_req.seq_group_metadata_list)
|
||||||
|
|
||||||
# Speculative decoding is disabled in the following cases:
|
if no_spec:
|
||||||
# 1. Prefill phase: Speculative decoding is not
|
|
||||||
# used during the prefill phase.
|
|
||||||
# 2. Auto-disable enabled: The running queue size exceeds
|
|
||||||
# the specified threshold.
|
|
||||||
# 3. No request: There are no requests in the batch.
|
|
||||||
# In any of these cases, the proposer and scorer workers
|
|
||||||
# are called normally.
|
|
||||||
if num_lookahead_slots == 0 or len(
|
|
||||||
execute_model_req.seq_group_metadata_list
|
|
||||||
) == 0 or disable_all_speculation:
|
|
||||||
return self._run_no_spec(execute_model_req,
|
return self._run_no_spec(execute_model_req,
|
||||||
skip_proposer=disable_all_speculation)
|
skip_proposer=disable_all_speculation)
|
||||||
return self._run_speculative_decoding_step(execute_model_req,
|
return self._run_speculative_decoding_step(execute_model_req,
|
||||||
|
@ -464,8 +479,6 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
not called, meaning that the kv-cache in proposer for requests is not
|
not called, meaning that the kv-cache in proposer for requests is not
|
||||||
updated, so they cannot enable spec decode in the rest decoding.
|
updated, so they cannot enable spec decode in the rest decoding.
|
||||||
"""
|
"""
|
||||||
if not skip_proposer:
|
|
||||||
self.proposer_worker.execute_model(execute_model_req)
|
|
||||||
|
|
||||||
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
sampler_output = self.scorer_worker.execute_model(execute_model_req)
|
||||||
assert len(sampler_output) == 1
|
assert len(sampler_output) == 1
|
||||||
|
@ -476,10 +489,20 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
if hidden_states is not None:
|
if hidden_states is not None:
|
||||||
if self.previous_hidden_states is None:
|
if self.previous_hidden_states is None:
|
||||||
self.previous_hidden_states = HiddenStates(
|
self.previous_hidden_states = HiddenStates(
|
||||||
execute_model_req.seq_group_metadata_list, hidden_states)
|
hidden_states, execute_model_req.seq_group_metadata_list)
|
||||||
else:
|
else:
|
||||||
self.previous_hidden_states.update(
|
self.previous_hidden_states.update(
|
||||||
execute_model_req.seq_group_metadata_list, hidden_states)
|
hidden_states, execute_model_req.seq_group_metadata_list)
|
||||||
|
|
||||||
|
if not skip_proposer:
|
||||||
|
# We prepare the prefill hidden states here so that there no
|
||||||
|
# additional complexity in worker for spec_decode vs non_spec_decode
|
||||||
|
# flow and execute_model doesn't need additional modifications.
|
||||||
|
execute_model_req.previous_hidden_states = \
|
||||||
|
prepare_prefill_hidden_states(
|
||||||
|
sampler_output.prefill_hidden_states)
|
||||||
|
|
||||||
|
self.proposer_worker.execute_model(execute_model_req)
|
||||||
|
|
||||||
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
sampler_output_to_return = (self._serialize_sampler_output_no_logprobs(
|
||||||
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
execute_model_req=execute_model_req, sampler_output=sampler_output)
|
||||||
|
@ -507,15 +530,23 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
return False
|
return False
|
||||||
num_lookahead_slots = data["num_lookahead_slots"]
|
num_lookahead_slots = data["num_lookahead_slots"]
|
||||||
|
|
||||||
# Even if num_lookahead_slots is zero, we want to run the proposer model
|
# In case of prefill, scorer_worker has to be run before proposer so
|
||||||
# as it may have KV.
|
# that the hidden states can be propagated to proposer when needed.
|
||||||
#
|
if data["no_spec"]:
|
||||||
# We run the proposer once per lookahead slot. In the future we should
|
self.scorer_worker.execute_model()
|
||||||
# delegate how many times it runs to the proposer.
|
|
||||||
for _ in range(max(num_lookahead_slots, 1)):
|
if not data["disable_all_speculation"]:
|
||||||
self.proposer_worker.execute_model()
|
# Even if num_lookahead_slots is zero, we want to run the
|
||||||
|
# proposer model as it may have KV.
|
||||||
|
#
|
||||||
|
# We run the proposer once per lookahead slot. In the future we
|
||||||
|
# should delegate how many times it runs to the proposer.
|
||||||
|
for _ in range(max(num_lookahead_slots, 1)):
|
||||||
|
self.proposer_worker.execute_model()
|
||||||
|
|
||||||
|
if not data["no_spec"]:
|
||||||
|
self.scorer_worker.execute_model()
|
||||||
|
|
||||||
self.scorer_worker.execute_model()
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
@nvtx_range("spec_decode_worker._run_speculative_decoding_step")
|
||||||
|
@ -546,6 +577,8 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
raise RuntimeError("Cannot handle cases where distributed draft "
|
raise RuntimeError("Cannot handle cases where distributed draft "
|
||||||
"workers generate no tokens")
|
"workers generate no tokens")
|
||||||
|
|
||||||
|
execute_model_req.previous_hidden_states = None
|
||||||
|
|
||||||
with Timer() as scoring_timer:
|
with Timer() as scoring_timer:
|
||||||
proposal_scores = self.scorer.score_proposals(
|
proposal_scores = self.scorer.score_proposals(
|
||||||
execute_model_req,
|
execute_model_req,
|
||||||
|
@ -651,10 +684,12 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
|
||||||
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
accepted_index = accepted_token_ids + 1 # Convert -1 to 0
|
||||||
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
accepted_index = accepted_index.count_nonzero(dim=1).add_(-1)
|
||||||
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
index = accepted_index[:, None, None].expand(-1, 1, hs_size)
|
||||||
|
second_last_token_hidden_states = hidden_states[:, -2] # b x d
|
||||||
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
hidden_states = hidden_states.gather(1, index).squeeze(1) # b x d
|
||||||
# Store hidden states from target model for subsequent decode step
|
# Store hidden states from target model for subsequent decode step
|
||||||
self.previous_hidden_states = HiddenStates(seq_group_metadata_list,
|
self.previous_hidden_states = HiddenStates(
|
||||||
hidden_states)
|
hidden_states, seq_group_metadata_list,
|
||||||
|
second_last_token_hidden_states)
|
||||||
|
|
||||||
return accepted_token_ids, logprobs
|
return accepted_token_ids, logprobs
|
||||||
|
|
||||||
|
@ -951,3 +986,15 @@ def split_num_cache_blocks_evenly(scorer_cache_block_size_bytes: int,
|
||||||
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
|
(proposer_cache_block_size_bytes + scorer_cache_block_size_bytes))
|
||||||
|
|
||||||
return new_num_gpu_blocks
|
return new_num_gpu_blocks
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_prefill_hidden_states(
|
||||||
|
prefill_hidden_states: torch.Tensor) -> HiddenStates:
|
||||||
|
# For prefill step in proposer, we run the model for N-1 tokens
|
||||||
|
# because Nth token will be processed in the first decode step. For
|
||||||
|
# N-1 tokens, the input should be 0:N-1 hidden states which should
|
||||||
|
# be concatanated with 1:N token (since output of scorer has to be
|
||||||
|
# the input for proposer). Therefore, we shift the hidden states to
|
||||||
|
# align n-1th hidden state with nth token.
|
||||||
|
return HiddenStates(prefill_hidden_states.roll(
|
||||||
|
shifts=1, dims=0)) if prefill_hidden_states is not None else None
|
||||||
|
|
|
@ -11,10 +11,11 @@ from transformers.models.auto.modeling_auto import (
|
||||||
from vllm.envs import VLLM_USE_MODELSCOPE
|
from vllm.envs import VLLM_USE_MODELSCOPE
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
from vllm.transformers_utils.configs import (ChatGLMConfig, DbrxConfig,
|
||||||
InternVLChatConfig, JAISConfig,
|
EAGLEConfig, InternVLChatConfig,
|
||||||
MedusaConfig, MLPSpeculatorConfig,
|
JAISConfig, MedusaConfig,
|
||||||
MPTConfig, NemotronConfig,
|
MLPSpeculatorConfig, MPTConfig,
|
||||||
RWConfig, UltravoxConfig)
|
NemotronConfig, RWConfig,
|
||||||
|
UltravoxConfig)
|
||||||
|
|
||||||
if VLLM_USE_MODELSCOPE:
|
if VLLM_USE_MODELSCOPE:
|
||||||
from modelscope import AutoConfig
|
from modelscope import AutoConfig
|
||||||
|
@ -32,6 +33,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||||
"jais": JAISConfig,
|
"jais": JAISConfig,
|
||||||
"mlp_speculator": MLPSpeculatorConfig,
|
"mlp_speculator": MLPSpeculatorConfig,
|
||||||
"medusa": MedusaConfig,
|
"medusa": MedusaConfig,
|
||||||
|
"eagle": EAGLEConfig,
|
||||||
"internvl_chat": InternVLChatConfig,
|
"internvl_chat": InternVLChatConfig,
|
||||||
"nemotron": NemotronConfig,
|
"nemotron": NemotronConfig,
|
||||||
"ultravox": UltravoxConfig,
|
"ultravox": UltravoxConfig,
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
from vllm.transformers_utils.configs.chatglm import ChatGLMConfig
|
||||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||||
|
from vllm.transformers_utils.configs.eagle import EAGLEConfig
|
||||||
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
|
||||||
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
|
||||||
# `FalconConfig` class from the official HuggingFace transformers library.
|
# `FalconConfig` class from the official HuggingFace transformers library.
|
||||||
|
@ -20,6 +21,7 @@ __all__ = [
|
||||||
"InternVLChatConfig",
|
"InternVLChatConfig",
|
||||||
"JAISConfig",
|
"JAISConfig",
|
||||||
"MedusaConfig",
|
"MedusaConfig",
|
||||||
|
"EAGLEConfig",
|
||||||
"MLPSpeculatorConfig",
|
"MLPSpeculatorConfig",
|
||||||
"NemotronConfig",
|
"NemotronConfig",
|
||||||
"UltravoxConfig",
|
"UltravoxConfig",
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
import os
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
from transformers import AutoConfig, PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class EAGLEConfig(PretrainedConfig):
|
||||||
|
model_type = "eagle"
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
model: Union[PretrainedConfig, dict, None] = None,
|
||||||
|
truncated_vocab_size: Optional[int] = None,
|
||||||
|
**kwargs):
|
||||||
|
|
||||||
|
model_config = None if model is None else (AutoConfig.for_model(
|
||||||
|
**model) if isinstance(model, dict) else model)
|
||||||
|
|
||||||
|
for k, v in kwargs.items():
|
||||||
|
if k != "architectures" and k != "model_type" and hasattr(
|
||||||
|
model_config, k):
|
||||||
|
setattr(model_config, k, v)
|
||||||
|
|
||||||
|
self.model = model_config
|
||||||
|
|
||||||
|
if self.model is None:
|
||||||
|
self.truncated_vocab_size = None
|
||||||
|
else:
|
||||||
|
self.truncated_vocab_size = self.model.vocab_size if \
|
||||||
|
truncated_vocab_size is None else truncated_vocab_size
|
||||||
|
|
||||||
|
if "architectures" not in kwargs:
|
||||||
|
kwargs["architectures"] = ["EAGLEModel"]
|
||||||
|
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
if self.model is not None:
|
||||||
|
for k, v in self.model.to_dict().items():
|
||||||
|
if not hasattr(self, k):
|
||||||
|
setattr(self, k, v)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls,
|
||||||
|
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||||
|
**kwargs,
|
||||||
|
) -> "EAGLEConfig":
|
||||||
|
config_dict, kwargs = cls.get_config_dict(
|
||||||
|
pretrained_model_name_or_path, **kwargs)
|
||||||
|
return cls.from_dict(config_dict, **kwargs)
|
|
@ -1,5 +1,6 @@
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import gc
|
import gc
|
||||||
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import time
|
import time
|
||||||
import warnings
|
import warnings
|
||||||
|
@ -1192,6 +1193,18 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||||
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
max_batch_size = max(_BATCH_SIZES_TO_CAPTURE)
|
||||||
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
input_tokens = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||||
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
input_positions = torch.zeros(max_batch_size, dtype=torch.long).cuda()
|
||||||
|
|
||||||
|
# Prepare dummy previous_hidden_states only if needed by the model.
|
||||||
|
# This is used by draft models such as EAGLE.
|
||||||
|
previous_hidden_states = None
|
||||||
|
if "previous_hidden_states" in inspect.signature(
|
||||||
|
self.model.forward).parameters:
|
||||||
|
previous_hidden_states = torch.empty(
|
||||||
|
[max_batch_size,
|
||||||
|
self.model_config.get_hidden_size()],
|
||||||
|
dtype=self.model_config.dtype,
|
||||||
|
device=self.device)
|
||||||
|
|
||||||
intermediate_inputs = None
|
intermediate_inputs = None
|
||||||
if not get_pp_group().is_first_rank:
|
if not get_pp_group().is_first_rank:
|
||||||
intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
intermediate_inputs = self.model.make_empty_intermediate_tensors(
|
||||||
|
@ -1264,6 +1277,11 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||||
"stream":
|
"stream":
|
||||||
graph_capture_context.stream
|
graph_capture_context.stream
|
||||||
}
|
}
|
||||||
|
if previous_hidden_states is not None:
|
||||||
|
capture_inputs[
|
||||||
|
"previous_hidden_states"] = previous_hidden_states[:
|
||||||
|
batch_size]
|
||||||
|
|
||||||
if self.has_seqlen_agnostic:
|
if self.has_seqlen_agnostic:
|
||||||
# Only used by Mamba-based models CUDA graph atm (Jamba)
|
# Only used by Mamba-based models CUDA graph atm (Jamba)
|
||||||
capture_inputs.update({
|
capture_inputs.update({
|
||||||
|
@ -1462,6 +1480,7 @@ class ModelRunner(GPUModelRunnerBase[ModelInputForGPUWithSamplingMetadata]):
|
||||||
if model_input.is_prompt:
|
if model_input.is_prompt:
|
||||||
hidden_states = hidden_or_intermediate_states.index_select(
|
hidden_states = hidden_or_intermediate_states.index_select(
|
||||||
0, indices)
|
0, indices)
|
||||||
|
output.prefill_hidden_states = hidden_or_intermediate_states
|
||||||
elif decode_meta.use_cuda_graph:
|
elif decode_meta.use_cuda_graph:
|
||||||
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
hidden_states = hidden_or_intermediate_states[:len(indices)]
|
||||||
else:
|
else:
|
||||||
|
@ -1510,11 +1529,11 @@ class CUDAGraphRunner:
|
||||||
# Note one iteration is not enough for torch.jit.script
|
# Note one iteration is not enough for torch.jit.script
|
||||||
for _ in range(_NUM_WARMUP_ITERS):
|
for _ in range(_NUM_WARMUP_ITERS):
|
||||||
self.model(
|
self.model(
|
||||||
input_ids,
|
input_ids=input_ids,
|
||||||
positions,
|
positions=positions,
|
||||||
kv_caches,
|
kv_caches=kv_caches,
|
||||||
attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
intermediate_inputs,
|
intermediate_tensors=intermediate_inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
@ -1523,11 +1542,11 @@ class CUDAGraphRunner:
|
||||||
self._graph = torch.cuda.CUDAGraph()
|
self._graph = torch.cuda.CUDAGraph()
|
||||||
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
with torch.cuda.graph(self._graph, pool=memory_pool, stream=stream):
|
||||||
output_hidden_or_intermediate_states = self.model(
|
output_hidden_or_intermediate_states = self.model(
|
||||||
input_ids,
|
input_ids=input_ids,
|
||||||
positions,
|
positions=positions,
|
||||||
kv_caches,
|
kv_caches=kv_caches,
|
||||||
attn_metadata,
|
attn_metadata=attn_metadata,
|
||||||
intermediate_inputs,
|
intermediate_tensors=intermediate_inputs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
if hidden_or_intermediate_states is not None:
|
if hidden_or_intermediate_states is not None:
|
||||||
|
@ -1588,6 +1607,11 @@ class CUDAGraphRunner:
|
||||||
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
if "seqlen_agnostic_capture_inputs" in self.input_buffers:
|
||||||
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
self.model.copy_inputs_before_cuda_graphs(self.input_buffers,
|
||||||
**kwargs)
|
**kwargs)
|
||||||
|
|
||||||
|
if "previous_hidden_states" in self.input_buffers:
|
||||||
|
self.input_buffers["previous_hidden_states"].copy_(
|
||||||
|
kwargs["previous_hidden_states"], non_blocking=True)
|
||||||
|
|
||||||
if intermediate_tensors is not None:
|
if intermediate_tensors is not None:
|
||||||
for key in intermediate_tensors.tensors:
|
for key in intermediate_tensors.tensors:
|
||||||
if key != "model_execute_time" and key != "model_forward_time":
|
if key != "model_execute_time" and key != "model_forward_time":
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
from vllm.distributed import broadcast_tensor_dict, get_pp_group
|
||||||
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
from vllm.sequence import ExecuteModelRequest, SamplerOutput
|
||||||
|
@ -43,7 +45,7 @@ class MultiStepWorker(Worker):
|
||||||
|
|
||||||
def _get_driver_input_and_broadcast(
|
def _get_driver_input_and_broadcast(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> Tuple[BroadcastableModelInput, WorkerInput]:
|
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||||
"""
|
"""
|
||||||
Get the driver input and broadcast it to other workers.
|
Get the driver input and broadcast it to other workers.
|
||||||
"""
|
"""
|
||||||
|
@ -85,7 +87,9 @@ class MultiStepWorker(Worker):
|
||||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||||
broadcast_tensor_dict(broadcast_data, src=0)
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
|
||||||
return model_input, worker_input
|
# Retuning empty dict here to keep this compatible with
|
||||||
|
# `LocalOrDistributedWorkerBase._get_driver_input_and_broadcast`
|
||||||
|
return model_input, worker_input, {}
|
||||||
|
|
||||||
def _prepare_last_sampled_token_ids_for_tp_workers(
|
def _prepare_last_sampled_token_ids_for_tp_workers(
|
||||||
self,
|
self,
|
||||||
|
@ -130,7 +134,8 @@ class MultiStepWorker(Worker):
|
||||||
def prepare_input(
|
def prepare_input(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None,
|
execute_model_req: Optional[ExecuteModelRequest] = None,
|
||||||
) -> Optional[Tuple[StatefulModelInput, WorkerInput]]:
|
) -> Optional[Tuple[StatefulModelInput, WorkerInput, Dict[str,
|
||||||
|
torch.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Depending on the current state of the request and multi step worker,
|
Depending on the current state of the request and multi step worker,
|
||||||
this method may skip the normal _prepare_model_input and
|
this method may skip the normal _prepare_model_input and
|
||||||
|
@ -148,8 +153,8 @@ class MultiStepWorker(Worker):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
virtual_engine = execute_model_req.virtual_engine
|
virtual_engine = execute_model_req.virtual_engine
|
||||||
model_input, worker_input = self._get_driver_input_and_broadcast(
|
(model_input, worker_input,
|
||||||
execute_model_req)
|
kwargs) = self._get_driver_input_and_broadcast(execute_model_req)
|
||||||
assert isinstance(model_input, StatefulModelInput)
|
assert isinstance(model_input, StatefulModelInput)
|
||||||
if execute_model_req.is_first_multi_step:
|
if execute_model_req.is_first_multi_step:
|
||||||
# cache the worker input and model input for the next steps
|
# cache the worker input and model input for the next steps
|
||||||
|
@ -162,7 +167,7 @@ class MultiStepWorker(Worker):
|
||||||
# loop
|
# loop
|
||||||
if broadcast_data is None:
|
if broadcast_data is None:
|
||||||
return None
|
return None
|
||||||
model_input, worker_input = broadcast_data
|
model_input, worker_input, kwargs = broadcast_data
|
||||||
assert isinstance(model_input, StatefulModelInput)
|
assert isinstance(model_input, StatefulModelInput)
|
||||||
virtual_engine = worker_input.virtual_engine
|
virtual_engine = worker_input.virtual_engine
|
||||||
if model_input.is_first_multi_step:
|
if model_input.is_first_multi_step:
|
||||||
|
@ -186,4 +191,4 @@ class MultiStepWorker(Worker):
|
||||||
|
|
||||||
assert model_input is not None
|
assert model_input is not None
|
||||||
assert worker_input is not None
|
assert worker_input is not None
|
||||||
return model_input, worker_input
|
return model_input, worker_input, kwargs
|
||||||
|
|
|
@ -86,7 +86,7 @@ class Worker(LocalOrDistributedWorkerBase):
|
||||||
or (speculative_config.draft_model_config.model ==
|
or (speculative_config.draft_model_config.model ==
|
||||||
model_config.model) \
|
model_config.model) \
|
||||||
or (speculative_config.draft_model_config.hf_config.model_type
|
or (speculative_config.draft_model_config.hf_config.model_type
|
||||||
not in ["medusa", "mlp_speculator"]) \
|
not in ["medusa", "mlp_speculator", "eagle"]) \
|
||||||
else {"return_hidden_states": True}
|
else {"return_hidden_states": True}
|
||||||
|
|
||||||
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
ModelRunnerClass: Type[GPUModelRunnerBase] = ModelRunner
|
||||||
|
|
|
@ -222,7 +222,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def _get_worker_input_from_broadcast(
|
def _get_worker_input_from_broadcast(
|
||||||
self) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
|
self
|
||||||
|
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||||
|
str, torch.Tensor]]]:
|
||||||
""" Get the worker input from the broadcasted tensor dict. """
|
""" Get the worker input from the broadcasted tensor dict. """
|
||||||
assert self.do_metadata_broadcast
|
assert self.do_metadata_broadcast
|
||||||
assert not self.is_driver_worker
|
assert not self.is_driver_worker
|
||||||
|
@ -235,11 +237,13 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||||
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
self.model_runner.make_model_input_from_broadcasted_tensor_dict(
|
||||||
broadcast_data))
|
broadcast_data))
|
||||||
|
|
||||||
return model_input, worker_input
|
kwargs = extract_previous_hidden_states(broadcast_data)
|
||||||
|
|
||||||
|
return model_input, worker_input, kwargs
|
||||||
|
|
||||||
def _get_driver_input_and_broadcast(
|
def _get_driver_input_and_broadcast(
|
||||||
self, execute_model_req: ExecuteModelRequest
|
self, execute_model_req: ExecuteModelRequest
|
||||||
) -> Tuple[BroadcastableModelInput, WorkerInput]:
|
) -> Tuple[BroadcastableModelInput, WorkerInput, Dict[str, torch.Tensor]]:
|
||||||
""" Get the driver input and broadcast it to other workers. """
|
""" Get the driver input and broadcast it to other workers. """
|
||||||
assert self.is_driver_worker
|
assert self.is_driver_worker
|
||||||
|
|
||||||
|
@ -251,17 +255,21 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||||
execute_model_req.virtual_engine,
|
execute_model_req.virtual_engine,
|
||||||
execute_model_req.finished_requests_ids))
|
execute_model_req.finished_requests_ids))
|
||||||
|
|
||||||
|
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||||
|
|
||||||
if self.do_metadata_broadcast:
|
if self.do_metadata_broadcast:
|
||||||
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
broadcast_data = worker_input.as_broadcastable_tensor_dict()
|
||||||
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
broadcast_data.update(model_input.as_broadcastable_tensor_dict())
|
||||||
|
broadcast_data.update(kwargs)
|
||||||
broadcast_tensor_dict(broadcast_data, src=0)
|
broadcast_tensor_dict(broadcast_data, src=0)
|
||||||
|
|
||||||
return model_input, worker_input
|
return model_input, worker_input, kwargs
|
||||||
|
|
||||||
def prepare_input(
|
def prepare_input(
|
||||||
self,
|
self,
|
||||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||||
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput]]:
|
) -> Optional[Tuple[BroadcastableModelInput, WorkerInput, Dict[
|
||||||
|
str, torch.Tensor]]]:
|
||||||
"""
|
"""
|
||||||
Prepare the inputs to ModelRunner and workers.
|
Prepare the inputs to ModelRunner and workers.
|
||||||
"""
|
"""
|
||||||
|
@ -291,7 +299,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||||
if inputs is None:
|
if inputs is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
model_input, worker_input = inputs
|
model_input, worker_input, kwargs = inputs
|
||||||
num_steps = worker_input.num_steps
|
num_steps = worker_input.num_steps
|
||||||
|
|
||||||
self.execute_worker(worker_input)
|
self.execute_worker(worker_input)
|
||||||
|
@ -312,9 +320,14 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||||
"model_execute_time", torch.tensor(0)).item()
|
"model_execute_time", torch.tensor(0)).item()
|
||||||
|
|
||||||
output = self.model_runner.execute_model(
|
output = self.model_runner.execute_model(
|
||||||
model_input, self.kv_cache[worker_input.virtual_engine]
|
model_input=model_input,
|
||||||
if self.kv_cache is not None else None, intermediate_tensors,
|
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||||
num_steps)
|
if self.kv_cache is not None else None,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
num_steps=num_steps,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
model_execute_time = time.perf_counter() - start_time
|
model_execute_time = time.perf_counter() - start_time
|
||||||
if not get_pp_group().is_last_rank:
|
if not get_pp_group().is_last_rank:
|
||||||
# output is IntermediateTensors
|
# output is IntermediateTensors
|
||||||
|
@ -360,9 +373,15 @@ class LocalOrDistributedWorkerBase(WorkerBase):
|
||||||
if worker_input.num_seq_groups == 0:
|
if worker_input.num_seq_groups == 0:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
kwargs = extract_previous_hidden_states(execute_model_req)
|
||||||
|
|
||||||
return self.model_runner.execute_model(
|
return self.model_runner.execute_model(
|
||||||
model_input, self.kv_cache[worker_input.virtual_engine]
|
model_input=model_input,
|
||||||
if self.kv_cache is not None else None, intermediate_tensors)
|
kv_caches=self.kv_cache[worker_input.virtual_engine]
|
||||||
|
if self.kv_cache is not None else None,
|
||||||
|
intermediate_tensors=intermediate_tensors,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WorkerWrapperBase:
|
class WorkerWrapperBase:
|
||||||
|
@ -439,3 +458,23 @@ class WorkerWrapperBase:
|
||||||
"This might cause deadlock in distributed execution.")
|
"This might cause deadlock in distributed execution.")
|
||||||
logger.exception(msg)
|
logger.exception(msg)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def extract_previous_hidden_states(
|
||||||
|
data: Union[ExecuteModelRequest, Dict[str, torch.Tensor]]) -> \
|
||||||
|
Dict[str, torch.Tensor]:
|
||||||
|
"""If data contains previous_hidden_states, extract it. This returns a dict
|
||||||
|
which can be used directly as additional kwargs in any following
|
||||||
|
execute_model calls. This is used in draft models like EAGLE."""
|
||||||
|
output = {}
|
||||||
|
|
||||||
|
# When called from non-driver worker, data is dict but when called from
|
||||||
|
# driver worker, data is ExecuteModelRequest.
|
||||||
|
if isinstance(data, dict):
|
||||||
|
if "previous_hidden_states" in data:
|
||||||
|
output["previous_hidden_states"] = data["previous_hidden_states"]
|
||||||
|
elif data.previous_hidden_states is not None:
|
||||||
|
output["previous_hidden_states"] = data.previous_hidden_states\
|
||||||
|
.hidden_states
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
Loading…
Reference in New Issue