[Misc] Load FP8 kv-cache scaling factors from checkpoints (#4893)

The 2nd PR for #4532.

This PR supports loading FP8 kv-cache scaling factors from a FP8 checkpoint (with .kv_scale parameter).
This commit is contained in:
Cody Yu 2024-05-22 13:28:20 -07:00 committed by GitHub
parent 8674f9880e
commit a3a73ab069
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
40 changed files with 284 additions and 158 deletions

View File

@ -153,15 +153,13 @@ if __name__ == '__main__':
action='store_true',
help='enforce eager mode and disable CUDA graph')
parser.add_argument(
"--kv-cache-dtype",
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
default='auto',
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.')
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,

View File

@ -323,15 +323,13 @@ if __name__ == "__main__":
action="store_true",
help="enforce eager execution")
parser.add_argument(
"--kv-cache-dtype",
'--kv-cache-dtype',
type=str,
choices=["auto", "fp8"],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
help='Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=str,

View File

@ -183,13 +183,11 @@ if __name__ == '__main__':
parser.add_argument(
"--kv-cache-dtype",
type=str,
choices=["auto", "fp8"],
choices=["auto", "fp8", "fp8_e5m2", "fp8_e4m3"],
default="auto",
help=
'Data type for kv cache storage. If "auto", will use model data type. '
'FP8_E5M2 (without scaling) is only supported on cuda version greater '
'than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead supported for '
'common inference criteria.')
help="Data type for kv cache storage. If 'auto', will use model "
"data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. "
"ROCm (AMD GPU) supports fp8 (=fp8_e4m3)")
args = parser.parse_args()
print(args)

View File

@ -16,31 +16,55 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
MAX_MODEL_LEN = 1024
MODELS = [
"nm-testing/Meta-Llama-3-8B-Instruct-FP8",
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV",
"meta-llama/Meta-Llama-3-8B-Instruct",
]
EXPECTED_STRS_MAP = {
"nm-testing/Meta-Llama-3-8B-Instruct-FP8": [
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred to a',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, guri o',
],
"meta-llama/Meta-Llama-3-8B-Instruct": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
],
"nm-testing/Meta-Llama-3-8B-Instruct-FP8-KV": {
"auto": [
'LLaMA is a high-throughput and memory-efficient inference and serving engine for Large Language Models (',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) process information in distinct ways, with both',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'Zeta-5, a highly advanced robot designed for menial labor, whirred and beep',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki no tori, nemuri no'
],
"fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system made up of several basic components that work together to enable it to',
'Zeta-5, a highly advanced robot designed for menial labor, had never experienced anything like',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. Here',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya kotori wa mushi o tsuk'
]
},
"meta-llama/Meta-Llama-3-8B-Instruct": {
"auto": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, composed of interconnected nodes or "ne',
'In the vast, sterile laboratory, Robot 3456-Alpha, or "Alpha" for short',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya aki wa mushi o tsukamu'
],
"fp8": [
'LLM (Large Language Model) is a type of artificial intelligence (AI) model that is trained',
'Here are the major milestones in the development of artificial intelligence (AI) from 1950 to ',
'Artificial intelligence (AI) and human intelligence (HI) differ significantly in how they process information.',
'A neural network is a complex system modeled after the human brain, consisting of interconnected nodes or "ne',
'In the year 2154, robotics engineer Dr. Rachel Kim had spent years perfecting her latest',
'The COVID-19 pandemic has had a profound impact on global economic structures and future business models. The',
'The Mona Lisa, painted by Leonardo da Vinci in the early 16th century, is one of',
'Here are the translations:\n\n**Japanese:** (Haya tori, mushi o tsukamu'
]
},
}
capability = torch.cuda.get_device_capability()
@ -52,14 +76,14 @@ fp8_not_supported = (capability <
@pytest.mark.skipif(fp8_not_supported,
reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize("model_name", MODELS)
def test_models(
example_prompts,
model_name,
) -> None:
@pytest.mark.parametrize("kv_cache_dtype", ["auto", "fp8"])
def test_models(example_prompts, model_name, kv_cache_dtype) -> None:
model = LLM(model=model_name,
max_model_len=MAX_MODEL_LEN,
trust_remote_code=True,
enforce_eager=True,
quantization="fp8")
quantization="fp8",
kv_cache_dtype=kv_cache_dtype)
tokenizer = AutoTokenizer.from_pretrained(model_name)
formatted_prompts = [
@ -81,8 +105,8 @@ def test_models(
generations.append(outputs[0].outputs[0].text)
del model
print(generations)
expected_strs = EXPECTED_STRS_MAP[model_name]
print(model_name, kv_cache_dtype, generations)
expected_strs = EXPECTED_STRS_MAP[model_name][kv_cache_dtype]
for i in range(len(example_prompts)):
generated_str = generations[i]
expected_str = expected_strs[i]

View File

@ -7,6 +7,8 @@ import torch.nn as nn
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.selector import get_attn_backend
from vllm.config import CacheConfig
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class Attention(nn.Module):
@ -30,6 +32,7 @@ class Attention(nn.Module):
alibi_slopes: Optional[List[float]] = None,
sliding_window: Optional[int] = None,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
if cache_config is not None:
@ -40,6 +43,27 @@ class Attention(nn.Module):
block_size = 16
if num_kv_heads is None:
num_kv_heads = num_heads
# The default kv_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self._kv_scale = 1.0
quant_method = quant_config.get_quant_method(
self) if quant_config else None
if quant_method is not None:
if self.kv_cache_dtype == "fp8_e5m2":
raise ValueError("fp8_e5m2 kv-cache is not supported with "
"fp8 checkpoints.")
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back
# to self._kv_scale in a native float32 value after weight loading.
self.quant_method = quant_method
self.quant_method.create_weights(self)
# During model initialization, the default dtype is set as the model
# weight and activation dtype.
dtype = torch.get_default_dtype()
@ -57,10 +81,9 @@ class Attention(nn.Module):
value: torch.Tensor,
kv_cache: Optional[torch.Tensor],
attn_metadata: AttentionMetadata,
kv_scale: float = 1.0,
) -> torch.Tensor:
return self.impl.forward(query, key, value, kv_cache, attn_metadata,
kv_scale)
self._kv_scale)
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore

View File

@ -355,14 +355,12 @@ class CacheConfig:
def _verify_cache_dtype(self) -> None:
if self.cache_dtype == "auto":
pass
elif self.cache_dtype == "fp8":
elif self.cache_dtype in ("fp8", "fp8_e4m3", "fp8_e5m2"):
logger.info(
"Using fp8 data type to store kv cache. It reduces the GPU "
"memory footprint and boosts the performance. "
"But it may cause slight accuracy drop without scaling "
"factors. FP8_E5M2 (without scaling) is only supported on "
"cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 "
"is instead supported for common inference criteria.")
"Meanwhile, it may cause accuracy drop without a proper "
"scaling factor")
else:
raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}")

View File

@ -191,12 +191,11 @@ class EngineArgs:
parser.add_argument(
'--kv-cache-dtype',
type=str,
choices=['auto', 'fp8'],
choices=['auto', 'fp8', 'fp8_e5m2', 'fp8_e4m3'],
default=EngineArgs.kv_cache_dtype,
help='Data type for kv cache storage. If "auto", will use model '
'data type. FP8_E5M2 (without scaling) is only supported on cuda '
'version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is instead '
'supported for common inference criteria.')
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)')
parser.add_argument(
'--quantization-param-path',
type=nullable_str,

View File

@ -8,8 +8,9 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
from vllm.utils import print_warning_once
ACTIVATION_SCHEMES = ["static", "dynamic"]
@ -58,9 +59,13 @@ class Fp8Config(QuantizationConfig):
activation_scheme=activation_scheme)
def get_quant_method(
self, layer: torch.nn.Module) -> Optional["Fp8LinearMethod"]:
self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
if isinstance(layer, LinearBase):
return Fp8LinearMethod(self)
if isinstance(layer, Attention):
return Fp8KVCacheMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
@ -251,6 +256,44 @@ class Fp8LinearMethod(LinearMethodBase):
return torch.narrow(output, 0, 0, x.shape[0])
class Fp8KVCacheMethod(QuantizeMethodBase):
"""Supports loading kv-cache scaling factors from FP8 checkpoints.
"""
def __init__(self, quant_config: Fp8Config):
self.quant_config = quant_config
def create_weights(self, layer: torch.nn.Module):
"""Create "weight" (aka kv_scale) for an attention layer.
Args:
layer: The layer that is using the QuantizeMethodBase factory.
"""
# Initialize the KV cache scale to 1.0 as the default value.
# If the kv_scale appears in the checkpoint, it will be
# overwritten when loading weights.
layer.kv_scale = Parameter(torch.tensor(1.0), requires_grad=False)
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
raise RuntimeError("Fp8KVCacheMethod.apply should not be called.")
def process_weights_after_loading(self, layer: Module) -> None:
# If the kv-cache dtype is auto, we enforce the kv-scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
if layer.kv_cache_dtype != "auto":
kv_scale = layer.kv_scale.to("cpu").tolist()
if not isinstance(kv_scale, float):
raise ValueError("Only support per-tensor scaling factor "
"for fp8 KV cache")
layer._kv_scale = kv_scale
if layer._kv_scale == 1.0 and "e5m2" not in layer.kv_cache_dtype:
print_warning_once(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"cause accuracy issues. Please make sure kv-cache scaling "
"factor is available in the fp8 checkpoint.")
del layer.kv_scale
def all_close_1d(x: torch.Tensor) -> bool:
assert len(x.shape) == 1
return all(torch.allclose(x[0], x[i]) for i in range(x.shape[0]))

View File

@ -268,7 +268,8 @@ class ArcticAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -154,7 +154,8 @@ class BaiChuanAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scaling,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
quant_config=quant_config)
else:
self.rotary_emb = get_rope(
self.head_dim,
@ -166,7 +167,8 @@ class BaiChuanAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -111,7 +111,8 @@ class BloomAttention(nn.Module):
self.head_dim,
scaling,
alibi_slopes=alibi_slopes,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -86,13 +86,12 @@ class GLMAttention(nn.Module):
base=10000 * rope_ratio,
is_neox_style=False,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -177,13 +177,12 @@ class CohereAttention(nn.Module):
rope_scaling=self.rope_scaling,
is_neox_style=False,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads,
self.head_dim),

View File

@ -218,13 +218,12 @@ class DbrxAttention(nn.Module):
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -232,7 +232,8 @@ class DeepseekAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -153,7 +153,8 @@ class FalconAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads)
num_kv_heads=self.num_kv_heads,
quant_config=quant_config)
elif self.use_alibi:
tp_rank = get_tensor_model_parallel_rank()
head_start = tp_rank * self.num_heads
@ -165,13 +166,15 @@ class FalconAttention(nn.Module):
self.head_dim,
self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
alibi_slopes=alibi_slopes)
alibi_slopes=alibi_slopes,
quant_config=quant_config)
else:
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.inv_norm_factor,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -157,7 +157,8 @@ class GemmaAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -75,7 +75,8 @@ class GPT2Attention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -88,7 +88,8 @@ class GPTBigCodeAttention(nn.Module):
self.head_dim,
scale=self.scale,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -88,7 +88,8 @@ class GPTJAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -89,7 +89,8 @@ class GPTNeoXAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -117,7 +117,8 @@ class InternLM2Attention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -105,13 +105,12 @@ class JAISAttention(nn.Module):
head_end = (tp_rank + 1) * self.num_heads
alibi_slopes = _get_alibi_slopes(total_num_heads)
alibi_slopes = alibi_slopes[head_start:head_end]
self.attn = Attention(
self.num_heads,
self.head_dim,
scale=self.scale,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scale,
alibi_slopes=alibi_slopes,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -47,7 +47,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader, kv_cache_scales_loader)
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import SamplerOutput
from vllm.utils import is_hip
from vllm.utils import is_hip, print_warning_once
class LlamaMLP(nn.Module):
@ -119,15 +119,6 @@ class LlamaAttention(nn.Module):
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
# This will be overwritten by model initialization if we are using it.
# N.B. currently we only support per tensor scalar scaling factors
# & only applicable to ROCm (AMD GPU).
# The scaling factor convention we are assuming is
# quantized_value * scaling_factor ~= true_value
# which is consistent with the practice of setting
# scaling_factor = tensor_amax / FPtype_max
self.kv_scale = 1.0
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
@ -155,7 +146,8 @@ class LlamaAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
@ -167,8 +159,7 @@ class LlamaAttention(nn.Module):
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata,
self.kv_scale)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata)
output, _ = self.o_proj(attn_output)
return output
@ -421,6 +412,19 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
f"Found kv scale in the checkpoint (e.g. {name}), "
"but not found the expected name in the model "
f"(e.g. {remapped_kv_scale_name}). kv-scale is "
"not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
@ -445,7 +449,7 @@ class LlamaForCausalLM(nn.Module):
# scaling_factor = tensor_amax / FPtype_max
scaling_factor *= 2
if hasattr(layer_self_attn, "kv_scale"):
layer_self_attn.kv_scale = scaling_factor
layer_self_attn.attn._kv_scale = scaling_factor
else:
raise RuntimeError("Self attention has no KV cache scaling "
"factor attribute!")

View File

@ -236,7 +236,8 @@ class MiniCPMAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -308,14 +308,13 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,
@ -581,6 +580,20 @@ class MixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# Remapping the name of FP8 kv-scale.
if name.endswith("kv_scale"):
remapped_kv_scale_name = name.replace(
".kv_scale", ".attn.kv_scale")
if remapped_kv_scale_name not in params_dict:
print_warning_once(
"Found kv scale in the checkpoint "
f"(e.g. {name}), but not found the expected "
f"name in the model "
f"(e.g. {remapped_kv_scale_name}). "
"kv-scale is not loaded.")
continue
else:
name = remapped_kv_scale_name
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)

View File

@ -213,14 +213,13 @@ class MixtralAttention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -110,7 +110,8 @@ class MPTAttention(nn.Module):
scaling,
alibi_slopes=alibi_slopes,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -96,7 +96,8 @@ class OlmoAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
# Attention output projection.
self.o_proj = RowParallelLinear(

View File

@ -91,7 +91,8 @@ class OPTAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
scale=self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -121,7 +121,8 @@ class OrionAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -110,7 +110,8 @@ class PhiAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_size,
scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -106,7 +106,8 @@ class QWenAttention(nn.Module):
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -141,7 +141,8 @@ class Qwen2Attention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -241,7 +241,8 @@ class Qwen2MoeAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -127,7 +127,8 @@ class StablelmAttention(nn.Module):
self.head_dim,
self.scaling,
num_kv_heads=self.num_key_value_heads,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -97,14 +97,13 @@ class Starcoder2Attention(nn.Module):
base=int(self.rope_theta),
is_neox_style=True,
)
self.attn = Attention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
)
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=self.sliding_window,
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -135,7 +135,8 @@ class XverseAttention(nn.Module):
self.scaling,
num_kv_heads=self.num_kv_heads,
sliding_window=sliding_window,
cache_config=cache_config)
cache_config=cache_config,
quant_config=quant_config)
def forward(
self,

View File

@ -31,6 +31,8 @@ STR_DTYPE_TO_TORCH_DTYPE = {
"bfloat16": torch.bfloat16,
"float": torch.float,
"fp8": torch.uint8,
"fp8_e4m3": torch.uint8,
"fp8_e5m2": torch.uint8,
}

View File

@ -1,4 +1,5 @@
import time
import warnings
from typing import Dict, List, NamedTuple, Optional, Set, Tuple, Union
import numpy as np
@ -168,11 +169,21 @@ class ModelRunner:
self.model = self.lora_manager.create_lora_manager(self.model)
if self.kv_cache_dtype == "fp8" and is_hip():
# Currently scaled KV cache is only enabled on ROCm
# Currently only ROCm accepts kv-cache scaling factors
# via quantization_param_path and this will be deprecated
# in the future.
if self.model_config.quantization_param_path is not None:
if callable(getattr(self.model, "load_kv_cache_scales", None)):
warnings.warn(
"Loading kv cache scaling factor from JSON is "
"deprecated and will be removed. Please include "
"kv cache scaling factors in the model checkpoint.",
FutureWarning,
stacklevel=2)
self.model.load_kv_cache_scales(
self.model_config.quantization_param_path)
logger.info("Loaded KV cache scaling factors from %s",
self.model_config.quantization_param_path)
else:
raise RuntimeError(
"Using FP8 KV cache and scaling factors provided but "
@ -183,10 +194,6 @@ class ModelRunner:
"Using FP8 KV cache but no scaling factors "
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
elif self.model_config.quantization_param_path is not None:
logger.warning("KV cache scaling factors provided, "
"but the KV cache data type is not FP8. "
"KV cache scaling factors will not be used.")
def save_sharded_state(
self,