mirror of https://github.com/vllm-project/vllm
[TPU] Support single and multi-host TPUs on GKE (#7613)
This commit is contained in:
parent
dc13e99348
commit
2148441fd3
|
@ -4,4 +4,4 @@
|
|||
# Dependencies for TPU
|
||||
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
||||
# You can install the dependencies in Dockerfile.tpu.
|
||||
ray
|
||||
ray[default]
|
||||
|
|
|
@ -123,7 +123,10 @@ class PallasAttentionBackendImpl(AttentionImpl):
|
|||
raise NotImplementedError("TPU version must be 4 or higher.")
|
||||
|
||||
self.megacore_mode = None
|
||||
tpu_type = torch_xla.tpu.get_tpu_env()["TYPE"].lower()
|
||||
tpu_env = torch_xla.tpu.get_tpu_env()
|
||||
tpu_type = tpu_env.get("TYPE") or tpu_env.get("ACCELERATOR_TYPE")
|
||||
tpu_type = tpu_type.lower()
|
||||
|
||||
if "lite" not in tpu_type:
|
||||
if self.num_kv_heads % 2 == 0:
|
||||
self.megacore_mode = "kv_head"
|
||||
|
|
|
@ -1,3 +1,5 @@
|
|||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -5,11 +7,12 @@ from torch.distributed import ProcessGroup
|
|||
from vllm.platforms import current_platform
|
||||
|
||||
if current_platform.is_tpu():
|
||||
import ray
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.runtime as xr
|
||||
from torch_xla._internal import pjrt
|
||||
|
||||
from vllm.executor import ray_utils
|
||||
|
||||
|
||||
class TpuCommunicator:
|
||||
|
||||
|
@ -24,9 +27,29 @@ class TpuCommunicator:
|
|||
# be simply calculated as follows.
|
||||
global_rank = dist.get_rank(group)
|
||||
global_world_size = dist.get_world_size(group)
|
||||
num_nodes = len(ray.nodes())
|
||||
|
||||
# Calculate how many TPU nodes are in the current deployment. This
|
||||
# is the Ray placement group if it is deployed with Ray. Default
|
||||
# to the number of TPU nodes in the Ray cluster. The number of TPU
|
||||
# nodes is computed by the total number of TPUs divided by the
|
||||
# number of TPU accelerators per node, to account for clusters
|
||||
# with both CPUs and TPUs.
|
||||
num_nodes = ray_utils.get_num_tpu_nodes()
|
||||
num_nodes_in_pg = ray_utils.get_num_nodes_in_placement_group()
|
||||
if num_nodes_in_pg > 0:
|
||||
num_nodes = num_nodes_in_pg
|
||||
|
||||
local_world_size = global_world_size // num_nodes
|
||||
local_rank = global_rank % local_world_size
|
||||
|
||||
# Ensure environment variables are set for multihost deployments.
|
||||
# On GKE, this is needed for libtpu and TPU driver to know which TPU
|
||||
# chip is actually visible. Otherwise the TPU driver will fail to
|
||||
# initialize because the number of devices would be different from
|
||||
# the number of visible worker addresses.
|
||||
os.environ["CLOUD_TPU_TASK_ID"] = str(global_rank)
|
||||
os.environ["TPU_VISIBLE_CHIPS"] = str(local_rank)
|
||||
|
||||
pjrt.initialize_multiprocess(local_rank, local_world_size)
|
||||
xr._init_world_size_ordinal()
|
||||
|
||||
|
|
|
@ -71,6 +71,19 @@ class RayTPUExecutor(TPUExecutor):
|
|||
worker_module_name = "vllm.worker.tpu_worker"
|
||||
worker_class_name = "TPUWorker"
|
||||
|
||||
# GKE does not fetch environment information from metadata server
|
||||
# and instead sets these from within the Ray process. Therefore we
|
||||
# need to override the Ray environment variables manually.
|
||||
override_env = {}
|
||||
if "TPU_CHIPS_PER_HOST_BOUNDS" in os.environ:
|
||||
override_env.update({
|
||||
"TPU_CHIPS_PER_HOST_BOUNDS":
|
||||
os.environ["TPU_CHIPS_PER_HOST_BOUNDS"]
|
||||
})
|
||||
if "TPU_HOST_BOUNDS" in os.environ:
|
||||
override_env.update(
|
||||
{"TPU_HOST_BOUNDS": os.environ["TPU_HOST_BOUNDS"]})
|
||||
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
resources={"TPU": 1},
|
||||
|
@ -81,6 +94,8 @@ class RayTPUExecutor(TPUExecutor):
|
|||
worker_class_name=worker_class_name,
|
||||
trust_remote_code=self.model_config.trust_remote_code,
|
||||
)
|
||||
if override_env:
|
||||
worker.override_env_vars.remote(override_env)
|
||||
|
||||
worker_ip = ray.get(worker.get_node_ip.remote())
|
||||
if worker_ip == driver_ip and self.driver_dummy_worker is None:
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
@ -84,6 +85,9 @@ try:
|
|||
|
||||
return output
|
||||
|
||||
def override_env_vars(self, vars: Dict[str, str]):
|
||||
os.environ.update(vars)
|
||||
|
||||
ray_import_err = None
|
||||
|
||||
except ImportError as e:
|
||||
|
@ -291,3 +295,28 @@ def initialize_ray_cluster(
|
|||
_verify_bundles(current_placement_group, parallel_config, device_str)
|
||||
# Set the placement group in the parallel config
|
||||
parallel_config.placement_group = current_placement_group
|
||||
|
||||
|
||||
def get_num_tpu_nodes() -> int:
|
||||
from ray._private.accelerators import TPUAcceleratorManager
|
||||
cluster_resources = ray.cluster_resources()
|
||||
total_tpus = int(cluster_resources["TPU"])
|
||||
tpus_per_node = TPUAcceleratorManager.get_current_node_num_accelerators()
|
||||
assert total_tpus % tpus_per_node == 0
|
||||
return total_tpus // tpus_per_node
|
||||
|
||||
|
||||
def get_num_nodes_in_placement_group() -> int:
|
||||
pg_table = ray.util.placement_group_table()
|
||||
current_pg = ray.util.get_current_placement_group()
|
||||
num_nodes = 0
|
||||
|
||||
if current_pg:
|
||||
nodes_in_pg = set()
|
||||
for pg_key, pg in pg_table.items():
|
||||
if pg_key == current_pg.id.hex():
|
||||
for _, node in pg["bundles_to_node_id"].items():
|
||||
nodes_in_pg.add(node)
|
||||
num_nodes = len(nodes_in_pg)
|
||||
|
||||
return num_nodes
|
||||
|
|
Loading…
Reference in New Issue