mirror of https://github.com/vllm-project/vllm
[NVIDIA] Support nvfp4 quantization (#12784)
This commit is contained in:
parent
9f9704dca6
commit
4fc5c23bb6
|
@ -264,6 +264,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
"csrc/custom_all_reduce.cu"
|
"csrc/custom_all_reduce.cu"
|
||||||
"csrc/permute_cols.cu"
|
"csrc/permute_cols.cu"
|
||||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||||
|
"csrc/quantization/fp4/nvfp4_quant_entry.cu"
|
||||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||||
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
|
"csrc/sparse/cutlass/sparse_compressor_entry.cu"
|
||||||
"csrc/cutlass_extensions/common.cpp")
|
"csrc/cutlass_extensions/common.cpp")
|
||||||
|
@ -377,6 +378,23 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
# FP4 Archs and flags
|
||||||
|
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||||
|
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||||
|
set(SRCS
|
||||||
|
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||||
|
)
|
||||||
|
set_gencode_flags_for_srcs(
|
||||||
|
SRCS "${SRCS}"
|
||||||
|
CUDA_ARCHS "${FP4_ARCHS}")
|
||||||
|
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||||
|
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1")
|
||||||
|
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||||
|
else()
|
||||||
|
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||||
|
# clear FP4_ARCHS
|
||||||
|
set(FP4_ARCHS)
|
||||||
|
endif()
|
||||||
|
|
||||||
#
|
#
|
||||||
# Machete kernels
|
# Machete kernels
|
||||||
|
|
|
@ -257,9 +257,9 @@ endmacro()
|
||||||
# where `<=` is the version comparison operator.
|
# where `<=` is the version comparison operator.
|
||||||
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
|
# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version
|
||||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||||
# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is
|
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
|
||||||
# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add
|
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
|
||||||
# 9.0a to the result (and remove 9.0 from TGT_CUDA_ARCHS).
|
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||||
#
|
#
|
||||||
# Example:
|
# Example:
|
||||||
|
@ -272,8 +272,8 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||||
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
list(REMOVE_DUPLICATES SRC_CUDA_ARCHS)
|
||||||
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
set(TGT_CUDA_ARCHS_ ${TGT_CUDA_ARCHS})
|
||||||
|
|
||||||
# if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should
|
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||||
# remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS
|
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
||||||
set(_CUDA_ARCHS)
|
set(_CUDA_ARCHS)
|
||||||
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
if ("9.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||||
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a")
|
||||||
|
@ -283,6 +283,14 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||||
endif()
|
endif()
|
||||||
endif()
|
endif()
|
||||||
|
|
||||||
|
if ("10.0a" IN_LIST SRC_CUDA_ARCHS)
|
||||||
|
list(REMOVE_ITEM SRC_CUDA_ARCHS "10.0a")
|
||||||
|
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
||||||
|
list(REMOVE_ITEM TGT_CUDA_ARCHS_ "10.0")
|
||||||
|
set(_CUDA_ARCHS "10.0a")
|
||||||
|
endif()
|
||||||
|
endif()
|
||||||
|
|
||||||
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||||
|
|
||||||
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
# for each ARCH in TGT_CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
#if defined(__CUDACC__) || defined(_NVHPC_CUDA)
|
||||||
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
|
#define HOST_DEVICE_INLINE __forceinline__ __host__ __device__
|
||||||
#define DEVICE_INLINE __forceinline__ __device__
|
#define DEVICE_INLINE __forceinline__ __device__
|
||||||
|
@ -10,6 +12,16 @@
|
||||||
#define HOST_INLINE inline
|
#define HOST_INLINE inline
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#define CUDA_CHECK(cmd) \
|
||||||
|
do { \
|
||||||
|
cudaError_t e = cmd; \
|
||||||
|
if (e != cudaSuccess) { \
|
||||||
|
printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
|
||||||
|
cudaGetErrorString(e)); \
|
||||||
|
exit(EXIT_FAILURE); \
|
||||||
|
} \
|
||||||
|
} while (0)
|
||||||
|
|
||||||
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
int64_t get_device_attribute(int64_t attribute, int64_t device_id);
|
||||||
|
|
||||||
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
int64_t get_max_shared_memory_per_block_device_attribute(int64_t device_id);
|
||||||
|
|
|
@ -1,16 +1,22 @@
|
||||||
|
#include "cuda_utils.h"
|
||||||
#ifdef USE_ROCM
|
#ifdef USE_ROCM
|
||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
#include <hip/hip_runtime_api.h>
|
#include <hip/hip_runtime_api.h>
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
|
int64_t get_device_attribute(int64_t attribute, int64_t device_id) {
|
||||||
int device, value;
|
// Return the cached value on subsequent calls
|
||||||
if (device_id < 0) {
|
static int value = [=]() {
|
||||||
cudaGetDevice(&device);
|
int device = static_cast<int>(device_id);
|
||||||
} else {
|
if (device < 0) {
|
||||||
device = device_id;
|
CUDA_CHECK(cudaGetDevice(&device));
|
||||||
}
|
}
|
||||||
cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute),
|
int value;
|
||||||
device);
|
CUDA_CHECK(cudaDeviceGetAttribute(
|
||||||
|
&value, static_cast<cudaDeviceAttr>(attribute), device));
|
||||||
|
return static_cast<int>(value);
|
||||||
|
}();
|
||||||
|
|
||||||
return value;
|
return value;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -195,6 +195,10 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||||
|
|
||||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||||
|
|
||||||
|
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||||
|
torch::Tensor& output_scale,
|
||||||
|
torch::Tensor const& input_scale);
|
||||||
|
|
||||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||||
torch::Tensor const& scale);
|
torch::Tensor const& scale);
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,32 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||||
|
void scaled_fp4_quant_sm100a(torch::Tensor const& output,
|
||||||
|
torch::Tensor const& input,
|
||||||
|
torch::Tensor const& output_sf,
|
||||||
|
torch::Tensor const& input_sf);
|
||||||
|
#endif
|
||||||
|
|
||||||
|
void scaled_fp4_quant(torch::Tensor& output, torch::Tensor const& input,
|
||||||
|
torch::Tensor& output_sf, torch::Tensor const& input_sf) {
|
||||||
|
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||||
|
return scaled_fp4_quant_sm100a(output, input, output_sf, input_sf);
|
||||||
|
#endif
|
||||||
|
TORCH_CHECK_NOT_IMPLEMENTED(false, "No compiled nvfp4 quantization");
|
||||||
|
}
|
|
@ -0,0 +1,379 @@
|
||||||
|
/*
|
||||||
|
* Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||||
|
*
|
||||||
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
* you may not use this file except in compliance with the License.
|
||||||
|
* You may obtain a copy of the License at
|
||||||
|
*
|
||||||
|
* http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
*
|
||||||
|
* Unless required by applicable law or agreed to in writing, software
|
||||||
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
* See the License for the specific language governing permissions and
|
||||||
|
* limitations under the License.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <torch/all.h>
|
||||||
|
|
||||||
|
#include <cuda_runtime_api.h>
|
||||||
|
#include <cuda_runtime.h>
|
||||||
|
|
||||||
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
#include <c10/cuda/CUDAGuard.h>
|
||||||
|
|
||||||
|
#include <cuda_fp8.h>
|
||||||
|
|
||||||
|
#include "cuda_utils.h"
|
||||||
|
|
||||||
|
// Get type2 from type or vice versa (applied to half and bfloat16)
|
||||||
|
template <typename T>
|
||||||
|
struct TypeConverter {
|
||||||
|
using Type = half2;
|
||||||
|
}; // keep for generality
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<half2> {
|
||||||
|
using Type = half;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<half> {
|
||||||
|
using Type = half2;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<__nv_bfloat162> {
|
||||||
|
using Type = __nv_bfloat16;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct TypeConverter<__nv_bfloat16> {
|
||||||
|
using Type = __nv_bfloat162;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define ELTS_PER_THREAD 8
|
||||||
|
|
||||||
|
constexpr int CVT_FP4_ELTS_PER_THREAD = 8;
|
||||||
|
constexpr int CVT_FP4_SF_VEC_SIZE = 16;
|
||||||
|
|
||||||
|
// Convert 8 float32 values into 8 e2m1 values (represented as one uint32_t).
|
||||||
|
inline __device__ uint32_t fp32_vec_to_e2m1(float (&array)[8]) {
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||||
|
uint32_t val;
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
".reg .b8 byte0;\n"
|
||||||
|
".reg .b8 byte1;\n"
|
||||||
|
".reg .b8 byte2;\n"
|
||||||
|
".reg .b8 byte3;\n"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||||
|
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||||
|
"}"
|
||||||
|
: "=r"(val)
|
||||||
|
: "f"(array[0]), "f"(array[1]), "f"(array[2]), "f"(array[3]),
|
||||||
|
"f"(array[4]), "f"(array[5]), "f"(array[6]), "f"(array[7]));
|
||||||
|
return val;
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert 4 float2 values into 8 e2m1 values (represented as one uint32_t).
|
||||||
|
inline __device__ uint32_t fp32_vec_to_e2m1(float2 (&array)[4]) {
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||||
|
uint32_t val;
|
||||||
|
asm volatile(
|
||||||
|
"{\n"
|
||||||
|
".reg .b8 byte0;\n"
|
||||||
|
".reg .b8 byte1;\n"
|
||||||
|
".reg .b8 byte2;\n"
|
||||||
|
".reg .b8 byte3;\n"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 byte0, %2, %1;\n"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 byte1, %4, %3;\n"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 byte2, %6, %5;\n"
|
||||||
|
"cvt.rn.satfinite.e2m1x2.f32 byte3, %8, %7;\n"
|
||||||
|
"mov.b32 %0, {byte0, byte1, byte2, byte3};\n"
|
||||||
|
"}"
|
||||||
|
: "=r"(val)
|
||||||
|
: "f"(array[0].x), "f"(array[0].y), "f"(array[1].x), "f"(array[1].y),
|
||||||
|
"f"(array[2].x), "f"(array[2].y), "f"(array[3].x), "f"(array[3].y));
|
||||||
|
return val;
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fast reciprocal.
|
||||||
|
inline __device__ float reciprocal_approximate_ftz(float a) {
|
||||||
|
float b;
|
||||||
|
asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a));
|
||||||
|
return b;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class SFType, int CVT_FP4_NUM_THREADS_PER_SF>
|
||||||
|
__device__ uint8_t* cvt_quant_to_fp4_get_sf_out_offset(int rowIdx, int colIdx,
|
||||||
|
int numCols,
|
||||||
|
SFType* SFout) {
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||||
|
static_assert(CVT_FP4_NUM_THREADS_PER_SF == 1 ||
|
||||||
|
CVT_FP4_NUM_THREADS_PER_SF == 2);
|
||||||
|
|
||||||
|
// One pair of threads write one SF to global memory.
|
||||||
|
// TODO: stage through smem for packed STG.32
|
||||||
|
// is it better than STG.8 from 4 threads ?
|
||||||
|
if (threadIdx.x % CVT_FP4_NUM_THREADS_PER_SF == 0) {
|
||||||
|
// SF vector index (16 elements share one SF in the K dimension).
|
||||||
|
int32_t kIdx = colIdx / CVT_FP4_NUM_THREADS_PER_SF;
|
||||||
|
int32_t mIdx = rowIdx;
|
||||||
|
|
||||||
|
// SF layout [numMTiles, numKTiles, 32 (mTile), 4 (mTile), 4(kTile)]
|
||||||
|
// --> index [mTileIdx, kTileIdx, outerMIdx, innerMIdx, innerKIdx]
|
||||||
|
|
||||||
|
int32_t mTileIdx = mIdx / (32 * 4);
|
||||||
|
// SF vector size 16.
|
||||||
|
int factor = CVT_FP4_SF_VEC_SIZE * 4;
|
||||||
|
int32_t numKTiles = (numCols + factor - 1) / factor;
|
||||||
|
int64_t mTileStride = numKTiles * 32 * 4 * 4;
|
||||||
|
|
||||||
|
int32_t kTileIdx = (kIdx / 4);
|
||||||
|
int64_t kTileStride = 32 * 4 * 4;
|
||||||
|
|
||||||
|
// M tile layout [32, 4] is column-major.
|
||||||
|
int32_t outerMIdx = (mIdx % 32);
|
||||||
|
int64_t outerMStride = 4 * 4;
|
||||||
|
|
||||||
|
int32_t innerMIdx = (mIdx % (32 * 4)) / 32;
|
||||||
|
int64_t innerMStride = 4;
|
||||||
|
|
||||||
|
int32_t innerKIdx = (kIdx % 4);
|
||||||
|
int64_t innerKStride = 1;
|
||||||
|
|
||||||
|
// Compute the global offset.
|
||||||
|
int64_t SFOffset = mTileIdx * mTileStride + kTileIdx * kTileStride +
|
||||||
|
outerMIdx * outerMStride + innerMIdx * innerMStride +
|
||||||
|
innerKIdx * innerKStride;
|
||||||
|
|
||||||
|
return reinterpret_cast<uint8_t*>(SFout) + SFOffset;
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Define a 16 bytes packed data type.
|
||||||
|
template <class Type>
|
||||||
|
struct PackedVec {
|
||||||
|
typename TypeConverter<Type>::Type elts[4];
|
||||||
|
};
|
||||||
|
|
||||||
|
template <>
|
||||||
|
struct PackedVec<__nv_fp8_e4m3> {
|
||||||
|
__nv_fp8x2_e4m3 elts[8];
|
||||||
|
};
|
||||||
|
|
||||||
|
// Quantizes the provided PackedVec into the uint32_t output
|
||||||
|
template <class Type, bool UE8M0_SF = false>
|
||||||
|
__device__ uint32_t cvt_warp_fp16_to_fp4(PackedVec<Type>& vec, float SFScaleVal,
|
||||||
|
uint8_t* SFout) {
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||||
|
// Get absolute maximum values among the local 8 values.
|
||||||
|
auto localMax = __habs2(vec.elts[0]);
|
||||||
|
|
||||||
|
// Local maximum value.
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 1; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||||
|
localMax = __hmax2(localMax, __habs2(vec.elts[i]));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get the absolute maximum among all 16 values (two threads).
|
||||||
|
localMax = __hmax2(__shfl_xor_sync(uint32_t(-1), localMax, 1), localMax);
|
||||||
|
// Get the final absolute maximum values.
|
||||||
|
float vecMax = float(__hmax(localMax.x, localMax.y));
|
||||||
|
|
||||||
|
// Get the SF (max value of the vector / max value of e2m1).
|
||||||
|
// maximum value of e2m1 = 6.0.
|
||||||
|
// TODO: use half as compute data type.
|
||||||
|
float SFValue = SFScaleVal * (vecMax * reciprocal_approximate_ftz(6.0f));
|
||||||
|
// 8 bits representation of the SF.
|
||||||
|
uint8_t fp8SFVal;
|
||||||
|
// Write the SF to global memory (STG.8).
|
||||||
|
if constexpr (UE8M0_SF) {
|
||||||
|
// Extract the 8 exponent bits from float32.
|
||||||
|
// float 32bits = 1 sign bit + 8 exponent bits + 23 mantissa bits.
|
||||||
|
uint32_t tmp = reinterpret_cast<uint32_t&>(SFValue) >> 23;
|
||||||
|
fp8SFVal = tmp & 0xff;
|
||||||
|
// Convert back to fp32.
|
||||||
|
reinterpret_cast<uint32_t&>(SFValue) = tmp << 23;
|
||||||
|
} else {
|
||||||
|
// Here SFValue is always positive, so E4M3 is the same as UE4M3.
|
||||||
|
__nv_fp8_e4m3 tmp = __nv_fp8_e4m3(SFValue);
|
||||||
|
reinterpret_cast<__nv_fp8_e4m3&>(fp8SFVal) = tmp;
|
||||||
|
// Convert back to fp32.
|
||||||
|
SFValue = float(tmp);
|
||||||
|
}
|
||||||
|
// Get the output scale.
|
||||||
|
// Recipe: final_scale = reciprocal(fp32(fp8(SFValue * SFScaleVal))) *
|
||||||
|
// reciprocal(SFScaleVal))
|
||||||
|
float outputScale =
|
||||||
|
SFValue != 0 ? reciprocal_approximate_ftz(
|
||||||
|
SFValue * reciprocal_approximate_ftz(SFScaleVal))
|
||||||
|
: 0.0f;
|
||||||
|
|
||||||
|
if (SFout) {
|
||||||
|
// Write the SF to global memory (STG.8).
|
||||||
|
*SFout = fp8SFVal;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert the input to float.
|
||||||
|
float2 fp2Vals[CVT_FP4_ELTS_PER_THREAD / 2];
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) {
|
||||||
|
if constexpr (std::is_same_v<Type, half>) {
|
||||||
|
fp2Vals[i] = __half22float2(vec.elts[i]);
|
||||||
|
} else {
|
||||||
|
fp2Vals[i] = __bfloat1622float2(vec.elts[i]);
|
||||||
|
}
|
||||||
|
fp2Vals[i].x *= outputScale;
|
||||||
|
fp2Vals[i].y *= outputScale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to e2m1 values.
|
||||||
|
uint32_t e2m1Vec = fp32_vec_to_e2m1(fp2Vals);
|
||||||
|
|
||||||
|
// Write the e2m1 values to global memory.
|
||||||
|
return e2m1Vec;
|
||||||
|
#else
|
||||||
|
return 0;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use UE4M3 by default.
|
||||||
|
template <class Type, bool UE8M0_SF = false>
|
||||||
|
__global__ void
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||||
|
__launch_bounds__(512, 4) cvt_fp16_to_fp4(
|
||||||
|
#else
|
||||||
|
cvt_fp16_to_fp4(
|
||||||
|
#endif
|
||||||
|
int32_t numRows, int32_t numCols, Type const* in, float const* SFScale,
|
||||||
|
uint32_t* out, uint32_t* SFout) {
|
||||||
|
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
|
||||||
|
using PackedVec = PackedVec<Type>;
|
||||||
|
static constexpr int CVT_FP4_NUM_THREADS_PER_SF =
|
||||||
|
(CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD);
|
||||||
|
static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD,
|
||||||
|
"Vec size is not matched.");
|
||||||
|
|
||||||
|
// Get the global scaling factor, which will be applied to the SF.
|
||||||
|
// Note SFScale is the same as next GEMM's alpha, which is
|
||||||
|
// (448.f / (Alpha_A / 6.f)).
|
||||||
|
float const SFScaleVal = SFScale == nullptr ? 1.0f : SFScale[0];
|
||||||
|
|
||||||
|
// Input tensor row/col loops.
|
||||||
|
for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) {
|
||||||
|
for (int colIdx = threadIdx.x; colIdx < numCols / CVT_FP4_ELTS_PER_THREAD;
|
||||||
|
colIdx += blockDim.x) {
|
||||||
|
int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx;
|
||||||
|
PackedVec in_vec = reinterpret_cast<PackedVec const*>(in)[inOffset];
|
||||||
|
// Get the output tensor offset.
|
||||||
|
// Same as inOffset because 8 elements are packed into one uint32_t.
|
||||||
|
int64_t outOffset = inOffset;
|
||||||
|
auto& out_pos = out[outOffset];
|
||||||
|
|
||||||
|
auto sf_out =
|
||||||
|
cvt_quant_to_fp4_get_sf_out_offset<uint32_t,
|
||||||
|
CVT_FP4_NUM_THREADS_PER_SF>(
|
||||||
|
rowIdx, colIdx, numCols, SFout);
|
||||||
|
|
||||||
|
out_pos =
|
||||||
|
cvt_warp_fp16_to_fp4<Type, UE8M0_SF>(in_vec, SFScaleVal, sf_out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
|
||||||
|
int64_t* output, int32_t* SFOuput, bool useUE8M0,
|
||||||
|
int multiProcessorCount, cudaStream_t stream) {
|
||||||
|
// Grid, Block size.
|
||||||
|
// Each thread converts 8 values.
|
||||||
|
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||||
|
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||||
|
int const numBlocksPerSM = 2048 / block.x;
|
||||||
|
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
||||||
|
|
||||||
|
// Launch the cvt kernel.
|
||||||
|
if (useUE8M0) {
|
||||||
|
cvt_fp16_to_fp4<T, true><<<grid, block, 0, stream>>>(
|
||||||
|
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
|
||||||
|
reinterpret_cast<uint32_t*>(SFOuput));
|
||||||
|
} else {
|
||||||
|
cvt_fp16_to_fp4<T, false><<<grid, block, 0, stream>>>(
|
||||||
|
m, n, input, SFScale, reinterpret_cast<uint32_t*>(output),
|
||||||
|
reinterpret_cast<uint32_t*>(SFOuput));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Instantiate the function.
|
||||||
|
template void invokeFP4Quantization(int m, int n, half const* input,
|
||||||
|
float const* SFScale, int64_t* output,
|
||||||
|
int32_t* SFOuput, bool useUE8M0,
|
||||||
|
int multiProcessorCount,
|
||||||
|
cudaStream_t stream);
|
||||||
|
|
||||||
|
template void invokeFP4Quantization(int m, int n, __nv_bfloat16 const* input,
|
||||||
|
float const* SFScale, int64_t* output,
|
||||||
|
int32_t* SFOuput, bool useUE8M0,
|
||||||
|
int multiProcessorCount,
|
||||||
|
cudaStream_t stream);
|
||||||
|
|
||||||
|
void scaled_fp4_quant_sm100a(torch::Tensor const& output,
|
||||||
|
torch::Tensor const& input,
|
||||||
|
torch::Tensor const& output_sf,
|
||||||
|
torch::Tensor const& input_sf) {
|
||||||
|
int32_t m = input.size(0);
|
||||||
|
int32_t n = input.size(1);
|
||||||
|
|
||||||
|
TORCH_CHECK(n % 16 == 0, "The N dimension must be multiple of 16.");
|
||||||
|
|
||||||
|
int multiProcessorCount =
|
||||||
|
get_device_attribute(cudaDevAttrMultiProcessorCount, -1);
|
||||||
|
|
||||||
|
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||||
|
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||||
|
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||||
|
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||||
|
auto stream = at::cuda::getStreamFromPool(false, input.get_device());
|
||||||
|
if (stream == nullptr) {
|
||||||
|
std::cerr << "Warning: Null CUDA stream" << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// We don't support e8m0 scales at this moment.
|
||||||
|
bool useUE8M0 = false;
|
||||||
|
|
||||||
|
switch (input.scalar_type()) {
|
||||||
|
case torch::kHalf: {
|
||||||
|
auto input_ptr = reinterpret_cast<half const*>(input.data_ptr());
|
||||||
|
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out,
|
||||||
|
useUE8M0, multiProcessorCount, stream);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case torch::kBFloat16: {
|
||||||
|
auto input_ptr = reinterpret_cast<__nv_bfloat16 const*>(input.data_ptr());
|
||||||
|
invokeFP4Quantization(m, n, input_ptr, input_sf_ptr, output_ptr, sf_out,
|
||||||
|
useUE8M0, multiProcessorCount, stream);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
default: {
|
||||||
|
std::cerr << "Observing: " << input.scalar_type()
|
||||||
|
<< " for the input datatype which is invalid";
|
||||||
|
throw std::runtime_error(
|
||||||
|
"Unsupported input data type for quantize_to_fp4.");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -423,6 +423,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||||
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
ops.impl("dynamic_per_token_scaled_fp8_quant", torch::kCUDA,
|
||||||
&dynamic_per_token_scaled_fp8_quant);
|
&dynamic_per_token_scaled_fp8_quant);
|
||||||
|
|
||||||
|
// Compute NVFP4 block quantized tensor.
|
||||||
|
ops.def(
|
||||||
|
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||||
|
" Tensor! output_scale, Tensor input_scale) -> ()");
|
||||||
|
ops.impl("scaled_fp4_quant", torch::kCUDA, &scaled_fp4_quant);
|
||||||
|
|
||||||
// Compute int8 quantized tensor for given scaling factor.
|
// Compute int8 quantized tensor for given scaling factor.
|
||||||
ops.def(
|
ops.def(
|
||||||
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
|
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
|
||||||
|
|
|
@ -0,0 +1,149 @@
|
||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
from vllm.platforms import current_platform
|
||||||
|
from vllm.scalar_type import scalar_types
|
||||||
|
|
||||||
|
if not current_platform.has_device_capability(100):
|
||||||
|
pytest.skip(reason="Nvfp4 Requires compute capability of 10 or above.",
|
||||||
|
allow_module_level=True)
|
||||||
|
|
||||||
|
DTYPES = [torch.float16, torch.bfloat16]
|
||||||
|
SHAPES = [(128, 64), (128, 128), (256, 64), (256, 128)]
|
||||||
|
PAD_SHAPES = [(90, 64), (150, 64), (128, 48), (128, 80), (150, 80), (90, 48),
|
||||||
|
(90, 128), (150, 128), (150, 48), (90, 80)]
|
||||||
|
SEEDS = [42]
|
||||||
|
CUDA_DEVICES = ['cuda:0']
|
||||||
|
|
||||||
|
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1fn.max()
|
||||||
|
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||||
|
|
||||||
|
# E2M1 to float
|
||||||
|
# 0111 -> 6
|
||||||
|
# 0110 -> 4
|
||||||
|
# 0101 -> 3
|
||||||
|
# 0100 -> 2
|
||||||
|
# 0011 -> 1.5
|
||||||
|
# 0010 -> 1
|
||||||
|
# 0001 -> 0.5
|
||||||
|
# 0000 -> 0
|
||||||
|
E2M1_TO_FLOAT32 = [
|
||||||
|
0., 0.5, 1., 1.5, 2., 3., 4., 6., 0., -0.5, -1., -1.5, -2., -3., -4., -6.
|
||||||
|
]
|
||||||
|
BLOCK_SIZE = 16
|
||||||
|
|
||||||
|
|
||||||
|
def cast_from_fp4(x, m, n):
|
||||||
|
# The fp4 values are packed in uint8 as [v_1st | v_2nd]
|
||||||
|
v_2nd = x & 0xF
|
||||||
|
v_1st = (x >> 4) & 0xF
|
||||||
|
c = torch.stack((v_2nd, v_1st), dim=-1)
|
||||||
|
out = torch.tensor([E2M1_TO_FLOAT32[x] for x in c.flatten()])
|
||||||
|
out = out.reshape(m, n).to(torch.float32)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def cast_to_fp4(x):
|
||||||
|
sign = torch.sign(x)
|
||||||
|
x = torch.abs(x)
|
||||||
|
x[(x >= 0.0) & (x <= 0.25)] = 0.0
|
||||||
|
x[(x > 0.25) & (x < 0.75)] = 0.5
|
||||||
|
x[(x >= 0.75) & (x <= 1.25)] = 1.0
|
||||||
|
x[(x > 1.25) & (x < 1.75)] = 1.5
|
||||||
|
x[(x >= 1.75) & (x <= 2.5)] = 2.0
|
||||||
|
x[(x > 2.5) & (x < 3.5)] = 3.0
|
||||||
|
x[(x >= 3.5) & (x <= 5.0)] = 4.0
|
||||||
|
x[x > 5.0] = 6.0
|
||||||
|
return x * sign
|
||||||
|
|
||||||
|
|
||||||
|
def get_reciprocal(x):
|
||||||
|
if isinstance(x, torch.Tensor):
|
||||||
|
return torch.where(x == 0, torch.tensor(0.0, dtype=x.dtype), 1.0 / x)
|
||||||
|
elif isinstance(x, (float, int)):
|
||||||
|
return 0.0 if x == 0 else 1.0 / x
|
||||||
|
else:
|
||||||
|
raise TypeError("Input must be a float, int, or a torch.Tensor.")
|
||||||
|
|
||||||
|
|
||||||
|
def ref_nvfp4_quant(x, global_scale):
|
||||||
|
assert global_scale.dtype == torch.float32
|
||||||
|
assert x.ndim == 2
|
||||||
|
m, n = x.shape
|
||||||
|
x = torch.reshape(x, (m, n // BLOCK_SIZE, BLOCK_SIZE))
|
||||||
|
vec_max = torch.max(torch.abs(x), dim=-1,
|
||||||
|
keepdim=True)[0].to(torch.float32)
|
||||||
|
scale = global_scale * (vec_max * get_reciprocal(FLOAT4_E2M1_MAX))
|
||||||
|
scale = scale.to(torch.float8_e4m3fn).to(torch.float32)
|
||||||
|
output_scale = get_reciprocal(scale * get_reciprocal(global_scale))
|
||||||
|
|
||||||
|
scaled_x = x.to(torch.float32) * output_scale
|
||||||
|
clipped_x = torch.clamp(scaled_x, -6.0, 6.0).reshape(m, n)
|
||||||
|
return cast_to_fp4(clipped_x), scale.squeeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
def recover_swizzled_scales(scale, m, n):
|
||||||
|
round_up = lambda x, y: (x + y - 1) // y * y
|
||||||
|
rounded_m = round_up(m, 128)
|
||||||
|
scale_n = n // BLOCK_SIZE
|
||||||
|
rounded_n = round_up(scale_n, 4)
|
||||||
|
# Recover the swizzled scaling factor to linear layout
|
||||||
|
tmp = torch.reshape(scale, (1, rounded_m // 128, rounded_n // 4, 32, 4, 4))
|
||||||
|
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
|
||||||
|
result = torch.reshape(tmp, (rounded_m, rounded_n)).to(torch.float32)
|
||||||
|
return result[:m, :scale_n]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dtype", DTYPES)
|
||||||
|
@pytest.mark.parametrize("shape", SHAPES)
|
||||||
|
@pytest.mark.parametrize("seed", SEEDS)
|
||||||
|
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_quantize_to_fp4(
|
||||||
|
dtype: torch.dtype,
|
||||||
|
shape: tuple[int, int],
|
||||||
|
seed: int,
|
||||||
|
device: str,
|
||||||
|
) -> None:
|
||||||
|
current_platform.seed_everything(seed)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
|
||||||
|
m, n = shape
|
||||||
|
|
||||||
|
x = torch.randn((m, n), dtype=dtype)
|
||||||
|
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||||
|
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||||
|
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
|
||||||
|
|
||||||
|
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
|
||||||
|
scale_ans = recover_swizzled_scales(out_scale, m, n)
|
||||||
|
out_ans = cast_from_fp4(out, m, n)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_ans, out_ref)
|
||||||
|
torch.testing.assert_close(scale_ans, scale_ref)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("pad_shape", PAD_SHAPES)
|
||||||
|
@torch.inference_mode()
|
||||||
|
def test_quantize_to_fp4_padded(pad_shape: tuple[int, int]) -> None:
|
||||||
|
dtype = torch.float16
|
||||||
|
current_platform.seed_everything(42)
|
||||||
|
torch.set_default_device('cuda:0')
|
||||||
|
|
||||||
|
m, n = pad_shape
|
||||||
|
|
||||||
|
x = torch.randn((m, n), dtype=dtype)
|
||||||
|
|
||||||
|
tensor_amax = torch.abs(x).max().to(torch.float32)
|
||||||
|
global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / tensor_amax
|
||||||
|
out_ref, scale_ref = ref_nvfp4_quant(x, global_scale)
|
||||||
|
|
||||||
|
out, out_scale = ops.scaled_fp4_quant(x, global_scale)
|
||||||
|
|
||||||
|
scale_ans = recover_swizzled_scales(out_scale, m, n)
|
||||||
|
out_ans = cast_from_fp4(out, m, n)
|
||||||
|
|
||||||
|
torch.testing.assert_close(out_ans, out_ref)
|
||||||
|
torch.testing.assert_close(scale_ans, scale_ref)
|
|
@ -11,6 +11,7 @@ from vllm.scalar_type import scalar_types
|
||||||
(0, 15, scalar_types.uint4),
|
(0, 15, scalar_types.uint4),
|
||||||
(-8, 7, scalar_types.uint4b8),
|
(-8, 7, scalar_types.uint4b8),
|
||||||
(-128, 127, scalar_types.uint8b128),
|
(-128, 127, scalar_types.uint8b128),
|
||||||
|
(-6., 6., scalar_types.float4_e2m1fn),
|
||||||
(-28., 28., scalar_types.float6_e3m2f),
|
(-28., 28., scalar_types.float6_e3m2f),
|
||||||
(torch.int8, scalar_types.int8),
|
(torch.int8, scalar_types.int8),
|
||||||
(torch.uint8, scalar_types.uint8),
|
(torch.uint8, scalar_types.uint8),
|
||||||
|
|
|
@ -765,6 +765,63 @@ def permute_cols(a: torch.Tensor, perm: torch.Tensor) -> torch.Tensor:
|
||||||
return torch.ops._C.permute_cols(a, perm)
|
return torch.ops._C.permute_cols(a, perm)
|
||||||
|
|
||||||
|
|
||||||
|
# fp4
|
||||||
|
def scaled_fp4_quant(
|
||||||
|
input: torch.Tensor,
|
||||||
|
input_global_scale: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Quantize input tensor to FP4 and return quantized tensor and scale.
|
||||||
|
|
||||||
|
This function quantizes the last dimension of the given tensor `input`. For
|
||||||
|
every 16 consecutive elements, a single dynamically computed scaling factor
|
||||||
|
is shared. This scaling factor is quantized using the `input_global_scale`
|
||||||
|
and is stored in a swizzled layout (see
|
||||||
|
https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input: The input tensor to be quantized to FP4
|
||||||
|
input_global_scale: A scalar scaling factor for the entire tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: The output tensor in FP4 but every
|
||||||
|
two values are packed into a uint8 and float8_e4m3 scaling factors
|
||||||
|
in the sizzled layout.
|
||||||
|
"""
|
||||||
|
assert input.ndim >= 1, (
|
||||||
|
f'input.ndim needs to be >= 1, but got {input.ndim}.')
|
||||||
|
other_dims = 1 if input.ndim == 1 else -1
|
||||||
|
input = input.reshape(other_dims, input.shape[-1])
|
||||||
|
m, n = input.shape
|
||||||
|
block_size = 16
|
||||||
|
device = input.device
|
||||||
|
|
||||||
|
assert n % block_size == 0, (
|
||||||
|
f'last dim has to be multiple of 16, but got {n}.')
|
||||||
|
assert input.dtype in (torch.float16, torch.bfloat16), (
|
||||||
|
f'input.dtype needs to be fp16 or bf16 but got {input.dtype}.')
|
||||||
|
|
||||||
|
# Two fp4 values will be packed into an uint8.
|
||||||
|
output = torch.empty((m, n // 2), device=device, dtype=torch.uint8)
|
||||||
|
|
||||||
|
# We use the rounded values to store the swizzled values. Due to the
|
||||||
|
# requirement of the Tensor Core, the minimum tile is 128x4 for the scales.
|
||||||
|
# So, we first pad the scales to multiples of 128 and 4. Then, the scales
|
||||||
|
# (in float8_e4m3fn) are packed into an int32 for every 4 values. More:
|
||||||
|
# https://docs.nvidia.com/cuda/parallel-thread-execution/#tcgen05-mma-scale-factor-b-layout-4x
|
||||||
|
round_up = lambda x, y: (x + y - 1) // y * y
|
||||||
|
rounded_m = round_up(m, 128)
|
||||||
|
scale_n = n // block_size
|
||||||
|
rounded_n = round_up(scale_n, 4)
|
||||||
|
output_scale = torch.empty((rounded_m, rounded_n // 4),
|
||||||
|
device=device,
|
||||||
|
dtype=torch.int32)
|
||||||
|
|
||||||
|
torch.ops._C.scaled_fp4_quant(output, input, output_scale,
|
||||||
|
input_global_scale)
|
||||||
|
output_scale = output_scale.view(torch.float8_e4m3fn)
|
||||||
|
return output, output_scale
|
||||||
|
|
||||||
|
|
||||||
# fp8
|
# fp8
|
||||||
def scaled_fp8_quant(
|
def scaled_fp8_quant(
|
||||||
input: torch.Tensor,
|
input: torch.Tensor,
|
||||||
|
|
|
@ -321,6 +321,9 @@ class scalar_types:
|
||||||
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
# fp6, https://github.com/usyd-fsalab/fp6_llm/tree/main
|
||||||
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
|
||||||
|
|
||||||
|
# fp4, https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
|
||||||
|
float4_e2m1fn = ScalarType.float_(2, 1, True, NanRepr.NONE)
|
||||||
|
|
||||||
# "gptq" types
|
# "gptq" types
|
||||||
uint2b2 = ScalarType.uint(2, 2)
|
uint2b2 = ScalarType.uint(2, 2)
|
||||||
uint3b4 = ScalarType.uint(3, 4)
|
uint3b4 = ScalarType.uint(3, 4)
|
||||||
|
|
Loading…
Reference in New Issue