Changed scheduler to use deques instead of lists (#2290)

Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
This commit is contained in:
Nadav Shmayovits 2024-01-07 19:48:07 +02:00 committed by GitHub
parent d0215a58e7
commit 05921a9a7a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 24 deletions

View File

@ -1,4 +1,5 @@
from typing import List
from collections import deque
from typing import Deque
from vllm.sequence import SequenceGroup
@ -15,13 +16,14 @@ class Policy:
def sort_by_priority(
self,
now: float,
seq_groups: List[SequenceGroup],
) -> List[SequenceGroup]:
return sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
)
seq_groups: Deque[SequenceGroup],
) -> Deque[SequenceGroup]:
return deque(
sorted(
seq_groups,
key=lambda seq_group: self.get_priority(now, seq_group),
reverse=True,
))
class FCFS(Policy):

View File

@ -1,6 +1,7 @@
from collections import deque
import enum
import time
from typing import Dict, Iterable, List, Optional, Tuple, Union
from typing import Deque, Dict, Iterable, List, Optional, Tuple, Union
from vllm.config import CacheConfig, SchedulerConfig
from vllm.core.block_manager import AllocStatus, BlockSpaceManager
@ -29,7 +30,7 @@ class SchedulerOutputs:
def __init__(
self,
scheduled_seq_groups: List[SequenceGroup],
scheduled_seq_groups: Iterable[SequenceGroup],
prompt_run: bool,
num_batched_tokens: int,
blocks_to_swap_in: Dict[int, int],
@ -75,13 +76,12 @@ class Scheduler:
num_cpu_blocks=self.cache_config.num_cpu_blocks,
sliding_window=self.cache_config.sliding_window)
# TODO(zhuohan): Use deque instead of list for better performance.
# Sequence groups in the WAITING state.
self.waiting: List[SequenceGroup] = []
self.waiting: Deque[SequenceGroup] = deque()
# Sequence groups in the RUNNING state.
self.running: List[SequenceGroup] = []
self.running: Deque[SequenceGroup] = deque()
# Sequence groups in the SWAPPED state.
self.swapped: List[SequenceGroup] = []
self.swapped: Deque[SequenceGroup] = deque()
def add_seq_group(self, seq_group: SequenceGroup) -> None:
# Add sequence groups to the waiting queue.
@ -152,7 +152,7 @@ class Scheduler:
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
self.waiting.popleft()
continue
# If the sequence group cannot be allocated, stop.
@ -166,7 +166,7 @@ class Scheduler:
for seq in waiting_seqs:
seq.status = SequenceStatus.FINISHED_IGNORED
ignored_seq_groups.append(seq_group)
self.waiting.pop(0)
self.waiting.popleft()
continue
# If the number of batched tokens exceeds the limit, stop.
@ -188,7 +188,7 @@ class Scheduler:
break
seq_lens = new_seq_lens
seq_group = self.waiting.pop(0)
seq_group = self.waiting.popleft()
self._allocate(seq_group)
self.running.append(seq_group)
num_curr_seqs += num_new_seqs
@ -214,14 +214,14 @@ class Scheduler:
self.running = self.policy.sort_by_priority(now, self.running)
# Reserve new token slots for the running sequence groups.
running: List[SequenceGroup] = []
running: Deque[SequenceGroup] = deque()
preempted: List[SequenceGroup] = []
while self.running:
seq_group = self.running.pop(0)
seq_group = self.running.popleft()
while not self.block_manager.can_append_slot(seq_group):
if self.running:
# Preempt the lowest-priority sequence groups.
victim_seq_group = self.running.pop(-1)
victim_seq_group = self.running.pop()
self._preempt(victim_seq_group, blocks_to_swap_out)
preempted.append(victim_seq_group)
else:
@ -255,7 +255,7 @@ class Scheduler:
self.scheduler_config.max_num_seqs):
break
seq_group = self.swapped.pop(0)
seq_group = self.swapped.popleft()
self._swap_in(seq_group, blocks_to_swap_in)
self._append_slot(seq_group, blocks_to_copy)
num_curr_seqs += num_new_seqs
@ -376,7 +376,7 @@ class Scheduler:
self.block_manager.free(seq)
# NOTE: For FCFS, we insert the preempted sequence group to the front
# of the waiting queue.
self.waiting.insert(0, seq_group)
self.waiting.appendleft(seq_group)
def _preempt_by_swap(
self,

View File

@ -601,8 +601,10 @@ class LLMEngine:
# Create the outputs.
request_outputs: List[RequestOutput] = []
for seq_group in (scheduled_seq_groups +
scheduler_outputs.ignored_seq_groups):
for seq_group in scheduled_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)
for seq_group in scheduler_outputs.ignored_seq_groups:
request_output = RequestOutput.from_seq_group(seq_group)
request_outputs.append(request_output)