2023-12-15 01:35:58 +08:00
|
|
|
#pragma once
|
|
|
|
|
2023-11-24 08:31:19 +08:00
|
|
|
#include <torch/extension.h>
|
|
|
|
|
|
|
|
void paged_attention_v1(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& query,
|
|
|
|
torch::Tensor& key_cache,
|
|
|
|
torch::Tensor& value_cache,
|
2023-12-11 02:12:53 +08:00
|
|
|
int num_kv_heads,
|
2023-11-24 08:31:19 +08:00
|
|
|
float scale,
|
|
|
|
torch::Tensor& block_tables,
|
|
|
|
torch::Tensor& context_lens,
|
|
|
|
int block_size,
|
|
|
|
int max_context_len,
|
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes);
|
|
|
|
|
|
|
|
void paged_attention_v2(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& exp_sums,
|
|
|
|
torch::Tensor& max_logits,
|
|
|
|
torch::Tensor& tmp_out,
|
|
|
|
torch::Tensor& query,
|
|
|
|
torch::Tensor& key_cache,
|
|
|
|
torch::Tensor& value_cache,
|
2023-12-11 02:12:53 +08:00
|
|
|
int num_kv_heads,
|
2023-11-24 08:31:19 +08:00
|
|
|
float scale,
|
|
|
|
torch::Tensor& block_tables,
|
|
|
|
torch::Tensor& context_lens,
|
|
|
|
int block_size,
|
|
|
|
int max_context_len,
|
|
|
|
const c10::optional<torch::Tensor>& alibi_slopes);
|
|
|
|
|
|
|
|
void rms_norm(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input,
|
|
|
|
torch::Tensor& weight,
|
|
|
|
float epsilon);
|
|
|
|
|
|
|
|
void fused_add_rms_norm(
|
|
|
|
torch::Tensor& input,
|
|
|
|
torch::Tensor& residual,
|
|
|
|
torch::Tensor& weight,
|
|
|
|
float epsilon);
|
|
|
|
|
|
|
|
void rotary_embedding(
|
|
|
|
torch::Tensor& positions,
|
|
|
|
torch::Tensor& query,
|
|
|
|
torch::Tensor& key,
|
|
|
|
int head_size,
|
|
|
|
torch::Tensor& cos_sin_cache,
|
|
|
|
bool is_neox);
|
|
|
|
|
|
|
|
void silu_and_mul(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input);
|
|
|
|
|
|
|
|
void gelu_new(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input);
|
|
|
|
|
|
|
|
void gelu_fast(
|
|
|
|
torch::Tensor& out,
|
|
|
|
torch::Tensor& input);
|
|
|
|
|
2023-12-08 15:16:52 +08:00
|
|
|
#ifndef USE_ROCM
|
2023-11-24 08:31:19 +08:00
|
|
|
torch::Tensor awq_gemm(
|
|
|
|
torch::Tensor _in_feats,
|
|
|
|
torch::Tensor _kernel,
|
|
|
|
torch::Tensor _scaling_factors,
|
|
|
|
torch::Tensor _zeros,
|
|
|
|
int split_k_iters);
|
2023-12-08 15:16:52 +08:00
|
|
|
#endif
|
2023-11-24 08:31:19 +08:00
|
|
|
|
|
|
|
void squeezellm_gemm(
|
|
|
|
torch::Tensor vec,
|
|
|
|
torch::Tensor mat,
|
|
|
|
torch::Tensor mul,
|
|
|
|
torch::Tensor lookup_table);
|