[Bugfix] Allow vllm to still work if triton is not installed. (#6786)

Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
Thomas Parnell 2024-07-29 23:51:27 +02:00 committed by GitHub
parent 7f8d612d24
commit 9a7e2d0534
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 65 additions and 37 deletions

View File

@ -4,4 +4,3 @@
# Dependencies for x86_64 CPUs
torch == 2.4.0; platform_machine != "ppc64le"
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

View File

@ -5,5 +5,3 @@
torch >= 2.1.2
openvino ~= 2024.3.0.dev
optimum-intel[openvino] >= 1.18.1
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.

View File

@ -5,4 +5,3 @@
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
# You can install the dependencies in Dockerfile.tpu.
ray
triton # To avoid import errors

View File

@ -5,11 +5,12 @@ import torch
import triton
import triton.language as tl
from vllm.model_executor.layers.ops.sample import (
MAX_TRITON_N_COLS, _uniform_to_exponential, get_num_triton_sampler_splits,
sample)
from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential,
sample)
from vllm.model_executor.sampling_metadata import SamplingTensors
from vllm.model_executor.utils import set_random_seed
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
get_num_triton_sampler_splits)
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100

View File

@ -4,7 +4,10 @@ from typing import List, Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.attention.ops.prefix_prefill import context_attention_fwd
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512

View File

@ -1,14 +1,22 @@
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
FusedMoEMethodBase)
from vllm.triton_utils import HAS_TRITON
__all__ = [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
"FusedMoE",
"FusedMoEMethodBase",
]
if HAS_TRITON:
from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk)
__all__ += [
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
]

View File

@ -1,4 +1,3 @@
import math
from typing import Optional, Tuple
import torch
@ -6,21 +5,10 @@ import triton
import triton.language as tl
from vllm.model_executor.layers.ops.rand import seeded_uniform
from vllm.triton_utils.sample import get_num_triton_sampler_splits
_EPS = 1e-6
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072
def get_num_triton_sampler_splits(n_cols: int) -> int:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return math.ceil(n_cols / MAX_TRITON_N_COLS)
def _multi_split_sample(
probs: torch.Tensor,

View File

@ -6,8 +6,7 @@ from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
fused_moe)
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import (
@ -404,6 +403,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
num_expert_group: Optional[int] = None,
topk_group: Optional[int] = None) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_moe
return fused_moe(x,
layer.w13_weight,
layer.w2_weight,

View File

@ -6,7 +6,11 @@ from typing import Dict, List, Optional, Tuple
import torch
import torch.nn as nn
from vllm.model_executor.layers.ops.sample import sample as sample_triton
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)

View File

@ -5,9 +5,9 @@ from typing import Dict, List, Optional, Tuple
import torch
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.sequence import SequenceData, SequenceGroupMetadata
from vllm.triton_utils.sample import get_num_triton_sampler_splits
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
make_tensor_with_pad, maybe_expand_dim)

View File

@ -1,6 +1,10 @@
from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)
from vllm.triton_utils.importing import HAS_TRITON
__all__ = [
"maybe_set_triton_cache_manager",
]
__all__ = ["HAS_TRITON"]
if HAS_TRITON:
from vllm.triton_utils.custom_cache_manager import (
maybe_set_triton_cache_manager)
__all__ += ["maybe_set_triton_cache_manager"]

View File

@ -0,0 +1,11 @@
from importlib.util import find_spec
from vllm.logger import init_logger
logger = init_logger(__name__)
HAS_TRITON = find_spec("triton") is not None
if not HAS_TRITON:
logger.info("Triton not installed; certain GPU-related functions"
" will be not be available.")

View File

@ -0,0 +1,13 @@
import math
# This is a hardcoded limit in Triton (max block size).
MAX_TRITON_N_COLS = 131072
def get_num_triton_sampler_splits(n_cols: int) -> int:
"""Get the number of splits to use for Triton sampling.
Triton has a limit on the number of columns it can handle, so we need to
split the tensor and call the kernel multiple times if it's too large.
"""
return math.ceil(n_cols / MAX_TRITON_N_COLS)