mirror of https://github.com/vllm-project/vllm
Change the name to vLLM (#150)
This commit is contained in:
parent
e5464ee484
commit
0b98ba15c7
|
@ -1,6 +1,6 @@
|
|||
# Contributing to CacheFlow
|
||||
# Contributing to vLLM
|
||||
|
||||
Thank you for your interest in contributing to CacheFlow!
|
||||
Thank you for your interest in contributing to vLLM!
|
||||
Our community is open to everyone and welcomes all kinds of contributions, no matter how small or large.
|
||||
There are several ways you can contribute to the project:
|
||||
|
||||
|
@ -11,9 +11,9 @@ There are several ways you can contribute to the project:
|
|||
However, remember that contributions aren't just about code.
|
||||
We believe in the power of community support; thus, answering queries, assisting others, and enhancing the documentation are highly regarded and beneficial contributions.
|
||||
|
||||
Finally, one of the most impactful ways to support us is by raising awareness about CacheFlow.
|
||||
Finally, one of the most impactful ways to support us is by raising awareness about vLLM.
|
||||
Talk about it in your blog posts, highlighting how it's driving your incredible projects.
|
||||
Express your support on Twitter if CacheFlow aids you, or simply offer your appreciation by starring our repository.
|
||||
Express your support on Twitter if vLLM aids you, or simply offer your appreciation by starring our repository.
|
||||
|
||||
|
||||
## Setup for development
|
||||
|
@ -70,5 +70,5 @@ If a comment isn't clear or you disagree with a suggestion, feel free to ask for
|
|||
|
||||
### Thank You
|
||||
|
||||
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to CacheFlow.
|
||||
Your contributions make CacheFlow a great tool for everyone!
|
||||
Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM.
|
||||
Your contributions make vLLM a great tool for everyone!
|
||||
|
|
10
README.md
10
README.md
|
@ -1,4 +1,4 @@
|
|||
# CacheFlow
|
||||
# vLLM
|
||||
|
||||
## Build from source
|
||||
|
||||
|
@ -28,7 +28,7 @@ python examples/simple_server.py --help
|
|||
To start the server:
|
||||
```bash
|
||||
ray start --head
|
||||
python -m cacheflow.entrypoints.fastapi_server # --model <your_model>
|
||||
python -m vllm.entrypoints.fastapi_server # --model <your_model>
|
||||
```
|
||||
|
||||
To test the server:
|
||||
|
@ -45,9 +45,9 @@ pip install gradio
|
|||
|
||||
Start the server:
|
||||
```bash
|
||||
python -m cacheflow.http_frontend.fastapi_frontend
|
||||
python -m vllm.http_frontend.fastapi_frontend
|
||||
# At another terminal
|
||||
python -m cacheflow.http_frontend.gradio_webserver
|
||||
python -m vllm.http_frontend.gradio_webserver
|
||||
```
|
||||
|
||||
## Load LLaMA weights
|
||||
|
@ -62,5 +62,5 @@ Since LLaMA weight is not fully public, we cannot directly download the LLaMA we
|
|||
2. For all the commands above, specify the model with `--model /output/path/llama-7b` to load the model. For example:
|
||||
```bash
|
||||
python simple_server.py --model /output/path/llama-7b
|
||||
python -m cacheflow.http_frontend.fastapi_frontend --model /output/path/llama-7b
|
||||
python -m vllm.http_frontend.fastapi_frontend --model /output/path/llama-7b
|
||||
```
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# Benchmarking CacheFlow
|
||||
# Benchmarking vLLM
|
||||
|
||||
## Downloading the ShareGPT dataset
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ def main(args: argparse.Namespace):
|
|||
for i in range(args.n_threads)]
|
||||
|
||||
api_url = f"http://{args.host}:{args.port}/generate"
|
||||
headers = {"User-Agent": "CacheFlow Benchmark Client"}
|
||||
headers = {"User-Agent": "vLLM Benchmark Client"}
|
||||
ploads = [{
|
||||
"prompt": p,
|
||||
"max_tokens": args.max_tokens,
|
||||
|
|
|
@ -6,7 +6,7 @@ import numpy as np
|
|||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from cacheflow import LLM, SamplingParams
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
"""Benchmark online serving throughput.
|
||||
|
||||
On the server side, run one of the following commands:
|
||||
(CacheFlow backend)
|
||||
python -m cacheflow.entrypoints.api_server \
|
||||
(vLLM backend)
|
||||
python -m vllm.entrypoints.api_server \
|
||||
--disable-log-requests --model <your_model>
|
||||
|
||||
(TGI backend)
|
||||
|
@ -114,7 +114,7 @@ async def send_request(
|
|||
request_start_time = time.time()
|
||||
|
||||
headers = {"User-Agent": "Benchmark Client"}
|
||||
if backend == "cacheflow":
|
||||
if backend == "vllm":
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"n": 1,
|
||||
|
@ -213,8 +213,8 @@ def main(args: argparse.Namespace):
|
|||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark the online serving throughput.")
|
||||
parser.add_argument("--backend", type=str, default="cacheflow",
|
||||
choices=["cacheflow", "tgi"])
|
||||
parser.add_argument("--backend", type=str, default="vllm",
|
||||
choices=["vllm", "tgi"])
|
||||
parser.add_argument("--host", type=str, default="localhost")
|
||||
parser.add_argument("--port", type=int, default=8001)
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
|
|
|
@ -5,12 +5,13 @@ import random
|
|||
import time
|
||||
from typing import List, Tuple
|
||||
|
||||
from cacheflow import LLM, SamplingParams
|
||||
import torch
|
||||
from transformers import (AutoConfig, AutoTokenizer, AutoModelForCausalLM,
|
||||
PreTrainedTokenizerBase)
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
def get_tokenizer(model_name: str) -> PreTrainedTokenizerBase:
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
|
@ -70,7 +71,7 @@ def sample_requests(
|
|||
return sampled_requests
|
||||
|
||||
|
||||
def run_cacheflow(
|
||||
def run_vllm(
|
||||
requests: List[Tuple[str, int, int]],
|
||||
model: str,
|
||||
tensor_parallel_size: int,
|
||||
|
@ -172,8 +173,8 @@ def main(args: argparse.Namespace):
|
|||
tokenizer = get_tokenizer(args.model)
|
||||
requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
|
||||
|
||||
if args.backend == "cacheflow":
|
||||
elapsed_time = run_cacheflow(
|
||||
if args.backend == "vllm":
|
||||
elapsed_time = run_vllm(
|
||||
requests, args.model, args.tensor_parallel_size, args.seed, args.n,
|
||||
args.use_beam_search)
|
||||
elif args.backend == "hf":
|
||||
|
@ -192,8 +193,8 @@ def main(args: argparse.Namespace):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Benchmark the throughput.")
|
||||
parser.add_argument("--backend", type=str, choices=["cacheflow", "hf"],
|
||||
default="cacheflow")
|
||||
parser.add_argument("--backend", type=str, choices=["vllm", "hf"],
|
||||
default="vllm")
|
||||
parser.add_argument("--dataset", type=str, required=True,
|
||||
help="Path to the dataset.")
|
||||
parser.add_argument("--model", type=str, default="facebook/opt-125m")
|
||||
|
@ -207,7 +208,7 @@ if __name__ == "__main__":
|
|||
parser.add_argument("--hf-max-batch-size", type=int, default=None,
|
||||
help="Maximum batch size for HF backend.")
|
||||
args = parser.parse_args()
|
||||
if args.backend == "cacheflow":
|
||||
if args.backend == "vllm":
|
||||
if args.hf_max_batch_size is not None:
|
||||
raise ValueError("HF max batch size is only for HF backend.")
|
||||
elif args.backend == "hf":
|
||||
|
|
|
@ -1,18 +0,0 @@
|
|||
from cacheflow.engine.arg_utils import EngineArgs
|
||||
from cacheflow.engine.llm_engine import LLMEngine
|
||||
from cacheflow.engine.ray_utils import initialize_cluster
|
||||
from cacheflow.entrypoints.llm import LLM
|
||||
from cacheflow.outputs import CompletionOutput, RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"LLMEngine",
|
||||
"EngineArgs",
|
||||
"initialize_cluster",
|
||||
]
|
|
@ -1,10 +0,0 @@
|
|||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.model_loader import get_model
|
||||
from cacheflow.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"get_model",
|
||||
"set_random_seed",
|
||||
]
|
|
@ -1,12 +0,0 @@
|
|||
from cacheflow.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from cacheflow.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||
from cacheflow.model_executor.models.llama import LlamaForCausalLM
|
||||
from cacheflow.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GPT2LMHeadModel",
|
||||
"GPTNeoXForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
]
|
|
@ -1,7 +1,7 @@
|
|||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__device__ __forceinline__ T silu(const T& x) {
|
||||
|
@ -22,7 +22,7 @@ __global__ void silu_and_mul_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
||||
void silu_and_mul(
|
||||
torch::Tensor& out, // [num_tokens, d]
|
||||
|
@ -40,7 +40,7 @@ void silu_and_mul(
|
|||
input.scalar_type(),
|
||||
"silu_and_mul_kernel",
|
||||
[&] {
|
||||
cacheflow::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
vllm::silu_and_mul_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
d);
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The CacheFlow team.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -19,7 +19,7 @@
|
|||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
// A vector type to store Q, K, V elements.
|
||||
template<typename T, int VEC_SIZE>
|
||||
|
@ -61,4 +61,4 @@ inline __device__ void zero(T& dst) {
|
|||
dst = tmp.raw;
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Copyright (c) 2023, The CacheFlow team.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -27,7 +27,7 @@
|
|||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
// Utility function for attention softmax.
|
||||
template<int NUM_WARPS>
|
||||
|
@ -315,10 +315,10 @@ __global__ void single_query_cached_kv_attention_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_ATTENTION_KERNEL(T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS) \
|
||||
cacheflow::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
vllm::single_query_cached_kv_attention_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* Copyright (c) 2023, The CacheFlow team.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -22,7 +22,7 @@
|
|||
#include <float.h>
|
||||
#include <type_traits>
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
// Q*K^T operation.
|
||||
template<int THREAD_GROUP_SIZE, typename Vec, int N>
|
||||
|
@ -52,4 +52,4 @@ struct Qk_dot {
|
|||
}
|
||||
};
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The CacheFlow team.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -25,7 +25,7 @@
|
|||
#include <cuda_fp16.h>
|
||||
#include <stdint.h>
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
// Define custom BF16 vector data types.
|
||||
struct bf16_4_t {
|
||||
|
@ -420,4 +420,4 @@ inline __device__ void from_float(bf16_8_t& dst, Float8_ src) {
|
|||
#endif
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The CacheFlow team.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -23,7 +23,7 @@
|
|||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
// FP16 vector types for Q, K, V.
|
||||
template<>
|
||||
|
@ -441,4 +441,4 @@ inline __device__ Float8_ to_float(uint4 u) {
|
|||
return tmp;
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
|
||||
* and https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
|
||||
* Copyright (c) 2023, The CacheFlow team.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -22,7 +22,7 @@
|
|||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
// Define custom FP32 vector data types.
|
||||
struct Float4_ {
|
||||
|
@ -265,4 +265,4 @@ inline __device__ Float8_ to_float(Float8_ u) {
|
|||
return u;
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
|
|
@ -46,7 +46,7 @@ void swap_blocks(
|
|||
}
|
||||
}
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
// Grid: (num_layers, num_pairs)
|
||||
template<typename scalar_t>
|
||||
|
@ -77,7 +77,7 @@ __global__ void copy_blocks_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
||||
void copy_blocks(
|
||||
std::vector<torch::Tensor>& key_caches,
|
||||
|
@ -129,7 +129,7 @@ void copy_blocks(
|
|||
at::ScalarType::Half,
|
||||
at::ScalarType::BFloat16,
|
||||
key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
|
||||
cacheflow::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
value_cache_ptrs_tensor.data_ptr<int64_t>(),
|
||||
block_mapping_tensor.data_ptr<int>(),
|
||||
|
@ -137,7 +137,7 @@ void copy_blocks(
|
|||
}));
|
||||
}
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void reshape_and_cache_kernel(
|
||||
|
@ -181,7 +181,7 @@ __global__ void reshape_and_cache_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
|
@ -208,7 +208,7 @@ void reshape_and_cache(
|
|||
key.scalar_type(),
|
||||
"reshape_and_cache_kernel",
|
||||
[&] {
|
||||
cacheflow::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
|
@ -223,7 +223,7 @@ void reshape_and_cache(
|
|||
});
|
||||
}
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
// Grid: (num_blocks, block_size).
|
||||
template<typename scalar_t>
|
||||
|
@ -343,7 +343,7 @@ __global__ void gather_cached_kv_kernel_optimized(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
||||
void gather_cached_kv(
|
||||
torch::Tensor& key, // [out] [num_tokens, num_heads, head_size]
|
||||
|
@ -370,7 +370,7 @@ void gather_cached_kv(
|
|||
key.scalar_type(),
|
||||
"gather_cached_kv_kernel_optimized",
|
||||
[&] {
|
||||
cacheflow::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
||||
vllm::gather_cached_kv_kernel_optimized<scalar_t><<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(),
|
||||
value.data_ptr<scalar_t>(),
|
||||
key_cache.data_ptr<scalar_t>(),
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
|
||||
#include "reduction_utils.cuh"
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template<typename scalar_t>
|
||||
|
@ -33,7 +33,7 @@ __global__ void rms_norm_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
||||
void rms_norm(
|
||||
torch::Tensor& out, // [num_tokens, hidden_size]
|
||||
|
@ -52,7 +52,7 @@ void rms_norm(
|
|||
input.scalar_type(),
|
||||
"rms_norm_kernel",
|
||||
[&] {
|
||||
cacheflow::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(),
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
template<typename scalar_t>
|
||||
__global__ void rotary_embedding_neox_kernel(
|
||||
|
@ -46,7 +46,7 @@ __global__ void rotary_embedding_neox_kernel(
|
|||
}
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
||||
void rotary_embedding_neox(
|
||||
torch::Tensor& positions, // [num_tokens]
|
||||
|
@ -70,7 +70,7 @@ void rotary_embedding_neox(
|
|||
query.scalar_type(),
|
||||
"rotary_embedding_neox",
|
||||
[&] {
|
||||
cacheflow::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
vllm::rotary_embedding_neox_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
positions.data_ptr<int64_t>(),
|
||||
query.data_ptr<scalar_t>(),
|
||||
key.data_ptr<scalar_t>(),
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
/*
|
||||
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
|
||||
* Copyright (c) 2023, The CacheFlow team.
|
||||
* Copyright (c) 2023, The vLLM team.
|
||||
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -17,7 +17,7 @@
|
|||
*/
|
||||
#pragma once
|
||||
|
||||
namespace cacheflow {
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
|
@ -48,4 +48,4 @@ __inline__ __device__ T blockReduceSum(T val) {
|
|||
return val;
|
||||
}
|
||||
|
||||
} // namespace cacheflow
|
||||
} // namespace vllm
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
# CacheFlow documents
|
||||
# vLLM documents
|
||||
|
||||
## Build the docs
|
||||
|
||||
|
|
|
@ -17,9 +17,9 @@
|
|||
|
||||
# -- Project information -----------------------------------------------------
|
||||
|
||||
project = 'CacheFlow'
|
||||
copyright = '2023, CacheFlow Team'
|
||||
author = 'the CacheFlow Team'
|
||||
project = 'vLLM'
|
||||
copyright = '2023, vLLM Team'
|
||||
author = 'the vLLM Team'
|
||||
|
||||
|
||||
# -- General configuration ---------------------------------------------------
|
||||
|
@ -55,7 +55,7 @@ html_title = project
|
|||
html_theme = 'sphinx_book_theme'
|
||||
html_theme_options = {
|
||||
'path_to_docs': 'docs/source',
|
||||
'repository_url': 'https://github.com/WoosukKwon/cacheflow',
|
||||
'repository_url': 'https://github.com/WoosukKwon/vllm',
|
||||
'use_repository_button': True,
|
||||
}
|
||||
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
Installation
|
||||
============
|
||||
|
||||
CacheFlow is a Python library that includes some C++ and CUDA code.
|
||||
CacheFlow can run on systems that meet the following requirements:
|
||||
vLLM is a Python library that includes some C++ and CUDA code.
|
||||
vLLM can run on systems that meet the following requirements:
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 or higher
|
||||
|
@ -10,23 +10,23 @@ CacheFlow can run on systems that meet the following requirements:
|
|||
* GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, etc.)
|
||||
|
||||
.. note::
|
||||
As of now, CacheFlow does not support CUDA 12.
|
||||
As of now, vLLM does not support CUDA 12.
|
||||
If you are using Hopper or Lovelace GPUs, please use CUDA 11.8.
|
||||
|
||||
.. tip::
|
||||
If you have trouble installing CacheFlow, we recommend using the NVIDIA PyTorch Docker image.
|
||||
If you have trouble installing vLLM, we recommend using the NVIDIA PyTorch Docker image.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ # Pull the Docker image with CUDA 11.8.
|
||||
$ docker run --gpus all -it --rm --shm-size=8g nvcr.io/nvidia/pytorch:22.12-py3
|
||||
|
||||
Inside the Docker container, please execute :code:`pip uninstall torch` before installing CacheFlow.
|
||||
Inside the Docker container, please execute :code:`pip uninstall torch` before installing vLLM.
|
||||
|
||||
Install with pip
|
||||
----------------
|
||||
|
||||
You can install CacheFlow using pip:
|
||||
You can install vLLM using pip:
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
|
@ -34,8 +34,8 @@ You can install CacheFlow using pip:
|
|||
$ conda create -n myenv python=3.8 -y
|
||||
$ conda activate myenv
|
||||
|
||||
$ # Install CacheFlow.
|
||||
$ pip install cacheflow # This may take 5-10 minutes.
|
||||
$ # Install vLLM.
|
||||
$ pip install vllm # This may take 5-10 minutes.
|
||||
|
||||
|
||||
.. _build_from_source:
|
||||
|
@ -43,10 +43,10 @@ You can install CacheFlow using pip:
|
|||
Build from source
|
||||
-----------------
|
||||
|
||||
You can also build and install CacheFlow from source.
|
||||
You can also build and install vLLM from source.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ git clone https://github.com/WoosukKwon/cacheflow.git
|
||||
$ cd cacheflow
|
||||
$ git clone https://github.com/WoosukKwon/vllm.git
|
||||
$ cd vllm
|
||||
$ pip install -e . # This may take 5-10 minutes.
|
||||
|
|
|
@ -8,7 +8,7 @@ Placeholder.
|
|||
|
||||
.. code-block:: python
|
||||
|
||||
from cacheflow import LLM, SamplingParams
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
# Sample prompts.
|
||||
prompts = [
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
Welcome to CacheFlow!
|
||||
=====================
|
||||
Welcome to vLLM!
|
||||
================
|
||||
|
||||
Documentation
|
||||
-------------
|
||||
|
|
|
@ -3,30 +3,30 @@
|
|||
Adding a New Model
|
||||
==================
|
||||
|
||||
This document provides a high-level guide on integrating a `HuggingFace Transformers <https://github.com/huggingface/transformers>`_ model into CacheFlow.
|
||||
This document provides a high-level guide on integrating a `HuggingFace Transformers <https://github.com/huggingface/transformers>`_ model into vLLM.
|
||||
|
||||
.. note::
|
||||
The complexity of adding a new model depends heavily on the model's architecture.
|
||||
The process is considerably straightforward if the model shares a similar architecture with an existing model in CacheFlow.
|
||||
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
|
||||
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
|
||||
|
||||
.. tip::
|
||||
If you are encountering issues while integrating your model into CacheFlow, feel free to open an issue on our `GitHub <https://github.com/WoosukKwon/cacheflow/issues>`_ repository.
|
||||
If you are encountering issues while integrating your model into vLLM, feel free to open an issue on our `GitHub <https://github.com/WoosukKwon/vllm/issues>`_ repository.
|
||||
We will be happy to help you out!
|
||||
|
||||
|
||||
0. Fork the CacheFlow repository
|
||||
0. Fork the vLLM repository
|
||||
--------------------------------
|
||||
|
||||
Start by forking our `GitHub <https://github.com/WoosukKwon/cacheflow/issues>`_ repository and then :ref:`build it from source <build_from_source>`.
|
||||
Start by forking our `GitHub <https://github.com/WoosukKwon/vllm/issues>`_ repository and then :ref:`build it from source <build_from_source>`.
|
||||
This gives you the ability to modify the codebase and test your model.
|
||||
|
||||
|
||||
1. Bring your model code
|
||||
------------------------
|
||||
|
||||
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `cacheflow/model_executor/models <https://github.com/WoosukKwon/cacheflow/tree/main/cacheflow/model_executor/models>`_ directory.
|
||||
For instance, CacheFlow's `OPT model <https://github.com/WoosukKwon/cacheflow/blob/main/cacheflow/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
|
||||
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `vllm/model_executor/models <https://github.com/WoosukKwon/vllm/tree/main/vllm/model_executor/models>`_ directory.
|
||||
For instance, vLLM's `OPT model <https://github.com/WoosukKwon/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
|
||||
|
||||
.. warning::
|
||||
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
|
||||
|
@ -62,11 +62,11 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
|
|||
+) -> Dict[int, SequenceOutputs]:
|
||||
|
||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
|
||||
4. Replace the attention operation with either :code:`GPTCacheFlowAttention` or :code:`GPTNeoXCacheFlowAttention`, depending on the model's architecture.
|
||||
4. Replace the attention operation with either :code:`GPTPagedAttention` or :code:`GPTNeoXPagedAttention`, depending on the model's architecture.
|
||||
|
||||
.. note::
|
||||
Currently, CacheFlow supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
|
||||
If your model employs a different attention mechanism, you will need to implement a new attention layer in CacheFlow.
|
||||
Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
|
||||
If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM.
|
||||
|
||||
|
||||
3. (Optional) Implement tensor parallelism support
|
||||
|
@ -91,4 +91,4 @@ While the process is straightforward for most layers, the tensor-parallel layers
|
|||
5. Register your model
|
||||
----------------------
|
||||
|
||||
Finally, include your :code:`*ForCausalLM` class in `cacheflow/model_executor/models/__init__.py <https://github.com/WoosukKwon/cacheflow/blob/main/cacheflow/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `cacheflow/model_executor/model_loader.py <https://github.com/WoosukKwon/cacheflow/blob/main/cacheflow/model_executor/model_loader.py>`_.
|
||||
Finally, include your :code:`*ForCausalLM` class in `vllm/model_executor/models/__init__.py <https://github.com/WoosukKwon/vllm/blob/main/vllm/model_executor/models/__init__.py>`_ and register it to the :code:`_MODEL_REGISTRY` in `vllm/model_executor/model_loader.py <https://github.com/WoosukKwon/vllm/blob/main/vllm/model_executor/model_loader.py>`_.
|
||||
|
|
|
@ -3,8 +3,8 @@
|
|||
Supported Models
|
||||
================
|
||||
|
||||
CacheFlow supports a variety of generative Transformer models in `HuggingFace Transformers <https://github.com/huggingface/transformers>`_.
|
||||
The following is the list of model architectures that are currently supported by CacheFlow.
|
||||
vLLM supports a variety of generative Transformer models in `HuggingFace Transformers <https://github.com/huggingface/transformers>`_.
|
||||
The following is the list of model architectures that are currently supported by vLLM.
|
||||
Alongside each architecture, we include some popular models that use it.
|
||||
|
||||
.. list-table::
|
||||
|
@ -22,19 +22,19 @@ Alongside each architecture, we include some popular models that use it.
|
|||
* - :code:`OPTForCausalLM`
|
||||
- OPT, OPT-IML
|
||||
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with CacheFlow.
|
||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
|
||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
||||
Alternatively, you can raise an issue on our `GitHub <https://github.com/WoosukKwon/cacheflow/issues>`_ project.
|
||||
Alternatively, you can raise an issue on our `GitHub <https://github.com/WoosukKwon/vllm/issues>`_ project.
|
||||
|
||||
.. tip::
|
||||
The easiest way to check if your model is supported is to run the program below:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from cacheflow import LLM
|
||||
from vllm import LLM
|
||||
|
||||
llm = LLM(model=...) # Name or path of your model
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
If CacheFlow successfully generates text, it indicates that your model is supported.
|
||||
If vLLM successfully generates text, it indicates that your model is supported.
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
"""Example Python client for cacheflow.entrypoints.api_server"""
|
||||
"""Example Python client for vllm.entrypoints.api_server"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
|
|
@ -6,7 +6,7 @@ import requests
|
|||
|
||||
|
||||
def http_bot(prompt):
|
||||
headers = {"User-Agent": "Cacheflow Client"}
|
||||
headers = {"User-Agent": "vLLM Client"}
|
||||
pload = {
|
||||
"prompt": prompt,
|
||||
"stream": True,
|
||||
|
@ -24,7 +24,7 @@ def http_bot(prompt):
|
|||
def build_demo():
|
||||
with gr.Blocks() as demo:
|
||||
gr.Markdown(
|
||||
"# Cacheflow text completion demo\n"
|
||||
"# vLLM text completion demo\n"
|
||||
)
|
||||
inputbox = gr.Textbox(label="Input", placeholder="Enter text and press ENTER")
|
||||
outputbox = gr.Textbox(label="Output", placeholder="Generated result from the model")
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import argparse
|
||||
|
||||
from cacheflow import EngineArgs, LLMEngine, SamplingParams
|
||||
from vllm import EngineArgs, LLMEngine, SamplingParams
|
||||
|
||||
|
||||
def main(args: argparse.Namespace):
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from cacheflow import LLM, SamplingParams
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
|
||||
# Sample prompts.
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import openai
|
||||
|
||||
# Modify OpenAI's API key and API base to use CacheFlow's API server.
|
||||
# Modify OpenAI's API key and API base to use vLLM's API server.
|
||||
openai.api_key = "EMPTY"
|
||||
openai.api_base = "http://localhost:8000/v1"
|
||||
model = "facebook/opt-125m"
|
||||
|
|
4
mypy.ini
4
mypy.ini
|
@ -3,6 +3,6 @@ python_version = 3.8
|
|||
|
||||
ignore_missing_imports = True
|
||||
|
||||
files = cacheflow
|
||||
files = vllm
|
||||
# TODO(woosuk): Include the code from Megatron and HuggingFace.
|
||||
exclude = cacheflow/model_executor/parallel_utils/|cacheflow/model_executor/models/
|
||||
exclude = vllm/model_executor/parallel_utils/|vllm/model_executor/models/
|
||||
|
|
26
setup.py
26
setup.py
|
@ -75,7 +75,7 @@ ext_modules = []
|
|||
|
||||
# Cache operations.
|
||||
cache_extension = CUDAExtension(
|
||||
name="cacheflow.cache_ops",
|
||||
name="vllm.cache_ops",
|
||||
sources=["csrc/cache.cpp", "csrc/cache_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
)
|
||||
|
@ -83,7 +83,7 @@ ext_modules.append(cache_extension)
|
|||
|
||||
# Attention kernels.
|
||||
attention_extension = CUDAExtension(
|
||||
name="cacheflow.attention_ops",
|
||||
name="vllm.attention_ops",
|
||||
sources=["csrc/attention.cpp", "csrc/attention/attention_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
)
|
||||
|
@ -91,7 +91,7 @@ ext_modules.append(attention_extension)
|
|||
|
||||
# Positional encoding kernels.
|
||||
positional_encoding_extension = CUDAExtension(
|
||||
name="cacheflow.pos_encoding_ops",
|
||||
name="vllm.pos_encoding_ops",
|
||||
sources=["csrc/pos_encoding.cpp", "csrc/pos_encoding_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
)
|
||||
|
@ -99,7 +99,7 @@ ext_modules.append(positional_encoding_extension)
|
|||
|
||||
# Layer normalization kernels.
|
||||
layernorm_extension = CUDAExtension(
|
||||
name="cacheflow.layernorm_ops",
|
||||
name="vllm.layernorm_ops",
|
||||
sources=["csrc/layernorm.cpp", "csrc/layernorm_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
)
|
||||
|
@ -107,7 +107,7 @@ ext_modules.append(layernorm_extension)
|
|||
|
||||
# Activation kernels.
|
||||
activation_extension = CUDAExtension(
|
||||
name="cacheflow.activation_ops",
|
||||
name="vllm.activation_ops",
|
||||
sources=["csrc/activation.cpp", "csrc/activation_kernels.cu"],
|
||||
extra_compile_args={"cxx": CXX_FLAGS, "nvcc": NVCC_FLAGS},
|
||||
)
|
||||
|
@ -144,18 +144,18 @@ def get_requirements() -> List[str]:
|
|||
|
||||
|
||||
setuptools.setup(
|
||||
name="cacheflow",
|
||||
version=find_version(get_path("cacheflow", "__init__.py")),
|
||||
author="CacheFlow Team",
|
||||
author_email="cacheflow@gmail.com",
|
||||
name="vllm",
|
||||
version=find_version(get_path("vllm", "__init__.py")),
|
||||
author="vLLM Team",
|
||||
author_email="vllm@gmail.com", # FIXME
|
||||
license="Apache 2.0",
|
||||
description="CacheFlow: A high-performance LLM Serving System",
|
||||
description="vLLM: Easy, Fast, and Cheap LLM Serving with PagedAttention", # FIXME
|
||||
long_description=read_readme(),
|
||||
long_description_content_type="text/markdown",
|
||||
url="https://github.com/WoosukKwon/cacheflow",
|
||||
url="https://github.com/WoosukKwon/vllm",
|
||||
project_urls={
|
||||
"Homepage": "https://github.com/WoosukKwon/cacheflow",
|
||||
"Documentation": "https://cacheflow.readthedocs.io/en/latest/",
|
||||
"Homepage": "https://github.com/WoosukKwon/vllm",
|
||||
"Documentation": "https://vllm.readthedocs.io/en/latest/", # FIXME
|
||||
},
|
||||
classifiers=[
|
||||
"Programming Language :: Python :: 3.8",
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cacheflow import activation_ops
|
||||
from vllm import activation_ops
|
||||
|
||||
|
||||
def ref_silu_and_mul(x: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
@ -5,7 +5,7 @@ import torch
|
|||
from xformers import ops as xops
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from cacheflow import attention_ops
|
||||
from vllm import attention_ops
|
||||
|
||||
MAX_SEQ_LEN = 4096
|
||||
TEST_SEED = 0
|
||||
|
|
|
@ -2,7 +2,7 @@ import random
|
|||
|
||||
import torch
|
||||
|
||||
from cacheflow import cache_ops
|
||||
from vllm import cache_ops
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow import layernorm_ops
|
||||
from vllm import layernorm_ops
|
||||
|
||||
|
||||
class RefRMSNorm(nn.Module):
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from cacheflow import pos_encoding_ops
|
||||
from vllm import pos_encoding_ops
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor) -> torch.Tensor:
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster
|
||||
from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
"SamplingParams",
|
||||
"RequestOutput",
|
||||
"CompletionOutput",
|
||||
"LLMEngine",
|
||||
"EngineArgs",
|
||||
"AsyncLLMEngine",
|
||||
"AsyncEngineArgs",
|
||||
"initialize_cluster",
|
||||
]
|
|
@ -1,7 +1,7 @@
|
|||
"""Token blocks."""
|
||||
from typing import List
|
||||
|
||||
from cacheflow.utils import Device
|
||||
from vllm.utils import Device
|
||||
|
||||
_BLANK_TOKEN_ID = -1
|
||||
|
|
@ -3,8 +3,8 @@ from typing import Optional
|
|||
import torch
|
||||
from transformers import AutoConfig, PretrainedConfig
|
||||
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.utils import get_cpu_memory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import get_cpu_memory
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -87,7 +87,7 @@ class CacheConfig:
|
|||
Args:
|
||||
block_size: Size of a cache block in number of tokens.
|
||||
gpu_memory_utilization: Fraction of GPU memory to use for the
|
||||
CacheFlow execution.
|
||||
vLLM execution.
|
||||
swap_space: Size of the CPU swap space per GPU (in GiB).
|
||||
"""
|
||||
def __init__(
|
|
@ -1,9 +1,9 @@
|
|||
"""A block manager that manages token blocks."""
|
||||
from typing import Dict, List, Optional, Set, Tuple
|
||||
|
||||
from cacheflow.block import PhysicalTokenBlock
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from cacheflow.utils import Device
|
||||
from vllm.block import PhysicalTokenBlock
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Device
|
||||
|
||||
|
||||
class BlockAllocator:
|
|
@ -1,6 +1,6 @@
|
|||
from typing import List
|
||||
|
||||
from cacheflow.sequence import SequenceGroup
|
||||
from vllm.sequence import SequenceGroup
|
||||
|
||||
|
||||
class Policy:
|
|
@ -2,13 +2,13 @@ import enum
|
|||
import time
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from cacheflow.config import CacheConfig, SchedulerConfig
|
||||
from cacheflow.core.block_manager import BlockSpaceManager
|
||||
from cacheflow.core.policy import PolicyFactory
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
from vllm.config import CacheConfig, SchedulerConfig
|
||||
from vllm.core.block_manager import BlockSpaceManager
|
||||
from vllm.core.policy import PolicyFactory
|
||||
from vllm.logger import init_logger
|
||||
from vllm.sequence import (Sequence, SequenceData, SequenceGroup,
|
||||
SequenceGroupMetadata, SequenceOutputs,
|
||||
SequenceStatus)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
|
@ -3,13 +3,13 @@ import dataclasses
|
|||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EngineArgs:
|
||||
"""Arguments for CacheFlow engine."""
|
||||
"""Arguments for vLLM engine."""
|
||||
model: str
|
||||
download_dir: Optional[str] = None
|
||||
use_np_weights: bool = False
|
||||
|
@ -33,7 +33,7 @@ class EngineArgs:
|
|||
def add_cli_args(
|
||||
parser: argparse.ArgumentParser,
|
||||
) -> argparse.ArgumentParser:
|
||||
"""Shared CLI arguments for CacheFlow engine."""
|
||||
"""Shared CLI arguments for vLLM engine."""
|
||||
# Model arguments
|
||||
parser.add_argument('--model', type=str, default='facebook/opt-125m',
|
||||
help='name or path of the huggingface model to use')
|
||||
|
@ -118,7 +118,7 @@ class EngineArgs:
|
|||
|
||||
@dataclass
|
||||
class AsyncEngineArgs(EngineArgs):
|
||||
"""Arguments for asynchronous CacheFlow engine."""
|
||||
"""Arguments for asynchronous vLLM engine."""
|
||||
engine_use_ray: bool = False
|
||||
disable_log_requests: bool = False
|
||||
|
|
@ -2,12 +2,12 @@ import asyncio
|
|||
import time
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||
from cacheflow.engine.llm_engine import LLMEngine
|
||||
from cacheflow.engine.ray_utils import initialize_cluster, ray
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.engine.ray_utils import initialize_cluster, ray
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -104,7 +104,7 @@ class AsyncLLMEngine:
|
|||
arrival_time = time.time()
|
||||
|
||||
# Create an event to notify us that there is new output from the
|
||||
# cacheflow engine.
|
||||
# vLLM engine.
|
||||
request_event = asyncio.Event()
|
||||
self.request_events[request_id] = request_event
|
||||
|
||||
|
@ -114,7 +114,7 @@ class AsyncLLMEngine:
|
|||
f"sampling params: {sampling_params}, "
|
||||
f"prompt token ids: {prompt_token_ids}.")
|
||||
|
||||
# Add the request into the cacheflow engine's waiting queue.
|
||||
# Add the request into the vLLM engine's waiting queue.
|
||||
if self.engine_use_ray:
|
||||
await self.engine.add_request.remote(
|
||||
request_id, prompt, sampling_params,
|
||||
|
@ -126,7 +126,7 @@ class AsyncLLMEngine:
|
|||
prompt_token_ids=prompt_token_ids,
|
||||
arrival_time=arrival_time)
|
||||
|
||||
# The cacheflow engine does not have a background loop that keeps
|
||||
# The vLLM engine does not have a background loop that keeps
|
||||
# processing incoming requests. Therefore, we need to keep kicking
|
||||
# the engine to process the requests.
|
||||
while True:
|
|
@ -1,19 +1,18 @@
|
|||
import time
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from cacheflow.core.scheduler import Scheduler
|
||||
from cacheflow.engine.arg_utils import EngineArgs
|
||||
from cacheflow.engine.ray_utils import DeviceID, initialize_cluster, ray
|
||||
from cacheflow.engine.tokenizer_utils import (detokenize_incrementally,
|
||||
get_tokenizer)
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from cacheflow.utils import Counter
|
||||
from cacheflow.worker.worker import Worker
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.core.scheduler import Scheduler
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.ray_utils import DeviceID, initialize_cluster, ray
|
||||
from vllm.engine.tokenizer_utils import detokenize_incrementally, get_tokenizer
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import Sequence, SequenceGroup, SequenceStatus
|
||||
from vllm.utils import Counter
|
||||
from vllm.worker.worker import Worker
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
@ -21,7 +20,7 @@ logger = init_logger(__name__)
|
|||
class LLMEngine:
|
||||
"""An LLM engine that receives requests and generates texts.
|
||||
|
||||
This is the main class for the CacheFlow LLM engine. It receives requests
|
||||
This is the main class for the vLLM engine. It receives requests
|
||||
from clients and generates texts from the LLM. It includes a tokenizer, a
|
||||
language model (possibly distributed across multiple GPUs), and GPU memory
|
||||
space allocated for intermediate states (aka KV cache). This class utilizes
|
|
@ -6,7 +6,7 @@ try:
|
|||
except ImportError:
|
||||
ray = None
|
||||
|
||||
from cacheflow.config import ParallelConfig
|
||||
from vllm.config import ParallelConfig
|
||||
|
||||
DeviceID = Tuple[int, Optional[str], int] # rank, node resource (node IP), device id
|
||||
|
|
@ -3,7 +3,7 @@ from typing import List, Tuple, Union
|
|||
from transformers import (AutoConfig, AutoTokenizer, PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast)
|
||||
|
||||
from cacheflow.logger import init_logger
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
|
@ -6,10 +6,10 @@ from fastapi import BackgroundTasks, FastAPI, Request
|
|||
from fastapi.responses import Response, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.utils import random_uuid
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds.
|
||||
TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds
|
|
@ -3,11 +3,11 @@ from typing import List, Optional, Union
|
|||
from tqdm import tqdm
|
||||
from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
|
||||
from cacheflow.engine.arg_utils import EngineArgs
|
||||
from cacheflow.engine.llm_engine import LLMEngine
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.utils import Counter
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.engine.llm_engine import LLMEngine
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import Counter
|
||||
|
||||
|
||||
class LLM:
|
|
@ -13,17 +13,17 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
import uvicorn
|
||||
|
||||
from cacheflow.engine.arg_utils import AsyncEngineArgs
|
||||
from cacheflow.engine.async_llm_engine import AsyncLLMEngine
|
||||
from cacheflow.engine.tokenizer_utils import get_tokenizer
|
||||
from cacheflow.entrypoints.openai.protocol import (
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
from vllm.engine.tokenizer_utils import get_tokenizer
|
||||
from vllm.entrypoints.openai.protocol import (
|
||||
CompletionRequest, CompletionResponse, CompletionResponseChoice,
|
||||
CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse,
|
||||
LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo)
|
||||
from cacheflow.logger import init_logger
|
||||
from cacheflow.outputs import RequestOutput
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.utils import random_uuid
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
TIMEOUT_KEEP_ALIVE = 5 # seconds
|
||||
|
||||
|
@ -93,11 +93,11 @@ async def create_completion(raw_request: Request):
|
|||
for the API specification. This API mimics the OpenAI Completion API.
|
||||
|
||||
NOTE: Currently we do not support the following features:
|
||||
- echo (since the cacheflow engine does not currently support
|
||||
- echo (since the vLLM engine does not currently support
|
||||
getting the logprobs of prompt tokens)
|
||||
- suffix (the language models we currently support do not support
|
||||
suffix)
|
||||
- logit_bias (to be supported in cacheflow engine)
|
||||
- logit_bias (to be supported by vLLM engine)
|
||||
"""
|
||||
request = CompletionRequest(**await raw_request.json())
|
||||
logger.info(f"Received completion request: {request}")
|
||||
|
@ -107,7 +107,7 @@ async def create_completion(raw_request: Request):
|
|||
return error_check_ret
|
||||
|
||||
if request.echo:
|
||||
# We do not support echo since the cacheflow engine does not
|
||||
# We do not support echo since the vLLM engine does not
|
||||
# currently support getting the logprobs of prompt tokens.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"echo is not currently supported")
|
||||
|
@ -118,7 +118,7 @@ async def create_completion(raw_request: Request):
|
|||
"suffix is not currently supported")
|
||||
|
||||
if request.logit_bias is not None:
|
||||
# TODO: support logit_bias in cacheflow engine.
|
||||
# TODO: support logit_bias in vLLM engine.
|
||||
return create_error_response(HTTPStatus.BAD_REQUEST,
|
||||
"logit_bias is not currently supported")
|
||||
|
||||
|
@ -274,7 +274,7 @@ async def create_completion(raw_request: Request):
|
|||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="CacheFlow OpenAI-Compatible RESTful API server."
|
||||
description="vLLM OpenAI-Compatible RESTful API server."
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="localhost", help="host name")
|
||||
parser.add_argument("--port", type=int, default=8000, help="port number")
|
|
@ -4,7 +4,7 @@ from typing import Dict, List, Literal, Optional, Union
|
|||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from cacheflow.utils import random_uuid
|
||||
from vllm.utils import random_uuid
|
||||
|
||||
|
||||
class ErrorResponse(BaseModel):
|
||||
|
@ -34,7 +34,7 @@ class ModelCard(BaseModel):
|
|||
id: str
|
||||
object: str = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "cacheflow"
|
||||
owned_by: str = "vllm"
|
||||
root: Optional[str] = None
|
||||
parent: Optional[str] = None
|
||||
permission: List[ModelPermission] = Field(default_factory=list)
|
||||
|
@ -82,7 +82,7 @@ class CompletionRequest(BaseModel):
|
|||
best_of: Optional[int] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
# Additional parameters supported by cacheflow
|
||||
# Additional parameters supported by vLLM
|
||||
top_k: Optional[int] = -1
|
||||
ignore_eos: Optional[bool] = False
|
||||
use_beam_search: Optional[bool] = False
|
|
@ -22,7 +22,7 @@ class NewLineFormatter(logging.Formatter):
|
|||
return msg
|
||||
|
||||
|
||||
_root_logger = logging.getLogger("cacheflow")
|
||||
_root_logger = logging.getLogger("vllm")
|
||||
_default_handler = None
|
||||
|
||||
|
|
@ -0,0 +1,10 @@
|
|||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
|
||||
|
||||
__all__ = [
|
||||
"InputMetadata",
|
||||
"get_model",
|
||||
"set_random_seed",
|
||||
]
|
|
@ -3,8 +3,8 @@ from typing import Dict, List, Tuple
|
|||
import torch
|
||||
from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask
|
||||
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceData
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceData
|
||||
|
||||
|
||||
class InputMetadata:
|
|
@ -2,7 +2,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow import activation_ops
|
||||
from vllm import activation_ops
|
||||
|
||||
_ACTIVATION_REGISTRY = {
|
||||
"gelu": nn.GELU(),
|
|
@ -5,16 +5,16 @@ import torch
|
|||
import torch.nn as nn
|
||||
from xformers import ops as xops
|
||||
|
||||
from cacheflow import attention_ops
|
||||
from cacheflow import cache_ops
|
||||
from cacheflow import pos_encoding_ops
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from vllm import attention_ops
|
||||
from vllm import cache_ops
|
||||
from vllm import pos_encoding_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 128]
|
||||
|
||||
|
||||
class GPTCacheFlowAttention(nn.Module):
|
||||
"""GPT-style multi-head attention.
|
||||
class PagedAttention(nn.Module):
|
||||
"""GPT-style multi-head PagedAttention.
|
||||
|
||||
This class takes flattened 1D query, key, and value tensors as input. The
|
||||
input 1D tensors can be split into three parts: the prompt tokens, the
|
||||
|
@ -164,8 +164,8 @@ class GPTCacheFlowAttention(nn.Module):
|
|||
return output.view(-1, self.num_heads * self.head_size)
|
||||
|
||||
|
||||
class GPTNeoXCacheFlowAttention(GPTCacheFlowAttention):
|
||||
"""Attention with GPT-NeoX style rotary embedding."""
|
||||
class PagedAttentionWithRoPE(PagedAttention):
|
||||
"""PagedAttention with GPT-NeoX style rotary embedding."""
|
||||
|
||||
def __init__(
|
||||
self,
|
|
@ -2,7 +2,7 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow import layernorm_ops
|
||||
from vllm import layernorm_ops
|
||||
|
||||
|
||||
class RMSNorm(nn.Module):
|
|
@ -5,11 +5,11 @@ import numpy as np
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
gather_from_tensor_model_parallel_region)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceOutputs
|
||||
|
||||
|
||||
class Sampler(nn.Module):
|
|
@ -5,10 +5,10 @@ import torch
|
|||
import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from cacheflow.config import ModelConfig
|
||||
from cacheflow.model_executor.models import (
|
||||
GPT2LMHeadModel, GPTNeoXForCausalLM, LlamaForCausalLM, OPTForCausalLM)
|
||||
from cacheflow.model_executor.weight_utils import initialize_dummy_weights
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import (GPT2LMHeadModel, GPTNeoXForCausalLM,
|
||||
LlamaForCausalLM, OPTForCausalLM)
|
||||
from vllm.model_executor.weight_utils import initialize_dummy_weights
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
|
@ -0,0 +1,12 @@
|
|||
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
|
||||
|
||||
__all__ = [
|
||||
"GPT2LMHeadModel",
|
||||
"GPTNeoXForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
]
|
|
@ -1,6 +1,6 @@
|
|||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt2/modeling_gpt2.py
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
||||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
|
@ -26,17 +26,17 @@ import torch
|
|||
from torch import nn
|
||||
from transformers import GPT2Config
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from vllm.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
@ -53,14 +53,14 @@ class GPT2Attention(nn.Module):
|
|||
self.head_dim = self.hidden_size // total_num_heads
|
||||
self.scale = self.head_dim ** -0.5
|
||||
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size, bias=True,
|
||||
gather_output=False,
|
||||
self.c_attn = ColumnParallelLinear(self.hidden_size, 3 * self.hidden_size,
|
||||
bias=True, gather_output=False,
|
||||
perform_initialization=False)
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size, bias=True,
|
||||
input_is_parallel=True,
|
||||
self.c_proj = RowParallelLinear(self.hidden_size, self.hidden_size,
|
||||
bias=True, input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim,
|
||||
scale=self.scale)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
scale=self.scale)
|
||||
|
||||
def forward(
|
||||
self,
|
|
@ -1,6 +1,6 @@
|
|||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -25,17 +25,17 @@ import torch
|
|||
from torch import nn
|
||||
from transformers import GPTNeoXConfig
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from vllm.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
@ -63,8 +63,8 @@ class GPTNeoXAttention(nn.Module):
|
|||
scaling = self.head_size ** -0.5
|
||||
rotary_dim = int(self.head_size * config.rotary_pct)
|
||||
assert rotary_dim % 2 == 0
|
||||
self.attn = GPTNeoXCacheFlowAttention(self.num_heads, self.head_size,
|
||||
scaling, rotary_dim)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_size,
|
||||
scaling, rotary_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -149,6 +149,7 @@ class GPTNeoXLayer(nn.Module):
|
|||
|
||||
|
||||
class GPTNeoXModel(nn.Module):
|
||||
|
||||
def __init__(self, config: GPTNeoXConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
|
@ -1,6 +1,6 @@
|
|||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
|
@ -30,19 +30,19 @@ import torch
|
|||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import SiluAndMul
|
||||
from cacheflow.model_executor.layers.layernorm import RMSNorm
|
||||
from cacheflow.model_executor.layers.attention import GPTNeoXCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.sequence import SequenceOutputs
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.attention import PagedAttentionWithRoPE
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from vllm.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
@ -104,8 +104,8 @@ class LlamaAttention(nn.Module):
|
|||
input_is_parallel=True,
|
||||
perform_initialization=False,
|
||||
)
|
||||
self.attn = GPTNeoXCacheFlowAttention(self.num_heads, self.head_dim,
|
||||
self.scaling, rotary_dim=self.head_dim)
|
||||
self.attn = PagedAttentionWithRoPE(self.num_heads, self.head_dim,
|
||||
self.scaling, rotary_dim=self.head_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
|
@ -1,6 +1,6 @@
|
|||
# coding=utf-8
|
||||
# Adapted from https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/opt/modeling_opt.py
|
||||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
@ -25,17 +25,17 @@ import torch
|
|||
from torch import nn
|
||||
from transformers import OPTConfig
|
||||
|
||||
from cacheflow.model_executor.input_metadata import InputMetadata
|
||||
from cacheflow.model_executor.layers.activation import get_act_fn
|
||||
from cacheflow.model_executor.layers.attention import GPTCacheFlowAttention
|
||||
from cacheflow.model_executor.layers.sampler import Sampler
|
||||
from cacheflow.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.activation import get_act_fn
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.weight_utils import (hf_model_weights_iterator,
|
||||
load_tensor_parallel_weights)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import (
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import (
|
||||
VocabParallelEmbedding, ColumnParallelLinear, RowParallelLinear)
|
||||
from cacheflow.sequence import SequenceOutputs
|
||||
from vllm.sequence import SequenceOutputs
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
@ -75,8 +75,8 @@ class OPTAttention(nn.Module):
|
|||
self.out_proj = RowParallelLinear(embed_dim, embed_dim, bias=bias,
|
||||
input_is_parallel=True,
|
||||
perform_initialization=False)
|
||||
self.attn = GPTCacheFlowAttention(self.num_heads, self.head_dim,
|
||||
scale=self.scaling)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim,
|
||||
scale=self.scaling)
|
||||
|
||||
def forward(
|
||||
self,
|
|
@ -1,5 +1,5 @@
|
|||
import cacheflow.model_executor.parallel_utils.parallel_state
|
||||
import cacheflow.model_executor.parallel_utils.tensor_parallel
|
||||
import vllm.model_executor.parallel_utils.parallel_state
|
||||
import vllm.model_executor.parallel_utils.tensor_parallel
|
||||
|
||||
# Alias parallel_state as mpu, its legacy name
|
||||
mpu = parallel_state
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/layers.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
|
@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|||
import torch.nn.init as init
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_all_reduce_launcher,
|
|
@ -1,10 +1,10 @@
|
|||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/mappings.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
import torch
|
||||
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tensor_model_parallel_group,
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/random.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
||||
|
@ -11,7 +11,7 @@ import torch
|
|||
from torch import _C
|
||||
from torch.cuda import _lazy_call, device as device_ctx_manager
|
||||
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank,
|
||||
)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
# Copyright 2023 The CacheFlow team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Adapted from https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
|
|
@ -4,8 +4,8 @@ import random
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import model_parallel_is_initialized
|
||||
from cacheflow.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
||||
from vllm.model_executor.parallel_utils.parallel_state import model_parallel_is_initialized
|
||||
from vllm.model_executor.parallel_utils.tensor_parallel import model_parallel_cuda_manual_seed
|
||||
|
||||
|
||||
def set_random_seed(seed: int) -> None:
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
from cacheflow.sequence import SequenceGroup, SequenceStatus
|
||||
from vllm.sequence import SequenceGroup, SequenceStatus
|
||||
|
||||
|
||||
class CompletionOutput:
|
|
@ -2,8 +2,8 @@ import copy
|
|||
import enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
from cacheflow.block import LogicalTokenBlock
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from vllm.block import LogicalTokenBlock
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
|
||||
class SequenceStatus(enum.Enum):
|
|
@ -3,8 +3,8 @@ from typing import Dict, List, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from cacheflow import cache_ops
|
||||
from cacheflow.config import CacheConfig, ModelConfig, ParallelConfig
|
||||
from vllm import cache_ops
|
||||
from vllm.config import CacheConfig, ModelConfig, ParallelConfig
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
|
@ -3,16 +3,15 @@ from typing import Dict, List, Tuple
|
|||
|
||||
import torch
|
||||
|
||||
from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from cacheflow.model_executor import get_model, InputMetadata, set_random_seed
|
||||
from cacheflow.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
|
||||
SchedulerConfig)
|
||||
from vllm.model_executor import get_model, InputMetadata, set_random_seed
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
initialize_model_parallel, initialize_all_reduce_launcher)
|
||||
from cacheflow.sampling_params import SamplingParams
|
||||
from cacheflow.sequence import (SequenceData, SequenceGroupMetadata,
|
||||
SequenceOutputs)
|
||||
from cacheflow.worker.cache_engine import CacheEngine
|
||||
from cacheflow.utils import get_gpu_memory
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.sequence import SequenceData, SequenceGroupMetadata, SequenceOutputs
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.utils import get_gpu_memory
|
||||
|
||||
|
||||
class Worker:
|
Loading…
Reference in New Issue