Support CPU inference with VSX PowerPC ISA (#5652)

This commit is contained in:
Chip Kerchner 2024-06-26 17:53:04 -04:00 committed by GitHub
parent f5c8628fdc
commit 38a1674abb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 1049 additions and 511 deletions

22
Dockerfile.ppc64le Normal file
View File

@ -0,0 +1,22 @@
FROM mambaorg/micromamba
ARG MAMBA_DOCKERFILE_ACTIVATE=1
USER root
RUN apt-get update -y && apt-get install -y git wget vim numactl gcc-12 g++-12 protobuf-compiler libprotobuf-dev && update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
# Some packages in requirements-cpu are installed here
# IBM provides optimized packages for ppc64le processors in the open-ce project for mamba
# Currently these may not be available for venv or pip directly
RUN micromamba install -y -n base -c https://ftp.osuosl.org/pub/open-ce/1.11.0-p10/ -c defaults python=3.10 pytorch-cpu=2.1.2 torchvision-cpu=0.16.2 && micromamba clean --all --yes
COPY ./ /workspace/vllm
WORKDIR /workspace/vllm
# These packages will be in rocketce eventually
RUN pip install -v -r requirements-cpu.txt --prefer-binary --extra-index-url https://repo.fury.io/mgiessing
RUN VLLM_TARGET_DEVICE=cpu python3 setup.py install
WORKDIR /vllm-workspace
ENTRYPOINT ["/opt/conda/bin/python3", "-m", "vllm.entrypoints.openai.api_server"]

View File

@ -46,6 +46,8 @@ is_avx512_disabled(AVX512_DISABLED)
find_isa(${CPUINFO} "avx2" AVX2_FOUND)
find_isa(${CPUINFO} "avx512f" AVX512_FOUND)
find_isa(${CPUINFO} "POWER10" POWER10_FOUND)
find_isa(${CPUINFO} "POWER9" POWER9_FOUND)
if (AVX512_FOUND AND NOT AVX512_DISABLED)
list(APPEND CXX_COMPILE_FLAGS
@ -68,8 +70,15 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
elseif (AVX2_FOUND)
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
message(WARNING "vLLM CPU backend using AVX2 ISA")
elseif (POWER9_FOUND OR POWER10_FOUND)
message(STATUS "PowerPC detected")
# Check for PowerPC VSX support
list(APPEND CXX_COMPILE_FLAGS
"-mvsx"
"-mcpu=native"
"-mtune=native")
else()
message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 ISA support.")
message(FATAL_ERROR "vLLM CPU backend requires AVX512 or AVX2 or Power9+ ISA support.")
endif()
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")

View File

@ -2,514 +2,14 @@
#ifndef CPU_TYPES_HPP
#define CPU_TYPES_HPP
#include <immintrin.h>
#include <torch/all.h>
#ifndef __AVX2__
static_assert(false, "AVX2 must be supported for the current implementation.");
#endif
namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#if defined(__x86_64__)
//x86 implementation
#include "cpu_types_x86.hpp"
#elif defined(__POWER9_VECTOR__)
//ppc implementation
#include "cpu_types_vsx.hpp"
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
#warning "unsupported vLLM cpu implementation"
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F &&f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T> struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
struct FP32Vec8;
struct FP32Vec16;
#ifdef __AVX512FP16__
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128h reg;
explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
explicit FP16Vec8(__m128h data) : reg(data) {}
FP16Vec8 operator*(const FP16Vec8 &b) const {
return FP16Vec8(_mm_mul_ph(reg, b.reg));
}
FP16Vec8 operator+(const FP16Vec8 &b) const {
return FP16Vec8(_mm_add_ph(reg, b.reg));
}
FP16Vec8 operator-(const FP16Vec8 &b) const {
return FP16Vec8(_mm_sub_ph(reg, b.reg));
}
FP16Vec8 operator/(const FP16Vec8 &b) const {
return FP16Vec8(_mm_div_ph(reg, b.reg));
}
void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
};
#endif
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128i reg;
explicit BF16Vec8(const void *ptr)
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
explicit BF16Vec8(const FP32Vec8 &);
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
__m256i reg;
explicit BF16Vec16(const void *ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
explicit BF16Vec16(const FP32Vec16 &);
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
};
#ifdef __AVX512F__
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
__m512i reg;
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
explicit BF16Vec32(__m512i data) : reg(data) {}
explicit BF16Vec32(BF16Vec8 &vec8_data)
: reg((__m512i)_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
(__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1),
(__m128i)vec8_data.reg, 2),
(__m128i)vec8_data.reg, 3)) {}
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
};
#else
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
__m256i reg_low;
__m256i reg_high;
explicit BF16Vec32(const void *ptr)
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
reg_high(high) {}
explicit BF16Vec32(BF16Vec8 &vec8_data)
: reg_low((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)),
reg_high((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)) {}
void save(void *ptr) const {
*reinterpret_cast<__m256i *>(ptr) = reg_low;
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
}
};
#endif
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__m128 reg;
float values[VEC_ELEM_NUM];
};
__m128 reg;
explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
explicit FP32Vec4(__m128 data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
__m256 reg;
float values[VEC_ELEM_NUM];
};
__m256 reg;
explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
explicit FP32Vec8(__m256 data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
#ifdef __AVX512FP16__
explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
#endif
explicit FP32Vec8(const BF16Vec8 &v)
: reg(_mm256_castsi256_ps(
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
expf(ar.values[5]), expf(ar.values[4]),
expf(ar.values[3]), expf(ar.values[2]),
expf(ar.values[1]), expf(ar.values[0])));
}
FP32Vec8 tanh() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
tanhf(ar.values[5]), tanhf(ar.values[4]),
tanhf(ar.values[3]), tanhf(ar.values[2]),
tanhf(ar.values[1]), tanhf(ar.values[0])));
}
FP32Vec8 er() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
erf(ar.values[5]), erf(ar.values[4]),
erf(ar.values[3]), erf(ar.values[2]),
erf(ar.values[1]), erf(ar.values[0])));
}
FP32Vec8 operator*(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
}
FP32Vec8 operator+(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_add_ps(reg, b.reg));
}
FP32Vec8 operator-(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
}
FP32Vec8 operator/(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_div_ps(reg, b.reg));
}
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
};
#ifdef __AVX512F__
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m512 reg;
float values[VEC_ELEM_NUM];
};
__m512 reg;
explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
explicit FP32Vec16(__m512 data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
explicit FP32Vec16(const FP32Vec4 &data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
(__m128i)data.reg, 1),
(__m128i)data.reg, 2),
(__m128i)data.reg, 3)) {}
explicit FP32Vec16(const FP32Vec8 &data)
: reg((__m512)_mm512_inserti32x8(
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
explicit FP32Vec16(const BF16Vec16 &v)
: reg(_mm512_castsi512_ps(
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
}
FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_add_ps(reg, b.reg));
}
FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
}
FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_div_ps(reg, b.reg));
}
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
template <int group_size> float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
return _mm512_mask_reduce_add_ps(mask, reg);
}
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
};
#else
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m256 reg;
float values[8];
};
__m256 reg_low;
__m256 reg_high;
explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
reg_high(_mm256_set1_ps(v)) {}
explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
reg_high(_mm256_set1_ps(0.0)) {}
explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
reg_high(_mm256_loadu_ps(ptr + 8)) {}
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
reg_high(data.reg_high) {}
explicit FP32Vec16(const FP32Vec4 &data)
: reg_low((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg),
(__m128i)data.reg, 1)),
reg_high((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg),
(__m128i)data.reg, 1)) {}
explicit FP32Vec16(const FP32Vec8 &data)
: reg_low(data.reg), reg_high(data.reg) {}
explicit FP32Vec16(const BF16Vec16 &v) {
__m128i low = _mm256_extractf128_si256(v.reg, 0);
__m128i high = _mm256_extractf128_si256(v.reg, 1);
__m256i v_low_epi32 = _mm256_cvtepu16_epi32(low);
__m256i v_high_epi32 = _mm256_cvtepu16_epi32(high);
__m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2);
__m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2);
reg_low = _mm256_castsi256_ps(v_low_shifted);
reg_high = _mm256_castsi256_ps(v_high_shifted);
}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
_mm256_mul_ps(reg_high, b.reg_high));
}
FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
_mm256_add_ps(reg_high, b.reg_high));
}
FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
_mm256_sub_ps(reg_high, b.reg_high));
}
FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
_mm256_div_ps(reg_high, b.reg_high));
}
float reduce_sum() const {
FP32Vec8 low = FP32Vec8(reg_low);
FP32Vec8 high = FP32Vec8(reg_high);
return low.reduce_sum() + high.reduce_sum();
}
template <int group_size> float reduce_sub_sum(int idx) {
float sum = 0.0;
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
uint32_t mask = base_mask << (idx * group_size);
AliasReg ar;
auto func = [&sum, &mask, &ar](int i) {
int flag = mask & 0x1;
mask = mask >> 1;
if (flag != 0) sum += ar.values[i];
};
ar.reg = reg_low;
unroll_loop<int, 8>(func);
ar.reg = reg_high;
unroll_loop<int, 8>(func);
return sum;
}
void save(float *ptr) const {
_mm256_storeu_ps(ptr, reg_low);
_mm256_storeu_ps(ptr + 8, reg_high);
}
};
#endif
template <typename T> struct VecType { using vec_type = void; };
template <typename T> using vec_t = typename VecType<T>::vec_type;
template <> struct VecType<float> { using vec_type = FP32Vec8; };
#ifdef __AVX512FP16__
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
#endif
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
#ifdef __AVX512FP16__
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
*reinterpret_cast<_Float16 *>(ptr) = v;
}
#endif
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc = acc + a * b;
}
#ifdef __AVX512BF16__
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
}
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
}
#else
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
reinterpret_cast<c10::BFloat16 *>(&v);
*ptr = *(v_ptr + 1);
}
#ifdef __AVX512F__
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(_mm256_cvtepi32_epi16(
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg(_mm512_cvtepi32_epi16(
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
#else
namespace{
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
__m256i ai = _mm256_castps_si256(a);
ai = _mm256_srli_epi32(ai, 16);
ai = _mm256_packus_epi32(ai, ai);
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
return _mm256_extracti128_si256(ai, 0);
}
}
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
}
#endif // __AVX512F__
#endif // __AVX512BF16__
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
}; // namespace vec_op
#endif

491
csrc/cpu/cpu_types_vsx.hpp Normal file
View File

@ -0,0 +1,491 @@
#ifndef CPU_TYPES_VSX_HPP
#define CPU_TYPES_VSX_HPP
#include <altivec.h>
#include <cmath>
#include <torch/all.h>
namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F &&f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T> struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
typedef struct ss16x8x2_t {
__vector signed short val[2];
} ss16x8x2_t;
typedef struct ss16x8x4_t {
__vector signed short val[4];
} ss16x8x4_t;
typedef struct f32x4x2_t {
__vector float val[2];
} f32x4x2_t;
typedef struct f32x4x4_t {
__vector float val[4];
} f32x4x4_t;
struct FP32Vec8;
struct FP32Vec16;
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__vector signed short reg;
explicit BF16Vec8(const void *ptr)
: reg((__vector signed short)vec_xl(0, (__vector signed short *)ptr)) {}
explicit BF16Vec8(const FP32Vec8 &);
void save(void *ptr) const { *reinterpret_cast<__vector signed short *>(ptr) = reg; }
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
ss16x8x2_t reg;
explicit BF16Vec16(const void *ptr) {
// Load 256 bits in two parts
reg.val[0] = (__vector signed short)vec_xl(0, (signed short *)ptr);
reg.val[1] = (__vector signed short)vec_xl(16, (signed short *)ptr);
}
explicit BF16Vec16(const FP32Vec16 &);
void save(void *ptr) const {
// Save 256 bits in two parts
vec_xst(reg.val[0], 0, (signed short *)ptr);
vec_xst(reg.val[1], 16, (signed short *)ptr);
}
};
const static __vector signed short zero = vec_splats((signed short)0);
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
ss16x8x4_t reg;
explicit BF16Vec32(const void *ptr)
: reg(*reinterpret_cast<const ss16x8x4_t *>(ptr)) {}
explicit BF16Vec32(ss16x8x4_t data) : reg(data) {}
explicit BF16Vec32(const BF16Vec8 &vec8_data) : reg({
vec8_data.reg,
vec8_data.reg,
vec8_data.reg,
vec8_data.reg
}) {}
void save(void *ptr) const { *reinterpret_cast<ss16x8x4_t *>(ptr) = reg; }
};
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__vector float reg;
float values[VEC_ELEM_NUM];
};
__vector float reg;
explicit FP32Vec4(float v) : reg(vec_splats(v)) {}
explicit FP32Vec4() : reg(vec_splats(0.0f)) {}
explicit FP32Vec4(const float *ptr) : reg(vec_xl(0, ptr)) {}
explicit FP32Vec4(__vector float data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
f32x4x2_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x2_t reg;
explicit FP32Vec8(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
}
explicit FP32Vec8() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
}
explicit FP32Vec8(const float *ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
}
explicit FP32Vec8(f32x4x2_t data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8 &data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
}
explicit FP32Vec8(const BF16Vec8 &v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg);
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::exp(ar.values[0]);
ret.val[0][1] = std::exp(ar.values[1]);
ret.val[0][2] = std::exp(ar.values[2]);
ret.val[0][3] = std::exp(ar.values[3]);
ret.val[1][0] = std::exp(ar.values[4]);
ret.val[1][1] = std::exp(ar.values[5]);
ret.val[1][2] = std::exp(ar.values[6]);
ret.val[1][3] = std::exp(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 tanh() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::tanh(ar.values[0]);
ret.val[0][1] = std::tanh(ar.values[1]);
ret.val[0][2] = std::tanh(ar.values[2]);
ret.val[0][3] = std::tanh(ar.values[3]);
ret.val[1][0] = std::tanh(ar.values[4]);
ret.val[1][1] = std::tanh(ar.values[5]);
ret.val[1][2] = std::tanh(ar.values[6]);
ret.val[1][3] = std::tanh(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 er() const {
// TODO: Vectorize this
AliasReg ar;
ar.reg = reg;
f32x4x4_t ret;
ret.val[0][0] = std::erf(ar.values[0]);
ret.val[0][1] = std::erf(ar.values[1]);
ret.val[0][2] = std::erf(ar.values[2]);
ret.val[0][3] = std::erf(ar.values[3]);
ret.val[1][0] = std::erf(ar.values[4]);
ret.val[1][1] = std::erf(ar.values[5]);
ret.val[1][2] = std::erf(ar.values[6]);
ret.val[1][3] = std::erf(ar.values[7]);
return FP32Vec8(f32x4x2_t({ret.val[0], ret.val[1]}));
}
FP32Vec8 operator*(const FP32Vec8 &b) const {
return FP32Vec8({vec_mul(reg.val[0], b.reg.val[0]), vec_mul(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator+(const FP32Vec8 &b) const {
return FP32Vec8({vec_add(reg.val[0], b.reg.val[0]), vec_add(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator-(const FP32Vec8 &b) const {
return FP32Vec8({vec_sub(reg.val[0], b.reg.val[0]), vec_sub(reg.val[1], b.reg.val[1])});
}
FP32Vec8 operator/(const FP32Vec8 &b) const {
return FP32Vec8({vec_div(reg.val[0], b.reg.val[0]), vec_div(reg.val[1], b.reg.val[1])});
}
void save(float *ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
}
};
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
f32x4x4_t reg;
float values[VEC_ELEM_NUM];
};
f32x4x4_t reg;
explicit FP32Vec16(float v) {
reg.val[0] = vec_splats(v);
reg.val[1] = vec_splats(v);
reg.val[2] = vec_splats(v);
reg.val[3] = vec_splats(v);
}
explicit FP32Vec16() {
reg.val[0] = vec_splats(0.0f);
reg.val[1] = vec_splats(0.0f);
reg.val[2] = vec_splats(0.0f);
reg.val[3] = vec_splats(0.0f);
}
explicit FP32Vec16(const float *ptr) {
reg.val[0] = vec_xl(0, ptr);
reg.val[1] = vec_xl(16, ptr);
reg.val[2] = vec_xl(32, ptr);
reg.val[3] = vec_xl(48, ptr);
}
explicit FP32Vec16(f32x4x4_t data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16 &data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[2];
reg.val[3] = data.reg.val[3];
}
explicit FP32Vec16(const FP32Vec4 &data) {
reg.val[0] = data.reg;
reg.val[1] = data.reg;
reg.val[2] = data.reg;
reg.val[3] = data.reg;
}
explicit FP32Vec16(const FP32Vec8 &data) {
reg.val[0] = data.reg.val[0];
reg.val[1] = data.reg.val[1];
reg.val[2] = data.reg.val[0];
reg.val[3] = data.reg.val[1];
}
explicit FP32Vec16(const BF16Vec16 &v) {
reg.val[0] = (__vector float)vec_mergeh(zero, v.reg.val[0]);
reg.val[1] = (__vector float)vec_mergel(zero, v.reg.val[0]);
reg.val[2] = (__vector float)vec_mergeh(zero, v.reg.val[1]);
reg.val[3] = (__vector float)vec_mergel(zero, v.reg.val[1]);
}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(f32x4x4_t({
vec_mul(reg.val[0], b.reg.val[0]),
vec_mul(reg.val[1], b.reg.val[1]),
vec_mul(reg.val[2], b.reg.val[2]),
vec_mul(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(f32x4x4_t({
vec_add(reg.val[0], b.reg.val[0]),
vec_add(reg.val[1], b.reg.val[1]),
vec_add(reg.val[2], b.reg.val[2]),
vec_add(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(f32x4x4_t({
vec_sub(reg.val[0], b.reg.val[0]),
vec_sub(reg.val[1], b.reg.val[1]),
vec_sub(reg.val[2], b.reg.val[2]),
vec_sub(reg.val[3], b.reg.val[3])}));
}
FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(f32x4x4_t({
vec_div(reg.val[0], b.reg.val[0]),
vec_div(reg.val[1], b.reg.val[1]),
vec_div(reg.val[2], b.reg.val[2]),
vec_div(reg.val[3], b.reg.val[3])}));
}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
return result;
}
template <int group_size> float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
AliasReg ar;
ar.reg = reg;
float result = 0;
const int start = idx * group_size;
unroll_loop<int, group_size>(
[&result, &start, ar](int i) { result += ar.values[start + i]; });
return result;
}
void save(float *ptr) const {
vec_xst(reg.val[0], 0, ptr);
vec_xst(reg.val[1], 16, ptr);
vec_xst(reg.val[2], 32, ptr);
vec_xst(reg.val[3], 48, ptr);
}
};
template <typename T> struct VecType { using vec_type = void; };
template <typename T> using vec_t = typename VecType<T>::vec_type;
template <> struct VecType<float> { using vec_type = FP32Vec8; };
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc = acc + a * b;
}
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
reinterpret_cast<c10::BFloat16 *>(&v);
*ptr = *(v_ptr + 1);
}
#ifndef __VEC_CLASS_FP_NAN
#define __VEC_CLASS_FP_NAN (1 << 6)
#endif
const static __vector unsigned char omask = { 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29 };
#ifndef _ARCH_PWR10
const static __vector unsigned int bias = { 0x00007fff, 0x00007fff, 0x00007fff, 0x00007fff };
const static __vector unsigned int nan = { 0x7fc00000, 0x7fc00000, 0x7fc00000, 0x7fc00000 };
const static __vector unsigned int sh16 = { 16, 16, 16, 16 };
const static __vector unsigned int one = { 1, 1, 1, 1 };
#endif
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v) {
#ifdef _ARCH_PWR10
__vector signed short ret[2];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
reg = vec_perm(ret[0], ret[1], omask);
#elif defined(_ARCH_PWR9)
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int lsb0 = vec_sr(inp0, sh16);
__vector unsigned int lsb1 = vec_sr(inp1, sh16);
lsb0 = vec_and(lsb0, one);
lsb1 = vec_and(lsb1, one);
__vector unsigned int rnd0 = vec_add(lsb0, bias);
__vector unsigned int rnd1 = vec_add(lsb1, bias);
inp0 = vec_add(inp0, rnd0);
inp1 = vec_add(inp1, rnd1);
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp0 = vec_sr(inp0, sh16);
inp1 = vec_sr(inp1, sh16);
reg = (__vector signed short)vec_perm(inp0, inp1, omask);
#endif
}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
#ifdef _ARCH_PWR10
__vector signed short ret[4];
ret[0] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[0]);
ret[1] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[1]);
ret[2] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[2]);
ret[3] = (__vector signed short)__builtin_vsx_xvcvspbf16((__vector unsigned char)v.reg.val[3]);
reg.val[0] = vec_perm(ret[0], ret[1], omask);
reg.val[1] = vec_perm(ret[2], ret[3], omask);
#elif defined(_ARCH_PWR9)
__vector unsigned int inp0 = (__vector unsigned int)(v.reg.val[0]);
__vector unsigned int inp1 = (__vector unsigned int)(v.reg.val[1]);
__vector unsigned int inp2 = (__vector unsigned int)(v.reg.val[2]);
__vector unsigned int inp3 = (__vector unsigned int)(v.reg.val[3]);
__vector unsigned int lsb0 = vec_sr(inp0, sh16);
__vector unsigned int lsb1 = vec_sr(inp1, sh16);
__vector unsigned int lsb2 = vec_sr(inp2, sh16);
__vector unsigned int lsb3 = vec_sr(inp3, sh16);
lsb0 = vec_and(lsb0, one);
lsb1 = vec_and(lsb1, one);
lsb2 = vec_and(lsb2, one);
lsb3 = vec_and(lsb3, one);
__vector unsigned int rnd0 = vec_add(lsb0, bias);
__vector unsigned int rnd1 = vec_add(lsb1, bias);
__vector unsigned int rnd2 = vec_add(lsb2, bias);
__vector unsigned int rnd3 = vec_add(lsb3, bias);
inp0 = vec_add(inp0, rnd0);
inp1 = vec_add(inp1, rnd1);
inp2 = vec_add(inp2, rnd2);
inp3 = vec_add(inp3, rnd3);
__vector __bool int sel0 = vec_test_data_class(v.reg.val[0], __VEC_CLASS_FP_NAN);
__vector __bool int sel1 = vec_test_data_class(v.reg.val[1], __VEC_CLASS_FP_NAN);
__vector __bool int sel2 = vec_test_data_class(v.reg.val[2], __VEC_CLASS_FP_NAN);
__vector __bool int sel3 = vec_test_data_class(v.reg.val[3], __VEC_CLASS_FP_NAN);
inp0 = vec_sel(inp0, nan, sel0);
inp1 = vec_sel(inp1, nan, sel1);
inp2 = vec_sel(inp2, nan, sel2);
inp3 = vec_sel(inp3, nan, sel3);
inp0 = vec_sr(inp0, sh16);
inp1 = vec_sr(inp1, sh16);
inp2 = vec_sr(inp2, sh16);
inp3 = vec_sr(inp3, sh16);
reg.val[0] = (__vector signed short)vec_perm(inp0, inp1, omask);
reg.val[1] = (__vector signed short)vec_perm(inp2, inp3, omask);
#endif
}
inline void prefetch(const void *addr) {
__asm__ __volatile__("dcbt 0, %0" : : "r"(addr) : "memory");
}
}; // namespace vec_op
#endif

515
csrc/cpu/cpu_types_x86.hpp Normal file
View File

@ -0,0 +1,515 @@
#ifndef CPU_TYPES_X86_HPP
#define CPU_TYPES_X86_HPP
#include <immintrin.h>
#include <torch/all.h>
#ifndef __AVX2__
static_assert(false, "AVX2 must be supported for the current implementation.");
#endif
namespace vec_op {
// FIXME: FP16 is not fully supported in Torch-CPU
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#ifndef CPU_OP_GUARD
#define CPU_KERNEL_GUARD_IN(NAME)
#define CPU_KERNEL_GUARD_OUT(NAME)
#else
#define CPU_KERNEL_GUARD_IN(NAME) \
std::cout << #NAME << " invoked." << std::endl;
#define CPU_KERNEL_GUARD_OUT(NAME) std::cout << #NAME << " exit." << std::endl;
#endif
#define FORCE_INLINE __attribute__((always_inline)) inline
namespace {
template <typename T, T... indexes, typename F>
constexpr void unroll_loop_item(std::integer_sequence<T, indexes...>, F &&f) {
(f(std::integral_constant<T, indexes>{}), ...);
}
}; // namespace
template <typename T, T count, typename F,
typename = std::enable_if_t<std::is_invocable_v<F, T>>>
constexpr void unroll_loop(F &&f) {
unroll_loop_item(std::make_integer_sequence<T, count>{}, std::forward<F>(f));
}
template <typename T> struct Vec {
constexpr static int get_elem_num() { return T::VEC_ELEM_NUM; }
};
struct FP32Vec8;
struct FP32Vec16;
#ifdef __AVX512FP16__
struct FP16Vec8 : public Vec<FP16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128h reg;
explicit FP16Vec8(_Float16 v) : reg(_mm_set1_ph(v)) {}
explicit FP16Vec8(const void *ptr) : reg(_mm_loadu_ph(ptr)) {}
explicit FP16Vec8(__m128h data) : reg(data) {}
FP16Vec8 operator*(const FP16Vec8 &b) const {
return FP16Vec8(_mm_mul_ph(reg, b.reg));
}
FP16Vec8 operator+(const FP16Vec8 &b) const {
return FP16Vec8(_mm_add_ph(reg, b.reg));
}
FP16Vec8 operator-(const FP16Vec8 &b) const {
return FP16Vec8(_mm_sub_ph(reg, b.reg));
}
FP16Vec8 operator/(const FP16Vec8 &b) const {
return FP16Vec8(_mm_div_ph(reg, b.reg));
}
void save(void *ptr) const { _mm_storeu_ph(ptr, reg); }
};
#endif
struct BF16Vec8 : public Vec<BF16Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
__m128i reg;
explicit BF16Vec8(const void *ptr)
: reg((__m128i)_mm_loadu_si128((__m128i *)ptr)) {}
explicit BF16Vec8(const FP32Vec8 &);
void save(void *ptr) const { *reinterpret_cast<__m128i *>(ptr) = reg; }
};
struct BF16Vec16 : public Vec<BF16Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
__m256i reg;
explicit BF16Vec16(const void *ptr)
: reg((__m256i)_mm256_loadu_si256((__m256i *)ptr)) {}
explicit BF16Vec16(const FP32Vec16 &);
void save(void *ptr) const { *reinterpret_cast<__m256i *>(ptr) = reg; }
};
#ifdef __AVX512F__
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
__m512i reg;
explicit BF16Vec32(const void *ptr) : reg((__m512i)_mm512_loadu_si512(ptr)) {}
explicit BF16Vec32(__m512i data) : reg(data) {}
explicit BF16Vec32(BF16Vec8 &vec8_data)
: reg((__m512i)_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_inserti32x4(_mm512_castsi128_si512(
(__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1),
(__m128i)vec8_data.reg, 2),
(__m128i)vec8_data.reg, 3)) {}
void save(void *ptr) const { *reinterpret_cast<__m512i *>(ptr) = reg; }
};
#else
struct BF16Vec32 : public Vec<BF16Vec32> {
constexpr static int VEC_ELEM_NUM = 32;
__m256i reg_low;
__m256i reg_high;
explicit BF16Vec32(const void *ptr)
: reg_low(_mm256_loadu_si256((__m256i const *)ptr)),
reg_high(_mm256_loadu_si256((__m256i const *)ptr + 1)) {}
explicit BF16Vec32(__m256i low, __m256i high) : reg_low(low),
reg_high(high) {}
explicit BF16Vec32(BF16Vec8 &vec8_data)
: reg_low((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)),
reg_high((__m256i)_mm256_inserti32x4(
_mm256_castsi128_si256((__m128i)vec8_data.reg),
(__m128i)vec8_data.reg, 1)) {}
void save(void *ptr) const {
*reinterpret_cast<__m256i *>(ptr) = reg_low;
*reinterpret_cast<__m256i *>((__m256i *)ptr + 1) = reg_high;
}
};
#endif
struct FP32Vec4 : public Vec<FP32Vec4> {
constexpr static int VEC_ELEM_NUM = 4;
union AliasReg {
__m128 reg;
float values[VEC_ELEM_NUM];
};
__m128 reg;
explicit FP32Vec4(float v) : reg(_mm_set1_ps(v)) {}
explicit FP32Vec4() : reg(_mm_set1_ps(0.0)) {}
explicit FP32Vec4(const float *ptr) : reg(_mm_loadu_ps(ptr)) {}
explicit FP32Vec4(__m128 data) : reg(data) {}
explicit FP32Vec4(const FP32Vec4 &data) : reg(data.reg) {}
};
struct FP32Vec8 : public Vec<FP32Vec8> {
constexpr static int VEC_ELEM_NUM = 8;
union AliasReg {
__m256 reg;
float values[VEC_ELEM_NUM];
};
__m256 reg;
explicit FP32Vec8(float v) : reg(_mm256_set1_ps(v)) {}
explicit FP32Vec8() : reg(_mm256_set1_ps(0.0)) {}
explicit FP32Vec8(const float *ptr) : reg(_mm256_loadu_ps(ptr)) {}
explicit FP32Vec8(__m256 data) : reg(data) {}
explicit FP32Vec8(const FP32Vec8 &data) : reg(data.reg) {}
#ifdef __AVX512FP16__
explicit FP32Vec8(__m128h v) : reg(_mm256_cvtph_ps(_mm_castph_si128(v))) {}
#endif
explicit FP32Vec8(const BF16Vec8 &v)
: reg(_mm256_castsi256_ps(
_mm256_bslli_epi128(_mm256_cvtepu16_epi32(v.reg), 2))) {}
float reduce_sum() const {
AliasReg ar;
ar.reg = reg;
float result = 0;
unroll_loop<int, VEC_ELEM_NUM>([&result, &ar](int i) { result += ar.values[i]; });
return result;
}
FP32Vec8 exp() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(expf(ar.values[7]), expf(ar.values[6]),
expf(ar.values[5]), expf(ar.values[4]),
expf(ar.values[3]), expf(ar.values[2]),
expf(ar.values[1]), expf(ar.values[0])));
}
FP32Vec8 tanh() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(tanhf(ar.values[7]), tanhf(ar.values[6]),
tanhf(ar.values[5]), tanhf(ar.values[4]),
tanhf(ar.values[3]), tanhf(ar.values[2]),
tanhf(ar.values[1]), tanhf(ar.values[0])));
}
FP32Vec8 er() const {
AliasReg ar;
ar.reg = reg;
return FP32Vec8(_mm256_set_ps(erf(ar.values[7]), erf(ar.values[6]),
erf(ar.values[5]), erf(ar.values[4]),
erf(ar.values[3]), erf(ar.values[2]),
erf(ar.values[1]), erf(ar.values[0])));
}
FP32Vec8 operator*(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_mul_ps(reg, b.reg));
}
FP32Vec8 operator+(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_add_ps(reg, b.reg));
}
FP32Vec8 operator-(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_sub_ps(reg, b.reg));
}
FP32Vec8 operator/(const FP32Vec8 &b) const {
return FP32Vec8(_mm256_div_ps(reg, b.reg));
}
void save(float *ptr) const { _mm256_storeu_ps(ptr, reg); }
};
#ifdef __AVX512F__
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m512 reg;
float values[VEC_ELEM_NUM];
};
__m512 reg;
explicit FP32Vec16(float v) : reg(_mm512_set1_ps(v)) {}
explicit FP32Vec16() : reg(_mm512_set1_ps(0.0)) {}
explicit FP32Vec16(const float *ptr) : reg(_mm512_loadu_ps(ptr)) {}
explicit FP32Vec16(__m512 data) : reg(data) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg(data.reg) {}
explicit FP32Vec16(const FP32Vec4 &data)
: reg((__m512)_mm512_inserti32x4(
_mm512_inserti32x4(
_mm512_inserti32x4(_mm512_castsi128_si512((__m128i)data.reg),
(__m128i)data.reg, 1),
(__m128i)data.reg, 2),
(__m128i)data.reg, 3)) {}
explicit FP32Vec16(const FP32Vec8 &data)
: reg((__m512)_mm512_inserti32x8(
_mm512_castsi256_si512((__m256i)data.reg), (__m256i)data.reg, 1)) {}
explicit FP32Vec16(const BF16Vec16 &v)
: reg(_mm512_castsi512_ps(
_mm512_bslli_epi128(_mm512_cvtepu16_epi32(v.reg), 2))) {}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_mul_ps(reg, b.reg));
}
FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_add_ps(reg, b.reg));
}
FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_sub_ps(reg, b.reg));
}
FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(_mm512_div_ps(reg, b.reg));
}
float reduce_sum() const { return _mm512_reduce_add_ps(reg); }
template <int group_size> float reduce_sub_sum(int idx) {
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
__mmask16 mask = _cvtu32_mask16(base_mask << (idx * group_size));
return _mm512_mask_reduce_add_ps(mask, reg);
}
void save(float *ptr) const { _mm512_storeu_ps(ptr, reg); }
};
#else
struct FP32Vec16 : public Vec<FP32Vec16> {
constexpr static int VEC_ELEM_NUM = 16;
union AliasReg {
__m256 reg;
float values[8];
};
__m256 reg_low;
__m256 reg_high;
explicit FP32Vec16(float v) : reg_low(_mm256_set1_ps(v)),
reg_high(_mm256_set1_ps(v)) {}
explicit FP32Vec16() : reg_low(_mm256_set1_ps(0.0)),
reg_high(_mm256_set1_ps(0.0)) {}
explicit FP32Vec16(const float *ptr) : reg_low(_mm256_loadu_ps(ptr)),
reg_high(_mm256_loadu_ps(ptr + 8)) {}
explicit FP32Vec16(__m256 low, __m256 high) : reg_low(low), reg_high(high) {}
explicit FP32Vec16(const FP32Vec16 &data) : reg_low(data.reg_low),
reg_high(data.reg_high) {}
explicit FP32Vec16(const FP32Vec4 &data)
: reg_low((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg),
(__m128i)data.reg, 1)),
reg_high((__m256)_mm256_inserti128_si256(
_mm256_castsi128_si256((__m128i)data.reg),
(__m128i)data.reg, 1)) {}
explicit FP32Vec16(const FP32Vec8 &data)
: reg_low(data.reg), reg_high(data.reg) {}
explicit FP32Vec16(const BF16Vec16 &v) {
__m128i low = _mm256_extractf128_si256(v.reg, 0);
__m128i high = _mm256_extractf128_si256(v.reg, 1);
__m256i v_low_epi32 = _mm256_cvtepu16_epi32(low);
__m256i v_high_epi32 = _mm256_cvtepu16_epi32(high);
__m256i v_low_shifted = _mm256_bslli_epi128(v_low_epi32, 2);
__m256i v_high_shifted = _mm256_bslli_epi128(v_high_epi32, 2);
reg_low = _mm256_castsi256_ps(v_low_shifted);
reg_high = _mm256_castsi256_ps(v_high_shifted);
}
explicit FP32Vec16(const BF16Vec8 &v) : FP32Vec16(FP32Vec8(v)) {}
FP32Vec16 operator*(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_mul_ps(reg_low, b.reg_low),
_mm256_mul_ps(reg_high, b.reg_high));
}
FP32Vec16 operator+(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_add_ps(reg_low, b.reg_low),
_mm256_add_ps(reg_high, b.reg_high));
}
FP32Vec16 operator-(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_sub_ps(reg_low, b.reg_low),
_mm256_sub_ps(reg_high, b.reg_high));
}
FP32Vec16 operator/(const FP32Vec16 &b) const {
return FP32Vec16(_mm256_div_ps(reg_low, b.reg_low),
_mm256_div_ps(reg_high, b.reg_high));
}
float reduce_sum() const {
FP32Vec8 low = FP32Vec8(reg_low);
FP32Vec8 high = FP32Vec8(reg_high);
return low.reduce_sum() + high.reduce_sum();
}
template <int group_size> float reduce_sub_sum(int idx) {
float sum = 0.0;
static_assert(VEC_ELEM_NUM % group_size == 0);
constexpr uint32_t base_mask = (0xFFFF >> (16 - group_size));
uint32_t mask = base_mask << (idx * group_size);
AliasReg ar;
auto func = [&sum, &mask, &ar](int i) {
int flag = mask & 0x1;
mask = mask >> 1;
if (flag != 0) sum += ar.values[i];
};
ar.reg = reg_low;
unroll_loop<int, 8>(func);
ar.reg = reg_high;
unroll_loop<int, 8>(func);
return sum;
}
void save(float *ptr) const {
_mm256_storeu_ps(ptr, reg_low);
_mm256_storeu_ps(ptr + 8, reg_high);
}
};
#endif
template <typename T> struct VecType { using vec_type = void; };
template <typename T> using vec_t = typename VecType<T>::vec_type;
template <> struct VecType<float> { using vec_type = FP32Vec8; };
#ifdef __AVX512FP16__
template <> struct VecType<c10::Half> { using vec_type = FP16Vec16; };
#endif
template <> struct VecType<c10::BFloat16> { using vec_type = BF16Vec8; };
template <typename T> void storeFP32(float v, T *ptr) { *ptr = v; }
#ifdef __AVX512FP16__
template <> inline void storeFP32<c10::Half>(float v, c10::Half *ptr) {
*reinterpret_cast<_Float16 *>(ptr) = v;
}
#endif
inline void fma(FP32Vec16 &acc, FP32Vec16 &a, FP32Vec16 &b) {
acc = acc + a * b;
}
#ifdef __AVX512BF16__
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
*reinterpret_cast<__bfloat16 *>(ptr) = _mm_cvtness_sbh(v);
}
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg((__m128i)_mm256_cvtneps_pbh(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg((__m256i)_mm512_cvtneps_pbh(v.reg)) {}
inline void fma(FP32Vec16 &acc, BF16Vec32 &a, BF16Vec32 &b) {
acc.reg = _mm512_dpbf16_ps(acc.reg, (__m512bh)a.reg, (__m512bh)b.reg);
}
#else
template <> inline void storeFP32<c10::BFloat16>(float v, c10::BFloat16 *ptr) {
c10::BFloat16 __attribute__((__may_alias__)) *v_ptr =
reinterpret_cast<c10::BFloat16 *>(&v);
*ptr = *(v_ptr + 1);
}
#ifdef __AVX512F__
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(_mm256_cvtepi32_epi16(
_mm256_bsrli_epi128(_mm256_castps_si256(v.reg), 2))) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v)
: reg(_mm512_cvtepi32_epi16(
_mm512_bsrli_epi128(_mm512_castps_si512(v.reg), 2))) {}
#else
namespace{
__m128i FP32Vec8_to_BF16Vec8_avx2(__m256 a) {
__m256i ai = _mm256_castps_si256(a);
ai = _mm256_srli_epi32(ai, 16);
ai = _mm256_packus_epi32(ai, ai);
ai = _mm256_permute4x64_epi64(ai, 0b00111001);
return _mm256_extracti128_si256(ai, 0);
}
}
inline BF16Vec8::BF16Vec8(const FP32Vec8 &v)
: reg(FP32Vec8_to_BF16Vec8_avx2(v.reg)) {}
inline BF16Vec16::BF16Vec16(const FP32Vec16 &v) {
BF16Vec8 low = BF16Vec8(FP32Vec8(v.reg_low));
BF16Vec8 high = BF16Vec8(FP32Vec8(v.reg_high));
reg = _mm256_insertf128_si256(_mm256_castsi128_si256(low.reg), high.reg, 1);
}
#endif // __AVX512F__
#endif // __AVX512BF16__
inline void prefetch(const void *addr) { _mm_prefetch(addr, _MM_HINT_T1); }
}; // namespace vec_op
#endif

View File

@ -1,5 +1,6 @@
#pragma once
#include <optional>
#include <torch/library.h>
void paged_attention_v1(

View File

@ -2,6 +2,6 @@
-r requirements-common.txt
# Dependencies for x86_64 CPUs
torch == 2.3.1+cpu
torchvision == 0.18.1+cpu # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.
torch == 2.3.1+cpu; platform_machine != "ppc64le"
torchvision == 0.18.1+cpu; platform_machine != "ppc64le" # required for the image processor of phi3v, this must be updated alongside torch
triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error.