[BUG] fixed fp8 conflict with aqlm (#4307)

Fixes fp8 iterface which broke in AQLM merge.
This commit is contained in:
Robert Shaw 2024-04-23 21:26:33 -04:00 committed by GitHub
parent eace8bf0b9
commit 79a268c4ab
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 18 additions and 4 deletions

View File

@ -96,6 +96,9 @@ steps:
- label: Metrics Test
command: pytest -v -s metrics
- label: Quantization Test
command: pytest -v -s quantization
- label: Benchmarks
working_dir: "/vllm-workspace/.buildkite"
commands:

View File

@ -34,9 +34,19 @@ class LinearMethodBase(ABC):
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for a linear layer.
The weights will be set as attributes of the layer."""
"""Create weights for a linear layer.
The weights will be set as attributes of the layer.
Args:
layer: The layer that is using the LinearMethodBase factory.
input_size_per_partition: Size of the weight input dim on rank X.
output_partition_sizes: Sizes of the output dim of each logical
weight on rank X. E.g., output_partition_sizes for QKVLinear
is a list contains the width of Wq, Wk, Wv on rank X.
input_size: Size of the input dim of the weight across all ranks.
output_size: Size of the output dim of the weight across all ranks.
params_dtype: Datatype of the parameters.
"""
raise NotImplementedError
@abstractmethod

View File

@ -64,12 +64,13 @@ class Fp8LinearMethod(LinearMethodBase):
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
):
output_size_per_partition = sum(output_partition_sizes)
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
dtype=params_dtype),