mirror of https://github.com/vllm-project/vllm
[Bugfix] Allow vllm to still work if triton is not installed. (#6786)
Signed-off-by: Thomas Parnell <tpa@zurich.ibm.com>
This commit is contained in:
parent
7f8d612d24
commit
9a7e2d0534
|
@ -4,4 +4,3 @@
|
|||
# Dependencies for x86_64 CPUs
|
||||
torch == 2.4.0; platform_machine != "ppc64le"
|
||||
torchvision; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
|
||||
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
|
||||
|
|
|
@ -5,5 +5,3 @@
|
|||
torch >= 2.1.2
|
||||
openvino ~= 2024.3.0.dev
|
||||
optimum-intel[openvino] >= 1.18.1
|
||||
|
||||
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
|
||||
|
|
|
@ -5,4 +5,3 @@
|
|||
# Currently, the TPU backend uses a nightly version of PyTorch XLA.
|
||||
# You can install the dependencies in Dockerfile.tpu.
|
||||
ray
|
||||
triton # To avoid import errors
|
||||
|
|
|
@ -5,11 +5,12 @@ import torch
|
|||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.model_executor.layers.ops.sample import (
|
||||
MAX_TRITON_N_COLS, _uniform_to_exponential, get_num_triton_sampler_splits,
|
||||
sample)
|
||||
from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential,
|
||||
sample)
|
||||
from vllm.model_executor.sampling_metadata import SamplingTensors
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
|
||||
get_num_triton_sampler_splits)
|
||||
|
||||
SINGLE_SPLIT_VOCAB_SIZE = 32000 # llama/mistral/mixtral vocab size
|
||||
MULTI_SPLIT_VOCAB_SIZE = MAX_TRITON_N_COLS + 100
|
||||
|
|
|
@ -4,7 +4,10 @@ from typing import List, Optional, Tuple
|
|||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.attention.ops.prefix_prefill import context_attention_fwd
|
||||
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
_PARTITION_SIZE = 512
|
||||
|
|
|
@ -1,14 +1,22 @@
|
|||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts, fused_moe, fused_topk, get_config_file_name, grouped_topk)
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoE,
|
||||
FusedMoEMethodBase)
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
__all__ = [
|
||||
"fused_moe",
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"grouped_topk",
|
||||
"FusedMoE",
|
||||
"FusedMoEMethodBase",
|
||||
]
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import (
|
||||
fused_experts, fused_moe, fused_topk, get_config_file_name,
|
||||
grouped_topk)
|
||||
|
||||
__all__ += [
|
||||
"fused_moe",
|
||||
"fused_topk",
|
||||
"fused_experts",
|
||||
"get_config_file_name",
|
||||
"grouped_topk",
|
||||
]
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
@ -6,21 +5,10 @@ import triton
|
|||
import triton.language as tl
|
||||
|
||||
from vllm.model_executor.layers.ops.rand import seeded_uniform
|
||||
from vllm.triton_utils.sample import get_num_triton_sampler_splits
|
||||
|
||||
_EPS = 1e-6
|
||||
|
||||
# This is a hardcoded limit in Triton (max block size).
|
||||
MAX_TRITON_N_COLS = 131072
|
||||
|
||||
|
||||
def get_num_triton_sampler_splits(n_cols: int) -> int:
|
||||
"""Get the number of splits to use for Triton sampling.
|
||||
|
||||
Triton has a limit on the number of columns it can handle, so we need to
|
||||
split the tensor and call the kernel multiple times if it's too large.
|
||||
"""
|
||||
return math.ceil(n_cols / MAX_TRITON_N_COLS)
|
||||
|
||||
|
||||
def _multi_split_sample(
|
||||
probs: torch.Tensor,
|
||||
|
|
|
@ -6,8 +6,7 @@ from torch.nn.parameter import Parameter
|
|||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
|
||||
fused_moe)
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase
|
||||
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
|
||||
UnquantizedLinearMethod)
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
|
@ -404,6 +403,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|||
num_expert_group: Optional[int] = None,
|
||||
topk_group: Optional[int] = None) -> torch.Tensor:
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import fused_moe
|
||||
return fused_moe(x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
|
|
|
@ -6,7 +6,11 @@ from typing import Dict, List, Optional, Tuple
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.model_executor.layers.ops.sample import sample as sample_triton
|
||||
|
||||
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
|
||||
SamplingTensors,
|
||||
SequenceGroupToSample)
|
||||
|
|
|
@ -5,9 +5,9 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.ops.sample import get_num_triton_sampler_splits
|
||||
from vllm.sampling_params import SamplingParams, SamplingType
|
||||
from vllm.sequence import SequenceData, SequenceGroupMetadata
|
||||
from vllm.triton_utils.sample import get_num_triton_sampler_splits
|
||||
from vllm.utils import (async_tensor_h2d, is_pin_memory_available,
|
||||
make_tensor_with_pad, maybe_expand_dim)
|
||||
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
from vllm.triton_utils.custom_cache_manager import (
|
||||
maybe_set_triton_cache_manager)
|
||||
from vllm.triton_utils.importing import HAS_TRITON
|
||||
|
||||
__all__ = [
|
||||
"maybe_set_triton_cache_manager",
|
||||
]
|
||||
__all__ = ["HAS_TRITON"]
|
||||
|
||||
if HAS_TRITON:
|
||||
|
||||
from vllm.triton_utils.custom_cache_manager import (
|
||||
maybe_set_triton_cache_manager)
|
||||
|
||||
__all__ += ["maybe_set_triton_cache_manager"]
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
from importlib.util import find_spec
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
HAS_TRITON = find_spec("triton") is not None
|
||||
|
||||
if not HAS_TRITON:
|
||||
logger.info("Triton not installed; certain GPU-related functions"
|
||||
" will be not be available.")
|
|
@ -0,0 +1,13 @@
|
|||
import math
|
||||
|
||||
# This is a hardcoded limit in Triton (max block size).
|
||||
MAX_TRITON_N_COLS = 131072
|
||||
|
||||
|
||||
def get_num_triton_sampler_splits(n_cols: int) -> int:
|
||||
"""Get the number of splits to use for Triton sampling.
|
||||
|
||||
Triton has a limit on the number of columns it can handle, so we need to
|
||||
split the tensor and call the kernel multiple times if it's too large.
|
||||
"""
|
||||
return math.ceil(n_cols / MAX_TRITON_N_COLS)
|
Loading…
Reference in New Issue