mirror of https://github.com/vllm-project/vllm
Make _prepare_sample non-blocking and use pinned memory for input buffers (#2207)
This commit is contained in:
parent
ba4f826738
commit
31bff69151
|
@ -10,6 +10,7 @@ from vllm.logger import init_logger
|
|||
from vllm.model_executor import get_model, InputMetadata, SamplingMetadata
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import in_wsl
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -52,6 +53,8 @@ class ModelRunner:
|
|||
# The shape of the cached block table will be
|
||||
# (max batch size to capture, max context len to capture / block size).
|
||||
self.graph_block_tables = None # Set after initial profiling.
|
||||
# cache in_wsl result
|
||||
self.in_wsl = in_wsl()
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(self.model_config)
|
||||
|
@ -203,24 +206,29 @@ class ModelRunner:
|
|||
# When using CUDA graph, we don't need to make the tensors on the GPU
|
||||
# because they will be eventually copied to the designated GPU buffer.
|
||||
device = "cpu" if use_captured_graph else "cuda"
|
||||
pin_memory = use_captured_graph and not self.in_wsl
|
||||
input_tokens = _make_tensor_with_pad(input_tokens,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
input_positions = _make_tensor_with_pad(input_positions,
|
||||
max_len=1,
|
||||
pad=0,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
slot_mapping = _make_tensor_with_pad(slot_mapping,
|
||||
max_len=1,
|
||||
pad=_PAD_SLOT_ID,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
context_lens = torch.tensor(context_lens,
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
|
||||
if use_captured_graph:
|
||||
# The shape of graph_block_tables is
|
||||
|
@ -229,7 +237,7 @@ class ModelRunner:
|
|||
for i, block_table in enumerate(block_tables):
|
||||
if block_table:
|
||||
input_block_tables[i, :len(block_table)] = block_table
|
||||
block_tables = torch.from_numpy(input_block_tables).to(device)
|
||||
block_tables = torch.tensor(input_block_tables, device=device)
|
||||
else:
|
||||
block_tables = _make_tensor_with_pad(
|
||||
block_tables,
|
||||
|
@ -297,11 +305,11 @@ class ModelRunner:
|
|||
categorized_sample_indices_start_idx + num_seqs))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
|
||||
selected_token_indices = torch.tensor(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
selected_token_indices = _async_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
pin_memory=not self.in_wsl)
|
||||
categorized_sample_indices = {
|
||||
t: torch.tensor(seq_ids, dtype=torch.int, device="cuda")
|
||||
t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
|
@ -334,8 +342,6 @@ class ModelRunner:
|
|||
else:
|
||||
inputs = self._prepare_decode(seq_group_metadata_list)
|
||||
input_tokens, input_positions, input_metadata = inputs
|
||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||
input_metadata.prompt_lens)
|
||||
|
||||
# Execute the model.
|
||||
if input_metadata.use_cuda_graph:
|
||||
|
@ -350,6 +356,9 @@ class ModelRunner:
|
|||
input_metadata=input_metadata,
|
||||
)
|
||||
|
||||
sampling_metadata = self._prepare_sample(seq_group_metadata_list,
|
||||
input_metadata.prompt_lens)
|
||||
|
||||
# Sample the next token.
|
||||
output = self.model.sample(
|
||||
hidden_states=hidden_states,
|
||||
|
@ -502,11 +511,14 @@ class CUDAGraphRunner:
|
|||
del kv_caches
|
||||
|
||||
# Copy the input tensors to the input buffers.
|
||||
self.input_buffers["input_ids"].copy_(input_ids)
|
||||
self.input_buffers["positions"].copy_(positions)
|
||||
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping)
|
||||
self.input_buffers["context_lens"].copy_(input_metadata.context_lens)
|
||||
self.input_buffers["block_tables"].copy_(input_metadata.block_tables)
|
||||
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
|
||||
self.input_buffers["positions"].copy_(positions, non_blocking=True)
|
||||
self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping,
|
||||
non_blocking=True)
|
||||
self.input_buffers["context_lens"].copy_(input_metadata.context_lens,
|
||||
non_blocking=True)
|
||||
self.input_buffers["block_tables"].copy_(input_metadata.block_tables,
|
||||
non_blocking=True)
|
||||
|
||||
# Run the graph.
|
||||
self.graph.replay()
|
||||
|
@ -529,9 +541,13 @@ def _make_tensor_with_pad(
|
|||
pad: int,
|
||||
dtype: torch.dtype,
|
||||
device: Union[str, torch.device] = "cuda",
|
||||
pin_memory: bool = False,
|
||||
) -> torch.Tensor:
|
||||
padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x]
|
||||
return torch.tensor(padded_x, dtype=dtype, device=device)
|
||||
return torch.tensor(padded_x,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory and str(device) == "cpu")
|
||||
|
||||
|
||||
def _get_graph_batch_size(batch_size: int) -> int:
|
||||
|
@ -541,3 +557,8 @@ def _get_graph_batch_size(batch_size: int) -> int:
|
|||
return 4
|
||||
else:
|
||||
return (batch_size + 7) // 8 * 8
|
||||
|
||||
|
||||
def _async_h2d(data: list, dtype, pin_memory):
|
||||
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory)
|
||||
return t.to(device="cuda", non_blocking=True)
|
||||
|
|
Loading…
Reference in New Issue