mirror of https://github.com/vllm-project/vllm
[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)
The 2nd PR for #4532. This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
parent
8674f9880e
commit
a3a73ab069
|
@ -153,15 +153,13 @@ if __name__ == '__main__':
|
|||
action='store_true',
|
||||
help='enforce eager mode and disable CUDA graph')
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8'],
|
||||
default='auto',
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
|
||||
'instead supported for common inference criteria.')
|
||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||
default="auto",
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=str,
|
||||
|
|
|
@ -323,15 +323,13 @@ if __name__ == "__main__":
|
|||
action="store_true",
|
||||
help="enforce eager execution")
|
||||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=str,
|
||||
|
|
|
@ -183,13 +183,11 @@ if __name__ == '__main__':
|
|||
parser.add_argument(
|
||||
"--kv-cache-dtype",
|
||||
type=str,
|
||||
choices=["auto", "fp8"],
|
||||
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
|
||||
default="auto",
|
||||
help=
|
||||
'Data type for kv cache storage. If "auto", will use model data type. '
|
||||
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
|
||||
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
|
||||
'common inference criteria.')
|
||||
help="Data type for kv cache storage. If 'auto', will use model "
|
||||
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
|
||||
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
|
||||
|
|
|
@ -16,31 +16,55 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
|
|||
MAX_MODEL_LEN = 1024
|
||||
|
||||
MODELS = [
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct",
|
||||
]
|
||||
|
||||
EXPECTED_STRS_MAP = {
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
|
||||
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
|
||||
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
|
||||
],
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct": [
|
||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
|
||||
],
|
||||
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
|
||||
"auto": [
|
||||
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
|
||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
|
||||
],
|
||||
"fp8": [
|
||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system made up of several basic components that work together to enable it to',
|
||||
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
|
||||
]
|
||||
},
|
||||
"meta-llama/Meta-Llama-3-8B-Instruct": {
|
||||
"auto": [
|
||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
|
||||
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
|
||||
],
|
||||
"fp8": [
|
||||
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
|
||||
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
|
||||
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
|
||||
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
|
||||
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
|
||||
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
|
||||
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
|
||||
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
|
||||
]
|
||||
},
|
||||
}
|
||||
|
||||
capability = torch.cuda.get_device_capability()
|
||||
|
@ -52,14 +76,14 @@ fp8_not_supported = (capability <
|
|||
@pytest.mark.skipif(fp8_not_supported,
|
||||
reason="fp8 is not supported on this GPU type.")
|
||||
@pytest.mark.parametrize("model_name", MODELS)
|
||||
def test_models(
|
||||
example_prompts,
|
||||
model_name,
|
||||
) -> None:
|
||||
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
|
||||
def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
|
||||
model = LLM(model=model_name,
|
||||
max_model_len=MAX_MODEL_LEN,
|
||||
trust_remote_code=True,
|
||||
enforce_eager=True,
|
||||
quantization="fp8")
|
||||
quantization="fp8",
|
||||
kv_cache_dtype=kv_cache_dtype)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
formatted_prompts = [
|
||||
|
@ -81,8 +105,8 @@ def test_models(
|
|||
generations.append(outputs[0].outputs[0].text)
|
||||
del model
|
||||
|
||||
print(generations)
|
||||
expected_strs = EXPECTED_STRS_MAP[model_name]
|
||||
print(model_name, kv_cache_dtype, generations)
|
||||
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
|
||||
for i in range(len(example_prompts)):
|
||||
generated_str = generations[i]
|
||||
expected_str = expected_strs[i]
|
||||
|
|
|
@ -7,6 +7,8 @@ import torch.nn as nn
|
|||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.attention.selector import get_attn_backend
|
||||
from vllm.config import CacheConfig
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
|
@ -30,6 +32,7 @@ class Attention(nn.Module):
|
|||
alibi_slopes: Optional[List[float]] = None,
|
||||
sliding_window: Optional[int] = None,
|
||||
cache_config: Optional[CacheConfig] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
if cache_config is not None:
|
||||
|
@ -40,6 +43,27 @@ class Attention(nn.Module):
|
|||
block_size = 16
|
||||
if num_kv_heads is None:
|
||||
num_kv_heads = num_heads
|
||||
|
||||
# The default kv_scale is set to 1.0. This is ignored
|
||||
# when kv-cache is not fp8, and should be used with
|
||||
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
|
||||
# expect the pre-quantized kv_scale to be loaded along
|
||||
# with the model weights.
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self._kv_scale = 1.0
|
||||
quant_method = quant_config.get_quant_method(
|
||||
self) if quant_config else None
|
||||
if quant_method is not None:
|
||||
if self.kv_cache_dtype == "fp8_e5m2":
|
||||
raise ValueError("fp8_e5m2 kv-cache is not supported with "
|
||||
"fp8 checkpoints.")
|
||||
# When FP8 quantization is enabled, we make a parameter
|
||||
# "kv_scale" so that it can be loaded from FP8 checkpoint.
|
||||
# The kv_scale will then be converted back
|
||||
# to self._kv_scale in a native float32 value after weight loading.
|
||||
self.quant_method = quant_method
|
||||
self.quant_method.create_weights(self)
|
||||
|
||||
# During model initialization, the default dtype is set as the model
|
||||
# weight and activation dtype.
|
||||
dtype = torch.get_default_dtype()
|
||||
|
@ -57,10 +81,9 @@ class Attention(nn.Module):
|
|||
value: torch.Tensor,
|
||||
kv_cache: Optional[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
kv_scale: float = 1.0,
|
||||
) -> torch.Tensor:
|
||||
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
|
||||
kv_scale)
|
||||
self._kv_scale)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
s = f"head_size={self.impl.head_size}" # type: ignore
|
||||
|
|
|
@ -355,14 +355,12 @@ class CacheConfig:
|
|||
def _verify_cache_dtype(self) -> None:
|
||||
if self.cache_dtype == "auto":
|
||||
pass
|
||||
elif self.cache_dtype == "fp8":
|
||||
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
|
||||
logger.info(
|
||||
"Using fp8 data type to store kv cache. It reduces the GPU "
|
||||
"memory footprint and boosts the performance. "
|
||||
"But it may cause slight accuracy drop without scaling "
|
||||
"factors. FP8_E5M2 (without scaling) is only supported on "
|
||||
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
|
||||
"is instead supported for common inference criteria.")
|
||||
"Meanwhile, it may cause accuracy drop without a proper "
|
||||
"scaling factor")
|
||||
else:
|
||||
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")
|
||||
|
||||
|
|
|
@ -191,12 +191,11 @@ class EngineArgs:
|
|||
parser.add_argument(
|
||||
'--kv-cache-dtype',
|
||||
type=str,
|
||||
choices=['auto', 'fp8'],
|
||||
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
|
||||
default=EngineArgs.kv_cache_dtype,
|
||||
help='Data type for kv cache storage. If "auto", will use model '
|
||||
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
|
||||
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
|
||||
'supported for common inference criteria.')
|
||||
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
|
||||
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
|
||||
parser.add_argument(
|
||||
'--quantization-param-path',
|
||||
type=nullable_str,
|
||||
|
|
|
@ -8,8 +8,9 @@ from vllm import _custom_ops as ops
|
|||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
|
||||
from vllm.model_executor.layers.quantization.base_config import (
|
||||
QuantizationConfig)
|
||||
QuantizationConfig, QuantizeMethodBase)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.utils import print_warning_once
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
|
@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig):
|
|||
activation_scheme=activation_scheme)
|
||||
|
||||
def get_quant_method(
|
||||
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
|
||||
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
|
||||
if isinstance(layer, LinearBase):
|
||||
return Fp8LinearMethod(self)
|
||||
if isinstance(layer, Attention):
|
||||
return Fp8KVCacheMethod(self)
|
||||
return None
|
||||
|
||||
def get_scaled_act_names(self) -> List[str]:
|
||||
|
@ -251,6 +256,44 @@ class Fp8LinearMethod(LinearMethodBase):
|
|||
return torch.narrow(output, 0, 0, x.shape[0])
|
||||
|
||||
|
||||
class Fp8KVCacheMethod(QuantizeMethodBase):
|
||||
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
|
||||
"""
|
||||
|
||||
def __init__(self, quant_config: Fp8Config):
|
||||
self.quant_config = quant_config
|
||||
|
||||
def create_weights(self, layer: torch.nn.Module):
|
||||
"""Create "weight" (aka kv_scale) for an attention layer.
|
||||
|
||||
Args:
|
||||
layer: The layer that is using the QuantizeMethodBase factory.
|
||||
"""
|
||||
# Initialize the KV cache scale to 1.0 as the default value.
|
||||
# If the kv_scale appears in the checkpoint, it will be
|
||||
# overwritten when loading weights.
|
||||
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
|
||||
|
||||
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
||||
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
|
||||
|
||||
def process_weights_after_loading(self, layer: Module) -> None:
|
||||
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
|
||||
# regardless whether the kv-scale is available in the checkpoint.
|
||||
if layer.kv_cache_dtype != "auto":
|
||||
kv_scale = layer.kv_scale.to("cpu").tolist()
|
||||
if not isinstance(kv_scale, float):
|
||||
raise ValueError("Only support per-tensor scaling factor "
|
||||
"for fp8 KV cache")
|
||||
layer._kv_scale = kv_scale
|
||||
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
|
||||
print_warning_once(
|
||||
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
|
||||
"cause accuracy issues. Please make sure kv-cache scaling "
|
||||
"factor is available in the fp8 checkpoint.")
|
||||
del layer.kv_scale
|
||||
|
||||
|
||||
def all_close_1d(x: torch.Tensor) -> bool:
|
||||
assert len(x.shape) == 1
|
||||
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))
|
||||
|
|
|
@ -268,7 +268,8 @@ class ArcticAttention(nn.Module):
|
|||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -154,7 +154,8 @@ class BaiChuanAttention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes)
|
||||
alibi_slopes=alibi_slopes,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
|
@ -166,7 +167,8 @@ class BaiChuanAttention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -111,7 +111,8 @@ class BloomAttention(nn.Module):
|
|||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -86,13 +86,12 @@ class GLMAttention(nn.Module):
|
|||
base=10000 * rope_ratio,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -177,13 +177,12 @@ class CohereAttention(nn.Module):
|
|||
rope_scaling=self.rope_scaling,
|
||||
is_neox_style=False,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
if self.use_qk_norm:
|
||||
self.q_norm = LayerNorm(param_shape=(self.num_heads,
|
||||
self.head_dim),
|
||||
|
|
|
@ -218,13 +218,12 @@ class DbrxAttention(nn.Module):
|
|||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -232,7 +232,8 @@ class DeepseekAttention(nn.Module):
|
|||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -153,7 +153,8 @@ class FalconAttention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
quant_config=quant_config)
|
||||
elif self.use_alibi:
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
|
@ -165,13 +166,15 @@ class FalconAttention(nn.Module):
|
|||
self.head_dim,
|
||||
self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
alibi_slopes=alibi_slopes)
|
||||
alibi_slopes=alibi_slopes,
|
||||
quant_config=quant_config)
|
||||
else:
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.inv_norm_factor,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -157,7 +157,8 @@ class GemmaAttention(nn.Module):
|
|||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -75,7 +75,8 @@ class GPT2Attention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -88,7 +88,8 @@ class GPTBigCodeAttention(nn.Module):
|
|||
self.head_dim,
|
||||
scale=self.scale,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -88,7 +88,8 @@ class GPTJAttention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -89,7 +89,8 @@ class GPTNeoXAttention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -117,7 +117,8 @@ class InternLM2Attention(nn.Module):
|
|||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -105,13 +105,12 @@ class JAISAttention(nn.Module):
|
|||
head_end = (tp_rank + 1) * self.num_heads
|
||||
alibi_slopes = _get_alibi_slopes(total_num_heads)
|
||||
alibi_slopes = alibi_slopes[head_start:head_end]
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scale,
|
||||
alibi_slopes=alibi_slopes,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
|
|||
default_weight_loader, kv_cache_scales_loader)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplerOutput
|
||||
from vllm.utils import is_hip
|
||||
from vllm.utils import is_hip, print_warning_once
|
||||
|
||||
|
||||
class LlamaMLP(nn.Module):
|
||||
|
@ -119,15 +119,6 @@ class LlamaAttention(nn.Module):
|
|||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
|
||||
# This will be overwritten by model initialization if we are using it.
|
||||
# N.B. currently we only support per tensor scalar scaling factors
|
||||
# & only applicable to ROCm (AMD GPU).
|
||||
# The scaling factor convention we are assuming is
|
||||
# quantized_value * scaling_factor ~= true_value
|
||||
# which is consistent with the practice of setting
|
||||
# scaling_factor = tensor_amax / FPtype_max
|
||||
self.kv_scale = 1.0
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
|
@ -155,7 +146,8 @@ class LlamaAttention(nn.Module):
|
|||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=sliding_window,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -167,8 +159,7 @@ class LlamaAttention(nn.Module):
|
|||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
|
||||
self.kv_scale)
|
||||
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
@ -421,6 +412,19 @@ class LlamaForCausalLM(nn.Module):
|
|||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
print_warning_once(
|
||||
f"Found kv scale in the checkpoint (e.g. {name}), "
|
||||
"but not found the expected name in the model "
|
||||
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
|
||||
"not loaded.")
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
@ -445,7 +449,7 @@ class LlamaForCausalLM(nn.Module):
|
|||
# scaling_factor = tensor_amax / FPtype_max
|
||||
scaling_factor *= 2
|
||||
if hasattr(layer_self_attn, "kv_scale"):
|
||||
layer_self_attn.kv_scale = scaling_factor
|
||||
layer_self_attn.attn._kv_scale = scaling_factor
|
||||
else:
|
||||
raise RuntimeError("Self attention has no KV cache scaling "
|
||||
"factor attribute!")
|
||||
|
|
|
@ -236,7 +236,8 @@ class MiniCPMAttention(nn.Module):
|
|||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -308,14 +308,13 @@ class MixtralAttention(nn.Module):
|
|||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -581,6 +580,20 @@ class MixtralForCausalLM(nn.Module):
|
|||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
if name.endswith("kv_scale"):
|
||||
remapped_kv_scale_name = name.replace(
|
||||
".kv_scale", ".attn.kv_scale")
|
||||
if remapped_kv_scale_name not in params_dict:
|
||||
print_warning_once(
|
||||
"Found kv scale in the checkpoint "
|
||||
f"(e.g. {name}), but not found the expected "
|
||||
f"name in the model "
|
||||
f"(e.g. {remapped_kv_scale_name}). "
|
||||
"kv-scale is not loaded.")
|
||||
continue
|
||||
else:
|
||||
name = remapped_kv_scale_name
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
|
|
|
@ -213,14 +213,13 @@ class MixtralAttention(nn.Module):
|
|||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -110,7 +110,8 @@ class MPTAttention(nn.Module):
|
|||
scaling,
|
||||
alibi_slopes=alibi_slopes,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -96,7 +96,8 @@ class OlmoAttention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scaling,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
# Attention output projection.
|
||||
self.o_proj = RowParallelLinear(
|
||||
|
|
|
@ -91,7 +91,8 @@ class OPTAttention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
scale=self.scaling,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -121,7 +121,8 @@ class OrionAttention(nn.Module):
|
|||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -110,7 +110,8 @@ class PhiAttention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_size,
|
||||
scaling,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -106,7 +106,8 @@ class QWenAttention(nn.Module):
|
|||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -141,7 +141,8 @@ class Qwen2Attention(nn.Module):
|
|||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -241,7 +241,8 @@ class Qwen2MoeAttention(nn.Module):
|
|||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -127,7 +127,8 @@ class StablelmAttention(nn.Module):
|
|||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_key_value_heads,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -97,14 +97,13 @@ class Starcoder2Attention(nn.Module):
|
|||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.attn = Attention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
)
|
||||
self.attn = Attention(self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -135,7 +135,8 @@ class XverseAttention(nn.Module):
|
|||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=sliding_window,
|
||||
cache_config=cache_config)
|
||||
cache_config=cache_config,
|
||||
quant_config=quant_config)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
|
|
@ -31,6 +31,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
|
|||
"bfloat16": torch.bfloat16,
|
||||
"float": torch.float,
|
||||
"fp8": torch.uint8,
|
||||
"fp8_e4m3": torch.uint8,
|
||||
"fp8_e5m2": torch.uint8,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
import time
|
||||
import warnings
|
||||
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
@ -168,11 +169,21 @@ class ModelRunner:
|
|||
self.model = self.lora_manager.create_lora_manager(self.model)
|
||||
|
||||
if self.kv_cache_dtype == "fp8" and is_hip():
|
||||
# Currently scaled KV cache is only enabled on ROCm
|
||||
# Currently only ROCm accepts kv-cache scaling factors
|
||||
# via quantization_param_path and this will be deprecated
|
||||
# in the future.
|
||||
if self.model_config.quantization_param_path is not None:
|
||||
if callable(getattr(self.model, "load_kv_cache_scales", None)):
|
||||
warnings.warn(
|
||||
"Loading kv cache scaling factor from JSON is "
|
||||
"deprecated and will be removed. Please include "
|
||||
"kv cache scaling factors in the model checkpoint.",
|
||||
FutureWarning,
|
||||
stacklevel=2)
|
||||
self.model.load_kv_cache_scales(
|
||||
self.model_config.quantization_param_path)
|
||||
logger.info("Loaded KV cache scaling factors from %s",
|
||||
self.model_config.quantization_param_path)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Using FP8 KV cache and scaling factors provided but "
|
||||
|
@ -183,10 +194,6 @@ class ModelRunner:
|
|||
"Using FP8 KV cache but no scaling factors "
|
||||
"provided. Defaulting to scaling factors of 1.0. "
|
||||
"This may lead to less accurate results!")
|
||||
elif self.model_config.quantization_param_path is not None:
|
||||
logger.warning("KV cache scaling factors provided, "
|
||||
"but the KV cache data type is not FP8. "
|
||||
"KV cache scaling factors will not be used.")
|
||||
|
||||
def save_sharded_state(
|
||||
self,
|
||||
|
|
Loading…
Reference in New Issue