mirror of https://github.com/vllm-project/vllm
wip
Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
parent
615921deff
commit
68543e17aa
|
@ -16,6 +16,11 @@ logger = init_logger(__name__)
|
|||
|
||||
|
||||
class BasicScheduler(SchedulerInterface):
|
||||
"""
|
||||
Mixed prefill and decode: X
|
||||
Chunked prefill: X
|
||||
Prefix caching: O
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -65,13 +70,100 @@ class BasicScheduler(SchedulerInterface):
|
|||
num_scheduled_tokens: Dict[str, int] = {}
|
||||
token_budget = self.max_num_scheduled_tokens
|
||||
|
||||
# Schedule prefill requests.
|
||||
while self.waiting:
|
||||
if len(self.running) == self.max_num_running_reqs:
|
||||
break
|
||||
if token_budget == 0:
|
||||
break
|
||||
|
||||
request = self.waiting[0]
|
||||
# Get already-cached tokens.
|
||||
computed_blocks, num_computed_tokens = \
|
||||
self.kv_cache_manager.get_computed_blocks(request)
|
||||
# Number of tokens to be scheduled.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if num_new_tokens == 0:
|
||||
# This happens when prompt length is divisible by the block
|
||||
# size and all blocks are cached. Now we force to recompute
|
||||
# the last block. Note that we have to re-compute an entire
|
||||
# block because allocate_slots() assumes num_computed_tokens
|
||||
# is always a multiple of the block size. This limitation
|
||||
# can potentially be removed in the future to slightly
|
||||
# improve the performance.
|
||||
num_computed_tokens -= self.block_size
|
||||
num_new_tokens = self.block_size
|
||||
computed_blocks.pop()
|
||||
# NOTE: This scheduler does not support chunked prefills.
|
||||
if num_new_tokens > token_budget:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
new_blocks = self.kv_cache_manager.allocate_slots(
|
||||
request, num_new_tokens, computed_blocks)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
break
|
||||
|
||||
self.waiting.popleft()
|
||||
self.running.append(request)
|
||||
if request.status == RequestStatus.WAITING:
|
||||
scheduled_new_reqs.append(request)
|
||||
elif request.status == RequestStatus.PREEMPTED:
|
||||
scheduled_resumed_reqs.append(request)
|
||||
else:
|
||||
raise RuntimeError(f"Invalid request status: {request.status}")
|
||||
|
||||
req_to_new_block_ids[request.request_id] = [
|
||||
b.block_id for b in computed_blocks + new_blocks
|
||||
]
|
||||
num_scheduled_tokens[request.request_id] = num_new_tokens
|
||||
token_budget -= num_new_tokens
|
||||
request.status = RequestStatus.RUNNING
|
||||
request.num_computed_tokens = num_computed_tokens
|
||||
|
||||
# If no prefill requests are scheduled, schedule decode requests.
|
||||
if not (scheduled_new_reqs or scheduled_resumed_reqs):
|
||||
req_index = 0
|
||||
while req_index < len(self.running):
|
||||
request = self.running[req_index]
|
||||
while True:
|
||||
new_blocks = self.kv_cache_manager.append_slots(
|
||||
request, num_tokens=1)
|
||||
if new_blocks is None:
|
||||
# The request cannot be scheduled.
|
||||
preempted_req = self.running.pop()
|
||||
self.kv_cache_manager.free(preempted_req)
|
||||
preempted_req.status = RequestStatus.PREEMPTED
|
||||
preempted_req.num_computed_tokens = 0
|
||||
|
||||
self.waiting.appendleft(preempted_req)
|
||||
preempted_reqs.append(preempted_req)
|
||||
if preempted_req == request:
|
||||
# No more request to preempt.
|
||||
can_schedule = False
|
||||
break
|
||||
else:
|
||||
can_schedule = True
|
||||
break
|
||||
if not can_schedule:
|
||||
break
|
||||
assert new_blocks is not None
|
||||
|
||||
# Schedule the request.
|
||||
scheduled_running_reqs.append(request)
|
||||
req_to_new_block_ids[request.request_id] = [
|
||||
b.block_id for b in new_blocks
|
||||
]
|
||||
num_scheduled_tokens[request.request_id] = 1
|
||||
token_budget -= 1
|
||||
req_index += 1
|
||||
|
||||
# Check if the scheduling constraints are satisfied.
|
||||
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
|
||||
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
|
||||
assert token_budget >= 0
|
||||
assert len(self.running) <= self.max_num_running_reqs
|
||||
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
|
||||
len(scheduled_running_reqs) == len(self.running))
|
||||
|
||||
# Construct the scheduler output.
|
||||
new_reqs_data = self.common_states.make_new_req_data(
|
||||
|
|
|
@ -202,7 +202,7 @@ class Scheduler(SchedulerInterface):
|
|||
# which have output tokens.
|
||||
num_new_tokens = request.num_tokens - num_computed_tokens
|
||||
if num_new_tokens == 0:
|
||||
# The happens when prompt length is divisible by the block
|
||||
# This happens when prompt length is divisible by the block
|
||||
# size and all blocks are cached. Now we force to recompute
|
||||
# the last block. Note that we have to re-compute an entire
|
||||
# block because allocate_slots() assumes num_computed_tokens
|
||||
|
|
Loading…
Reference in New Issue