Add docstrings for LLM (#137)

This commit is contained in:
Woosuk Kwon 2023-06-04 12:52:41 -07:00 committed by GitHub
parent 62ec38ea41
commit 8274ca23ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 10 deletions

View File

@ -30,7 +30,6 @@ def main(args: argparse.Namespace):
max_tokens=args.output_len,
)
print(sampling_params)
dummy_prompts = [""] * args.batch_size
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
def run_to_completion(profile: bool = False):
@ -38,7 +37,8 @@ def main(args: argparse.Namespace):
torch.cuda.cudart().cudaProfilerStart()
start_time = time.time()
llm.generate(dummy_prompts, sampling_params, dummy_prompt_token_ids,
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
sampling_params=sampling_params,
use_tqdm=False)
end_time = time.time()

View File

@ -72,9 +72,9 @@ def main(args: argparse.Namespace):
)
# FIXME(woosuk): Do not use internal method.
llm._add_request(
prompt="",
sampling_params=sampling_params,
prompt=None,
prompt_token_ids=prompt_token_ids,
sampling_params=sampling_params,
)
start = time.time()
@ -85,7 +85,9 @@ def main(args: argparse.Namespace):
len(prompt_token_ids) + output_len
for prompt_token_ids, output_len in requests
)
print(f"Throughput: {total_num_tokens / (end - start):.2f} tokens/s")
elapsed_time = end - start
print(f"Throughput: {len(requests) / elapsed_time:.2f} requests/s, "
f"{total_num_tokens / elapsed_time:.2f} tokens/s")
if __name__ == "__main__":

View File

@ -11,6 +11,28 @@ from cacheflow.utils import Counter
class LLM:
"""An LLM for generating texts from given prompts and sampling parameters.
This class includes a tokenizer, a language model (possibly distributed
across multiple GPUs), and GPU memory space allocated for intermediate
states (aka KV cache). Given a batch of prompts and sampling parameters,
this class generates texts from the model, using an intelligent batching
mechanism and efficient memory management.
NOTE: This class is intended to be used for offline inference. For online
serving, use the `AsyncLLMServer` class instead.
NOTE: For the comprehensive list of arguments, see `ServerArgs`.
Args:
model: The name or path of a HuggingFace Transformers model.
tensor_parallel_size: The number of GPUs to use for distributed
execution with tensor parallelism.
dtype: The data type for the model weights and activations. Currently,
we support `float16` and `bfloat16`. If `default`, we use the
`torch_dtype` attribute of the model config. If the `torch_dtype`
is `float32`, we use `float16` instead.
seed: The seed to initialize the random number generator for sampling.
"""
def __init__(
self,
@ -39,19 +61,50 @@ class LLM:
def generate(
self,
prompts: Union[str, List[str]],
prompts: Optional[Union[str, List[str]]] = None,
sampling_params: Optional[SamplingParams] = None,
prompt_token_ids: Optional[List[List[int]]] = None,
use_tqdm: bool = True,
) -> List[RequestOutput]:
"""Generates the completions for the input prompts.
NOTE: This class automatically batches the given prompts, considering
the memory constraint. For the best performance, put all of your prompts
into a single list and pass it to this method.
Args:
prompts: A list of prompts to generate completions for.
sampling_params: The sampling parameters for text generation. If
None, we use the default sampling parameters.
prompt_token_ids: A list of token IDs for the prompts. If None, we
use the tokenizer to convert the prompts to token IDs.
use_tqdm: Whether to use tqdm to display the progress bar.
Returns:
A list of `RequestOutput` objects containing the generated
completions in the same order as the input prompts.
"""
if prompts is None and prompt_token_ids is None:
raise ValueError("Either prompts or prompt_token_ids must be "
"provided.")
if isinstance(prompts, str):
# Convert a single prompt to a list.
prompts = [prompts]
if prompts is not None and prompt_token_ids is not None:
if len(prompts) != len(prompt_token_ids):
raise ValueError("The lengths of prompts and prompt_token_ids "
"must be the same.")
if sampling_params is None:
# Use default sampling params.
sampling_params = SamplingParams()
# Add requests to the server.
for i in range(len(prompts)):
prompt = prompts[i]
if prompts is not None:
num_requests = len(prompts)
else:
num_requests = len(prompt_token_ids)
for i in range(num_requests):
prompt = prompts[i] if prompts is not None else None
if prompt_token_ids is None:
token_ids = None
else:
@ -61,7 +114,7 @@ class LLM:
def _add_request(
self,
prompt: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]],
) -> None:

View File

@ -126,7 +126,7 @@ class LLMServer:
def add_request(
self,
request_id: str,
prompt: str,
prompt: Optional[str],
sampling_params: SamplingParams,
prompt_token_ids: Optional[List[int]] = None,
arrival_time: Optional[float] = None,
@ -134,6 +134,7 @@ class LLMServer:
if arrival_time is None:
arrival_time = time.time()
if prompt_token_ids is None:
assert prompt is not None
prompt_token_ids = self.tokenizer.encode(prompt)
# Create the sequences.