[Core] Some simplification of WorkerWrapper changes (#4183)

This commit is contained in:
Nick Hill 2024-04-23 00:49:08 -07:00 committed by GitHub
parent 0ae11f78ab
commit 8f2ea22bde
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 45 additions and 54 deletions

View File

@ -2,6 +2,7 @@ import asyncio
import os
import pickle
from collections import defaultdict
from itertools import islice, repeat
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple
from vllm.engine.ray_utils import RayWorkerWrapper, ray
@ -136,16 +137,14 @@ class RayGPUExecutor(ExecutorBase):
VLLM_INSTANCE_ID = get_vllm_instance_id()
# Set environment variables for the driver and workers.
all_args_to_update_environment_variables = []
for (node_id, _) in worker_node_and_gpu_ids:
all_args_to_update_environment_variables.append([{
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"),
}])
all_args_to_update_environment_variables = [({
"CUDA_VISIBLE_DEVICES":
",".join(map(str, node_gpus[node_id])),
"VLLM_INSTANCE_ID":
VLLM_INSTANCE_ID,
"VLLM_TRACE_FUNCTION":
os.getenv("VLLM_TRACE_FUNCTION", "0"),
}, ) for (node_id, _) in worker_node_and_gpu_ids]
self._run_workers("update_environment_variables",
all_args=all_args_to_update_environment_variables)
@ -156,10 +155,9 @@ class RayGPUExecutor(ExecutorBase):
# avoid writing `{"name": value}` manually
return kwargs
init_worker_all_kwargs = []
# Initialize the actual workers inside worker wrapper.
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids, ):
init_worker_all_kwargs = []
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
local_rank = node_workers[node_id].index(rank)
init_worker_all_kwargs.append(
collect_arg_helper_func(
@ -265,40 +263,40 @@ class RayGPUExecutor(ExecutorBase):
self,
method: str,
*args,
driver_args: Optional[Tuple[Any]] = None,
driver_args: Optional[Tuple[Any, ...]] = None,
driver_kwargs: Optional[Dict[str, Any]] = None,
all_args: Optional[List[List[Any]]] = None,
all_args: Optional[List[Tuple[Any, ...]]] = None,
all_kwargs: Optional[List[Dict[str, Any]]] = None,
use_dummy_driver: bool = False,
max_concurrent_workers: Optional[int] = None,
use_ray_compiled_dag: bool = False,
**kwargs,
) -> Any:
"""Runs the given method on all workers.
all_args and all_kwargs are used to pass heterogeneous arguments,
i.e. different arguments for each worker.
"""Runs the given method on all workers. Can be used in the following
ways:
- args/kwargs: All workers share the same args/kwargs
- args/kwargs and driver_args/driver_kwargs: Driver worker has
different args
- all_args/all_kwargs: args/kwargs for each worker are specified
individually
"""
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# for mypy type checking
assert driver_args is not None
assert driver_kwargs is not None
if all_args is None:
all_args = [driver_args] + [args] * len(self.workers)
if all_kwargs is None:
all_kwargs = [driver_kwargs] + [kwargs] * len(self.workers)
# for mypy type checking
assert all_args is not None
assert all_kwargs is not None
if max_concurrent_workers:
raise NotImplementedError(
"max_concurrent_workers is not supported yet.")
if driver_args is None:
driver_args = args if all_args is None else all_args[0]
if driver_kwargs is None:
driver_kwargs = kwargs if all_kwargs is None else all_kwargs[0]
count = len(self.workers)
all_worker_args = repeat(args, count) if all_args is None \
else islice(all_args, 1, None)
all_worker_kwargs = repeat(kwargs, count) if all_kwargs is None \
else islice(all_kwargs, 1, None)
if use_ray_compiled_dag:
# Right now, compiled DAG can only accept a single
# input. TODO(sang): Fix it.
@ -310,22 +308,17 @@ class RayGPUExecutor(ExecutorBase):
worker.execute_method.remote(method, *worker_args,
**worker_kwargs)
for (worker, worker_args, worker_kwargs
) in zip(self.workers, all_args[1:], all_kwargs[1:])
) in zip(self.workers, all_worker_args, all_worker_kwargs)
]
if driver_args is None:
driver_args = args
if driver_kwargs is None:
driver_kwargs = kwargs
# Start the driver worker after all the ray workers.
if not use_dummy_driver:
driver_worker_output = self.driver_worker.execute_method(
method, *all_args[0], **all_kwargs[0])
method, *driver_args, **driver_kwargs)
else:
driver_worker_output = ray.get(
self.driver_dummy_worker.execute_method.remote(
method, *all_args[0], **all_kwargs[0]))
method, *driver_args, **driver_kwargs))
# Get the results of the ray workers.
if self.workers:
if use_ray_compiled_dag:
@ -383,6 +376,10 @@ class RayGPUExecutor(ExecutorBase):
class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.driver_executor = make_async(self.driver_worker.execute_method)
async def _run_workers_async(
self,
method: str,
@ -399,13 +396,8 @@ class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase):
if driver_kwargs is None:
driver_kwargs = kwargs
# Run the driver worker asynchronously.
def helper():
return self.driver_worker.execute_method(method, *driver_args,
**driver_kwargs)
driver_executor = make_async(helper)
coros.append(driver_executor())
coros.append(
self.driver_executor(method, *driver_args, **driver_kwargs))
# Run the ray workers asynchronously.
for worker in self.workers:

View File

@ -108,7 +108,8 @@ class WorkerWrapperBase:
self.worker_class_name = worker_class_name
self.worker = None
def update_environment_variables(self, envs: Dict[str, str]) -> None:
@staticmethod
def update_environment_variables(envs: Dict[str, str]) -> None:
key = 'CUDA_VISIBLE_DEVICES'
if key in envs and key in os.environ:
# overwriting CUDA_VISIBLE_DEVICES is desired behavior
@ -138,10 +139,8 @@ class WorkerWrapperBase:
def execute_method(self, method, *args, **kwargs):
try:
if hasattr(self, method):
executor = getattr(self, method)
else:
executor = getattr(self.worker, method)
target = self if self.worker is None else self.worker
executor = getattr(target, method)
return executor(*args, **kwargs)
except Exception as e:
# if the driver worker also execute methods,