[Core] Pipeline parallel with Ray ADAG (#6837)

Support pipeline-parallelism with Ray accelerated DAG.

Signed-off-by: Rui Qiao <ruisearch42@gmail.com>
This commit is contained in:
Rui Qiao 2024-08-02 13:55:40 -07:00 committed by GitHub
parent a8d604ca2a
commit 05308891e2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 199 additions and 77 deletions

View File

@ -42,6 +42,7 @@ WORKDIR /workspace
# install build and runtime dependencies # install build and runtime dependencies
COPY requirements-common.txt requirements-common.txt COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt COPY requirements-cuda.txt requirements-cuda.txt
RUN --mount=type=cache,target=/root/.cache/pip \ RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install -r requirements-cuda.txt python3 -m pip install -r requirements-cuda.txt
@ -78,6 +79,7 @@ COPY setup.py setup.py
COPY cmake cmake COPY cmake cmake
COPY CMakeLists.txt CMakeLists.txt COPY CMakeLists.txt CMakeLists.txt
COPY requirements-common.txt requirements-common.txt COPY requirements-common.txt requirements-common.txt
COPY requirements-adag.txt requirements-adag.txt
COPY requirements-cuda.txt requirements-cuda.txt COPY requirements-cuda.txt requirements-cuda.txt
COPY pyproject.toml pyproject.toml COPY pyproject.toml pyproject.toml
COPY vllm vllm COPY vllm vllm

View File

@ -1,4 +1,5 @@
include LICENSE include LICENSE
include requirements-adag.txt
include requirements-common.txt include requirements-common.txt
include requirements-cuda.txt include requirements-cuda.txt
include requirements-rocm.txt include requirements-rocm.txt

3
requirements-adag.txt Normal file
View File

@ -0,0 +1,3 @@
# Dependencies for Ray accelerated DAG
cupy-cuda12x
ray >= 2.32

View File

@ -1,3 +1,6 @@
# Needed for Ray accelerated DAG tests
-r requirements-adag.txt
# testing # testing
pytest pytest
tensorizer>=2.9.0 tensorizer>=2.9.0

View File

@ -15,22 +15,31 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, DIST_BACKEND", ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, "
[ "MODEL_NAME, DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL"), [
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", False, False),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), (1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, False),
]) (2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
@fork_new_process_for_each_test (2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray", True, True),
(2, 2, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(2, 2, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(1, 3, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(1, 4, 0, 1, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
(1, 4, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp", False, False),
])
def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME, def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
DIST_BACKEND): DIST_BACKEND, USE_RAY_ADAG, USE_RAY_ADAG_NCCL):
if VLLM_MULTI_NODE and DIST_BACKEND == "mp": if VLLM_MULTI_NODE and DIST_BACKEND == "mp":
pytest.skip("Skipping multi-node pipeline parallel test for " pytest.skip("Skipping multi-node pipeline parallel test for "
"multiprocessing distributed backend") "multiprocessing distributed backend")
@ -67,8 +76,18 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, MODEL_NAME,
if EAGER_MODE: if EAGER_MODE:
pp_args.append("--enforce-eager") pp_args.append("--enforce-eager")
tp_args.append("--enforce-eager") tp_args.append("--enforce-eager")
pp_env = None
if USE_RAY_ADAG:
assert DIST_BACKEND == "ray", (
"Ray ADAG is only supported with Ray distributed backend")
pp_env = {
"VLLM_USE_RAY_COMPILED_DAG": "1",
"VLLM_USE_RAY_SPMD_WORKER": "1",
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
str(int(USE_RAY_ADAG_NCCL)),
}
compare_two_settings(MODEL_NAME, pp_args, tp_args) compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env)
@pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [ @pytest.mark.parametrize("PP_SIZE, MODEL_NAME", [

View File

@ -7,7 +7,7 @@ import time
import warnings import warnings
from contextlib import contextmanager from contextlib import contextmanager
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
import openai import openai
import ray import ray
@ -57,6 +57,7 @@ class RemoteOpenAIServer:
model: str, model: str,
cli_args: List[str], cli_args: List[str],
*, *,
env_dict: Optional[Dict[str, str]] = None,
auto_port: bool = True, auto_port: bool = True,
) -> None: ) -> None:
if auto_port: if auto_port:
@ -77,6 +78,8 @@ class RemoteOpenAIServer:
# the current process might initialize cuda, # the current process might initialize cuda,
# to be safe, we should use spawn method # to be safe, we should use spawn method
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn' env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
if env_dict is not None:
env.update(env_dict)
self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args, self.proc = subprocess.Popen(["vllm", "serve"] + [model] + cli_args,
env=env, env=env,
stdout=sys.stdout, stdout=sys.stdout,
@ -89,6 +92,11 @@ class RemoteOpenAIServer:
def __exit__(self, exc_type, exc_value, traceback): def __exit__(self, exc_type, exc_value, traceback):
self.proc.terminate() self.proc.terminate()
try:
self.proc.wait(3)
except subprocess.TimeoutExpired:
# force kill if needed
self.proc.kill()
def _wait_for_server(self, *, url: str, timeout: float): def _wait_for_server(self, *, url: str, timeout: float):
# run health check # run health check
@ -127,10 +135,21 @@ class RemoteOpenAIServer:
) )
def compare_two_settings(model: str, arg1: List[str], arg2: List[str]): def compare_two_settings(model: str,
arg1: List[str],
arg2: List[str],
env1: Optional[Dict[str, str]] = None,
env2: Optional[Dict[str, str]] = None):
""" """
Launch API server with two different sets of arguments and compare the Launch API server with two different sets of arguments/environments
results of the API calls. The arguments are after the model name. and compare the results of the API calls.
Args:
model: The model to test.
arg1: The first set of arguments to pass to the API server.
arg2: The second set of arguments to pass to the API server.
env1: The first set of environment variables to pass to the API server.
env2: The second set of environment variables to pass to the API server.
""" """
tokenizer = AutoTokenizer.from_pretrained(model) tokenizer = AutoTokenizer.from_pretrained(model)
@ -138,8 +157,8 @@ def compare_two_settings(model: str, arg1: List[str], arg2: List[str]):
prompt = "Hello, my name is" prompt = "Hello, my name is"
token_ids = tokenizer(prompt)["input_ids"] token_ids = tokenizer(prompt)["input_ids"]
results = [] results = []
for args in (arg1, arg2): for args, env in ((arg1, env1), (arg2, env2)):
with RemoteOpenAIServer(model, args) as server: with RemoteOpenAIServer(model, args, env_dict=env) as server:
client = server.get_client() client = server.get_client()
# test models list # test models list

View File

@ -38,6 +38,7 @@ if TYPE_CHECKING:
VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024 VLLM_FUSED_MOE_CHUNK_SIZE: int = 64 * 1024
VLLM_USE_RAY_SPMD_WORKER: bool = False VLLM_USE_RAY_SPMD_WORKER: bool = False
VLLM_USE_RAY_COMPILED_DAG: bool = False VLLM_USE_RAY_COMPILED_DAG: bool = False
VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: bool = True
VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_WORKER_MULTIPROC_METHOD: str = "fork"
VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets")
VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_IMAGE_FETCH_TIMEOUT: int = 5
@ -273,13 +274,20 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# execution on all workers. # execution on all workers.
# Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it. # Run vLLM with VLLM_USE_RAY_SPMD_WORKER=1 to enable it.
"VLLM_USE_RAY_SPMD_WORKER": "VLLM_USE_RAY_SPMD_WORKER":
lambda: bool(os.getenv("VLLM_USE_RAY_SPMD_WORKER", 0)), lambda: bool(int(os.getenv("VLLM_USE_RAY_SPMD_WORKER", "0"))),
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it. # Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
"VLLM_USE_RAY_COMPILED_DAG": "VLLM_USE_RAY_COMPILED_DAG":
lambda: bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)), lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG", "0"))),
# If the env var is set, it uses NCCL for communication in
# Ray's compiled DAG. This flag is ignored if
# VLLM_USE_RAY_COMPILED_DAG is not set.
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL":
lambda: bool(int(os.getenv("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL", "1"))
),
# Use dedicated multiprocess context for workers. # Use dedicated multiprocess context for workers.
# Both spawn and fork work # Both spawn and fork work

View File

@ -105,12 +105,19 @@ class RayGPUExecutor(DistributedGPUExecutor):
# The remaining workers are the actual ray actors. # The remaining workers are the actual ray actors.
self.workers: List[RayWorkerWrapper] = [] self.workers: List[RayWorkerWrapper] = []
# Used in ray compiled DAG: indexed first by PP rank,
# and then TP rank. In other words, the inner list is
# the TP group of workers for a PP rank.
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
if self.parallel_config.ray_workers_use_nsight: if self.parallel_config.ray_workers_use_nsight:
ray_remote_kwargs = self._configure_ray_workers_use_nsight( ray_remote_kwargs = self._configure_ray_workers_use_nsight(
ray_remote_kwargs) ray_remote_kwargs)
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
# Create the workers. # Create the workers.
driver_ip = get_ip() driver_ip = get_ip()
logger.info("driver_ip: %s", driver_ip)
worker_wrapper_kwargs = self._get_worker_wrapper_args() worker_wrapper_kwargs = self._get_worker_wrapper_args()
for bundle_id, bundle in enumerate(placement_group.bundle_specs): for bundle_id, bundle in enumerate(placement_group.bundle_specs):
if not bundle.get("GPU", 0): if not bundle.get("GPU", 0):
@ -142,42 +149,49 @@ class RayGPUExecutor(DistributedGPUExecutor):
# Else, added to the list of workers. # Else, added to the list of workers.
self.workers.append(worker) self.workers.append(worker)
logger.debug("workers: %s", self.workers)
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None: if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
raise ValueError( raise ValueError(
"Ray does not allocate any GPUs on the driver node. Consider " "Ray does not allocate any GPUs on the driver node. Consider "
"adjusting the Ray placement group or running the driver on a " "adjusting the Ray placement group or running the driver on a "
"GPU node.") "GPU node.")
worker_ips = [
ray.get(worker.get_node_ip.remote()) # type: ignore[attr-defined]
for worker in self.workers
]
ip_counts: Dict[str, int] = {}
for ip in worker_ips:
ip_counts[ip] = ip_counts.get(ip, 0) + 1
def sort_by_driver_then_worker_ip(worker):
"""
Sort the workers based on 3 properties:
1. If the worker is on the same node as the driver (vllm engine),
it should be placed first.
2. Then, if the worker is on a node with fewer workers, it should
be placed first.
3. Finally, if the work is on a node with smaller IP address, it
should be placed first.
"""
ip = ray.get(worker.get_node_ip.remote())
return (ip != driver_ip, ip_counts[ip], ip)
# After sorting, the workers on the same node will be
# close to each other, and the workers on the driver
# node will be placed first.
self.workers = sorted(self.workers, key=sort_by_driver_then_worker_ip)
# Get the set of GPU IDs used on each node. # Get the set of GPU IDs used on each node.
worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids", worker_node_and_gpu_ids = self._run_workers("get_node_and_gpu_ids",
use_dummy_driver=True) use_dummy_driver=True)
# the order in `worker_node_and_gpu_ids` does not necessarily match
# the machine boundaries. We need to make sure that workers in the
# same node are assigned consecutive ranks.
# examples:
# [('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [0]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [1]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [2]), ('dfaad7adfdae57a694cc74490db45bd112c9f31243523e43ddc2e7f0', [3]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [1]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [2]), ('852a09a13c7503ef126d7c828454c741494b1be33a8627a5206604d9', [3])] # noqa
# initialize worker ranks with -1 (unassigned)
worker_ranks = [-1 for x in worker_node_and_gpu_ids]
current_rank = 0
while -1 in worker_ranks:
# whenever we find an unassigned worker, find the node
index = worker_ranks.index(-1)
current_node_id = worker_node_and_gpu_ids[index][0]
# assign ranks to all workers in the same node
for i, (node_id, _) in enumerate(worker_node_and_gpu_ids):
if node_id == current_node_id:
worker_ranks[i] = current_rank
current_rank += 1
# with the above example, worker_ranks will be [0, 4, 5, 6, 7, 1, 2, 3]
node_workers = defaultdict(list) # node id -> list of worker ranks node_workers = defaultdict(list) # node id -> list of worker ranks
node_gpus = defaultdict(list) # node id -> list of gpu ids node_gpus = defaultdict(list) # node id -> list of gpu ids
for worker_rank, (node_id, gpu_ids) in zip(worker_ranks, for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
worker_node_and_gpu_ids): node_workers[node_id].append(i)
node_workers[node_id].append(worker_rank)
# `gpu_ids` can be a list of strings or integers. # `gpu_ids` can be a list of strings or integers.
# convert them to integers for consistency. # convert them to integers for consistency.
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs), # NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
@ -202,16 +216,6 @@ class RayGPUExecutor(DistributedGPUExecutor):
self._run_workers("update_environment_variables", self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables) all_args=all_args_to_update_environment_variables)
if len(node_gpus) == 1:
# in single node case, we don't need to get the IP address.
# the loopback address is sufficient
# NOTE: a node may have several IP addresses, one for each
# network interface. `get_ip()` might return any of them,
# while they might not work for communication inside the node
# if the network setup is complicated. Using the loopback address
# solves this issue, as it always works for communication inside
# the node.
driver_ip = "127.0.0.1"
distributed_init_method = get_distributed_init_method( distributed_init_method = get_distributed_init_method(
driver_ip, get_open_port()) driver_ip, get_open_port())
@ -221,8 +225,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
local_rank=node_workers[node_id].index(rank), local_rank=node_workers[node_id].index(rank),
rank=rank, rank=rank,
distributed_init_method=distributed_init_method, distributed_init_method=distributed_init_method,
) for rank, (node_id, ) for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids)
_) in zip(worker_ranks, worker_node_and_gpu_ids)
] ]
self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs)
@ -231,6 +234,19 @@ class RayGPUExecutor(DistributedGPUExecutor):
max_concurrent_workers=self.parallel_config. max_concurrent_workers=self.parallel_config.
max_parallel_loading_workers) max_parallel_loading_workers)
if self.use_ray_spmd_worker:
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
self.pp_tp_workers.append([])
for tp_rank in range(
self.parallel_config.tensor_parallel_size):
# PP=2, TP=4
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
rank = (pp_rank * self.parallel_config.tensor_parallel_size
) + tp_rank
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
assert pp_rank < len(self.pp_tp_workers)
self.pp_tp_workers[pp_rank].append(self.workers[rank])
# This is the list of workers that are rank 0 of each TP group EXCEPT # This is the list of workers that are rank 0 of each TP group EXCEPT
# global rank 0. These are the workers that will broadcast to the # global rank 0. These are the workers that will broadcast to the
# rest of the workers. # rest of the workers.
@ -241,9 +257,9 @@ class RayGPUExecutor(DistributedGPUExecutor):
self.non_driver_workers: List[RayWorkerWrapper] = [] self.non_driver_workers: List[RayWorkerWrapper] = []
# Enforce rank order for correct rank to return final output. # Enforce rank order for correct rank to return final output.
for rank, worker in sorted(zip(worker_ranks[1:], self.workers)): for index, worker in enumerate(self.workers):
# We need to skip the driver worker, which we # The driver worker is rank 0 and not in self.workers.
# do by skipping worker_ranks[0] which is always 0. rank = index + 1
if rank % self.parallel_config.tensor_parallel_size == 0: if rank % self.parallel_config.tensor_parallel_size == 0:
self.tp_driver_workers.append(worker) self.tp_driver_workers.append(worker)
else: else:
@ -376,16 +392,47 @@ class RayGPUExecutor(DistributedGPUExecutor):
raise ValueError(f"Ray version {required_version} or greater is " raise ValueError(f"Ray version {required_version} or greater is "
f"required, but found {current_version}") f"required, but found {current_version}")
from ray.dag import InputNode, MultiOutputNode
assert self.parallel_config.use_ray assert self.parallel_config.use_ray
from ray.dag import InputNode, MultiOutputNode
from ray.experimental.channel.torch_tensor_type import TorchTensorType
# Right now, compiled DAG requires at least 1 arg. We send logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s",
# a dummy value for now. It will be fixed soon. envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL)
with InputNode() as input_data: with InputNode() as input_data:
forward_dag = MultiOutputNode([ # Example DAG: PP=2, TP=4
worker.execute_model_spmd.bind( # type: ignore[attr-defined] # (ExecuteModelReq, None) -> 0 -> (ExecuteModelReq, IntermediateOutput) -> 4 -> SamplerOutput # noqa: E501
input_data) for worker in self.workers # -> 1 -> (ExecuteModelReq, IntermediateOutput) -> 5 -> SamplerOutput # noqa: E501
]) # -> 2 -> (ExecuteModelReq, IntermediateOutput) -> 6 -> SamplerOutput # noqa: E501
# -> 3 -> (ExecuteModelReq, IntermediateOutput) -> 7 -> SamplerOutput # noqa: E501
# All workers in the first TP group will take in the
# ExecuteModelRequest as input.
outputs = [input_data for _ in self.pp_tp_workers[0]]
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
# Each PP worker takes in the output of the previous PP worker,
# and the TP group executes in SPMD fashion.
outputs = [
worker.execute_model_spmd.
bind( # type: ignore[attr-defined]
outputs[i]) for i, worker in enumerate(tp_group)
]
last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank:
# Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last
# pp stage.
transport = "nccl" \
if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
else "auto"
outputs = [
output.with_type_hint(
TorchTensorType(transport=transport))
for output in outputs
]
forward_dag = MultiOutputNode(outputs)
return forward_dag.experimental_compile(enable_asyncio=enable_asyncio) return forward_dag.experimental_compile(enable_asyncio=enable_asyncio)
def __del__(self): def __del__(self):

View File

@ -1,8 +1,8 @@
from typing import List, Optional, Tuple from typing import List, Optional, Tuple, Union
from vllm.config import ParallelConfig from vllm.config import ParallelConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest, IntermediateTensors
from vllm.utils import get_ip, is_hip, is_tpu, is_xpu from vllm.utils import get_ip, is_hip, is_tpu, is_xpu
from vllm.worker.worker_base import WorkerWrapperBase from vllm.worker.worker_base import WorkerWrapperBase
@ -31,9 +31,17 @@ try:
gpu_ids = ray.get_gpu_ids() gpu_ids = ray.get_gpu_ids()
return node_id, gpu_ids return node_id, gpu_ids
def execute_model_spmd(self, execute_model_req: ExecuteModelRequest): def execute_model_spmd(
"""Used only when SPMD worker and compiled DAG are both self, req_or_tuple: Union[ExecuteModelRequest,
enabled.""" Tuple[ExecuteModelRequest,
IntermediateTensors]]):
"""Execute model in SPMD fashion: used only when SPMD worker and
compiled DAG are both enabled.
Args:
req_or_tuple: The request to execute the model, or a tuple
containing the request and intermediate tensors.
"""
# TODO(swang): This is needed right now because Ray aDAG executes # TODO(swang): This is needed right now because Ray aDAG executes
# on a background thread, so we need to reset torch's current # on a background thread, so we need to reset torch's current
# device. # device.
@ -42,7 +50,17 @@ try:
torch.cuda.set_device(self.worker.device) torch.cuda.set_device(self.worker.device)
self.compiled_dag_cuda_device_set = True self.compiled_dag_cuda_device_set = True
return self.worker._execute_model_spmd(execute_model_req) if isinstance(req_or_tuple, tuple):
execute_model_req, intermediate_tensors = req_or_tuple
else:
execute_model_req = req_or_tuple
intermediate_tensors = None
output = self.worker._execute_model_spmd(execute_model_req,
intermediate_tensors)
if isinstance(output, IntermediateTensors):
return execute_model_req, output
return output
ray_import_err = None ray_import_err = None

View File

@ -285,7 +285,9 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return output return output
def _execute_model_spmd( def _execute_model_spmd(
self, execute_model_req: ExecuteModelRequest self,
execute_model_req: ExecuteModelRequest,
intermediate_tensors: Optional[IntermediateTensors] = None
) -> Optional[List[SamplerOutput]]: ) -> Optional[List[SamplerOutput]]:
""" """
Execute model in Single Program Multiple Data (SPMD) fashion. Execute model in Single Program Multiple Data (SPMD) fashion.
@ -309,7 +311,7 @@ class LocalOrDistributedWorkerBase(WorkerBase):
return self.model_runner.execute_model( return self.model_runner.execute_model(
model_input, self.kv_cache[worker_input.virtual_engine] model_input, self.kv_cache[worker_input.virtual_engine]
if self.kv_cache is not None else None) if self.kv_cache is not None else None, intermediate_tensors)
class WorkerWrapperBase: class WorkerWrapperBase: