mirror of https://github.com/vllm-project/vllm
109 lines
3.5 KiB
Python
109 lines
3.5 KiB
Python
"""
|
|
This example shows how to use Ray Data for running offline batch inference
|
|
distributively on a multi-nodes cluster.
|
|
|
|
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
|
|
"""
|
|
|
|
from typing import Any, Dict, List
|
|
|
|
import numpy as np
|
|
import ray
|
|
from packaging.version import Version
|
|
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
|
|
|
from vllm import LLM, SamplingParams
|
|
|
|
assert Version(ray.__version__) >= Version(
|
|
"2.22.0"), "Ray version must be at least 2.22.0"
|
|
|
|
# Create a sampling params object.
|
|
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
|
|
|
|
# Set tensor parallelism per instance.
|
|
tensor_parallel_size = 1
|
|
|
|
# Set number of instances. Each instance will use tensor_parallel_size GPUs.
|
|
num_instances = 1
|
|
|
|
|
|
# Create a class to do batch inference.
|
|
class LLMPredictor:
|
|
|
|
def __init__(self):
|
|
# Create an LLM.
|
|
self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf",
|
|
tensor_parallel_size=tensor_parallel_size)
|
|
|
|
def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
|
|
# Generate texts from the prompts.
|
|
# The output is a list of RequestOutput objects that contain the prompt,
|
|
# generated text, and other information.
|
|
outputs = self.llm.generate(batch["text"], sampling_params)
|
|
prompt: List[str] = []
|
|
generated_text: List[str] = []
|
|
for output in outputs:
|
|
prompt.append(output.prompt)
|
|
generated_text.append(' '.join([o.text for o in output.outputs]))
|
|
return {
|
|
"prompt": prompt,
|
|
"generated_text": generated_text,
|
|
}
|
|
|
|
|
|
# Read one text file from S3. Ray Data supports reading multiple files
|
|
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
|
|
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
|
|
|
|
|
|
# For tensor_parallel_size > 1, we need to create placement groups for vLLM
|
|
# to use. Every actor has to have its own placement group.
|
|
def scheduling_strategy_fn():
|
|
# One bundle per tensor parallel worker
|
|
pg = ray.util.placement_group(
|
|
[{
|
|
"GPU": 1,
|
|
"CPU": 1
|
|
}] * tensor_parallel_size,
|
|
strategy="STRICT_PACK",
|
|
)
|
|
return dict(scheduling_strategy=PlacementGroupSchedulingStrategy(
|
|
pg, placement_group_capture_child_tasks=True))
|
|
|
|
|
|
resources_kwarg: Dict[str, Any] = {}
|
|
if tensor_parallel_size == 1:
|
|
# For tensor_parallel_size == 1, we simply set num_gpus=1.
|
|
resources_kwarg["num_gpus"] = 1
|
|
else:
|
|
# Otherwise, we have to set num_gpus=0 and provide
|
|
# a function that will create a placement group for
|
|
# each instance.
|
|
resources_kwarg["num_gpus"] = 0
|
|
resources_kwarg["ray_remote_args_fn"] = scheduling_strategy_fn
|
|
|
|
# Apply batch inference for all input data.
|
|
ds = ds.map_batches(
|
|
LLMPredictor,
|
|
# Set the concurrency to the number of LLM instances.
|
|
concurrency=num_instances,
|
|
# Specify the batch size for inference.
|
|
batch_size=32,
|
|
**resources_kwarg,
|
|
)
|
|
|
|
# Peek first 10 results.
|
|
# NOTE: This is for local testing and debugging. For production use case,
|
|
# one should write full result out as shown below.
|
|
outputs = ds.take(limit=10)
|
|
for output in outputs:
|
|
prompt = output["prompt"]
|
|
generated_text = output["generated_text"]
|
|
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
|
|
|
# Write inference output data out as Parquet files to S3.
|
|
# Multiple files would be written to the output destination,
|
|
# and each task would write one or more files separately.
|
|
#
|
|
# ds.write_parquet("s3://<your-output-bucket>")
|