mirror of https://github.com/vllm-project/vllm
[1/n] Triton sampling kernel (#3186)
Co-authored-by: Roger Wang <136131678+ywang96@users.noreply.github.com>
This commit is contained in:
parent
80e254834d
commit
426ec4ec67
|
@ -0,0 +1,51 @@
|
|||
import torch
|
||||
import pytest
|
||||
import random
|
||||
|
||||
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype",
|
||||
[torch.float32, torch.float16, torch.bfloat16])
|
||||
@pytest.mark.parametrize("use_3d", [True, False])
|
||||
def test_seeded_uniform(dtype: torch.dtype, use_3d: bool):
|
||||
device = "cuda"
|
||||
for seed in range(512):
|
||||
set_random_seed(seed)
|
||||
rows = random.randint(1, 512)
|
||||
cols = random.randint(1, 64000)
|
||||
if use_3d:
|
||||
third_dim = random.randint(2, 10)
|
||||
dims = [rows, third_dim, cols]
|
||||
else:
|
||||
dims = [rows, cols]
|
||||
seeds = torch.randint(torch.iinfo(torch.long).min,
|
||||
torch.iinfo(torch.long).max, (rows, ),
|
||||
device=device)
|
||||
|
||||
# Test that the same seed produces the same output
|
||||
out = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
|
||||
out2 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
|
||||
torch.testing.assert_close(out, out2)
|
||||
# del to save memory
|
||||
del out2
|
||||
|
||||
out3 = seeded_uniform(*dims, seeds=seeds, dtype=dtype, device=device)
|
||||
torch.testing.assert_close(out, out3)
|
||||
# del to save memory
|
||||
del out3
|
||||
|
||||
# Initialize out tensor with garbage to ensure that it is overwritten
|
||||
out_with_tensor = seeded_uniform(
|
||||
*dims,
|
||||
out=torch.full(
|
||||
(*dims, ),
|
||||
-1,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
seeds=seeds,
|
||||
dtype=dtype,
|
||||
)
|
||||
torch.testing.assert_close(out, out_with_tensor)
|
|
@ -0,0 +1,196 @@
|
|||
import gc
|
||||
|
||||
import torch
|
||||
import pytest
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.model_executor.layers.ops.sample import (
|
||||
_uniform_to_exponential, sample, get_num_triton_sampler_splits,
|
||||
MAX_TRITON_N_COLS)
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.model_executor.sampling_metadata import SamplingTensors
|
||||
|
||||
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
|
||||
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup():
|
||||
yield
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _uniform_to_exponential_kernel(input, output, n: tl.constexpr):
|
||||
idx = tl.arange(0, n)
|
||||
x = tl.load(input + idx)
|
||||
y = _uniform_to_exponential(x)
|
||||
tl.store(output + idx, y)
|
||||
|
||||
|
||||
def test_uniform_to_exponential():
|
||||
"""Test that we can convert uniform to exponential without div by 0."""
|
||||
input = torch.tensor([0.0, 1.0 - torch.finfo(torch.float32).eps],
|
||||
dtype=torch.float32,
|
||||
device="cuda")
|
||||
output = torch.zeros(input.shape, dtype=torch.float32, device="cuda")
|
||||
_uniform_to_exponential_kernel[(1, )](input, output, 2)
|
||||
assert torch.all(torch.isfinite(output))
|
||||
assert torch.all(output > 0)
|
||||
assert torch.all(torch.isfinite(torch.full_like(output, 1.0) / output))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
|
||||
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
|
||||
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
|
||||
@pytest.mark.parametrize("seed", [1337])
|
||||
@pytest.mark.parametrize("vocab_size",
|
||||
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
|
||||
@pytest.mark.parametrize("save_logprobs", [True, False])
|
||||
def test_sample_decoding_only(random_sampling, max_best_of,
|
||||
modify_greedy_probs, seed, vocab_size,
|
||||
save_logprobs):
|
||||
set_random_seed(seed)
|
||||
bs = 8
|
||||
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
|
||||
for i in range(bs):
|
||||
probs[i, i * (vocab_size // bs)] = 1.0
|
||||
logprobs = torch.rand_like(probs)
|
||||
sample_indices = torch.arange(bs, dtype=torch.long, device="cuda")
|
||||
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
||||
if random_sampling == "mixed":
|
||||
random_sampling_mask = (torch.rand(
|
||||
(1, bs), device="cuda") < 0.5).expand(n_splits, bs)
|
||||
elif random_sampling:
|
||||
random_sampling_mask = torch.ones((n_splits, bs),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
else:
|
||||
random_sampling_mask = torch.zeros((n_splits, bs),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
|
||||
seeds = torch.randint(1,
|
||||
torch.iinfo(torch.long).max, (n_splits, bs),
|
||||
device="cuda").mul_(random_sampling_mask)
|
||||
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
_save_modified_probs=True)
|
||||
assert sampled_tokens.shape == (bs, max_best_of)
|
||||
for i in range(bs):
|
||||
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
|
||||
request_uses_random_sampling = random_sampling_mask[0, i]
|
||||
if modify_greedy_probs and not request_uses_random_sampling:
|
||||
# If we are modifying greedy probs and the request is greedy,
|
||||
# we want to make sure the probs tensor is modified in place
|
||||
assert torch.allclose(
|
||||
probs[i][sampled_tokens[i]],
|
||||
torch.full_like(probs[i][sampled_tokens[i]], 1.0))
|
||||
assert torch.sum(probs[i]) == 1.0
|
||||
assert torch.allclose(
|
||||
sampled_modified_probs[i][0],
|
||||
torch.full_like(sampled_modified_probs[i][0], 1.0))
|
||||
elif request_uses_random_sampling:
|
||||
# If the request is random, we want to make sure
|
||||
# sampled_modified_probs tensor has noise added
|
||||
# (and thus is different from probs tensor)
|
||||
assert not torch.allclose(sampled_modified_probs[i][0],
|
||||
probs[i][sampled_tokens[i]])
|
||||
elif not request_uses_random_sampling:
|
||||
# If the request is greedy and we are not modifying greedy probs,
|
||||
# we want to make sure sampled_modified_probs tensor is the same as
|
||||
# the probs tensor.
|
||||
assert torch.allclose(sampled_modified_probs[i][0],
|
||||
probs[i][sampled_tokens[i]])
|
||||
|
||||
if save_logprobs:
|
||||
assert sampled_logprobs.shape == (bs, max_best_of)
|
||||
for i in range(bs):
|
||||
for best_of in range(max_best_of):
|
||||
assert torch.all(sampled_logprobs[i] == logprobs[i][
|
||||
sampled_tokens[i, best_of]])
|
||||
else:
|
||||
assert sampled_logprobs is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("random_sampling", [True, False, "mixed"])
|
||||
@pytest.mark.parametrize("max_best_of", [1, 2, 3, 4, 5])
|
||||
@pytest.mark.parametrize("modify_greedy_probs", [True, False])
|
||||
@pytest.mark.parametrize("seed", [1337])
|
||||
@pytest.mark.parametrize("vocab_size",
|
||||
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
|
||||
def test_sample_prompt_logprobs(random_sampling, max_best_of,
|
||||
modify_greedy_probs, seed, vocab_size):
|
||||
set_random_seed(seed)
|
||||
prompt_sizes = [16, 32, 64, 128] * 2
|
||||
samples = 8
|
||||
bs = samples + sum(prompt_sizes)
|
||||
probs = torch.zeros((bs, vocab_size), dtype=torch.float32, device="cuda")
|
||||
for i in range(bs):
|
||||
probs[i, i * (vocab_size // bs)] = 1.0
|
||||
logprobs = torch.rand_like(probs)
|
||||
sample_indices = torch.tensor(prompt_sizes,
|
||||
dtype=torch.long,
|
||||
device="cuda").cumsum_(0)
|
||||
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
||||
if random_sampling == "mixed":
|
||||
random_sampling_mask = torch.rand(
|
||||
(n_splits, samples), device="cuda") < 0.5
|
||||
elif random_sampling:
|
||||
random_sampling_mask = torch.ones((n_splits, samples),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
else:
|
||||
random_sampling_mask = torch.zeros((n_splits, samples),
|
||||
dtype=torch.bool,
|
||||
device="cuda")
|
||||
|
||||
seeds = torch.randint(1,
|
||||
torch.iinfo(torch.long).max, (n_splits, samples),
|
||||
device="cuda").mul_(random_sampling_mask)
|
||||
sampled_tokens, sampled_logprobs, _ = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=True)
|
||||
assert sampled_tokens.shape == (samples, max_best_of)
|
||||
assert sampled_logprobs.shape == (samples, max_best_of)
|
||||
for i, t in enumerate(sample_indices):
|
||||
assert torch.all(sampled_tokens[i] == t * (vocab_size // bs))
|
||||
for best_of in range(max_best_of):
|
||||
assert torch.all(sampled_logprobs[i] == logprobs[sample_indices[i]]
|
||||
[sampled_tokens[i, best_of]])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("seed", list(range(16)))
|
||||
def test_get_sequence_seeds(seed):
|
||||
"""Ensure that we get a different child seed from base
|
||||
seed + extra entropy"""
|
||||
starting_seed = seed
|
||||
seq_seed = None
|
||||
extra_entropy = 1
|
||||
for i in range(512):
|
||||
new_seq_seed = SamplingTensors._get_sequence_seeds(starting_seed,
|
||||
i,
|
||||
seeds_to_generate=1,
|
||||
is_greedy=False)[0]
|
||||
new_seq_seed_extra_entropy = SamplingTensors._get_sequence_seeds(
|
||||
starting_seed,
|
||||
i,
|
||||
extra_entropy,
|
||||
seeds_to_generate=1,
|
||||
is_greedy=False)[0]
|
||||
assert new_seq_seed_extra_entropy != new_seq_seed
|
||||
assert seq_seed != new_seq_seed
|
||||
seq_seed = new_seq_seed
|
|
@ -302,11 +302,11 @@ def test_sampler_logits_processors(seed: int, device: str):
|
|||
batch_size = random.randint(1, 256)
|
||||
input_tensor, _, sampler, model_runner = _prepare_test(batch_size)
|
||||
|
||||
# This sample logits processor gives infinite score to the i-th token,
|
||||
# This sample logits processor gives maximum score to the i-th token,
|
||||
# where i is the length of the input sequence.
|
||||
# We therefore expect the output token sequence to be [0, 1, 2, ...]
|
||||
def pick_ith(token_ids, logits):
|
||||
logits[len(token_ids)] = float("inf")
|
||||
logits[len(token_ids)] = torch.finfo(logits.dtype).max
|
||||
return logits
|
||||
|
||||
seq_group_metadata_list = []
|
||||
|
@ -385,7 +385,7 @@ def test_sampler_top_k_top_p(seed: int, device: str):
|
|||
|
||||
sample_probs = None
|
||||
|
||||
def mock_sample(probs, logprobs, sampling_metadata):
|
||||
def mock_sample(probs, *args, **kwargs):
|
||||
nonlocal sample_probs
|
||||
sample_probs = probs
|
||||
return [[prob.topk(1, dim=-1).indices.tolist(), [0]] for prob in probs]
|
||||
|
|
|
@ -0,0 +1,157 @@
|
|||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from typing import Optional, Union
|
||||
|
||||
|
||||
def seeded_uniform(
|
||||
*size,
|
||||
seeds: torch.Tensor,
|
||||
out: Optional[torch.Tensor] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
device: Optional[Union[torch.device, str]] = None,
|
||||
pin_memory: Optional[bool] = False,
|
||||
) -> torch.Tensor:
|
||||
"""Similar to torch.rand, but allows for seeds to be set per row.
|
||||
|
||||
seeds must be a 1d tensor. The output tensor may be 1d, 2d, or 3d.
|
||||
If it is 3d, the additional seeds needed will be derived automatically
|
||||
in a deterministic fashion:
|
||||
[
|
||||
row 0: [columns_with_seed_0], [columns_with_seed0^1], ...
|
||||
]
|
||||
"""
|
||||
n_dims = len(size)
|
||||
|
||||
if n_dims > 3:
|
||||
raise ValueError("seeded_uniform only supports up to 3D tensors")
|
||||
|
||||
if out is None:
|
||||
out = torch.empty(*size,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
pin_memory=pin_memory)
|
||||
elif out.shape != size:
|
||||
raise ValueError("shape of out and size must be the same")
|
||||
|
||||
if n_dims == 3:
|
||||
n_rows, n_3d, n_cols = out.shape
|
||||
stride_row = out.stride(0)
|
||||
stride_3d = out.stride(1)
|
||||
elif n_dims == 2:
|
||||
n_rows, n_cols = out.shape
|
||||
n_3d = 1
|
||||
stride_row = out.stride(0)
|
||||
stride_3d = 1
|
||||
else:
|
||||
n_cols = out.shape[0]
|
||||
n_rows = 1
|
||||
n_3d = 1
|
||||
stride_row = 1
|
||||
stride_3d = 1
|
||||
|
||||
if seeds.ndim != 1:
|
||||
raise ValueError("seeds must be a 1D tensor")
|
||||
|
||||
if seeds.numel() != n_rows:
|
||||
raise ValueError(
|
||||
"seeds must have the same number of elements as out has rows")
|
||||
|
||||
# The philox PRNG Triton uses generates 4 random numbers at once.
|
||||
# Therefore, the most efficient use of it is to divide the
|
||||
# block size by 4, and then save the generated random numbers to
|
||||
# each of the 4 slices of the tensor.
|
||||
full_block_size = triton.next_power_of_2(n_cols)
|
||||
philox_block_size = max(full_block_size // 4, 1)
|
||||
n_slices = full_block_size // philox_block_size
|
||||
num_warps = 4
|
||||
# Manual tuning. This seems to give best performance on A100 for
|
||||
# simple kernels like this.
|
||||
if philox_block_size >= 8192:
|
||||
num_warps = 32
|
||||
elif philox_block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif philox_block_size >= 2048:
|
||||
num_warps = 8
|
||||
|
||||
_seeded_uniform_triton[(n_rows, n_3d)](
|
||||
out,
|
||||
seeds,
|
||||
stride_row,
|
||||
stride_3d,
|
||||
seeds.stride(0),
|
||||
n_rows,
|
||||
n_3d,
|
||||
n_cols,
|
||||
n_slices=n_slices,
|
||||
num_warps=num_warps,
|
||||
block_size=philox_block_size,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _seeded_uniform_triton(
|
||||
out_ptr: torch.Tensor,
|
||||
seed_ptr: torch.Tensor,
|
||||
out_row_stride: int,
|
||||
out_3d_stride: int,
|
||||
seed_row_stride: int,
|
||||
n_rows: int,
|
||||
n_3d: int,
|
||||
n_cols: int,
|
||||
n_slices: tl.constexpr,
|
||||
block_size: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Generate a random float32 number in [0, 1) for each element in the output
|
||||
tensor. The random numbers in a row generated using the seed for that row.
|
||||
|
||||
Args:
|
||||
out_ptr: The output tensor.
|
||||
seed_ptr: The per-row seeds to use for random number generation.
|
||||
out_row_stride: The stride between rows of the output tensor.
|
||||
out_3d_stride: The stride between 3D slices of the output tensor.
|
||||
seed_row_stride: The stride between rows of the seed tensor.
|
||||
n_rows: The number of rows in the output tensor.
|
||||
n_3d: The size of second dimension of the output tensor,
|
||||
if output tensor is 3D.
|
||||
n_cols: The number of columns in the output tensor.
|
||||
n_slices: The number of philox outputs to use.
|
||||
"""
|
||||
tl.static_assert(n_slices > 0 and n_slices <= 4, "0 < n_slices <= 4")
|
||||
|
||||
# Get the row index.
|
||||
row_idx = tl.program_id(axis=0)
|
||||
three_d_idx = tl.program_id(axis=1)
|
||||
|
||||
philox_offsets = tl.arange(0, block_size)
|
||||
# Get the seed for the current element.
|
||||
seed = tl.load(seed_ptr + row_idx * seed_row_stride)
|
||||
if three_d_idx > 0:
|
||||
seed ^= three_d_idx
|
||||
# Generate random numbers in [0, 1).
|
||||
out1, out2, out3, out4 = tl.rand4x(seed, philox_offsets)
|
||||
|
||||
output_row_start_ptr = (out_ptr + row_idx * out_row_stride +
|
||||
three_d_idx * out_3d_stride)
|
||||
out1_offsets = philox_offsets
|
||||
tl.store(output_row_start_ptr + out1_offsets,
|
||||
out1,
|
||||
mask=out1_offsets < n_cols)
|
||||
if n_slices > 1:
|
||||
out2_offsets = tl.arange(block_size, block_size * 2)
|
||||
tl.store(output_row_start_ptr + out2_offsets,
|
||||
out2,
|
||||
mask=out2_offsets < n_cols)
|
||||
if n_slices > 2:
|
||||
out3_offsets = tl.arange(block_size * 2, block_size * 3)
|
||||
tl.store(output_row_start_ptr + out3_offsets,
|
||||
out3,
|
||||
mask=out3_offsets < n_cols)
|
||||
if n_slices > 3:
|
||||
out4_offsets = tl.arange(block_size * 3, block_size * 4)
|
||||
tl.store(output_row_start_ptr + out4_offsets,
|
||||
out4,
|
||||
mask=out4_offsets < n_cols)
|
|
@ -0,0 +1,405 @@
|
|||
import math
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
||||
|
||||
_EPS = 1e-6
|
||||
|
||||
# This is a hardcoded limit in Triton (max block size).
|
||||
MAX_TRITON_N_COLS = 131072
|
||||
|
||||
|
||||
def get_num_triton_sampler_splits(n_cols: int) -> int:
|
||||
"""Get the number of splits to use for Triton sampling.
|
||||
|
||||
Triton has a limit on the number of columns it can handle, so we need to
|
||||
split the tensor and call the kernel multiple times if it's too large.
|
||||
"""
|
||||
return math.ceil(n_cols / MAX_TRITON_N_COLS)
|
||||
|
||||
|
||||
def _multi_split_sample(
|
||||
probs: torch.Tensor,
|
||||
seeds: torch.Tensor,
|
||||
n_splits: int,
|
||||
sampled_tokens_size: Tuple[int, int],
|
||||
sampled_logprobs_size: Tuple[int, int],
|
||||
sample_indices: torch.Tensor,
|
||||
*,
|
||||
logprobs: Optional[torch.Tensor] = None,
|
||||
modify_greedy_probs: bool = False,
|
||||
save_logprobs: bool = False,
|
||||
):
|
||||
"""Sample tokens where vocab size is split into multiple parts
|
||||
(too large for Triton otherwise)."""
|
||||
assert seeds.ndim == 2 and seeds.shape[0] == n_splits
|
||||
split_probs = probs.tensor_split(n_splits, 1)
|
||||
split_logprobs = logprobs.tensor_split(n_splits, 1)
|
||||
sampled_tokens_tmp = [
|
||||
torch.empty(sampled_tokens_size, dtype=torch.long, device=probs.device)
|
||||
for _ in range(n_splits)
|
||||
]
|
||||
sampled_logprobs_tmp = [
|
||||
torch.empty(sampled_logprobs_size,
|
||||
dtype=probs.dtype,
|
||||
device=probs.device) for _ in range(n_splits)
|
||||
]
|
||||
# We are purposefuly using sampled_tokens_size as we need to always
|
||||
# save modified probs in this case.
|
||||
sampled_modified_probs_tmp = [
|
||||
torch.empty(sampled_tokens_size,
|
||||
dtype=probs.dtype,
|
||||
device=probs.device) for _ in range(n_splits)
|
||||
]
|
||||
for i in range(n_splits):
|
||||
n_samples = sample_indices.shape[0]
|
||||
n_cols = split_probs[i].shape[1]
|
||||
n_best = sampled_tokens_tmp[i].shape[1]
|
||||
uniform_noise = seeded_uniform(n_samples,
|
||||
n_best,
|
||||
n_cols,
|
||||
seeds=seeds[i].flatten(),
|
||||
device=split_probs[i].device,
|
||||
dtype=split_probs[i].dtype)
|
||||
# TODO(yard1): See if we can remove the contiguous() calls.
|
||||
# Will need kernel support.
|
||||
_sample(
|
||||
split_probs[i].contiguous(),
|
||||
split_logprobs[i].contiguous(),
|
||||
sample_indices,
|
||||
sampled_tokens_tmp[i],
|
||||
sampled_logprobs_tmp[i],
|
||||
sampled_modified_probs_tmp[i],
|
||||
seeds[i],
|
||||
uniform_noise,
|
||||
modify_greedy_probs=False,
|
||||
save_logprobs=save_logprobs,
|
||||
save_modified_probs=True,
|
||||
)
|
||||
if i > 0:
|
||||
# Add offset to sampled tokens
|
||||
sampled_tokens_tmp[i].add_(i * split_probs[i - 1].shape[1])
|
||||
sampled_tokens = torch.stack(sampled_tokens_tmp)
|
||||
sampled_modified_probs = torch.stack(sampled_modified_probs_tmp)
|
||||
# Reduce the results from the splits.
|
||||
sampled_modified_probs, indices = torch.max(sampled_modified_probs,
|
||||
dim=0,
|
||||
keepdim=True)
|
||||
sampled_tokens = sampled_tokens.gather(0, indices).squeeze(0)
|
||||
if save_logprobs:
|
||||
sampled_logprobs = torch.stack(sampled_logprobs_tmp)
|
||||
sampled_logprobs = sampled_logprobs.gather(0, indices).squeeze(0)
|
||||
else:
|
||||
sampled_logprobs = None
|
||||
sampled_modified_probs = sampled_modified_probs.squeeze(0)
|
||||
|
||||
if modify_greedy_probs:
|
||||
# We need to modify the greedy probs for the sampled tokens.
|
||||
# We can't do this in the kernel as we need to know the
|
||||
# sampled tokens.
|
||||
probs.fill_(0.0)
|
||||
probs.scatter_(1, sampled_tokens, 1.0)
|
||||
|
||||
return (sampled_tokens, sampled_logprobs, sampled_modified_probs)
|
||||
|
||||
|
||||
def sample(
|
||||
probs: torch.Tensor,
|
||||
seeds: torch.Tensor,
|
||||
*,
|
||||
max_best_of: int = 1,
|
||||
sample_indices: Optional[torch.Tensor] = None,
|
||||
logprobs: Optional[torch.Tensor] = None,
|
||||
modify_greedy_probs: bool = False,
|
||||
save_logprobs: bool = False,
|
||||
_save_modified_probs: bool = False, # pylint: disable=invalid-name
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""Sample tokens from probs. with per-sequence seeds.
|
||||
|
||||
Can sample from a subset of sequences through sample_indices.
|
||||
|
||||
Args:
|
||||
probs: Probabilities to sample from.
|
||||
shape = [batch_size, vocab_size]
|
||||
seeds: Per-sequence seed values.
|
||||
shape = [n, math.ceil(vocab_size / MAX_TRITON_N_COLS)]
|
||||
max_best_of: Number of samples to generate per sequence.
|
||||
Sequence seed will be incremented by 1 each time.
|
||||
sample_indices: Indices of sequences to sample from.
|
||||
If not provided, will sample from all sequences.
|
||||
shape = [n]
|
||||
logprobs: Log-probabilities of the sampled tokens.
|
||||
Only used for saving the logprobs if save_logprobs is True.
|
||||
shape = [batch_size, vocab_size]
|
||||
modify_greedy_probs: Whether to modify the greedy probabilities
|
||||
for speculative sampling (sampled token = 1.0,
|
||||
everything else = 0.0).
|
||||
save_logprobs: Whether to save the log-probabilities of the
|
||||
sampled tokens to a tensor.
|
||||
_save_modified_probs: Whether to save the modified probabilities
|
||||
(including gumbel noise) of the sampled tokens to a tensor.
|
||||
DOES NOT include the modification done by modify_greedy_probs
|
||||
(because we want to use the unmodified probs to pick the best
|
||||
split in case of multi-split sampling).
|
||||
This is exposed only for testing.
|
||||
|
||||
Returns:
|
||||
sampled_tokens: shape = [n, max_best_of]
|
||||
sampled_logprobs: shape = [n, max_best_of] if save_logprobs else None
|
||||
sampled_modified_probs: shape = [n, max_best_of]
|
||||
if save_modified_probs else None
|
||||
"""
|
||||
if sample_indices is None:
|
||||
sample_indices = torch.arange(0, probs.shape[0], device=probs.device)
|
||||
|
||||
sampled_tokens_size = (sample_indices.size(0), max_best_of)
|
||||
if save_logprobs:
|
||||
if logprobs is None:
|
||||
raise ValueError(
|
||||
"logprobs tensor must be provided if save_logprobs is True")
|
||||
sampled_logprobs_size = sampled_tokens_size
|
||||
else:
|
||||
# Empty tensors to invoke the kernel
|
||||
sampled_logprobs_size = (0, 0)
|
||||
logprobs = probs
|
||||
|
||||
if _save_modified_probs:
|
||||
sampled_modified_probs_size = sampled_tokens_size
|
||||
else:
|
||||
# Empty tensors to invoke the kernel
|
||||
sampled_modified_probs_size = (0, 0)
|
||||
|
||||
# If the number of columns in probs is too large for Triton to handle,
|
||||
# we split the tensor and sample from each split separately, and then
|
||||
# do an argmax+gather to combine the results.
|
||||
n_splits = get_num_triton_sampler_splits(probs.shape[1])
|
||||
if n_splits > 1:
|
||||
(sampled_tokens, sampled_logprobs,
|
||||
sampled_modified_probs) = _multi_split_sample(
|
||||
probs,
|
||||
seeds,
|
||||
n_splits,
|
||||
sampled_tokens_size,
|
||||
sampled_logprobs_size,
|
||||
sample_indices,
|
||||
logprobs=logprobs,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs)
|
||||
else:
|
||||
sampled_tokens = torch.empty(sampled_tokens_size,
|
||||
dtype=torch.long,
|
||||
device=probs.device)
|
||||
sampled_logprobs = torch.empty(sampled_logprobs_size,
|
||||
dtype=probs.dtype,
|
||||
device=probs.device)
|
||||
sampled_modified_probs = torch.empty(sampled_modified_probs_size,
|
||||
dtype=probs.dtype,
|
||||
device=probs.device)
|
||||
n_samples = sample_indices.shape[0]
|
||||
n_cols = probs.shape[1]
|
||||
uniform_noise = seeded_uniform(n_samples,
|
||||
max_best_of,
|
||||
n_cols,
|
||||
seeds=seeds.flatten(),
|
||||
device=probs.device,
|
||||
dtype=probs.dtype)
|
||||
|
||||
_sample(
|
||||
probs,
|
||||
logprobs,
|
||||
sample_indices,
|
||||
sampled_tokens,
|
||||
sampled_logprobs,
|
||||
sampled_modified_probs,
|
||||
seeds,
|
||||
uniform_noise,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
save_modified_probs=_save_modified_probs,
|
||||
)
|
||||
return (sampled_tokens, sampled_logprobs if save_logprobs else None,
|
||||
sampled_modified_probs if _save_modified_probs else None)
|
||||
|
||||
|
||||
def _sample(probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sample_indices: torch.Tensor,
|
||||
output_samples: torch.Tensor,
|
||||
output_logprobs: torch.Tensor,
|
||||
output_modified_probs: torch.Tensor,
|
||||
seeds: torch.Tensor,
|
||||
uniform_noise: torch.Tensor,
|
||||
*,
|
||||
modify_greedy_probs: bool = False,
|
||||
save_logprobs: bool = True,
|
||||
save_modified_probs: bool = False) -> torch.Tensor:
|
||||
"""Sample tokens from probs.
|
||||
|
||||
Args:
|
||||
probs [batch_size, vocab_size]: probs to sample from.
|
||||
logprobs [batch_size, vocab_size]: logprobs (used when
|
||||
save_logprobsis True).
|
||||
sample_indices [n]: Indices of the samples to use for each row of probs.
|
||||
output_samples [n, n_best]: Output tensor to store samples in.
|
||||
output_logprobs [n, n_best]: Output tensor to store logprobs in.
|
||||
output_modified_probs [n, n_best]: Output tensor to store
|
||||
probs of chosen tokens in (modified with noise).
|
||||
seeds [n]: Seeds to use for sampling. If the seed is 0, we use
|
||||
greedy sampling. Note this is ONLY used for determining
|
||||
whether to use random sampling or not. The actual random
|
||||
noise should be passed as uniform_noise.
|
||||
uniform_noise [batch_size, n_best, vocab_size]: Uniform
|
||||
noise to use for random sampling (will be converted
|
||||
to exponential gumbel noise by the kernel).
|
||||
modify_greedy_probs: If True, we modify the probs tensor in-place
|
||||
to encode the sampling method used for each row. This is used
|
||||
in speculative decoding. Only applies in greedy decoding.
|
||||
save_logprobs: If True, we save the logprobs of the sampled tokens
|
||||
in the output_logprobs tensor.
|
||||
save_modified_probs: If True, we save the modified probs (with noise)
|
||||
of the sampled tokens in the output_modified_probs tensor.
|
||||
DOES NOT include the modification done by modify_greedy_probs
|
||||
(because we want to use the unmodified probs to pick the best
|
||||
split in case of multi-split sampling).
|
||||
"""
|
||||
n_samples = sample_indices.shape[0]
|
||||
n_cols = probs.shape[1]
|
||||
n_best = output_samples.shape[1] if len(output_samples.shape) > 1 else 1
|
||||
|
||||
# The block size is the smallest power of two greater than the number of
|
||||
# columns in probs
|
||||
block_size = triton.next_power_of_2(n_cols)
|
||||
num_warps = 4
|
||||
# Manual tuning. This seems to give best performance on A100 for
|
||||
# simple kernels like this.
|
||||
if block_size >= 8192:
|
||||
num_warps = 32
|
||||
elif block_size >= 4096:
|
||||
num_warps = 16
|
||||
elif block_size >= 2048:
|
||||
num_warps = 8
|
||||
|
||||
# Enqueue kernel. The 1D launch grid is simple: we have one kernel
|
||||
# instance per row of the probs matrix
|
||||
_sample_triton[(n_samples, n_best)](
|
||||
sample_indices,
|
||||
output_samples,
|
||||
output_logprobs,
|
||||
output_modified_probs,
|
||||
probs,
|
||||
logprobs,
|
||||
seeds,
|
||||
uniform_noise,
|
||||
output_samples.stride(0),
|
||||
probs.stride(0),
|
||||
uniform_noise.stride(0),
|
||||
uniform_noise.stride(1) if n_best > 1 else 1,
|
||||
n_samples,
|
||||
n_cols,
|
||||
n_best,
|
||||
num_warps=num_warps,
|
||||
block_size=block_size,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
save_modified_probs=save_modified_probs,
|
||||
)
|
||||
return output_samples, output_logprobs, output_modified_probs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _uniform_to_exponential(uniform_noise):
|
||||
"""Convert uniform samples to exponential samples."""
|
||||
# tl.rand returns values in [0, 1), so we clamp lower bound
|
||||
# to _EPS to avoid log(0) and thus division by 0 later
|
||||
lb = tl.full(uniform_noise.shape, _EPS, uniform_noise.dtype)
|
||||
uniform_noise = tl.maximum(uniform_noise, lb)
|
||||
# Use the inversion method to turn uniform samples
|
||||
# into exponential samples
|
||||
exponential_noise = -tl.log(uniform_noise)
|
||||
return exponential_noise
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _sample_triton(
|
||||
sample_indices_ptr: torch.Tensor, output_ptr: torch.Tensor,
|
||||
output_logprobs_ptr: torch.Tensor,
|
||||
output_modified_probs_ptr: torch.Tensor, probs_ptr: torch.Tensor,
|
||||
logprobs_ptr: torch.Tensor, seeds_ptr: torch.Tensor,
|
||||
uniform_noise_ptr: torch.Tensor, output_row_stride: int,
|
||||
probs_row_stride: int, uniform_noise_row_stride: int,
|
||||
uniform_noise_best_stride: int, n_samples: int, n_cols: int,
|
||||
n_best: int, block_size: tl.constexpr,
|
||||
modify_greedy_probs: tl.constexpr, save_logprobs: tl.constexpr,
|
||||
save_modified_probs: tl.constexpr):
|
||||
# The rows are independent, so we parallelize across those
|
||||
sample_idx = tl.program_id(0)
|
||||
best_idx = tl.program_id(1)
|
||||
|
||||
# Load the row index from DRAM
|
||||
row_idx = tl.load(sample_indices_ptr + sample_idx)
|
||||
seed = tl.load(seeds_ptr + sample_idx)
|
||||
uses_random_sampling = seed != 0
|
||||
|
||||
# The stride represents how much we need to increase the
|
||||
# pointer to advance 1 row
|
||||
row_start_ptr = probs_ptr + row_idx * probs_row_stride
|
||||
|
||||
# The block size is the next power of two greater than n_cols,
|
||||
# so we can fit each row in a single block
|
||||
col_offsets = tl.arange(0, block_size)
|
||||
|
||||
# Load the row into SRAM, using a mask since block_size may be > than n_cols
|
||||
row = tl.load(row_start_ptr + col_offsets,
|
||||
mask=col_offsets < n_cols,
|
||||
other=float("-inf"))
|
||||
|
||||
if uses_random_sampling:
|
||||
uniform_noise_start_ptr = (uniform_noise_ptr +
|
||||
sample_idx * uniform_noise_row_stride +
|
||||
best_idx * uniform_noise_best_stride)
|
||||
uniform_noise = tl.load(uniform_noise_start_ptr + col_offsets,
|
||||
mask=col_offsets < n_cols,
|
||||
other=0.5)
|
||||
exponential_noise = _uniform_to_exponential(uniform_noise)
|
||||
row /= exponential_noise
|
||||
|
||||
sampled_value, sampled_token = tl.max(row, axis=0, return_indices=True)
|
||||
# clamp sampled token to n_cols - 1
|
||||
# this should not be necessary, but we do it
|
||||
# just in case
|
||||
if sampled_token >= n_cols:
|
||||
sampled_token = n_cols - 1
|
||||
# Write back output to DRAM
|
||||
output_row_start_ptr = (output_ptr + sample_idx * output_row_stride +
|
||||
best_idx)
|
||||
tl.store(output_row_start_ptr, sampled_token)
|
||||
|
||||
if modify_greedy_probs: # noqa
|
||||
if not uses_random_sampling:
|
||||
# Set the probability of the sampled token to 1, all other
|
||||
# tokens to zero. This is used in speculative decoding where
|
||||
# the sampling method must be encoded within the sampled
|
||||
# probability distributions.
|
||||
row = tl.where(col_offsets == sampled_token, 1.0, 0.0)
|
||||
tl.store(row_start_ptr + col_offsets,
|
||||
row,
|
||||
mask=col_offsets < n_cols)
|
||||
|
||||
if save_modified_probs:
|
||||
output_row_start_ptr = (output_modified_probs_ptr +
|
||||
sample_idx * output_row_stride + best_idx)
|
||||
tl.store(output_row_start_ptr, sampled_value)
|
||||
|
||||
if save_logprobs:
|
||||
# Load the row into SRAM, using a mask since block_size
|
||||
# may be > than n_cols
|
||||
sampled_logprob = tl.load(logprobs_ptr + row_idx * probs_row_stride +
|
||||
sampled_token)
|
||||
# Write back output to DRAM
|
||||
output_row_start_ptr = (output_logprobs_ptr +
|
||||
sample_idx * output_row_stride + best_idx)
|
||||
tl.store(output_row_start_ptr, sampled_logprob)
|
|
@ -12,6 +12,7 @@ from vllm.sampling_params import SamplingParams, SamplingType
|
|||
from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs,
|
||||
SamplerOutput, SequenceData, SequenceGroupOutput,
|
||||
SequenceOutput)
|
||||
from vllm.model_executor.layers.ops.sample import (sample as sample_triton)
|
||||
from vllm.utils import is_neuron
|
||||
|
||||
|
||||
|
@ -114,7 +115,8 @@ class Sampler(nn.Module):
|
|||
logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float)
|
||||
|
||||
# Sample the next tokens.
|
||||
sample_results = _sample(probs, logprobs, sampling_metadata)
|
||||
sample_results = _sample(probs, logprobs, sampling_metadata,
|
||||
sampling_tensors)
|
||||
# Get the logprobs query results.
|
||||
prompt_logprobs, sample_logprobs = _get_logprobs(
|
||||
logprobs, sampling_metadata, sample_results)
|
||||
|
@ -375,7 +377,7 @@ def _multinomial(
|
|||
return probs.div_(q).argmax(dim=1).view(-1, num_samples)
|
||||
|
||||
|
||||
def _sample(
|
||||
def _sample_with_torch(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
|
@ -394,7 +396,7 @@ def _sample(
|
|||
# Counterintiutively, having two loops here is actually faster.
|
||||
# The first loop can run without waiting on GPU<->CPU sync.
|
||||
for sampling_type in SamplingType:
|
||||
sample_indices = categorized_sample_indices[sampling_type]
|
||||
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
||||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
|
@ -407,17 +409,19 @@ def _sample(
|
|||
greedy_samples = torch.argmax(logprobs[sample_indices.long()],
|
||||
dim=-1)
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
max_best_of = 1
|
||||
max_best_of_in_batch = 1
|
||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||
if is_prompt:
|
||||
_, sampling_params = seq_group
|
||||
max_best_of = max(max_best_of, sampling_params.best_of)
|
||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||
sampling_params.best_of)
|
||||
seeded_args = {} if sampling_type == SamplingType.RANDOM else {
|
||||
"seq_groups": seq_groups,
|
||||
"generators": sampling_metadata.generators,
|
||||
}
|
||||
multinomial_samples[sampling_type] = _multinomial(
|
||||
probs[sample_indices.long()], max_best_of, **seeded_args)
|
||||
probs[sample_indices.long()], max_best_of_in_batch,
|
||||
**seeded_args)
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
beam_search_logprobs = logprobs[sample_indices]
|
||||
else:
|
||||
|
@ -448,6 +452,99 @@ def _sample(
|
|||
return sample_results
|
||||
|
||||
|
||||
def _sample_with_triton_kernel(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
categorized_seq_group_ids = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices = sampling_metadata.categorized_sample_indices
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
_, sampling_params = seq_group
|
||||
sampling_type = sampling_params.sampling_type
|
||||
categorized_seq_group_ids[sampling_type].append(i)
|
||||
|
||||
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
|
||||
sample_metadata = {}
|
||||
max_best_of_in_batch = 1
|
||||
|
||||
# Counterintiutively, having two loops here is actually faster.
|
||||
# The first loop can run without waiting on GPU<->CPU sync.
|
||||
for sampling_type in SamplingType:
|
||||
sample_indices = categorized_sample_indices[sampling_type][:, 0]
|
||||
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
|
||||
num_tokens = len(sample_indices)
|
||||
if num_tokens == 0:
|
||||
continue
|
||||
seq_group_ids = categorized_seq_group_ids[sampling_type]
|
||||
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_ids]
|
||||
is_prompts = [i < sampling_metadata.num_prompts for i in seq_group_ids]
|
||||
sample_metadata[sampling_type] = (seq_group_ids, seq_groups,
|
||||
is_prompts, sample_indices,
|
||||
sampled_token_indices)
|
||||
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
|
||||
SamplingType.RANDOM_SEED):
|
||||
for seq_group, is_prompt in zip(seq_groups, is_prompts):
|
||||
if is_prompt:
|
||||
_, sampling_params = seq_group
|
||||
max_best_of_in_batch = max(max_best_of_in_batch,
|
||||
sampling_params.best_of)
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
beam_search_logprobs = logprobs[sample_indices]
|
||||
else:
|
||||
raise ValueError(f"Unsupported sampling type: {sampling_type}")
|
||||
|
||||
sampled_tokens, _, _ = sample_triton(
|
||||
probs=probs,
|
||||
seeds=sampling_tensors.sampling_seeds,
|
||||
max_best_of=max_best_of_in_batch,
|
||||
sample_indices=sampling_tensors.sample_indices,
|
||||
logprobs=logprobs,
|
||||
# don't save logprobs because we have logic for that below
|
||||
# TODO: use this instead of the CPU-based logic below
|
||||
save_logprobs=False,
|
||||
)
|
||||
|
||||
# GPU<->CPU sync happens in the loop below.
|
||||
|
||||
for sampling_type in SamplingType:
|
||||
if sampling_type not in sample_metadata:
|
||||
continue
|
||||
(seq_group_ids, seq_groups, is_prompts, sample_indices,
|
||||
sampled_token_indices) = sample_metadata[sampling_type]
|
||||
if sampling_type == SamplingType.GREEDY:
|
||||
sample_results = _greedy_sample(
|
||||
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
|
||||
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
|
||||
sample_results = _random_sample(
|
||||
seq_groups, is_prompts, sampled_tokens[sampled_token_indices])
|
||||
elif sampling_type == SamplingType.BEAM:
|
||||
sample_results = _beam_search_sample(seq_groups, is_prompts,
|
||||
sampling_metadata.seq_data,
|
||||
beam_search_logprobs)
|
||||
sample_results_dict.update(zip(seq_group_ids, sample_results))
|
||||
|
||||
sample_results = [
|
||||
sample_results_dict[i]
|
||||
for i in range(len(sampling_metadata.seq_groups))
|
||||
]
|
||||
return sample_results
|
||||
|
||||
|
||||
def _sample(
|
||||
probs: torch.Tensor,
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
sampling_tensors: SamplingTensors,
|
||||
) -> List[Tuple[List[int], List[int]]]:
|
||||
return _sample_with_torch(probs, logprobs, sampling_metadata)
|
||||
|
||||
# TODO: Enable once Triton kernel & associated code is faster.
|
||||
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
|
||||
# sampling_tensors)
|
||||
|
||||
|
||||
def _get_logprobs(
|
||||
logprobs: torch.Tensor,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
|
|
|
@ -2,12 +2,16 @@ from dataclasses import dataclass
|
|||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import random
|
||||
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData
|
||||
from vllm.utils import in_wsl, is_neuron
|
||||
from vllm.model_executor.layers.ops.sample import (
|
||||
get_num_triton_sampler_splits)
|
||||
|
||||
_SAMPLING_EPS = 1e-5
|
||||
_SEED_0_REPLACEMENT = 3403598558
|
||||
|
||||
|
||||
class SamplingMetadata:
|
||||
|
@ -67,14 +71,28 @@ class SamplingTensors:
|
|||
presence_penalties: torch.Tensor
|
||||
frequency_penalties: torch.Tensor
|
||||
repetition_penalties: torch.Tensor
|
||||
sampling_seeds: torch.Tensor
|
||||
sample_indices: torch.Tensor
|
||||
extra_seeds: Optional[torch.Tensor]
|
||||
prompt_tokens: torch.Tensor
|
||||
output_tokens: torch.Tensor
|
||||
|
||||
@classmethod
|
||||
def from_sampling_metadata(
|
||||
cls, sampling_metadata: "SamplingMetadata", vocab_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype) -> Tuple["SamplingTensors", bool, bool, bool]:
|
||||
cls,
|
||||
sampling_metadata: "SamplingMetadata",
|
||||
vocab_size: int,
|
||||
device: torch.device,
|
||||
dtype: torch.dtype,
|
||||
*,
|
||||
extra_seeds_to_generate: int = 0,
|
||||
extra_entropy: Optional[Tuple[int, ...]] = None
|
||||
) -> Tuple["SamplingTensors", bool, bool, bool]:
|
||||
"""
|
||||
extra_seeds_to_generate: extra seeds to generate using the
|
||||
user-defined seed for each sequence.
|
||||
extra_entropy: extra entropy to use when generating seeds.
|
||||
"""
|
||||
prompt_tokens: List[List[int]] = []
|
||||
output_tokens: List[List[int]] = []
|
||||
top_ks: List[int] = []
|
||||
|
@ -84,9 +102,18 @@ class SamplingTensors:
|
|||
presence_penalties: List[float] = []
|
||||
frequency_penalties: List[float] = []
|
||||
repetition_penalties: List[float] = []
|
||||
sampling_seeds: List[int] = []
|
||||
sample_indices: List[int] = []
|
||||
prompt_best_of: List[int] = []
|
||||
do_penalties = False
|
||||
do_top_p_top_k = False
|
||||
do_min_p = False
|
||||
|
||||
# We need one base seed per Triton slice.
|
||||
seeds_to_generate = (extra_seeds_to_generate +
|
||||
get_num_triton_sampler_splits(vocab_size))
|
||||
|
||||
sample_indices_start_idx = 0
|
||||
for i, seq_group in enumerate(sampling_metadata.seq_groups):
|
||||
seq_ids, sampling_params = seq_group
|
||||
temperature = sampling_params.temperature
|
||||
|
@ -95,6 +122,10 @@ class SamplingTensors:
|
|||
r = sampling_params.repetition_penalty
|
||||
top_p = sampling_params.top_p
|
||||
min_p = sampling_params.min_p
|
||||
seed = sampling_params.seed
|
||||
|
||||
is_greedy = sampling_params.sampling_type == SamplingType.GREEDY
|
||||
|
||||
# k should not be greater than the vocab size.
|
||||
top_k = min(sampling_params.top_k, vocab_size)
|
||||
top_k = vocab_size if top_k == -1 else top_k
|
||||
|
@ -112,6 +143,7 @@ class SamplingTensors:
|
|||
or abs(f) >= _SAMPLING_EPS
|
||||
or abs(r - 1.0) >= _SAMPLING_EPS):
|
||||
do_penalties = True
|
||||
|
||||
if (i < sampling_metadata.num_prompts
|
||||
and sampling_params.prompt_logprobs is not None):
|
||||
# For tokens in the prompt that we only need to get
|
||||
|
@ -138,10 +170,34 @@ class SamplingTensors:
|
|||
frequency_penalties += [f] * len(seq_ids)
|
||||
repetition_penalties += [r] * len(seq_ids)
|
||||
|
||||
is_prompt = i < sampling_metadata.num_prompts
|
||||
if is_prompt:
|
||||
prompt_best_of.append(sampling_params.best_of)
|
||||
prompt_len = sampling_metadata.prompt_lens[i]
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
# NOTE: the sampling position is the last token
|
||||
# in the prompt
|
||||
sample_indices_start_idx += prompt_len - 1
|
||||
for seq_id in seq_ids:
|
||||
seq_data = sampling_metadata.seq_data[seq_id]
|
||||
extra_entropy = extra_entropy or ()
|
||||
seq_seeds = cls._get_sequence_seeds(
|
||||
seed,
|
||||
seq_data.get_len(),
|
||||
*extra_entropy,
|
||||
seq_id,
|
||||
seeds_to_generate=seeds_to_generate,
|
||||
is_greedy=is_greedy)
|
||||
sampling_seeds.append(seq_seeds)
|
||||
sample_indices.append(sample_indices_start_idx)
|
||||
sample_indices_start_idx += 1
|
||||
|
||||
sampling_tensors = SamplingTensors.from_lists(
|
||||
temperatures, top_ps, top_ks, min_ps, presence_penalties,
|
||||
frequency_penalties, repetition_penalties, prompt_tokens,
|
||||
output_tokens, vocab_size, device, dtype)
|
||||
frequency_penalties, repetition_penalties, sampling_seeds,
|
||||
sample_indices, prompt_tokens, output_tokens, vocab_size,
|
||||
extra_seeds_to_generate, device, dtype)
|
||||
return (sampling_tensors, do_penalties, do_top_p_top_k, do_min_p)
|
||||
|
||||
@classmethod
|
||||
|
@ -150,9 +206,10 @@ class SamplingTensors:
|
|||
presence_penalties: List[float],
|
||||
frequency_penalties: List[float],
|
||||
repetition_penalties: List[float],
|
||||
sampling_seeds: List[int], sample_indices: List[int],
|
||||
prompt_tokens: List[List[int]],
|
||||
output_tokens: List[List[int]], vocab_size: int,
|
||||
device: torch.device,
|
||||
extra_seeds_to_generate: int, device: torch.device,
|
||||
dtype: torch.dtype) -> "SamplingTensors":
|
||||
# Note that the performance will be very bad without
|
||||
# pinned memory.
|
||||
|
@ -210,6 +267,12 @@ class SamplingTensors:
|
|||
dtype=torch.int,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
sample_indices_t = torch.tensor(
|
||||
sample_indices,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
prompt_tensor = torch.tensor(
|
||||
prompt_padded_tokens,
|
||||
device="cpu",
|
||||
|
@ -222,8 +285,28 @@ class SamplingTensors:
|
|||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
)
|
||||
# need to transpose and make contiguous to
|
||||
# copy the tensor correctly.
|
||||
# [batch_size, n_seeds] -> [n_seeds, batch_size]
|
||||
sampling_seeds_t = torch.tensor(
|
||||
sampling_seeds,
|
||||
device="cpu",
|
||||
dtype=torch.long,
|
||||
pin_memory=pin_memory,
|
||||
).T.contiguous()
|
||||
|
||||
# Because the memory is pinned, we can do non-blocking
|
||||
# transfer to device.
|
||||
|
||||
# How many seeds the sample operation itself will need.
|
||||
num_base_seeds = sampling_seeds_t.shape[0] - extra_seeds_to_generate
|
||||
sampling_seeds_gpu = sampling_seeds_t.to(device=device,
|
||||
non_blocking=True)
|
||||
extra_seeds_gpu = sampling_seeds_gpu[num_base_seeds:]
|
||||
if not extra_seeds_gpu.numel():
|
||||
extra_seeds_gpu = None
|
||||
sampling_seeds_gpu = sampling_seeds_gpu[:num_base_seeds]
|
||||
|
||||
return cls(
|
||||
temperatures=temperatures_t.to(device=device, non_blocking=True),
|
||||
top_ps=top_ps_t.to(device=device, non_blocking=True),
|
||||
|
@ -237,4 +320,38 @@ class SamplingTensors:
|
|||
non_blocking=True),
|
||||
prompt_tokens=prompt_tensor.to(device=device, non_blocking=True),
|
||||
output_tokens=output_tensor.to(device=device, non_blocking=True),
|
||||
sampling_seeds=sampling_seeds_gpu,
|
||||
sample_indices=sample_indices_t.to(device=device,
|
||||
non_blocking=True),
|
||||
extra_seeds=extra_seeds_gpu,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _get_sequence_seeds(
|
||||
seed: int,
|
||||
*extra_entropy: int,
|
||||
seeds_to_generate: int,
|
||||
is_greedy: bool,
|
||||
):
|
||||
"""Get `seeds_to_generate` child seeds from `seed` and extra entropy."""
|
||||
if not is_greedy:
|
||||
if seed is None:
|
||||
randint_fn = random.randint
|
||||
else:
|
||||
generator = random.Random(str((seed, ) + extra_entropy))
|
||||
randint_fn = generator.randint
|
||||
lo, hi = torch.iinfo(torch.long).min, torch.iinfo(torch.long).max
|
||||
# If the user/random sets seed = 0 but request should
|
||||
# have sampling, we need to change it to something
|
||||
# else. We use a constant in that case.
|
||||
# This way we don't need to create and load a bool
|
||||
# matrix in the sampling kernel, which reduces CPU
|
||||
# overhead and latency.
|
||||
seq_seeds = [
|
||||
randint_fn(lo, hi) or _SEED_0_REPLACEMENT
|
||||
for _ in range(seeds_to_generate)
|
||||
]
|
||||
else:
|
||||
# For the kernel, seed == 0 means greedy decoding.
|
||||
seq_seeds = [0] * seeds_to_generate
|
||||
return seq_seeds
|
||||
|
|
|
@ -242,6 +242,9 @@ class Sequence:
|
|||
def get_token_ids(self) -> List[int]:
|
||||
return self.data.get_token_ids()
|
||||
|
||||
def get_prompt_token_ids(self) -> List[int]:
|
||||
return self.data.get_prompt_token_ids()
|
||||
|
||||
def get_last_token_id(self) -> int:
|
||||
return self.data.get_last_token_id()
|
||||
|
||||
|
|
|
@ -408,6 +408,7 @@ class ModelRunner:
|
|||
selected_token_start_idx = 0
|
||||
categorized_sample_indices = {t: [] for t in SamplingType}
|
||||
categorized_sample_indices_start_idx = 0
|
||||
categorized_sampled_token_indices_start_idx = 0
|
||||
pin_memory = not self.in_wsl and not self.device_config.is_neuron
|
||||
|
||||
max_subquery_len = max(subquery_lens) if subquery_lens else 1
|
||||
|
@ -425,9 +426,12 @@ class ModelRunner:
|
|||
categorized_sample_indices_start_idx += subquery_len - 1
|
||||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].append(
|
||||
categorized_sample_indices_start_idx)
|
||||
sampling_params.sampling_type].append([
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx
|
||||
])
|
||||
categorized_sample_indices_start_idx += 1
|
||||
categorized_sampled_token_indices_start_idx += 1
|
||||
|
||||
if sampling_params.prompt_logprobs is not None:
|
||||
selected_token_indices.extend(
|
||||
|
@ -449,9 +453,17 @@ class ModelRunner:
|
|||
|
||||
categorized_sample_indices[
|
||||
sampling_params.sampling_type].extend(
|
||||
range(categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx + num_seqs))
|
||||
zip(
|
||||
range(
|
||||
categorized_sample_indices_start_idx,
|
||||
categorized_sample_indices_start_idx +
|
||||
num_seqs),
|
||||
range(
|
||||
categorized_sampled_token_indices_start_idx,
|
||||
categorized_sampled_token_indices_start_idx +
|
||||
num_seqs)))
|
||||
categorized_sample_indices_start_idx += num_seqs
|
||||
categorized_sampled_token_indices_start_idx += num_seqs
|
||||
|
||||
if sampling_params.seed is not None:
|
||||
generators.append(seq_group_metadata.state.generator)
|
||||
|
@ -459,12 +471,14 @@ class ModelRunner:
|
|||
selected_token_indices = _async_h2d(selected_token_indices,
|
||||
dtype=torch.long,
|
||||
target_device=self.device,
|
||||
pin_memory=pin_memory)
|
||||
pin_memory=not self.in_wsl)
|
||||
|
||||
categorized_sample_indices = {
|
||||
t: _async_h2d(seq_ids,
|
||||
dtype=torch.int,
|
||||
target_device=self.device,
|
||||
pin_memory=pin_memory)
|
||||
t: _maybe_expand_dim(
|
||||
_async_h2d(seq_ids,
|
||||
dtype=torch.int,
|
||||
target_device=self.device,
|
||||
pin_memory=pin_memory), 2, 2)
|
||||
for t, seq_ids in categorized_sample_indices.items()
|
||||
}
|
||||
|
||||
|
@ -884,3 +898,11 @@ def _async_h2d(
|
|||
) -> torch.Tensor:
|
||||
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
|
||||
return t.to(device=target_device, non_blocking=True)
|
||||
|
||||
|
||||
def _maybe_expand_dim(tensor: torch.Tensor,
|
||||
target_dims: int,
|
||||
size: int = 1) -> torch.Tensor:
|
||||
if tensor.ndim < target_dims:
|
||||
tensor = tensor.view(-1, *([size] * (target_dims - tensor.ndim)))
|
||||
return tensor
|
||||
|
|
Loading…
Reference in New Issue