Signed-off-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Woosuk Kwon 2025-01-29 02:04:07 -08:00
parent 615921deff
commit 68543e17aa
2 changed files with 95 additions and 3 deletions

View File

@ -16,6 +16,11 @@ logger = init_logger(__name__)
class BasicScheduler(SchedulerInterface):
"""
Mixed prefill and decode: X
Chunked prefill: X
Prefix caching: O
"""
def __init__(
self,
@ -65,13 +70,100 @@ class BasicScheduler(SchedulerInterface):
num_scheduled_tokens: Dict[str, int] = {}
token_budget = self.max_num_scheduled_tokens
# Schedule prefill requests.
while self.waiting:
if len(self.running) == self.max_num_running_reqs:
break
if token_budget == 0:
break
request = self.waiting[0]
# Get already-cached tokens.
computed_blocks, num_computed_tokens = \
self.kv_cache_manager.get_computed_blocks(request)
# Number of tokens to be scheduled.
num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# This happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens
# is always a multiple of the block size. This limitation
# can potentially be removed in the future to slightly
# improve the performance.
num_computed_tokens -= self.block_size
num_new_tokens = self.block_size
computed_blocks.pop()
# NOTE: This scheduler does not support chunked prefills.
if num_new_tokens > token_budget:
# The request cannot be scheduled.
break
new_blocks = self.kv_cache_manager.allocate_slots(
request, num_new_tokens, computed_blocks)
if new_blocks is None:
# The request cannot be scheduled.
break
self.waiting.popleft()
self.running.append(request)
if request.status == RequestStatus.WAITING:
scheduled_new_reqs.append(request)
elif request.status == RequestStatus.PREEMPTED:
scheduled_resumed_reqs.append(request)
else:
raise RuntimeError(f"Invalid request status: {request.status}")
req_to_new_block_ids[request.request_id] = [
b.block_id for b in computed_blocks + new_blocks
]
num_scheduled_tokens[request.request_id] = num_new_tokens
token_budget -= num_new_tokens
request.status = RequestStatus.RUNNING
request.num_computed_tokens = num_computed_tokens
# If no prefill requests are scheduled, schedule decode requests.
if not (scheduled_new_reqs or scheduled_resumed_reqs):
req_index = 0
while req_index < len(self.running):
request = self.running[req_index]
while True:
new_blocks = self.kv_cache_manager.append_slots(
request, num_tokens=1)
if new_blocks is None:
# The request cannot be scheduled.
preempted_req = self.running.pop()
self.kv_cache_manager.free(preempted_req)
preempted_req.status = RequestStatus.PREEMPTED
preempted_req.num_computed_tokens = 0
self.waiting.appendleft(preempted_req)
preempted_reqs.append(preempted_req)
if preempted_req == request:
# No more request to preempt.
can_schedule = False
break
else:
can_schedule = True
break
if not can_schedule:
break
assert new_blocks is not None
# Schedule the request.
scheduled_running_reqs.append(request)
req_to_new_block_ids[request.request_id] = [
b.block_id for b in new_blocks
]
num_scheduled_tokens[request.request_id] = 1
token_budget -= 1
req_index += 1
# Check if the scheduling constraints are satisfied.
total_num_scheduled_tokens = sum(num_scheduled_tokens.values())
assert total_num_scheduled_tokens <= self.max_num_scheduled_tokens
assert token_budget >= 0
assert len(self.running) <= self.max_num_running_reqs
assert (len(scheduled_new_reqs) + len(scheduled_resumed_reqs) +
len(scheduled_running_reqs) == len(self.running))
# Construct the scheduler output.
new_reqs_data = self.common_states.make_new_req_data(

View File

@ -202,7 +202,7 @@ class Scheduler(SchedulerInterface):
# which have output tokens.
num_new_tokens = request.num_tokens - num_computed_tokens
if num_new_tokens == 0:
# The happens when prompt length is divisible by the block
# This happens when prompt length is divisible by the block
# size and all blocks are cached. Now we force to recompute
# the last block. Note that we have to re-compute an entire
# block because allocate_slots() assumes num_computed_tokens