[1/n] Triton sampling kernel (#3186)

Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
Antoni Baum 2024-03-20 14:45:08 -07:00 committed by GitHub
parent 80e254834d
commit 426ec4ec67
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 1072 additions and 24 deletions

View File

@ -0,0 +1,51 @@
import torch
import pytest
import random
from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.model_executor.utils import set_random_seed
@pytest.mark.parametrize("dtype",
[torch.float32, torch.float16, torch.bfloat16])
@pytest.mark.parametrize("use_3d", [True, False])
def test_seeded_uniform(dtype: torch.dtype, use_3d: bool):
device = "cuda"
for seed in range(512):
set_random_seed(seed)
rows = random.randint(1, 512)
cols = random.randint(1, 64000)
if use_3d:
third_dim = random.randint(2, 10)
dims = [rows, third_dim, cols]
else:
dims = [rows, cols]
seeds = torch.randint(torch.iinfo(torch.long).min,
torch.iinfo(torch.long).max, (rows, ),
device=device)
# Test that the same seed produces the same output
out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
torch.testing.assert_close(out, out2)
# del to save memory
del out2
out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
torch.testing.assert_close(out, out3)
# del to save memory
del out3
# Initialize out tensor with garbage to ensure that it is overwritten
out_with_tensor = seeded_uniform(
*dims,
out=torch.full(
(*dims, ),
-1,
dtype=dtype,
device=device,
),
seeds=seeds,
dtype=dtype,
)
torch.testing.assert_close(out, out_with_tensor)

View File

@ -0,0 +1,196 @@
import gc
import torch
import pytest
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.sample import (
_uniform_to_exponential, sample, get_num_triton_sampler_splits,
MAX_TRITON_N_COLS)
from vllm.model_executor.utils import set_random_seed
from vllm.model_executor.sampling_metadata import SamplingTensors
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
@pytest.fixture(autouse=True)
def _cleanup():
yield
gc.collect()
torch.cuda.empty_cache()
@triton.jit
def _uniform_to_exponential_kernel(input, output, n: tl.constexpr):
idx = tl.arange(0, n)
x = tl.load(input + idx)
y = _uniform_to_exponential(x)
tl.store(output + idx, y)
def test_uniform_to_exponential():
"""Test that we can convert uniform to exponential without div by 0."""
input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],
dtype=torch.float32,
device="cuda")
output = torch.zeros(input.shape, dtype=torch.float32, device="cuda")
_uniform_to_exponential_kernel[(1, )](input, output, 2)
assert torch.all(torch.isfinite(output))
assert torch.all(output > 0)
assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
@pytest.mark.parametrize("seed", [1337])
@pytest.mark.parametrize("vocab_size",
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
@pytest.mark.parametrize("save_logprobs", [True, False])
def test_sample_decoding_only(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size,
save_logprobs):
set_random_seed(seed)
bs = 8
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
for i in range(bs):
probs[i, i * (vocab_size // bs)] = 1.0
logprobs = torch.rand_like(probs)
sample_indices = torch.arange(bs, dtype=torch.long, device="cuda")
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if random_sampling == "mixed":
random_sampling_mask = (torch.rand(
(1, bs), device="cuda") < 0.5).expand(n_splits, bs)
elif random_sampling:
random_sampling_mask = torch.ones((n_splits, bs),
dtype=torch.bool,
device="cuda")
else:
random_sampling_mask = torch.zeros((n_splits, bs),
dtype=torch.bool,
device="cuda")
seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, bs),
device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
_save_modified_probs=True)
assert sampled_tokens.shape == (bs, max_best_of)
for i in range(bs):
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
request_uses_random_sampling = random_sampling_mask[0, i]
if modify_greedy_probs and not request_uses_random_sampling:
# If we are modifying greedy probs and the request is greedy,
# we want to make sure the probs tensor is modified in place
assert torch.allclose(
probs[i][sampled_tokens[i]],
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
assert torch.sum(probs[i]) == 1.0
assert torch.allclose(
sampled_modified_probs[i][0],
torch.full_like(sampled_modified_probs[i][0], 1.0))
elif request_uses_random_sampling:
# If the request is random, we want to make sure
# sampled_modified_probs tensor has noise added
# (and thus is different from probs tensor)
assert not torch.allclose(sampled_modified_probs[i][0],
probs[i][sampled_tokens[i]])
elif not request_uses_random_sampling:
# If the request is greedy and we are not modifying greedy probs,
# we want to make sure sampled_modified_probs tensor is the same as
# the probs tensor.
assert torch.allclose(sampled_modified_probs[i][0],
probs[i][sampled_tokens[i]])
if save_logprobs:
assert sampled_logprobs.shape == (bs, max_best_of)
for i in range(bs):
for best_of in range(max_best_of):
assert torch.all(sampled_logprobs[i] == logprobs[i][
sampled_tokens[i, best_of]])
else:
assert sampled_logprobs is None
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
@pytest.mark.parametrize("seed", [1337])
@pytest.mark.parametrize("vocab_size",
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
def test_sample_prompt_logprobs(random_sampling, max_best_of,
modify_greedy_probs, seed, vocab_size):
set_random_seed(seed)
prompt_sizes = [16, 32, 64, 128] * 2
samples = 8
bs = samples + sum(prompt_sizes)
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
for i in range(bs):
probs[i, i * (vocab_size // bs)] = 1.0
logprobs = torch.rand_like(probs)
sample_indices = torch.tensor(prompt_sizes,
dtype=torch.long,
device="cuda").cumsum_(0)
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if random_sampling == "mixed":
random_sampling_mask = torch.rand(
(n_splits, samples), device="cuda") < 0.5
elif random_sampling:
random_sampling_mask = torch.ones((n_splits, samples),
dtype=torch.bool,
device="cuda")
else:
random_sampling_mask = torch.zeros((n_splits, samples),
dtype=torch.bool,
device="cuda")
seeds = torch.randint(1,
torch.iinfo(torch.long).max, (n_splits, samples),
device="cuda").mul_(random_sampling_mask)
sampled_tokens, sampled_logprobs, _ = sample(
probs=probs,
logprobs=logprobs,
sample_indices=sample_indices,
seeds=seeds,
max_best_of=max_best_of,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=True)
assert sampled_tokens.shape == (samples, max_best_of)
assert sampled_logprobs.shape == (samples, max_best_of)
for i, t in enumerate(sample_indices):
assert torch.all(sampled_tokens[i] == t * (vocab_size // bs))
for best_of in range(max_best_of):
assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]]
[sampled_tokens[i, best_of]])
@pytest.mark.parametrize("seed", list(range(16)))
def test_get_sequence_seeds(seed):
"""Ensure that we get a different child seed from base
seed + extra entropy"""
starting_seed = seed
seq_seed = None
extra_entropy = 1
for i in range(512):
new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed,
i,
seeds_to_generate=1,
is_greedy=False)[0]
new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds(
starting_seed,
i,
extra_entropy,
seeds_to_generate=1,
is_greedy=False)[0]
assert new_seq_seed_extra_entropy != new_seq_seed
assert seq_seed != new_seq_seed
seq_seed = new_seq_seed

View File

@ -302,11 +302,11 @@ def test_sampler_logits_processors(seed: int, device: str):
batch_size = random.randint(1, 256)
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
# This sample logits processor gives infinite score to the i-th token,
# This sample logits processor gives maximum score to the i-th token,
# where i is the length of the input sequence.
# We therefore expect the output token sequence to be [0, 1, 2, ...]
def pick_ith(token_ids, logits):
logits[len(token_ids)] = float("inf")
logits[len(token_ids)] = torch.finfo(logits.dtype).max
return logits
seq_group_metadata_list = []
@ -385,7 +385,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
sample_probs = None
def mock_sample(probs, logprobs, sampling_metadata):
def mock_sample(probs, *args, **kwargs):
nonlocal sample_probs
sample_probs = probs
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]

View File

@ -0,0 +1,157 @@
import torch
import triton
import triton.language as tl
from typing import Optional, Union
def seeded_uniform(
*size,
seeds: torch.Tensor,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[torch.device, str]] = None,
pin_memory: Optional[bool] = False,
) -> torch.Tensor:
"""Similar to torch.rand, but allows for seeds to be set per row.
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
If it is 3d, the additional seeds needed will be derived automatically
in a deterministic fashion:
[
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
]
"""
n_dims = len(size)
if n_dims > 3:
raise ValueError("seeded_uniform only supports up to 3D tensors")
if out is None:
out = torch.empty(*size,
dtype=dtype,
device=device,
pin_memory=pin_memory)
elif out.shape != size:
raise ValueError("shape of out and size must be the same")
if n_dims == 3:
n_rows, n_3d, n_cols = out.shape
stride_row = out.stride(0)
stride_3d = out.stride(1)
elif n_dims == 2:
n_rows, n_cols = out.shape
n_3d = 1
stride_row = out.stride(0)
stride_3d = 1
else:
n_cols = out.shape[0]
n_rows = 1
n_3d = 1
stride_row = 1
stride_3d = 1
if seeds.ndim != 1:
raise ValueError("seeds must be a 1D tensor")
if seeds.numel() != n_rows:
raise ValueError(
"seeds must have the same number of elements as out has rows")
# The philox PRNG Triton uses generates 4 random numbers at once.
# Therefore, the most efficient use of it is to divide the
# block size by 4, and then save the generated random numbers to
# each of the 4 slices of the tensor.
full_block_size = triton.next_power_of_2(n_cols)
philox_block_size = max(full_block_size // 4, 1)
n_slices = full_block_size // philox_block_size
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if philox_block_size >= 8192:
num_warps = 32
elif philox_block_size >= 4096:
num_warps = 16
elif philox_block_size >= 2048:
num_warps = 8
_seeded_uniform_triton[(n_rows, n_3d)](
out,
seeds,
stride_row,
stride_3d,
seeds.stride(0),
n_rows,
n_3d,
n_cols,
n_slices=n_slices,
num_warps=num_warps,
block_size=philox_block_size,
)
return out
@triton.jit
def _seeded_uniform_triton(
out_ptr: torch.Tensor,
seed_ptr: torch.Tensor,
out_row_stride: int,
out_3d_stride: int,
seed_row_stride: int,
n_rows: int,
n_3d: int,
n_cols: int,
n_slices: tl.constexpr,
block_size: tl.constexpr,
):
"""
Generate a random float32 number in [0, 1) for each element in the output
tensor. The random numbers in a row generated using the seed for that row.
Args:
out_ptr: The output tensor.
seed_ptr: The per-row seeds to use for random number generation.
out_row_stride: The stride between rows of the output tensor.
out_3d_stride: The stride between 3D slices of the output tensor.
seed_row_stride: The stride between rows of the seed tensor.
n_rows: The number of rows in the output tensor.
n_3d: The size of second dimension of the output tensor,
if output tensor is 3D.
n_cols: The number of columns in the output tensor.
n_slices: The number of philox outputs to use.
"""
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
# Get the row index.
row_idx = tl.program_id(axis=0)
three_d_idx = tl.program_id(axis=1)
philox_offsets = tl.arange(0, block_size)
# Get the seed for the current element.
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
if three_d_idx > 0:
seed ^= three_d_idx
# Generate random numbers in [0, 1).
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
three_d_idx * out_3d_stride)
out1_offsets = philox_offsets
tl.store(output_row_start_ptr + out1_offsets,
out1,
mask=out1_offsets < n_cols)
if n_slices > 1:
out2_offsets = tl.arange(block_size, block_size * 2)
tl.store(output_row_start_ptr + out2_offsets,
out2,
mask=out2_offsets < n_cols)
if n_slices > 2:
out3_offsets = tl.arange(block_size * 2, block_size * 3)
tl.store(output_row_start_ptr + out3_offsets,
out3,
mask=out3_offsets < n_cols)
if n_slices > 3:
out4_offsets = tl.arange(block_size * 3, block_size * 4)
tl.store(output_row_start_ptr + out4_offsets,
out4,
mask=out4_offsets < n_cols)

View File

@ -0,0 +1,405 @@
import math
from typing import Tuple, Optional
import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform
_EPS = 1e-6
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072
def get_num_triton_sampler_splits(n_cols: int) -> int:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return math.ceil(n_cols / MAX_TRITON_N_COLS)
def _multi_split_sample(
probs: torch.Tensor,
seeds: torch.Tensor,
n_splits: int,
sampled_tokens_size: Tuple[int, int],
sampled_logprobs_size: Tuple[int, int],
sample_indices: torch.Tensor,
*,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
):
"""Sample tokens where vocab size is split into multiple parts
(too large for Triton otherwise)."""
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
split_probs = probs.tensor_split(n_splits, 1)
split_logprobs = logprobs.tensor_split(n_splits, 1)
sampled_tokens_tmp = [
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
for _ in range(n_splits)
]
sampled_logprobs_tmp = [
torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
# We are purposefuly using sampled_tokens_size as we need to always
# save modified probs in this case.
sampled_modified_probs_tmp = [
torch.empty(sampled_tokens_size,
dtype=probs.dtype,
device=probs.device) for _ in range(n_splits)
]
for i in range(n_splits):
n_samples = sample_indices.shape[0]
n_cols = split_probs[i].shape[1]
n_best = sampled_tokens_tmp[i].shape[1]
uniform_noise = seeded_uniform(n_samples,
n_best,
n_cols,
seeds=seeds[i].flatten(),
device=split_probs[i].device,
dtype=split_probs[i].dtype)
# TODO(yard1): See if we can remove the contiguous() calls.
# Will need kernel support.
_sample(
split_probs[i].contiguous(),
split_logprobs[i].contiguous(),
sample_indices,
sampled_tokens_tmp[i],
sampled_logprobs_tmp[i],
sampled_modified_probs_tmp[i],
seeds[i],
uniform_noise,
modify_greedy_probs=False,
save_logprobs=save_logprobs,
save_modified_probs=True,
)
if i > 0:
# Add offset to sampled tokens
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
sampled_tokens = torch.stack(sampled_tokens_tmp)
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
# Reduce the results from the splits.
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
dim=0,
keepdim=True)
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
if save_logprobs:
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
else:
sampled_logprobs = None
sampled_modified_probs = sampled_modified_probs.squeeze(0)
if modify_greedy_probs:
# We need to modify the greedy probs for the sampled tokens.
# We can't do this in the kernel as we need to know the
# sampled tokens.
probs.fill_(0.0)
probs.scatter_(1, sampled_tokens, 1.0)
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
def sample(
probs: torch.Tensor,
seeds: torch.Tensor,
*,
max_best_of: int = 1,
sample_indices: Optional[torch.Tensor] = None,
logprobs: Optional[torch.Tensor] = None,
modify_greedy_probs: bool = False,
save_logprobs: bool = False,
_save_modified_probs: bool = False, # pylint: disable=invalid-name
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""Sample tokens from probs. with per-sequence seeds.
Can sample from a subset of sequences through sample_indices.
Args:
probs: Probabilities to sample from.
shape = [batch_size, vocab_size]
seeds: Per-sequence seed values.
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
max_best_of: Number of samples to generate per sequence.
Sequence seed will be incremented by 1 each time.
sample_indices: Indices of sequences to sample from.
If not provided, will sample from all sequences.
shape = [n]
logprobs: Log-probabilities of the sampled tokens.
Only used for saving the logprobs if save_logprobs is True.
shape = [batch_size, vocab_size]
modify_greedy_probs: Whether to modify the greedy probabilities
for speculative sampling (sampled token = 1.0,
everything else = 0.0).
save_logprobs: Whether to save the log-probabilities of the
sampled tokens to a tensor.
_save_modified_probs: Whether to save the modified probabilities
(including gumbel noise) of the sampled tokens to a tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
This is exposed only for testing.
Returns:
sampled_tokens: shape = [n, max_best_of]
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
sampled_modified_probs: shape = [n, max_best_of]
if save_modified_probs else None
"""
if sample_indices is None:
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
sampled_tokens_size = (sample_indices.size(0), max_best_of)
if save_logprobs:
if logprobs is None:
raise ValueError(
"logprobs tensor must be provided if save_logprobs is True")
sampled_logprobs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_logprobs_size = (0, 0)
logprobs = probs
if _save_modified_probs:
sampled_modified_probs_size = sampled_tokens_size
else:
# Empty tensors to invoke the kernel
sampled_modified_probs_size = (0, 0)
# If the number of columns in probs is too large for Triton to handle,
# we split the tensor and sample from each split separately, and then
# do an argmax+gather to combine the results.
n_splits = get_num_triton_sampler_splits(probs.shape[1])
if n_splits > 1:
(sampled_tokens, sampled_logprobs,
sampled_modified_probs) = _multi_split_sample(
probs,
seeds,
n_splits,
sampled_tokens_size,
sampled_logprobs_size,
sample_indices,
logprobs=logprobs,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs)
else:
sampled_tokens = torch.empty(sampled_tokens_size,
dtype=torch.long,
device=probs.device)
sampled_logprobs = torch.empty(sampled_logprobs_size,
dtype=probs.dtype,
device=probs.device)
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
dtype=probs.dtype,
device=probs.device)
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
uniform_noise = seeded_uniform(n_samples,
max_best_of,
n_cols,
seeds=seeds.flatten(),
device=probs.device,
dtype=probs.dtype)
_sample(
probs,
logprobs,
sample_indices,
sampled_tokens,
sampled_logprobs,
sampled_modified_probs,
seeds,
uniform_noise,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=_save_modified_probs,
)
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
sampled_modified_probs if _save_modified_probs else None)
def _sample(probs: torch.Tensor,
logprobs: torch.Tensor,
sample_indices: torch.Tensor,
output_samples: torch.Tensor,
output_logprobs: torch.Tensor,
output_modified_probs: torch.Tensor,
seeds: torch.Tensor,
uniform_noise: torch.Tensor,
*,
modify_greedy_probs: bool = False,
save_logprobs: bool = True,
save_modified_probs: bool = False) -> torch.Tensor:
"""Sample tokens from probs.
Args:
probs [batch_size, vocab_size]: probs to sample from.
logprobs [batch_size, vocab_size]: logprobs (used when
save_logprobsis True).
sample_indices [n]: Indices of the samples to use for each row of probs.
output_samples [n, n_best]: Output tensor to store samples in.
output_logprobs [n, n_best]: Output tensor to store logprobs in.
output_modified_probs [n, n_best]: Output tensor to store
probs of chosen tokens in (modified with noise).
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
greedy sampling. Note this is ONLY used for determining
whether to use random sampling or not. The actual random
noise should be passed as uniform_noise.
uniform_noise [batch_size, n_best, vocab_size]: Uniform
noise to use for random sampling (will be converted
to exponential gumbel noise by the kernel).
modify_greedy_probs: If True, we modify the probs tensor in-place
to encode the sampling method used for each row. This is used
in speculative decoding. Only applies in greedy decoding.
save_logprobs: If True, we save the logprobs of the sampled tokens
in the output_logprobs tensor.
save_modified_probs: If True, we save the modified probs (with noise)
of the sampled tokens in the output_modified_probs tensor.
DOES NOT include the modification done by modify_greedy_probs
(because we want to use the unmodified probs to pick the best
split in case of multi-split sampling).
"""
n_samples = sample_indices.shape[0]
n_cols = probs.shape[1]
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
# The block size is the smallest power of two greater than the number of
# columns in probs
block_size = triton.next_power_of_2(n_cols)
num_warps = 4
# Manual tuning. This seems to give best performance on A100 for
# simple kernels like this.
if block_size >= 8192:
num_warps = 32
elif block_size >= 4096:
num_warps = 16
elif block_size >= 2048:
num_warps = 8
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
# instance per row of the probs matrix
_sample_triton[(n_samples, n_best)](
sample_indices,
output_samples,
output_logprobs,
output_modified_probs,
probs,
logprobs,
seeds,
uniform_noise,
output_samples.stride(0),
probs.stride(0),
uniform_noise.stride(0),
uniform_noise.stride(1) if n_best > 1 else 1,
n_samples,
n_cols,
n_best,
num_warps=num_warps,
block_size=block_size,
modify_greedy_probs=modify_greedy_probs,
save_logprobs=save_logprobs,
save_modified_probs=save_modified_probs,
)
return output_samples, output_logprobs, output_modified_probs
@triton.jit
def _uniform_to_exponential(uniform_noise):
"""Convert uniform samples to exponential samples."""
# tl.rand returns values in [0, 1), so we clamp lower bound
# to _EPS to avoid log(0) and thus division by 0 later
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
uniform_noise = tl.maximum(uniform_noise, lb)
# Use the inversion method to turn uniform samples
# into exponential samples
exponential_noise = -tl.log(uniform_noise)
return exponential_noise
@triton.jit
def _sample_triton(
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
output_logprobs_ptr: torch.Tensor,
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
probs_row_stride: int, uniform_noise_row_stride: int,
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
n_best: int, block_size: tl.constexpr,
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
save_modified_probs: tl.constexpr):
# The rows are independent, so we parallelize across those
sample_idx = tl.program_id(0)
best_idx = tl.program_id(1)
# Load the row index from DRAM
row_idx = tl.load(sample_indices_ptr + sample_idx)
seed = tl.load(seeds_ptr + sample_idx)
uses_random_sampling = seed != 0
# The stride represents how much we need to increase the
# pointer to advance 1 row
row_start_ptr = probs_ptr + row_idx * probs_row_stride
# The block size is the next power of two greater than n_cols,
# so we can fit each row in a single block
col_offsets = tl.arange(0, block_size)
# Load the row into SRAM, using a mask since block_size may be > than n_cols
row = tl.load(row_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=float("-inf"))
if uses_random_sampling:
uniform_noise_start_ptr = (uniform_noise_ptr +
sample_idx * uniform_noise_row_stride +
best_idx * uniform_noise_best_stride)
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
mask=col_offsets < n_cols,
other=0.5)
exponential_noise = _uniform_to_exponential(uniform_noise)
row /= exponential_noise
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
# clamp sampled token to n_cols - 1
# this should not be necessary, but we do it
# just in case
if sampled_token >= n_cols:
sampled_token = n_cols - 1
# Write back output to DRAM
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
best_idx)
tl.store(output_row_start_ptr, sampled_token)
if modify_greedy_probs: # noqa
if not uses_random_sampling:
# Set the probability of the sampled token to 1, all other
# tokens to zero. This is used in speculative decoding where
# the sampling method must be encoded within the sampled
# probability distributions.
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
tl.store(row_start_ptr + col_offsets,
row,
mask=col_offsets < n_cols)
if save_modified_probs:
output_row_start_ptr = (output_modified_probs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_value)
if save_logprobs:
# Load the row into SRAM, using a mask since block_size
# may be > than n_cols
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
sampled_token)
# Write back output to DRAM
output_row_start_ptr = (output_logprobs_ptr +
sample_idx * output_row_stride + best_idx)
tl.store(output_row_start_ptr, sampled_logprob)

View File

@ -12,6 +12,7 @@ from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
SamplerOutput, SequenceData, SequenceGroupOutput,
SequenceOutput)
from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
from vllm.utils import is_neuron
@ -114,7 +115,8 @@ class Sampler(nn.Module):
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
# Sample the next tokens.
sample_results = _sample(probs, logprobs, sampling_metadata)
sample_results = _sample(probs, logprobs, sampling_metadata,
sampling_tensors)
# Get the logprobs query results.
prompt_logprobs, sample_logprobs = _get_logprobs(
logprobs, sampling_metadata, sample_results)
@ -375,7 +377,7 @@ def _multinomial(
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
def _sample(
def _sample_with_torch(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
@ -394,7 +396,7 @@ def _sample(
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type]
sample_indices = categorized_sample_indices[sampling_type][:, 0]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
@ -407,17 +409,19 @@ def _sample(
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
dim=-1)
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
max_best_of = 1
max_best_of_in_batch = 1
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
max_best_of = max(max_best_of, sampling_params.best_of)
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
"seq_groups": seq_groups,
"generators": sampling_metadata.generators,
}
multinomial_samples[sampling_type] = _multinomial(
probs[sample_indices.long()], max_best_of, **seeded_args)
probs[sample_indices.long()], max_best_of_in_batch,
**seeded_args)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
@ -448,6 +452,99 @@ def _sample(
return sample_results
def _sample_with_triton_kernel(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]:
categorized_seq_group_ids = {t: [] for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
_, sampling_params = seq_group
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata = {}
max_best_of_in_batch = 1
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0]
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_ids = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
is_prompts, sample_indices,
sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
for seq_group, is_prompt in zip(seq_groups, is_prompts):
if is_prompt:
_, sampling_params = seq_group
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sampled_tokens, _, _ = sample_triton(
probs=probs,
seeds=sampling_tensors.sampling_seeds,
max_best_of=max_best_of_in_batch,
sample_indices=sampling_tensors.sample_indices,
logprobs=logprobs,
# don't save logprobs because we have logic for that below
# TODO: use this instead of the CPU-based logic below
save_logprobs=False,
)
# GPU<->CPU sync happens in the loop below.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_ids, seq_groups, is_prompts, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups, is_prompts,
sampling_metadata.seq_data,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_ids, sample_results))
sample_results = [
sample_results_dict[i]
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> List[Tuple[List[int], List[int]]]:
return _sample_with_torch(probs, logprobs, sampling_metadata)
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
# sampling_tensors)
def _get_logprobs(
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,

View File

@ -2,12 +2,16 @@ from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
import torch
import random
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData
from vllm.utils import in_wsl, is_neuron
from vllm.model_executor.layers.ops.sample import (
get_num_triton_sampler_splits)
_SAMPLING_EPS = 1e-5
_SEED_0_REPLACEMENT = 3403598558
class SamplingMetadata:
@ -67,14 +71,28 @@ class SamplingTensors:
presence_penalties: torch.Tensor
frequency_penalties: torch.Tensor
repetition_penalties: torch.Tensor
sampling_seeds: torch.Tensor
sample_indices: torch.Tensor
extra_seeds: Optional[torch.Tensor]
prompt_tokens: torch.Tensor
output_tokens: torch.Tensor
@classmethod
def from_sampling_metadata(
cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
device: torch.device,
dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]:
cls,
sampling_metadata: "SamplingMetadata",
vocab_size: int,
device: torch.device,
dtype: torch.dtype,
*,
extra_seeds_to_generate: int = 0,
extra_entropy: Optional[Tuple[int, ...]] = None
) -> Tuple["SamplingTensors", bool, bool, bool]:
"""
extra_seeds_to_generate: extra seeds to generate using the
user-defined seed for each sequence.
extra_entropy: extra entropy to use when generating seeds.
"""
prompt_tokens: List[List[int]] = []
output_tokens: List[List[int]] = []
top_ks: List[int] = []
@ -84,9 +102,18 @@ class SamplingTensors:
presence_penalties: List[float] = []
frequency_penalties: List[float] = []
repetition_penalties: List[float] = []
sampling_seeds: List[int] = []
sample_indices: List[int] = []
prompt_best_of: List[int] = []
do_penalties = False
do_top_p_top_k = False
do_min_p = False
# We need one base seed per Triton slice.
seeds_to_generate = (extra_seeds_to_generate +
get_num_triton_sampler_splits(vocab_size))
sample_indices_start_idx = 0
for i, seq_group in enumerate(sampling_metadata.seq_groups):
seq_ids, sampling_params = seq_group
temperature = sampling_params.temperature
@ -95,6 +122,10 @@ class SamplingTensors:
r = sampling_params.repetition_penalty
top_p = sampling_params.top_p
min_p = sampling_params.min_p
seed = sampling_params.seed
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
# k should not be greater than the vocab size.
top_k = min(sampling_params.top_k, vocab_size)
top_k = vocab_size if top_k == -1 else top_k
@ -112,6 +143,7 @@ class SamplingTensors:
or abs(f) >= _SAMPLING_EPS
or abs(r - 1.0) >= _SAMPLING_EPS):
do_penalties = True
if (i < sampling_metadata.num_prompts
and sampling_params.prompt_logprobs is not None):
# For tokens in the prompt that we only need to get
@ -138,10 +170,34 @@ class SamplingTensors:
frequency_penalties += [f] * len(seq_ids)
repetition_penalties += [r] * len(seq_ids)
is_prompt = i < sampling_metadata.num_prompts
if is_prompt:
prompt_best_of.append(sampling_params.best_of)
prompt_len = sampling_metadata.prompt_lens[i]
if sampling_params.prompt_logprobs is not None:
# NOTE: the sampling position is the last token
# in the prompt
sample_indices_start_idx += prompt_len - 1
for seq_id in seq_ids:
seq_data = sampling_metadata.seq_data[seq_id]
extra_entropy = extra_entropy or ()
seq_seeds = cls._get_sequence_seeds(
seed,
seq_data.get_len(),
*extra_entropy,
seq_id,
seeds_to_generate=seeds_to_generate,
is_greedy=is_greedy)
sampling_seeds.append(seq_seeds)
sample_indices.append(sample_indices_start_idx)
sample_indices_start_idx += 1
sampling_tensors = SamplingTensors.from_lists(
temperatures, top_ps, top_ks, min_ps, presence_penalties,
frequency_penalties, repetition_penalties, prompt_tokens,
output_tokens, vocab_size, device, dtype)
frequency_penalties, repetition_penalties, sampling_seeds,
sample_indices, prompt_tokens, output_tokens, vocab_size,
extra_seeds_to_generate, device, dtype)
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
@classmethod
@ -150,9 +206,10 @@ class SamplingTensors:
presence_penalties: List[float],
frequency_penalties: List[float],
repetition_penalties: List[float],
sampling_seeds: List[int], sample_indices: List[int],
prompt_tokens: List[List[int]],
output_tokens: List[List[int]], vocab_size: int,
device: torch.device,
extra_seeds_to_generate: int, device: torch.device,
dtype: torch.dtype) -> "SamplingTensors":
# Note that the performance will be very bad without
# pinned memory.
@ -210,6 +267,12 @@ class SamplingTensors:
dtype=torch.int,
pin_memory=pin_memory,
)
sample_indices_t = torch.tensor(
sample_indices,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
)
prompt_tensor = torch.tensor(
prompt_padded_tokens,
device="cpu",
@ -222,8 +285,28 @@ class SamplingTensors:
dtype=torch.long,
pin_memory=pin_memory,
)
# need to transpose and make contiguous to
# copy the tensor correctly.
# [batch_size, n_seeds] -> [n_seeds, batch_size]
sampling_seeds_t = torch.tensor(
sampling_seeds,
device="cpu",
dtype=torch.long,
pin_memory=pin_memory,
).T.contiguous()
# Because the memory is pinned, we can do non-blocking
# transfer to device.
# How many seeds the sample operation itself will need.
num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
sampling_seeds_gpu = sampling_seeds_t.to(device=device,
non_blocking=True)
extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
if not extra_seeds_gpu.numel():
extra_seeds_gpu = None
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
@ -237,4 +320,38 @@ class SamplingTensors:
non_blocking=True),
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
output_tokens=output_tensor.to(device=device, non_blocking=True),
sampling_seeds=sampling_seeds_gpu,
sample_indices=sample_indices_t.to(device=device,
non_blocking=True),
extra_seeds=extra_seeds_gpu,
)
@staticmethod
def _get_sequence_seeds(
seed: int,
*extra_entropy: int,
seeds_to_generate: int,
is_greedy: bool,
):
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
if not is_greedy:
if seed is None:
randint_fn = random.randint
else:
generator = random.Random(str((seed, ) + extra_entropy))
randint_fn = generator.randint
lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
# If the user/random sets seed = 0 but request should
# have sampling, we need to change it to something
# else. We use a constant in that case.
# This way we don't need to create and load a bool
# matrix in the sampling kernel, which reduces CPU
# overhead and latency.
seq_seeds = [
randint_fn(lo, hi) or _SEED_0_REPLACEMENT
for _ in range(seeds_to_generate)
]
else:
# For the kernel, seed == 0 means greedy decoding.
seq_seeds = [0] * seeds_to_generate
return seq_seeds

View File

@ -242,6 +242,9 @@ class Sequence:
def get_token_ids(self) -> List[int]:
return self.data.get_token_ids()
def get_prompt_token_ids(self) -> List[int]:
return self.data.get_prompt_token_ids()
def get_last_token_id(self) -> int:
return self.data.get_last_token_id()

View File

@ -408,6 +408,7 @@ class ModelRunner:
selected_token_start_idx = 0
categorized_sample_indices = {t: [] for t in SamplingType}
categorized_sample_indices_start_idx = 0
categorized_sampled_token_indices_start_idx = 0
pin_memory = not self.in_wsl and not self.device_config.is_neuron
max_subquery_len = max(subquery_lens) if subquery_lens else 1
@ -425,9 +426,12 @@ class ModelRunner:
categorized_sample_indices_start_idx += subquery_len - 1
categorized_sample_indices[
sampling_params.sampling_type].append(
categorized_sample_indices_start_idx)
sampling_params.sampling_type].append([
categorized_sample_indices_start_idx,
categorized_sampled_token_indices_start_idx
])
categorized_sample_indices_start_idx += 1
categorized_sampled_token_indices_start_idx += 1
if sampling_params.prompt_logprobs is not None:
selected_token_indices.extend(
@ -449,9 +453,17 @@ class ModelRunner:
categorized_sample_indices[
sampling_params.sampling_type].extend(
range(categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx + num_seqs))
zip(
range(
categorized_sample_indices_start_idx,
categorized_sample_indices_start_idx +
num_seqs),
range(
categorized_sampled_token_indices_start_idx,
categorized_sampled_token_indices_start_idx +
num_seqs)))
categorized_sample_indices_start_idx += num_seqs
categorized_sampled_token_indices_start_idx += num_seqs
if sampling_params.seed is not None:
generators.append(seq_group_metadata.state.generator)
@ -459,12 +471,14 @@ class ModelRunner:
selected_token_indices = _async_h2d(selected_token_indices,
dtype=torch.long,
target_device=self.device,
pin_memory=pin_memory)
pin_memory=not self.in_wsl)
categorized_sample_indices = {
t: _async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=pin_memory)
t: _maybe_expand_dim(
_async_h2d(seq_ids,
dtype=torch.int,
target_device=self.device,
pin_memory=pin_memory), 2, 2)
for t, seq_ids in categorized_sample_indices.items()
}
@ -884,3 +898,11 @@ def _async_h2d(
) -> torch.Tensor:
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
return t.to(device=target_device, non_blocking=True)
def _maybe_expand_dim(tensor: torch.Tensor,
target_dims: int,
size: int = 1) -> torch.Tensor:
if tensor.ndim < target_dims:
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
return tensor