Make _prepare_sample non-blocking and use pinned memory for input buffers (#2207)

This commit is contained in:
Hanzhi Zhou 2023-12-19 16:52:46 -08:00 committed by GitHub
parent ba4f826738
commit 31bff69151
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 38 additions and 17 deletions

View File

@ -10,6 +10,7 @@ from vllm.logger import init_logger
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
from vllm.utils import in_wsl
logger = init_logger(__name__)
@ -52,6 +53,8 @@ class ModelRunner:
# The shape of the cached block table will be
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables = None # Set after initial profiling.
# cache in_wsl result
self.in_wsl = in_wsl()
def load_model(self) -> None:
self.model = get_model(self.model_config)
@ -203,24 +206,29 @@ class ModelRunner:
# When using CUDA graph, we don't need to make the tensors on the GPU
# because they will be eventually copied to the designated GPU buffer.
device = "cpu" if use_captured_graph else "cuda"
pin_memory = use_captured_graph and not self.in_wsl
input_tokens = _make_tensor_with_pad(input_tokens,
max_len=1,
pad=0,
dtype=torch.long,
device=device)
device=device,
pin_memory=pin_memory)
input_positions = _make_tensor_with_pad(input_positions,
max_len=1,
pad=0,
dtype=torch.long,
device=device)
device=device,
pin_memory=pin_memory)
slot_mapping = _make_tensor_with_pad(slot_mapping,
max_len=1,
pad=_PAD_SLOT_ID,
dtype=torch.long,
device=device)
device=device,
pin_memory=pin_memory)
context_lens = torch.tensor(context_lens,
dtype=torch.int,
device=device)
device=device,
pin_memory=pin_memory)
if use_captured_graph:
# The shape of graph_block_tables is
@ -229,7 +237,7 @@ class ModelRunner:
for i, block_table in enumerate(block_tables):
if block_table:
input_block_tables[i, :len(block_table)] = block_table
block_tables = torch.from_numpy(input_block_tables).to(device)
block_tables = torch.tensor(input_block_tables, device=device)
else:
block_tables = _make_tensor_with_pad(
block_tables,
@ -297,11 +305,11 @@ class ModelRunner:
categorized_sample_indices_start_idx + num_seqs))
categorized_sample_indices_start_idx += num_seqs
selected_token_indices = torch.tensor(selected_token_indices,
dtype=torch.long,
device="cuda")
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
pin_memory=not self.in_wsl)
categorized_sample_indices = {
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
for t, seq_ids in categorized_sample_indices.items()
}
@ -334,8 +342,6 @@ class ModelRunner:
else:
inputs = self._prepare_decode(seq_group_metadata_list)
input_tokens, input_positions, input_metadata = inputs
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)
# Execute the model.
if input_metadata.use_cuda_graph:
@ -350,6 +356,9 @@ class ModelRunner:
input_metadata=input_metadata,
)
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
input_metadata.prompt_lens)
# Sample the next token.
output = self.model.sample(
hidden_states=hidden_states,
@ -502,11 +511,14 @@ class CUDAGraphRunner:
del kv_caches
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids)
self.input_buffers["positions"].copy_(positions)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables)
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
non_blocking=True)
self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
non_blocking=True)
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
non_blocking=True)
# Run the graph.
self.graph.replay()
@ -529,9 +541,13 @@ def _make_tensor_with_pad(
pad: int,
dtype: torch.dtype,
device: Union[str, torch.device] = "cuda",
pin_memory: bool = False,
) -> torch.Tensor:
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
return torch.tensor(padded_x, dtype=dtype, device=device)
return torch.tensor(padded_x,
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")
def _get_graph_batch_size(batch_size: int) -> int:
@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int:
return 4
else:
return (batch_size + 7) // 8 * 8
def _async_h2d(data: list, dtype, pin_memory):
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
return t.to(device="cuda", non_blocking=True)