mirror of https://github.com/vllm-project/vllm
[Kernel][RFC] Refactor the punica kernel based on Triton (#5036)
This commit is contained in:
parent
7eb0cb4a14
commit
7ecee34321
|
@ -13,8 +13,6 @@ $python_executable -m pip install -r requirements-cuda.txt
|
|||
|
||||
# Limit the number of parallel jobs to avoid OOM
|
||||
export MAX_JOBS=1
|
||||
# Make sure punica is built for the release (for LoRA)
|
||||
export VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
# Make sure release wheels are built for the following architectures
|
||||
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
||||
# Build
|
||||
|
|
|
@ -223,61 +223,7 @@ define_gpu_extension_target(
|
|||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
#
|
||||
# _punica_C extension
|
||||
#
|
||||
|
||||
set(VLLM_PUNICA_EXT_SRC
|
||||
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||
"csrc/punica/punica_ops.cu"
|
||||
"csrc/punica/torch_bindings.cpp")
|
||||
|
||||
#
|
||||
# Copy GPU compilation flags+update for punica
|
||||
#
|
||||
set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||
list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
|
||||
"-D__CUDA_NO_HALF_OPERATORS__"
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
||||
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__")
|
||||
|
||||
#
|
||||
# Filter out CUDA architectures < 8.0 for punica.
|
||||
#
|
||||
if (${VLLM_GPU_LANG} STREQUAL "CUDA")
|
||||
set(VLLM_PUNICA_GPU_ARCHES)
|
||||
foreach(ARCH ${VLLM_GPU_ARCHES})
|
||||
string_to_ver(CODE_VER ${ARCH})
|
||||
if (CODE_VER GREATER_EQUAL 8.0)
|
||||
list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
|
||||
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
endif()
|
||||
|
||||
if (VLLM_PUNICA_GPU_ARCHES)
|
||||
define_gpu_extension_target(
|
||||
_punica_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${VLLM_PUNICA_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
else()
|
||||
message(WARNING "Unable to create _punica_C target because none of the "
|
||||
"requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Add the `default` target which detects which extensions should be
|
||||
|
@ -301,12 +247,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
|||
message(STATUS "Enabling moe extension.")
|
||||
add_dependencies(default _moe_C)
|
||||
|
||||
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
|
||||
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
|
||||
# there are supported target arches.
|
||||
if (VLLM_PUNICA_GPU_ARCHES AND
|
||||
(ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
|
||||
message(STATUS "Enabling punica extension.")
|
||||
add_dependencies(default _punica_C)
|
||||
endif()
|
||||
endif()
|
||||
|
|
|
@ -88,8 +88,6 @@ ENV MAX_JOBS=${max_jobs}
|
|||
# number of threads used by nvcc
|
||||
ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
# make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
ARG buildkite_commit
|
||||
ENV BUILDKITE_COMMIT=${buildkite_commit}
|
||||
|
|
|
@ -131,8 +131,7 @@ COPY . .
|
|||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
|
||||
|
||||
# Make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
# Workaround for ray >= 2.10.0
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
# Silences the HF Tokenizers warning
|
||||
|
|
|
@ -1,217 +0,0 @@
|
|||
Contains code from https://github.com/punica-ai/punica
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
|
||||
|
||||
Apache-2.0
|
||||
* third_party/nvbench (with LLVM exception)
|
||||
* third_party/flashinfer
|
||||
|
||||
BSD-3-Clause:
|
||||
* third_party/cutlass
|
|
@ -1,5 +0,0 @@
|
|||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
|
@ -1,5 +0,0 @@
|
|||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)
|
|
@ -1,218 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale);
|
||||
|
||||
// clang-format off
|
||||
|
||||
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, narrow, 128) \
|
||||
f(in_T, out_T, W_T, narrow, 256) \
|
||||
f(in_T, out_T, W_T, narrow, 512) \
|
||||
f(in_T, out_T, W_T, narrow, 640) \
|
||||
f(in_T, out_T, W_T, narrow, 768) \
|
||||
f(in_T, out_T, W_T, narrow, 896) \
|
||||
f(in_T, out_T, W_T, narrow, 1024) \
|
||||
f(in_T, out_T, W_T, narrow, 1152) \
|
||||
f(in_T, out_T, W_T, narrow, 1216) \
|
||||
f(in_T, out_T, W_T, narrow, 1280) \
|
||||
f(in_T, out_T, W_T, narrow, 1536) \
|
||||
f(in_T, out_T, W_T, narrow, 1664) \
|
||||
f(in_T, out_T, W_T, narrow, 1728) \
|
||||
f(in_T, out_T, W_T, narrow, 1792) \
|
||||
f(in_T, out_T, W_T, narrow, 2048) \
|
||||
f(in_T, out_T, W_T, narrow, 2240) \
|
||||
f(in_T, out_T, W_T, narrow, 2304) \
|
||||
f(in_T, out_T, W_T, narrow, 2368) \
|
||||
f(in_T, out_T, W_T, narrow, 2432) \
|
||||
f(in_T, out_T, W_T, narrow, 2560) \
|
||||
f(in_T, out_T, W_T, narrow, 2752) \
|
||||
f(in_T, out_T, W_T, narrow, 2816) \
|
||||
f(in_T, out_T, W_T, narrow, 3072) \
|
||||
f(in_T, out_T, W_T, narrow, 3328) \
|
||||
f(in_T, out_T, W_T, narrow, 3456) \
|
||||
f(in_T, out_T, W_T, narrow, 3584) \
|
||||
f(in_T, out_T, W_T, narrow, 3712) \
|
||||
f(in_T, out_T, W_T, narrow, 4096) \
|
||||
f(in_T, out_T, W_T, narrow, 4480) \
|
||||
f(in_T, out_T, W_T, narrow, 4608) \
|
||||
f(in_T, out_T, W_T, narrow, 4736) \
|
||||
f(in_T, out_T, W_T, narrow, 4864) \
|
||||
f(in_T, out_T, W_T, narrow, 5120) \
|
||||
f(in_T, out_T, W_T, narrow, 5504) \
|
||||
f(in_T, out_T, W_T, narrow, 5632) \
|
||||
f(in_T, out_T, W_T, narrow, 5888) \
|
||||
f(in_T, out_T, W_T, narrow, 6144) \
|
||||
f(in_T, out_T, W_T, narrow, 6400) \
|
||||
f(in_T, out_T, W_T, narrow, 6848) \
|
||||
f(in_T, out_T, W_T, narrow, 6912) \
|
||||
f(in_T, out_T, W_T, narrow, 7168) \
|
||||
f(in_T, out_T, W_T, narrow, 7424) \
|
||||
f(in_T, out_T, W_T, narrow, 8192) \
|
||||
f(in_T, out_T, W_T, narrow, 8960) \
|
||||
f(in_T, out_T, W_T, narrow, 9216) \
|
||||
f(in_T, out_T, W_T, narrow, 9472) \
|
||||
f(in_T, out_T, W_T, narrow, 10240) \
|
||||
f(in_T, out_T, W_T, narrow, 11008) \
|
||||
f(in_T, out_T, W_T, narrow, 11264) \
|
||||
f(in_T, out_T, W_T, narrow, 12288) \
|
||||
f(in_T, out_T, W_T, narrow, 13696) \
|
||||
f(in_T, out_T, W_T, narrow, 13824) \
|
||||
f(in_T, out_T, W_T, narrow, 14336) \
|
||||
f(in_T, out_T, W_T, narrow, 14784) \
|
||||
f(in_T, out_T, W_T, narrow, 14848) \
|
||||
f(in_T, out_T, W_T, narrow, 15360) \
|
||||
f(in_T, out_T, W_T, narrow, 16384) \
|
||||
f(in_T, out_T, W_T, narrow, 18944) \
|
||||
f(in_T, out_T, W_T, narrow, 20480) \
|
||||
f(in_T, out_T, W_T, narrow, 22016) \
|
||||
f(in_T, out_T, W_T, narrow, 22528) \
|
||||
f(in_T, out_T, W_T, narrow, 24576) \
|
||||
f(in_T, out_T, W_T, narrow, 27392) \
|
||||
f(in_T, out_T, W_T, narrow, 27648) \
|
||||
f(in_T, out_T, W_T, narrow, 28672) \
|
||||
f(in_T, out_T, W_T, narrow, 29568) \
|
||||
f(in_T, out_T, W_T, narrow, 29696) \
|
||||
f(in_T, out_T, W_T, narrow, 32000) \
|
||||
f(in_T, out_T, W_T, narrow, 32256) \
|
||||
f(in_T, out_T, W_T, narrow, 32512) \
|
||||
f(in_T, out_T, W_T, narrow, 32768) \
|
||||
f(in_T, out_T, W_T, narrow, 33024) \
|
||||
f(in_T, out_T, W_T, narrow, 36864) \
|
||||
f(in_T, out_T, W_T, narrow, 43264) \
|
||||
f(in_T, out_T, W_T, narrow, 49152) \
|
||||
f(in_T, out_T, W_T, narrow, 49408) \
|
||||
f(in_T, out_T, W_T, narrow, 60544) \
|
||||
f(in_T, out_T, W_T, narrow, 60672) \
|
||||
f(in_T, out_T, W_T, narrow, 64000) \
|
||||
f(in_T, out_T, W_T, narrow, 64256) \
|
||||
f(in_T, out_T, W_T, narrow, 64512) \
|
||||
f(in_T, out_T, W_T, narrow, 102400) \
|
||||
f(in_T, out_T, W_T, narrow, 102656) \
|
||||
f(in_T, out_T, W_T, narrow, 102912) \
|
||||
f(in_T, out_T, W_T, narrow, 128000) \
|
||||
f(in_T, out_T, W_T, narrow, 128256) \
|
||||
f(in_T, out_T, W_T, narrow, 128512) \
|
||||
|
||||
|
||||
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||
// and vllm/tests/lora/test_punica.py
|
||||
|
||||
// Used for defining kernels going from the variety of
|
||||
// dim in to the narrow dim out
|
||||
// Using it for the fully sharded column
|
||||
// parallel LoRA A which splits the rank dim
|
||||
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, 128, narrow) \
|
||||
f(in_T, out_T, W_T, 256, narrow) \
|
||||
f(in_T, out_T, W_T, 512, narrow) \
|
||||
f(in_T, out_T, W_T, 640, narrow) \
|
||||
f(in_T, out_T, W_T, 768, narrow) \
|
||||
f(in_T, out_T, W_T, 896, narrow) \
|
||||
f(in_T, out_T, W_T, 1024, narrow) \
|
||||
f(in_T, out_T, W_T, 1152, narrow) \
|
||||
f(in_T, out_T, W_T, 1216, narrow) \
|
||||
f(in_T, out_T, W_T, 1280, narrow) \
|
||||
f(in_T, out_T, W_T, 1536, narrow) \
|
||||
f(in_T, out_T, W_T, 1664, narrow) \
|
||||
f(in_T, out_T, W_T, 1728, narrow) \
|
||||
f(in_T, out_T, W_T, 1792, narrow) \
|
||||
f(in_T, out_T, W_T, 2048, narrow) \
|
||||
f(in_T, out_T, W_T, 2240, narrow) \
|
||||
f(in_T, out_T, W_T, 2304, narrow) \
|
||||
f(in_T, out_T, W_T, 2368, narrow) \
|
||||
f(in_T, out_T, W_T, 2432, narrow) \
|
||||
f(in_T, out_T, W_T, 2560, narrow) \
|
||||
f(in_T, out_T, W_T, 2752, narrow) \
|
||||
f(in_T, out_T, W_T, 2816, narrow) \
|
||||
f(in_T, out_T, W_T, 3072, narrow) \
|
||||
f(in_T, out_T, W_T, 3328, narrow) \
|
||||
f(in_T, out_T, W_T, 3456, narrow) \
|
||||
f(in_T, out_T, W_T, 3584, narrow) \
|
||||
f(in_T, out_T, W_T, 3712, narrow) \
|
||||
f(in_T, out_T, W_T, 4096, narrow) \
|
||||
f(in_T, out_T, W_T, 4480, narrow) \
|
||||
f(in_T, out_T, W_T, 4608, narrow) \
|
||||
f(in_T, out_T, W_T, 4736, narrow) \
|
||||
f(in_T, out_T, W_T, 4864, narrow) \
|
||||
f(in_T, out_T, W_T, 5120, narrow) \
|
||||
f(in_T, out_T, W_T, 5504, narrow) \
|
||||
f(in_T, out_T, W_T, 5632, narrow) \
|
||||
f(in_T, out_T, W_T, 5888, narrow) \
|
||||
f(in_T, out_T, W_T, 6144, narrow) \
|
||||
f(in_T, out_T, W_T, 6400, narrow) \
|
||||
f(in_T, out_T, W_T, 6848, narrow) \
|
||||
f(in_T, out_T, W_T, 6912, narrow) \
|
||||
f(in_T, out_T, W_T, 7168, narrow) \
|
||||
f(in_T, out_T, W_T, 7424, narrow) \
|
||||
f(in_T, out_T, W_T, 8192, narrow) \
|
||||
f(in_T, out_T, W_T, 8960, narrow) \
|
||||
f(in_T, out_T, W_T, 9216, narrow) \
|
||||
f(in_T, out_T, W_T, 9472, narrow) \
|
||||
f(in_T, out_T, W_T, 10240, narrow) \
|
||||
f(in_T, out_T, W_T, 11008, narrow) \
|
||||
f(in_T, out_T, W_T, 11264, narrow) \
|
||||
f(in_T, out_T, W_T, 12288, narrow) \
|
||||
f(in_T, out_T, W_T, 13696, narrow) \
|
||||
f(in_T, out_T, W_T, 13824, narrow) \
|
||||
f(in_T, out_T, W_T, 14336, narrow) \
|
||||
f(in_T, out_T, W_T, 14784, narrow) \
|
||||
f(in_T, out_T, W_T, 14848, narrow) \
|
||||
f(in_T, out_T, W_T, 15360, narrow) \
|
||||
f(in_T, out_T, W_T, 16384, narrow) \
|
||||
f(in_T, out_T, W_T, 18944, narrow) \
|
||||
f(in_T, out_T, W_T, 20480, narrow) \
|
||||
f(in_T, out_T, W_T, 22016, narrow) \
|
||||
f(in_T, out_T, W_T, 22528, narrow) \
|
||||
f(in_T, out_T, W_T, 24576, narrow) \
|
||||
f(in_T, out_T, W_T, 27392, narrow) \
|
||||
f(in_T, out_T, W_T, 27648, narrow) \
|
||||
f(in_T, out_T, W_T, 28672, narrow) \
|
||||
f(in_T, out_T, W_T, 29568, narrow) \
|
||||
f(in_T, out_T, W_T, 29696, narrow) \
|
||||
f(in_T, out_T, W_T, 32000, narrow) \
|
||||
f(in_T, out_T, W_T, 32256, narrow) \
|
||||
f(in_T, out_T, W_T, 32512, narrow) \
|
||||
f(in_T, out_T, W_T, 32768, narrow) \
|
||||
f(in_T, out_T, W_T, 33024, narrow) \
|
||||
f(in_T, out_T, W_T, 36864, narrow) \
|
||||
f(in_T, out_T, W_T, 43264, narrow) \
|
||||
f(in_T, out_T, W_T, 49152, narrow) \
|
||||
f(in_T, out_T, W_T, 49408, narrow) \
|
||||
f(in_T, out_T, W_T, 60544, narrow) \
|
||||
f(in_T, out_T, W_T, 60672, narrow) \
|
||||
f(in_T, out_T, W_T, 64000, narrow) \
|
||||
f(in_T, out_T, W_T, 64256, narrow) \
|
||||
f(in_T, out_T, W_T, 64512, narrow) \
|
||||
f(in_T, out_T, W_T, 102400, narrow) \
|
||||
f(in_T, out_T, W_T, 102656, narrow) \
|
||||
f(in_T, out_T, W_T, 102912, narrow) \
|
||||
f(in_T, out_T, W_T, 128000, narrow) \
|
||||
f(in_T, out_T, W_T, 128256, narrow) \
|
||||
f(in_T, out_T, W_T, 128512, narrow) \
|
||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||
|
||||
|
||||
// Keep this in sync with vllm/config::LoRAConfig
|
||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||
|
||||
|
||||
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
|
||||
f(in_T, out_T, W_T, 8, 64) \
|
||||
f(in_T, out_T, W_T, 16, 64) \
|
||||
f(in_T, out_T, W_T, 32, 64) \
|
||||
f(in_T, out_T, W_T, 64, 64)
|
||||
|
||||
// clang-format on
|
|
@ -1,5 +0,0 @@
|
|||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
|
|
@ -1,5 +0,0 @@
|
|||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)
|
|
@ -1,5 +0,0 @@
|
|||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)
|
|
@ -1,5 +0,0 @@
|
|||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)
|
|
@ -1,451 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <cooperative_groups.h>
|
||||
#else
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#endif
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda/pipeline>
|
||||
#endif
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "vec_dtypes.cuh"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
template <size_t len>
|
||||
__host__ __device__
|
||||
inline void* memcpy_blocking(void *dst, const void *src) {
|
||||
// Does not handle the case of long datatypes
|
||||
char *d = reinterpret_cast<char *>(dst);
|
||||
const char *s = reinterpret_cast<const char *>(src);
|
||||
size_t i = 0;
|
||||
#pragma unroll
|
||||
for (i = 0; i < len; ++i) {
|
||||
d[i] = s[i];
|
||||
}
|
||||
return dst;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
// nthrs = (32, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t j = blockIdx.x;
|
||||
constexpr size_t num_pipeline_stages = 2;
|
||||
constexpr size_t tile_size = tx * ty * vec_size;
|
||||
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ float y_warpwise[ty];
|
||||
|
||||
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
auto pipe = cuda::make_pipeline();
|
||||
|
||||
// pipeline load W/X and compute WX;
|
||||
pipe.producer_acquire();
|
||||
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
pipe.producer_commit();
|
||||
size_t copy_idx, compute_idx;
|
||||
float y = 0.f;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
size_t tile_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
|
||||
++tile_idx) {
|
||||
copy_idx = tile_idx % num_pipeline_stages;
|
||||
// pipeline stage: async copy W fragment
|
||||
pipe.producer_acquire();
|
||||
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
|
||||
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) + tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
}
|
||||
pipe.producer_commit();
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// pipeline stage: compute WX
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] = sum;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
}
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// final pipeline stage
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] =
|
||||
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
|
||||
? sum
|
||||
: 0.f;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
|
||||
// write Y;
|
||||
if (block.thread_rank() == 0) {
|
||||
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t j = blockIdx.x;
|
||||
constexpr size_t tile_size = tx * ty * vec_size;
|
||||
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
|
||||
__shared__ float y_warpwise[ty];
|
||||
|
||||
float y = 0;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
size_t tile_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
|
||||
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||
x_vec.load(X + (batch_idx * feat_in) +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W + (idx * feat_out + j) * feat_in +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
}
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||
y += sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
y_warpwise[threadIdx.y] = y;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float y_write = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y_write += y_warpwise[i];
|
||||
}
|
||||
|
||||
// write Y;
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
size_t y_idx = batch_idx * full_y_size + y_offset + j;
|
||||
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// nthrs = (2, 16, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||
typename in_T, typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t tile_idx = blockIdx.x;
|
||||
|
||||
// load X;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
|
||||
|
||||
// load W;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
|
||||
block.thread_rank() * vec_size);
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
#ifndef USE_ROCM
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
#else
|
||||
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||
#endif
|
||||
}
|
||||
|
||||
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += g.shfl_down(sum, offset);
|
||||
}
|
||||
sum = g.shfl(sum, 0);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
#ifndef USE_ROCM
|
||||
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||
#else
|
||||
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y;
|
||||
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
constexpr size_t vec_size = 8;
|
||||
constexpr int tz = 4;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if constexpr (feat_in <= feat_out) {
|
||||
static_assert(feat_in % vec_size == 0);
|
||||
constexpr int tx = feat_in / vec_size;
|
||||
|
||||
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
|
||||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
|
||||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
|
||||
|
||||
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
|
||||
constexpr int ty = 32 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
|
||||
constexpr int ty = 16 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else {
|
||||
constexpr int ty = 8 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
} else {
|
||||
#ifndef USE_ROCM
|
||||
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||
feat_in % (vec_size * 16) == 0 ||
|
||||
feat_in % (vec_size * 8) == 0);
|
||||
|
||||
if constexpr (feat_in % (vec_size * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
|
||||
vec_size * sizeof(W_T), tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
|
||||
constexpr int tx = 16;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
#else
|
||||
constexpr size_t rocm_warp_size = warpSize;
|
||||
|
||||
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
|
||||
feat_in % (rocm_warp_size * vec_size_) == 0
|
||||
|
||||
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
|
||||
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
|
||||
constexpr size_t vec_size_shrink = vec_size_; \
|
||||
constexpr int tx = tx_; \
|
||||
constexpr int ty = ty_; \
|
||||
dim3 nblks(feat_out, batch_size); \
|
||||
dim3 nthrs(tx, ty); \
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
|
||||
vec_size_shrink * sizeof(in_T), \
|
||||
vec_size_shrink * sizeof(W_T), \
|
||||
tx, ty, tz> \
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
|
||||
full_y_size, num_layers, layer_idx, \
|
||||
scale); \
|
||||
}
|
||||
|
||||
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
|
||||
CHECK_INPUT_TILEABLE_BY(16) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 8) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 4) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 2) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 1));
|
||||
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
|
||||
|
||||
#undef CHECK_INPUT_TILEABLE_BY
|
||||
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
|
||||
template void bgmv_kernel<feat_in, feat_out>( \
|
||||
out_T * __restrict__ Y, const in_T *__restrict__ X, \
|
||||
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
|
||||
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
||||
int64_t num_layers, int64_t layer_idx, float scale);
|
||||
|
||||
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
|
||||
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
|
||||
|
||||
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
|
||||
INST_BGMV(wide, narrow, in_T, out_T, W_T)
|
|
@ -1,48 +0,0 @@
|
|||
DTYPES = ["fp16", "bf16", "fp32"]
|
||||
DTYPE_MAP = {
|
||||
"fp16": "nv_half",
|
||||
"bf16": "nv_bfloat16",
|
||||
"fp32": "float",
|
||||
}
|
||||
|
||||
TEMPLATE = """
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
for input_dtype in DTYPES:
|
||||
for output_dtype in DTYPES:
|
||||
for weight_dtype in DTYPES:
|
||||
if weight_dtype == "fp32":
|
||||
# FP32 weights are not supported.
|
||||
continue
|
||||
if output_dtype == "fp32":
|
||||
# LoRA A matrix.
|
||||
if input_dtype != weight_dtype:
|
||||
# NOTE(woosuk): While Punica supports the case where the
|
||||
# input and weight dtypes are different, we only generate
|
||||
# the kernels the same dtypes to reduce the binary size.
|
||||
continue
|
||||
elif input_dtype == "fp32":
|
||||
# LoRA B matrix.
|
||||
if output_dtype != weight_dtype:
|
||||
# NOTE(woosuk): While Punica supports the case where the
|
||||
# output and weight dtypes are different, we only generate
|
||||
# the kernels the same dtypes to reduce the binary size.
|
||||
continue
|
||||
elif not (input_dtype == output_dtype == weight_dtype):
|
||||
# NOTE(woosuk): While Punica supports mixed data types for
|
||||
# input, output, and weight, we only generate the kernels with
|
||||
# the same data types to reduce the binary size.
|
||||
continue
|
||||
|
||||
kernel_definition = TEMPLATE.format(
|
||||
input_dtype=DTYPE_MAP[input_dtype],
|
||||
output_dtype=DTYPE_MAP[output_dtype],
|
||||
weight_dtype=DTYPE_MAP[weight_dtype])
|
||||
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
|
||||
with open(filename, "w") as f:
|
||||
f.write(kernel_definition)
|
File diff suppressed because it is too large
Load Diff
|
@ -1,569 +0,0 @@
|
|||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cstdint>
|
||||
|
||||
#include "type_convert.h"
|
||||
#include "../cuda_compat.h"
|
||||
#include "bgmv/bgmv_config.h"
|
||||
|
||||
|
||||
//====== utils ======
|
||||
|
||||
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||
const char *a_name, const char *b_name) {
|
||||
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
|
||||
a.dim(), " vs ", b.dim());
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
|
||||
".size(", i, ")");
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
|
||||
return (uint64_t(a) << 32) | uint64_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) \
|
||||
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||
|
||||
#define CHECK_EQ(a, b) \
|
||||
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
//====== bgmv ======
|
||||
|
||||
template <typename in_T, typename out_T, typename W_T>
|
||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||
const int64_t *lora_indices,
|
||||
uint32_t in_features, uint32_t out_features,
|
||||
int64_t y_offset, int64_t full_y_size,
|
||||
int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
// NOTE(woosuk): While Punica supports various combinations of input/output
|
||||
// data types, we limit the supported data types to reduce the binary size.
|
||||
constexpr bool is_input_float = std::is_same<in_T, float>::value;
|
||||
constexpr bool is_output_float = std::is_same<out_T, float>::value;
|
||||
if (is_input_float) {
|
||||
if (!std::is_same<out_T, W_T>::value) {
|
||||
return false;
|
||||
}
|
||||
} else if (is_output_float) {
|
||||
if (!std::is_same<in_T, W_T>::value) {
|
||||
return false;
|
||||
}
|
||||
} else if (!(std::is_same<in_T, W_T>::value &&
|
||||
std::is_same<out_T, W_T>::value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (pack_u32(in_features, out_features)) {
|
||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||
case pack_u32(feat_in, feat_out): \
|
||||
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||
full_y_size, batch_size, num_layers, \
|
||||
layer_idx, scale); \
|
||||
break;
|
||||
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
|
||||
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
|
||||
#undef CASE
|
||||
#undef CASE_ONESIDE
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx, double scale) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t h_in = x.size(1);
|
||||
int64_t h_out = y.size(1);
|
||||
int64_t num_layers = w.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx,
|
||||
double scale, int64_t h_in, int64_t h_out,
|
||||
int64_t y_offset) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t num_layers = w.size(1);
|
||||
int64_t full_y_size = y.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
|
@ -1,11 +0,0 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx, double scale);
|
||||
|
||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx,
|
||||
double scale, int64_t h_in, int64_t h_out,
|
||||
int64_t y_offset);
|
|
@ -1,18 +0,0 @@
|
|||
#include "registration.h"
|
||||
#include "punica_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
|
||||
"layer_idx, float scale) -> ()");
|
||||
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
|
||||
|
||||
m.def(
|
||||
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
|
||||
"Tensor indicies, int layer_idx,"
|
||||
"float scale, int h_in, int h_out,"
|
||||
"int y_offset) -> ()");
|
||||
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
|
@ -1,82 +0,0 @@
|
|||
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
|
||||
#define CSRC__PUNICA__TYPE_CONVERT_H__
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#else
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
|
||||
|
||||
typedef __half nv_half;
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
|
||||
return __hip_bfloat162{val, val};
|
||||
}
|
||||
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
|
||||
return __hip_bfloat162{vall, valr};
|
||||
}
|
||||
|
||||
template <typename T_src, typename T_dst>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline T_dst convert_type(T_src val) {
|
||||
return static_cast<T_dst>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline float convert_type<__half, float>(__half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __half convert_type<float, __half>(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline T vllm_add(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __half vllm_add<__half>(__half a, __half b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
#undef __TYPE_CONVERT__HOST_DEVICE__
|
||||
|
||||
#endif // USE_ROCM
|
||||
|
||||
#endif // CSRC__PUNICA__TYPE_CONVERT_H__
|
|
@ -66,7 +66,6 @@ You can also build and install vLLM from source:
|
|||
|
||||
$ git clone https://github.com/vllm-project/vllm.git
|
||||
$ cd vllm
|
||||
$ # export VLLM_INSTALL_PUNICA_KERNELS=1 # optionally build for multi-LoRA capability
|
||||
$ pip install -e . # This may take 5-10 minutes.
|
||||
|
||||
.. tip::
|
||||
|
|
10
setup.py
10
setup.py
|
@ -181,9 +181,6 @@ class cmake_build_ext(build_ext):
|
|||
# match.
|
||||
cmake_args += ['-DVLLM_PYTHON_EXECUTABLE={}'.format(sys.executable)]
|
||||
|
||||
if _install_punica():
|
||||
cmake_args += ['-DVLLM_INSTALL_PUNICA_KERNELS=ON']
|
||||
|
||||
#
|
||||
# Setup parallelism and build tool
|
||||
#
|
||||
|
@ -274,10 +271,6 @@ def _build_custom_ops() -> bool:
|
|||
return _is_cuda() or _is_hip() or _is_cpu()
|
||||
|
||||
|
||||
def _install_punica() -> bool:
|
||||
return envs.VLLM_INSTALL_PUNICA_KERNELS
|
||||
|
||||
|
||||
def get_hipcc_rocm_version():
|
||||
# Run the hipcc --version command
|
||||
result = subprocess.run(['hipcc', '--version'],
|
||||
|
@ -446,9 +439,6 @@ if _is_cuda() or _is_hip():
|
|||
if _build_custom_ops():
|
||||
ext_modules.append(CMakeExtension(name="vllm._C"))
|
||||
|
||||
if _install_punica():
|
||||
ext_modules.append(CMakeExtension(name="vllm._punica_C"))
|
||||
|
||||
package_data = {
|
||||
"vllm": ["py.typed", "model_executor/layers/fused_moe/configs/*.json"]
|
||||
}
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import gc
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.model_executor.layers.ops.sample import (_uniform_to_exponential,
|
||||
from vllm.model_executor.layers.ops.sample import (_sample_triton,
|
||||
_uniform_to_exponential,
|
||||
sample)
|
||||
from vllm.model_executor.sampling_metadata import SamplingTensors
|
||||
from vllm.model_executor.utils import set_random_seed
|
||||
from vllm.triton_utils.libentry import LibEntry
|
||||
from vllm.triton_utils.sample import (MAX_TRITON_N_COLS,
|
||||
get_num_triton_sampler_splits)
|
||||
|
||||
|
@ -76,15 +79,20 @@ def test_sample_decoding_only(random_sampling, max_best_of,
|
|||
seeds = torch.randint(1,
|
||||
torch.iinfo(torch.long).max, (n_splits, bs),
|
||||
device="cuda").mul_(random_sampling_mask)
|
||||
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
_save_modified_probs=True)
|
||||
#The current _sample_triton does not utilize the
|
||||
# libentry decoration. The purpose of adding this patch is to test
|
||||
# the correctness of libentry.
|
||||
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
|
||||
LibEntry(_sample_triton)):
|
||||
sampled_tokens, sampled_logprobs, sampled_modified_probs = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=save_logprobs,
|
||||
_save_modified_probs=True)
|
||||
assert sampled_tokens.shape == (bs, max_best_of)
|
||||
for i in range(bs):
|
||||
assert torch.all(sampled_tokens[i] == i * (vocab_size // bs))
|
||||
|
@ -130,6 +138,7 @@ def test_sample_decoding_only(random_sampling, max_best_of,
|
|||
[SINGLE_SPLIT_VOCAB_SIZE, MULTI_SPLIT_VOCAB_SIZE])
|
||||
def test_sample_prompt_logprobs(random_sampling, max_best_of,
|
||||
modify_greedy_probs, seed, vocab_size):
|
||||
|
||||
set_random_seed(seed)
|
||||
prompt_sizes = [16, 32, 64, 128] * 2
|
||||
samples = 8
|
||||
|
@ -157,14 +166,17 @@ def test_sample_prompt_logprobs(random_sampling, max_best_of,
|
|||
seeds = torch.randint(1,
|
||||
torch.iinfo(torch.long).max, (n_splits, samples),
|
||||
device="cuda").mul_(random_sampling_mask)
|
||||
sampled_tokens, sampled_logprobs, _ = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=True)
|
||||
#ditto
|
||||
with patch("vllm.model_executor.layers.ops.sample._sample_triton",
|
||||
LibEntry(_sample_triton)):
|
||||
sampled_tokens, sampled_logprobs, _ = sample(
|
||||
probs=probs,
|
||||
logprobs=logprobs,
|
||||
sample_indices=sample_indices,
|
||||
seeds=seeds,
|
||||
max_best_of=max_best_of,
|
||||
modify_greedy_probs=modify_greedy_probs,
|
||||
save_logprobs=True)
|
||||
assert sampled_tokens.shape == (samples, max_best_of)
|
||||
assert sampled_logprobs.shape == (samples, max_best_of)
|
||||
for i, t in enumerate(sample_indices):
|
||||
|
|
|
@ -37,7 +37,7 @@ def test_gemma_lora(gemma_lora_files):
|
|||
expected_lora_output = [
|
||||
"more important than knowledge.\nAuthor: Albert Einstein\n",
|
||||
"everyone else is already taken.\nAuthor: Oscar Wilde\n",
|
||||
"so little time\nAuthor: Frank Zappa\n",
|
||||
"so little time.\nAuthor: Frank Zappa\n",
|
||||
]
|
||||
|
||||
output1 = do_sample(llm, gemma_lora_files, lora_id=1)
|
||||
|
|
|
@ -26,7 +26,8 @@ from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA,
|
|||
VocabParallelEmbeddingWithLoRA)
|
||||
# yapf: enable
|
||||
from vllm.lora.models import (LongContextLoRAContext, LoRALayerWeights,
|
||||
PackedLoRALayerWeights, convert_mapping)
|
||||
PackedLoRALayerWeights)
|
||||
from vllm.lora.punica import PunicaWrapper
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
|
@ -47,6 +48,9 @@ TOLERANCES = {
|
|||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
# We will launch different triton kernels between the prefill and decode
|
||||
# stages, so we need to verify this. prefill stage(True) or decode stage(False)
|
||||
STAGES = [True, False]
|
||||
|
||||
|
||||
def get_random_id_to_index(num_loras: int,
|
||||
|
@ -182,10 +186,12 @@ def create_random_inputs(
|
|||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_embeddings(dist_init, num_loras, device, vocab_size, stage) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16)
|
||||
|
@ -204,7 +210,7 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
|
|||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
embedding, lora_embedding = create_random_embedding_layer()
|
||||
|
||||
lora_embedding.set_mapping(punica_wrapper)
|
||||
lora_dict, _ = populate_loras(
|
||||
id_to_index,
|
||||
layer=lora_embedding,
|
||||
|
@ -217,12 +223,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
|
|||
input_size=(200, ),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info)
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
|
||||
|
@ -255,12 +261,12 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
|
|||
input_size=(200, ),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_embedding(torch.cat(inputs))
|
||||
expected_result = embedding(torch.cat(inputs))
|
||||
|
@ -278,11 +284,13 @@ def test_embeddings(dist_init, num_loras, device, vocab_size) -> None:
|
|||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
||||
vocab_size) -> None:
|
||||
vocab_size, stage) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16)
|
||||
|
@ -318,6 +326,7 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||
generate_embeddings_tensor=256,
|
||||
)
|
||||
|
||||
lora_embedding.set_mapping(punica_wrapper)
|
||||
# All embeddings tensors have the same shape.
|
||||
embeddings_tensors = [
|
||||
lora_dict[id].embeddings_tensor for id in sorted(lora_dict.keys())
|
||||
|
@ -334,8 +343,12 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||
input_size=(200, ),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
original_inputs = deepcopy(inputs)
|
||||
|
||||
# Force some of the inputs to be in the extended embeddings range
|
||||
|
@ -349,11 +362,6 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||
(embedding_id + 1) * embeddings_tensor_len - 1)
|
||||
original_input_[-2] = vocab_size + embeddings_tensor_len - 1
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
expanded_embedding.weight[vocab_size:vocab_size +
|
||||
(embeddings_tensor_len *
|
||||
max_loras)] = torch.cat(embeddings_tensors)
|
||||
|
@ -390,15 +398,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||
input_size=(200, ),
|
||||
input_range=(1, vocab_size),
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
original_inputs = deepcopy(inputs)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_embedding.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_embedding(torch.cat(original_inputs))
|
||||
expected_result = expanded_embedding(torch.cat(inputs))
|
||||
|
||||
|
@ -413,11 +419,13 @@ def test_embeddings_with_new_embeddings(dist_init, num_loras, device,
|
|||
@pytest.mark.parametrize("num_loras", [1, 2, 4, 8])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("vocab_size", [512, 32000, 64000, 128000])
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device,
|
||||
vocab_size) -> None:
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_lm_head_logits_processor(dist_init, num_loras, device, vocab_size,
|
||||
stage) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
max_loras = 8
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
lora_dtype=torch.float16)
|
||||
|
@ -443,7 +451,7 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
|||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
linear, logits_processor, lora_logits_processor = _pretest()
|
||||
|
||||
lora_logits_processor.set_mapping(punica_wrapper)
|
||||
# NOTE: all the generated loras share the same embeddings tensor.
|
||||
lora_dict, _ = populate_loras(
|
||||
id_to_index,
|
||||
|
@ -461,17 +469,17 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
|||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
input_ = torch.rand(20, 1024)
|
||||
mapping_info = convert_mapping(
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
input_ = torch.rand(20, 1024)
|
||||
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
|
@ -510,12 +518,16 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
|||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size)
|
||||
lora_logits_processor.set_mapping(*mapping_info, )
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
|
||||
lora_result = lora_logits_processor._get_logits(
|
||||
hidden_states=torch.cat(inputs),
|
||||
|
@ -538,10 +550,12 @@ def test_lm_head_logits_processor(dist_init, num_loras, device,
|
|||
@pytest.mark.parametrize("orientation", ["row", "column"])
|
||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
||||
device) -> None:
|
||||
device, stage) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
|
@ -575,7 +589,7 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||
|
||||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
linear, lora_linear = create_random_linear_parallel_layer()
|
||||
|
||||
lora_linear.set_mapping(punica_wrapper)
|
||||
lora_dict, _ = populate_loras(
|
||||
id_to_index,
|
||||
layer=lora_linear,
|
||||
|
@ -589,16 +603,16 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
|
||||
mapping_info = convert_mapping(
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
lora_linear.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
|
||||
|
@ -628,11 +642,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
|
||||
mapping_info = convert_mapping(lora_mapping, id_to_index, max_loras,
|
||||
punica_wrapper.update_metadata(lora_mapping, id_to_index, max_loras,
|
||||
512, lora_config.lora_extra_vocab_size)
|
||||
lora_linear.set_mapping(*mapping_info, )
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
expected_result = linear(torch.cat(inputs))[0]
|
||||
|
@ -649,10 +664,12 @@ def test_linear_parallel(dist_init, num_loras, orientation, fully_shard,
|
|||
@pytest.mark.parametrize("repeats", [1, 2, 3])
|
||||
@pytest.mark.parametrize("fully_shard", [True, False])
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@pytest.mark.parametrize("stage", STAGES)
|
||||
def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
||||
device) -> None:
|
||||
device, stage) -> None:
|
||||
|
||||
torch.set_default_device(device)
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
|
@ -707,7 +724,7 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||
id_to_index = get_random_id_to_index(num_loras, max_loras)
|
||||
|
||||
linear, lora_linear = create_column_parallel_packed_layer()
|
||||
|
||||
lora_linear.set_mapping(punica_wrapper)
|
||||
lora_dict, sublora_dict = populate_loras(
|
||||
id_to_index,
|
||||
layer=lora_linear,
|
||||
|
@ -722,16 +739,17 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
|
||||
mapping_info = convert_mapping(
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
lora_linear.set_mapping(*mapping_info)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
|
||||
|
@ -762,16 +780,18 @@ def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard,
|
|||
input_range=(0, 1),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
lora_mapping = LoRAMapping(index_mapping,
|
||||
prompt_mapping,
|
||||
is_prefill=stage)
|
||||
|
||||
mapping_info = convert_mapping(
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
512,
|
||||
lora_config.lora_extra_vocab_size,
|
||||
)
|
||||
lora_linear.set_mapping(*mapping_info)
|
||||
# lora_linear.set_mapping(*mapping_info)
|
||||
|
||||
lora_result = lora_linear(torch.cat(inputs))[0]
|
||||
expected_result = linear(torch.cat(inputs))[0]
|
||||
|
@ -803,7 +823,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
|||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.set_default_device(device)
|
||||
|
||||
punica_wrapper = PunicaWrapper(8192, 256, device)
|
||||
max_loras = 8
|
||||
lora_config = LoRAConfig(max_loras=max_loras,
|
||||
max_lora_rank=8,
|
||||
|
@ -825,6 +845,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
|||
is_neox_style,
|
||||
)
|
||||
lora_rope = LinearScalingRotaryEmbeddingWithLora(rope)
|
||||
lora_rope.set_mapping(punica_wrapper)
|
||||
lora_rope.create_lora_weights(max_loras, lora_config)
|
||||
linear_rope = get_rope(head_size, rotary_dim, max_position, base,
|
||||
is_neox_style, {
|
||||
|
@ -840,6 +861,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
|||
input_range=(0, lora_config.lora_extra_vocab_size),
|
||||
input_type=torch.float16,
|
||||
)
|
||||
|
||||
lora_mapping = LoRAMapping(index_mapping, prompt_mapping)
|
||||
long_lora_context = LongContextLoRAContext(list(scaling_factors),
|
||||
rotary_dim)
|
||||
|
@ -854,7 +876,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
|||
for i in range(len(scaling_factors)):
|
||||
long_lora_context.offsets_by_lora_id[i] = scaling_factor_to_offset.get(
|
||||
scaling_factors[i], 0)
|
||||
mapping_info = convert_mapping(
|
||||
punica_wrapper.update_metadata(
|
||||
lora_mapping,
|
||||
id_to_index,
|
||||
max_loras,
|
||||
|
@ -862,7 +884,7 @@ def test_rotary_embedding_long_context(dist_init, num_loras, device,
|
|||
lora_config.lora_extra_vocab_size,
|
||||
long_lora_context=long_lora_context,
|
||||
)
|
||||
lora_rope.set_mapping(*mapping_info)
|
||||
# lora_rope.set_mapping(*mapping_info)
|
||||
|
||||
positions = torch.randint(0, max_position, (batch_size, seq_len))
|
||||
query = torch.randn(batch_size,
|
||||
|
|
|
@ -1,224 +0,0 @@
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.lora.layers import _apply_lora, _apply_lora_packed_nslice
|
||||
|
||||
from .utils import DummyLoRAManager
|
||||
|
||||
TENSOR_SIZES = [128, 1024, 2048, 4096, 8192, 11008, 11008 // 2, 11008 // 4]
|
||||
QKV_TENSOR_SIZES = [
|
||||
(8192, 1024, 1024),
|
||||
(8192 // 8, 1024 // 8, 1024 // 8),
|
||||
(4096, 4096, 4096),
|
||||
(4096 // 2, 4096 // 2, 4096 // 2),
|
||||
]
|
||||
BATCH_SIZES = [8, 32, 256]
|
||||
RANKS = [8]
|
||||
DTYPES = [torch.float16]
|
||||
TOLERANCES = {
|
||||
torch.float16: (5e-3, 5e-3),
|
||||
torch.bfloat16: (3e-2, 2e-2),
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", TENSOR_SIZES)
|
||||
@pytest.mark.parametrize("n", TENSOR_SIZES)
|
||||
@pytest.mark.parametrize("k", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("rank", RANKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_apply_lora(m, n, k, rank, dtype) -> None:
|
||||
manager = DummyLoRAManager()
|
||||
|
||||
module_name = "module"
|
||||
weight = torch.rand([m, n], device="cuda", dtype=dtype)
|
||||
|
||||
manager.init_random_lora(module_name, weight, rank=rank)
|
||||
lora = manager.get_module_lora(module_name)
|
||||
|
||||
input = torch.rand(k, n, device="cuda", dtype=dtype)
|
||||
expected = input @ lora.lora_a @ lora.lora_b * lora.scaling
|
||||
|
||||
lora_a_stack = torch.zeros(8,
|
||||
1,
|
||||
lora.lora_a.shape[1],
|
||||
lora.lora_a.shape[0],
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
lora_b_stack = torch.zeros(8,
|
||||
1,
|
||||
lora.lora_b.shape[1],
|
||||
lora.lora_b.shape[0],
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
for i in range(lora_a_stack.shape[0]):
|
||||
lora_a_stack[i][0] = lora.lora_a.T
|
||||
lora_b_stack[i][0] = (lora.lora_b * lora.scaling).T
|
||||
|
||||
output = torch.zeros(k, m, device="cuda", dtype=dtype)
|
||||
_apply_lora(
|
||||
input, lora_a_stack, lora_b_stack,
|
||||
torch.randint(0, lora_a_stack.shape[0], (len(input), ), device="cuda"),
|
||||
output)
|
||||
|
||||
rtol, atol = TOLERANCES[dtype]
|
||||
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
|
||||
|
||||
output[:] = 0
|
||||
_apply_lora(input, lora_a_stack, lora_b_stack,
|
||||
torch.full((len(input), ), -1, device="cuda"), output)
|
||||
assert torch.allclose(torch.zeros_like(output), output)
|
||||
|
||||
manager.reset_lora()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("m", TENSOR_SIZES)
|
||||
@pytest.mark.parametrize("n", TENSOR_SIZES)
|
||||
@pytest.mark.parametrize("k", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("rank", RANKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_apply_lora_packed_2slice(m, n, k, rank, dtype) -> None:
|
||||
if m % 2 != 0:
|
||||
pytest.skip("m must be divisible by 2")
|
||||
if m // 2 not in TENSOR_SIZES:
|
||||
pytest.skip("m//2 must be in TENSOR_SIZES")
|
||||
|
||||
manager = DummyLoRAManager()
|
||||
|
||||
module_name = "module"
|
||||
weight = torch.rand([m // 2, n], device="cuda", dtype=dtype)
|
||||
|
||||
manager.init_random_lora(module_name + "1", weight, rank=rank)
|
||||
lora_1 = manager.get_module_lora(module_name + "1")
|
||||
manager.init_random_lora(module_name + "2", weight, rank=rank)
|
||||
lora_2 = manager.get_module_lora(module_name + "2")
|
||||
|
||||
input = torch.rand(k, n, device="cuda", dtype=dtype)
|
||||
expected = torch.cat([
|
||||
input @ lora_1.lora_a @ lora_1.lora_b * lora_1.scaling,
|
||||
input @ lora_2.lora_a @ lora_2.lora_b * lora_2.scaling
|
||||
],
|
||||
dim=1)
|
||||
|
||||
lora_a_stacks = [
|
||||
torch.zeros(8,
|
||||
1,
|
||||
lora_1.lora_a.shape[1],
|
||||
lora_1.lora_a.shape[0],
|
||||
device="cuda",
|
||||
dtype=dtype) for i in range(2)
|
||||
]
|
||||
lora_b_stacks = [
|
||||
torch.zeros(8,
|
||||
1,
|
||||
lora_1.lora_b.shape[1],
|
||||
lora_1.lora_b.shape[0],
|
||||
device="cuda",
|
||||
dtype=dtype) for i in range(2)
|
||||
]
|
||||
for i in range(lora_a_stacks[0].shape[0]):
|
||||
lora_a_stacks[0][i][0] = lora_1.lora_a.T
|
||||
lora_b_stacks[0][i][0] = (lora_1.lora_b * lora_1.scaling).T
|
||||
lora_a_stacks[1][i][0] = lora_2.lora_a.T
|
||||
lora_b_stacks[1][i][0] = (lora_2.lora_b * lora_2.scaling).T
|
||||
|
||||
output = torch.zeros(k, m, device="cuda", dtype=dtype)
|
||||
_apply_lora_packed_nslice(
|
||||
input, lora_a_stacks, lora_b_stacks,
|
||||
torch.randint(0,
|
||||
lora_a_stacks[0].shape[0], (len(input), ),
|
||||
device="cuda"), output, (m // 2, m // 2))
|
||||
|
||||
rtol, atol = TOLERANCES[dtype]
|
||||
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
|
||||
|
||||
output[:] = 0
|
||||
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
|
||||
torch.full((len(input), ), -1, device="cuda"),
|
||||
output, (m // 2, m // 2))
|
||||
assert torch.allclose(torch.zeros_like(output), output)
|
||||
|
||||
manager.reset_lora()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("qkv", QKV_TENSOR_SIZES)
|
||||
@pytest.mark.parametrize("n", TENSOR_SIZES)
|
||||
@pytest.mark.parametrize("k", BATCH_SIZES)
|
||||
@pytest.mark.parametrize("rank", RANKS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
def test_apply_lora_packed_3slice(qkv, n, k, rank, dtype) -> None:
|
||||
manager = DummyLoRAManager()
|
||||
|
||||
module_name = "module"
|
||||
weight_q = torch.empty(qkv[0], n, device="cuda", dtype=dtype)
|
||||
weight_kv = torch.empty(qkv[1], n, device="cuda", dtype=dtype)
|
||||
|
||||
manager.init_random_lora(module_name + "q", weight_q, rank=rank)
|
||||
lora_q = manager.get_module_lora(module_name + "q")
|
||||
manager.init_random_lora(module_name + "k", weight_kv, rank=rank)
|
||||
lora_k = manager.get_module_lora(module_name + "k")
|
||||
manager.init_random_lora(module_name + "v", weight_kv, rank=rank)
|
||||
lora_v = manager.get_module_lora(module_name + "v")
|
||||
|
||||
input = torch.rand(k, n, device="cuda", dtype=dtype)
|
||||
expected = torch.cat([
|
||||
input @ lora_q.lora_a @ lora_q.lora_b * lora_q.scaling,
|
||||
input @ lora_k.lora_a @ lora_k.lora_b * lora_k.scaling,
|
||||
input @ lora_v.lora_a @ lora_v.lora_b * lora_v.scaling
|
||||
],
|
||||
dim=1)
|
||||
|
||||
lora_a_stacks = [
|
||||
torch.zeros(8,
|
||||
1,
|
||||
lora_q.lora_a.shape[1],
|
||||
lora_q.lora_a.shape[0],
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
] + [
|
||||
torch.zeros(8,
|
||||
1,
|
||||
lora_k.lora_a.shape[1],
|
||||
lora_k.lora_a.shape[0],
|
||||
device="cuda",
|
||||
dtype=dtype) for i in range(2)
|
||||
]
|
||||
lora_b_stacks = [
|
||||
torch.zeros(8,
|
||||
1,
|
||||
lora_q.lora_b.shape[1],
|
||||
lora_q.lora_b.shape[0],
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
] + [
|
||||
torch.zeros(8,
|
||||
1,
|
||||
lora_k.lora_b.shape[1],
|
||||
lora_k.lora_b.shape[0],
|
||||
device="cuda",
|
||||
dtype=dtype) for i in range(2)
|
||||
]
|
||||
for i in range(lora_a_stacks[0].shape[0]):
|
||||
lora_a_stacks[0][i][0] = lora_q.lora_a.T
|
||||
lora_b_stacks[0][i][0] = (lora_q.lora_b * lora_q.scaling).T
|
||||
lora_a_stacks[1][i][0] = lora_k.lora_a.T
|
||||
lora_b_stacks[1][i][0] = (lora_k.lora_b * lora_k.scaling).T
|
||||
lora_a_stacks[2][i][0] = lora_v.lora_a.T
|
||||
lora_b_stacks[2][i][0] = (lora_v.lora_b * lora_v.scaling).T
|
||||
|
||||
output = torch.zeros(k, sum(qkv), device="cuda", dtype=dtype)
|
||||
_apply_lora_packed_nslice(
|
||||
input, lora_a_stacks, lora_b_stacks,
|
||||
torch.randint(0,
|
||||
lora_a_stacks[0].shape[0], (len(input), ),
|
||||
device="cuda"), output, (qkv[0], qkv[1], qkv[2]))
|
||||
|
||||
rtol, atol = TOLERANCES[dtype]
|
||||
assert torch.allclose(expected, output, rtol=rtol, atol=atol)
|
||||
|
||||
output[:] = 0
|
||||
_apply_lora_packed_nslice(input, lora_a_stacks, lora_b_stacks,
|
||||
torch.full((len(input), ), -1, device="cuda"),
|
||||
output, (qkv[0], qkv[1], qkv[2]))
|
||||
assert torch.allclose(torch.zeros_like(output), output)
|
||||
|
||||
manager.reset_lora()
|
|
@ -1,258 +0,0 @@
|
|||
# Based on code from https://github.com/punica-ai/punica
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.lora.punica as punica
|
||||
|
||||
|
||||
def assert_close(a, b):
|
||||
rtol, atol = {
|
||||
torch.float16: (5e-3, 5e-3),
|
||||
torch.bfloat16: (3e-2, 2e-2),
|
||||
torch.float32: (None, None),
|
||||
}[a.dtype]
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def _lora_ref_impl(
|
||||
y_final: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_T_all: torch.Tensor,
|
||||
wb_T_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
):
|
||||
y_stage_1 = torch.empty(
|
||||
(x.size(0), wa_T_all.size(-2)),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
bs = x.shape[0]
|
||||
s = torch.tensor(scale, dtype=torch.float32, device=x.device)
|
||||
for i, lora_idx in zip(range(bs), indicies.cpu().tolist()):
|
||||
xi = x[i].unsqueeze(0).to(torch.float32)
|
||||
wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32)
|
||||
if wb_T_all is not None:
|
||||
wb = wb_T_all[lora_idx, layer_idx].transpose(-1,
|
||||
-2).to(torch.float32)
|
||||
|
||||
tmp = xi @ wa
|
||||
y_stage_1[i] = tmp.squeeze(0)
|
||||
y_final[i] += ((tmp @ wb).squeeze(0) *
|
||||
s if wb_T_all is not None else y_stage_1[i])
|
||||
return y_final, y_stage_1
|
||||
|
||||
|
||||
H1 = H2 = [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
896,
|
||||
1024,
|
||||
1152,
|
||||
1216,
|
||||
1280,
|
||||
1536,
|
||||
1664,
|
||||
2048,
|
||||
2240,
|
||||
2304,
|
||||
2368,
|
||||
2432,
|
||||
2560,
|
||||
2752,
|
||||
3072,
|
||||
3328,
|
||||
3456,
|
||||
3584,
|
||||
3712,
|
||||
4096,
|
||||
4480,
|
||||
4608,
|
||||
4736,
|
||||
4864,
|
||||
5120,
|
||||
5504,
|
||||
5632,
|
||||
5888,
|
||||
6144,
|
||||
6400,
|
||||
6848,
|
||||
6912,
|
||||
7168,
|
||||
7424,
|
||||
8192,
|
||||
8960,
|
||||
9216,
|
||||
9472,
|
||||
10240,
|
||||
11008,
|
||||
11264,
|
||||
13824,
|
||||
14336,
|
||||
14784,
|
||||
14848,
|
||||
15360,
|
||||
18944,
|
||||
22016,
|
||||
22528,
|
||||
24576,
|
||||
27392,
|
||||
27648,
|
||||
29568,
|
||||
29696,
|
||||
32000,
|
||||
32256,
|
||||
32512,
|
||||
32768,
|
||||
33024,
|
||||
36864,
|
||||
43264,
|
||||
49152,
|
||||
49408,
|
||||
60544,
|
||||
60672,
|
||||
64000,
|
||||
64256,
|
||||
102400,
|
||||
102656,
|
||||
128000,
|
||||
128256,
|
||||
]
|
||||
H2 = [64] + H2
|
||||
R = [1, 2, 4]
|
||||
SEED = [0xabcdabcd987]
|
||||
CUDA_DEVICES = [
|
||||
f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2)
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||
@pytest.mark.parametrize("h1", H1)
|
||||
@pytest.mark.parametrize("r", R)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@torch.inference_mode()
|
||||
def test_lora_a_extra_shapes(dtype_str, h1, r, seed):
|
||||
torch.manual_seed(seed)
|
||||
num_loras = 4
|
||||
num_layers = 1
|
||||
bs = 32
|
||||
dtype = getattr(torch, dtype_str)
|
||||
device = torch.device("cuda")
|
||||
|
||||
wa_T_all = torch.randn(num_loras,
|
||||
num_layers,
|
||||
r,
|
||||
h1,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
x = torch.randn(bs, h1, dtype=dtype, device=device)
|
||||
y = torch.randn(bs, r, dtype=dtype, device=device)
|
||||
|
||||
y_ref = y.clone()
|
||||
_lora_ref_impl(
|
||||
y_ref,
|
||||
x,
|
||||
wa_T_all,
|
||||
None,
|
||||
indices,
|
||||
layer_idx,
|
||||
1.0,
|
||||
)
|
||||
|
||||
y_our = y.clone()
|
||||
punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0)
|
||||
|
||||
assert_close(y_ref, y_our)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||
@pytest.mark.parametrize("h1", H1)
|
||||
@pytest.mark.parametrize("h2", H2)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_lora_correctness(dtype_str, h1, h2, seed, device):
|
||||
torch.manual_seed(seed)
|
||||
num_loras = 4
|
||||
num_layers = 1
|
||||
r = 8
|
||||
bs = 32
|
||||
scale = 0.123
|
||||
dtype = getattr(torch, dtype_str)
|
||||
torch.set_default_device(device)
|
||||
|
||||
wa_T_all = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
|
||||
wb_T_all = torch.randn(num_loras, num_layers, h2, r, dtype=dtype)
|
||||
indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
x = torch.randn(bs, h1, dtype=dtype)
|
||||
y = torch.randn(bs, h2, dtype=dtype)
|
||||
|
||||
y_ref = y.clone()
|
||||
_lora_ref_impl(y_ref, x, wa_T_all, wb_T_all, indices, layer_idx, scale)
|
||||
|
||||
y_our = y.clone()
|
||||
punica.add_lora(y_our, x, wa_T_all, wb_T_all, indices, layer_idx,
|
||||
scale)
|
||||
|
||||
assert_close(y_ref, y_our)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"])
|
||||
@pytest.mark.parametrize("h1", H1)
|
||||
@pytest.mark.parametrize("h2", H2)
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_lora_correctness_slice(dtype_str, h1, h2, seed, device):
|
||||
if h2 % 3 != 0 or h2 // 3 not in H1:
|
||||
pytest.skip("h2 must be divisible by 3 and in supported shapes")
|
||||
torch.manual_seed(seed)
|
||||
num_loras = 4
|
||||
num_layers = 1
|
||||
r = 8
|
||||
bs = 32
|
||||
scale = 0.123
|
||||
dtype = getattr(torch, dtype_str)
|
||||
torch.set_default_device(device)
|
||||
|
||||
wa_T_all_0 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
|
||||
wa_T_all_1 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
|
||||
wa_T_all_2 = torch.randn(num_loras, num_layers, r, h1, dtype=dtype)
|
||||
wb_T_all_0 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
|
||||
wb_T_all_1 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
|
||||
wb_T_all_2 = torch.randn(num_loras, num_layers, h2 // 3, r, dtype=dtype)
|
||||
|
||||
indices = torch.randint(num_loras, (bs, ), dtype=torch.long)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
x = torch.randn(bs, h1, dtype=dtype)
|
||||
y = torch.randn(bs, h2, dtype=dtype)
|
||||
s = h2 // 3
|
||||
|
||||
y_ref = y.clone()
|
||||
_lora_ref_impl(y_ref[:, :s], x, wa_T_all_0, wb_T_all_0, indices,
|
||||
layer_idx, scale)
|
||||
_lora_ref_impl(y_ref[:, s:s * 2], x, wa_T_all_1, wb_T_all_1, indices,
|
||||
layer_idx, scale)
|
||||
_lora_ref_impl(y_ref[:, s * 2:], x, wa_T_all_2, wb_T_all_2, indices,
|
||||
layer_idx, scale)
|
||||
|
||||
y_our = y.clone()
|
||||
punica.add_lora_slice(y_our, x, wa_T_all_0, wb_T_all_0, indices,
|
||||
layer_idx, scale, 0, s)
|
||||
punica.add_lora_slice(y_our, x, wa_T_all_1, wb_T_all_1, indices,
|
||||
layer_idx, scale, s, s)
|
||||
punica.add_lora_slice(y_our, x, wa_T_all_2, wb_T_all_2, indices,
|
||||
layer_idx, scale, s * 2, s)
|
||||
|
||||
assert_close(y_ref[:, :s], y_our[:, :s])
|
||||
assert_close(y_ref[:, s:s * 2], y_our[:, s:s * 2])
|
||||
assert_close(y_ref[:, s * 2:], y_our[:, s * 2:])
|
|
@ -0,0 +1,408 @@
|
|||
"""
|
||||
This script is mainly used to tests various hidden_sizes. We have collected the
|
||||
hidden_sizes included in the LoRA models currently supported by vLLM. It tests
|
||||
whether the corresponding Triton kernel can run normally when tensor parallelism
|
||||
is set to [1, 2, 4, 8, 16, 32, 64].
|
||||
"""
|
||||
import random
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.triton_utils.libentry import LibEntry
|
||||
|
||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||
ref_torch_groupgemm)
|
||||
|
||||
HIDDEN_SIZES = [
|
||||
128,
|
||||
256,
|
||||
512,
|
||||
896,
|
||||
1024,
|
||||
1152,
|
||||
1216,
|
||||
1280,
|
||||
1536,
|
||||
1664,
|
||||
2048,
|
||||
2240,
|
||||
2304,
|
||||
2368,
|
||||
2432,
|
||||
2560,
|
||||
2752,
|
||||
3072,
|
||||
3328,
|
||||
3456,
|
||||
3584,
|
||||
3712,
|
||||
4096,
|
||||
4480,
|
||||
4608,
|
||||
4736,
|
||||
4864,
|
||||
5120,
|
||||
5504,
|
||||
5632,
|
||||
5888,
|
||||
6144,
|
||||
6400,
|
||||
6848,
|
||||
6912,
|
||||
7168,
|
||||
7424,
|
||||
8192,
|
||||
8960,
|
||||
9216,
|
||||
9472,
|
||||
10240,
|
||||
11008,
|
||||
11264,
|
||||
13824,
|
||||
14336,
|
||||
14784,
|
||||
14848,
|
||||
15360,
|
||||
18944,
|
||||
22016,
|
||||
22528,
|
||||
24576,
|
||||
27392,
|
||||
27648,
|
||||
29568,
|
||||
29696,
|
||||
32000,
|
||||
32256,
|
||||
32512,
|
||||
32768,
|
||||
33024,
|
||||
36864,
|
||||
43264,
|
||||
49152,
|
||||
49408,
|
||||
60544,
|
||||
60672,
|
||||
64000,
|
||||
64256,
|
||||
102400,
|
||||
102656,
|
||||
128000,
|
||||
128256,
|
||||
]
|
||||
#The size of TP
|
||||
divisibility = [1, 2, 4, 8, 16, 32, 64]
|
||||
|
||||
all_hidden_size = []
|
||||
for div in divisibility:
|
||||
for hidden_size in HIDDEN_SIZES:
|
||||
all_hidden_size.append(hidden_size // div)
|
||||
|
||||
HIDDEN_SIZES = list(set(all_hidden_size))
|
||||
|
||||
BATCHES = [4]
|
||||
NUM_LORA = [4]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
MAX_RANKS = [32]
|
||||
SCALES = [0.5]
|
||||
SEED = [0]
|
||||
CUDA_DEVICES = [f"cuda:{0}"]
|
||||
|
||||
|
||||
def assert_close(a, b):
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
torch.bfloat16: (6e-2, 6e-2),
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
}[a.dtype]
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("scaling", SCALES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_punica_sgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
scaling: float,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_length = 128
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
if op_type == "shrink":
|
||||
sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
scaling,
|
||||
)
|
||||
else:
|
||||
sgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor,
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling if op_type == "shrink" else 1.0,
|
||||
op_type,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("scaling", SCALES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_punica_bgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
scaling: float,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
|
||||
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
|
||||
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
# The current _bgmv_shrink_kernel does not require the libentry
|
||||
# decoration. The purpose of adding this patch is to test the
|
||||
# correctness of libentry.
|
||||
with patch(
|
||||
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
|
||||
LibEntry(_bgmv_shrink_kernel),
|
||||
):
|
||||
bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
else:
|
||||
# ditto
|
||||
with patch(
|
||||
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
|
||||
LibEntry(_bgmv_expand_kernel),
|
||||
):
|
||||
bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor,
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling if op_type == "shrink" else 1.0,
|
||||
op_type,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_punica_expand_nslices(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
|
||||
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seq_length = 128 if op_type == "sgmv" else 1
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_outputs,
|
||||
ref_outputs,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data_for_expand_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
nslices,
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
if op_type == "sgmv":
|
||||
sgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
else:
|
||||
# The current _bgmv_expand_slice_kernel does not require the
|
||||
# libentry decoration. The purpose of adding this patch is to test
|
||||
# the correctness of libentry.
|
||||
with patch(
|
||||
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
|
||||
LibEntry(_bgmv_expand_slice_kernel),
|
||||
):
|
||||
bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
1.0,
|
||||
op_type="expand",
|
||||
)
|
||||
|
||||
slice_offset += hidden_size
|
||||
assert_close(our_outputs, ref_outputs)
|
|
@ -0,0 +1,342 @@
|
|||
"""
|
||||
This script is mainly used to test whether trtion kernels can run normally
|
||||
under different conditions, including various batches, numbers of LoRA , and
|
||||
maximum ranks.
|
||||
"""
|
||||
import random
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
from vllm.triton_utils.libentry import LibEntry
|
||||
|
||||
from .utils import (generate_data, generate_data_for_expand_nslices,
|
||||
ref_torch_groupgemm)
|
||||
|
||||
HIDDEN_SIZES = [3424, 4096, 4097]
|
||||
|
||||
BATCHES = [1, 4, 16, 32]
|
||||
NUM_LORA = [1, 4, 8, 16, 32, 64, 128]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
MAX_RANKS = [1, 4, 8, 16, 32, 64, 128]
|
||||
SCALES = [0.5]
|
||||
SEED = [0]
|
||||
CUDA_DEVICES = [f"cuda:{0}"]
|
||||
|
||||
|
||||
def assert_close(a, b):
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
torch.bfloat16: (6e-2, 6e-2),
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
}[a.dtype]
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("scaling", SCALES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_punica_sgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
scaling: float,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_length = 128
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
if op_type == "shrink":
|
||||
sgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
scaling,
|
||||
)
|
||||
else:
|
||||
sgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor,
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling if op_type == "shrink" else 1.0,
|
||||
op_type,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("scaling", SCALES)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["shrink", "expand"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_punica_bgmv(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
scaling: float,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
from vllm.lora.ops.bgmv_expand import _bgmv_expand_kernel
|
||||
from vllm.lora.ops.bgmv_shrink import _bgmv_shrink_kernel
|
||||
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
|
||||
seq_length = 1
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
op_type,
|
||||
device,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
# The current _bgmv_shrink_kernel does not require the libentry
|
||||
# decoration. The purpose of adding this patch is to test the
|
||||
# correctness of libentry.
|
||||
with patch(
|
||||
"vllm.lora.ops.bgmv_shrink._bgmv_shrink_kernel",
|
||||
LibEntry(_bgmv_shrink_kernel),
|
||||
):
|
||||
bgmv_shrink(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
scaling,
|
||||
)
|
||||
else:
|
||||
# ditto
|
||||
with patch(
|
||||
"vllm.lora.ops.bgmv_expand._bgmv_expand_kernel",
|
||||
LibEntry(_bgmv_expand_kernel),
|
||||
):
|
||||
bgmv_expand(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
indices,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_out_tensor,
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling if op_type == "shrink" else 1.0,
|
||||
op_type,
|
||||
)
|
||||
if op_type == "shrink":
|
||||
ref_out_tensor = ref_out_tensor.to(torch.float32)
|
||||
assert_close(our_out_tensor, ref_out_tensor)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batches", BATCHES)
|
||||
@pytest.mark.parametrize("num_loras", NUM_LORA)
|
||||
@pytest.mark.parametrize("rank", MAX_RANKS)
|
||||
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
|
||||
@pytest.mark.parametrize("nslices", [2, 3])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"])
|
||||
@pytest.mark.parametrize("seed", SEED)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
def test_punica_expand_nslices(
|
||||
batches: int,
|
||||
num_loras: int,
|
||||
rank: int,
|
||||
hidden_size: int,
|
||||
nslices: int,
|
||||
dtype: torch.dtype,
|
||||
op_type: str,
|
||||
seed: int,
|
||||
device: str,
|
||||
):
|
||||
from vllm.lora.ops.bgmv_expand_slice import _bgmv_expand_slice_kernel
|
||||
|
||||
random.seed(seed)
|
||||
torch.set_default_device(device)
|
||||
torch.random.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
seq_length = 128 if op_type == "sgmv" else 1
|
||||
(
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_outputs,
|
||||
ref_outputs,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
) = generate_data_for_expand_nslices(
|
||||
batches,
|
||||
hidden_size,
|
||||
num_loras,
|
||||
rank,
|
||||
seq_length,
|
||||
dtype,
|
||||
nslices,
|
||||
device,
|
||||
)
|
||||
max_seq_length = seq_len_tensor.max()
|
||||
if isinstance(max_seq_length, tuple):
|
||||
max_seq_length = max_seq_length[0].item()
|
||||
else:
|
||||
max_seq_length = max_seq_length.item()
|
||||
slice_offset = 0
|
||||
for index in range(nslices):
|
||||
lora_weights = lora_weights_lst[index]
|
||||
if op_type == "sgmv":
|
||||
sgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
batches,
|
||||
max_seq_length,
|
||||
slice_offset,
|
||||
hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
else:
|
||||
# The current _bgmv_expand_slice_kernel does not require the
|
||||
# libentry decoration. The purpose of adding this patch is to test
|
||||
# the correctness of libentry.
|
||||
with patch(
|
||||
"vllm.lora.ops.bgmv_expand_slice._bgmv_expand_slice_kernel",
|
||||
LibEntry(_bgmv_expand_slice_kernel),
|
||||
):
|
||||
bgmv_expand_slice(
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_outputs,
|
||||
indices,
|
||||
slice_offset,
|
||||
slice_size=hidden_size,
|
||||
add_inputs=True,
|
||||
)
|
||||
ref_torch_groupgemm(
|
||||
ref_outputs[:, slice_offset:slice_offset + hidden_size],
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
1.0,
|
||||
op_type="expand",
|
||||
)
|
||||
|
||||
slice_offset += hidden_size
|
||||
assert_close(our_outputs, ref_outputs)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from itertools import product
|
||||
|
||||
lst = list(
|
||||
product(
|
||||
BATCHES,
|
||||
NUM_LORA,
|
||||
MAX_RANKS,
|
||||
[1.0],
|
||||
[torch.float16],
|
||||
["expand"],
|
||||
SEED,
|
||||
CUDA_DEVICES,
|
||||
))
|
||||
for ele in lst:
|
||||
test_punica_bgmv(*ele)
|
||||
print(f"{ele},pass")
|
|
@ -64,14 +64,16 @@ def test_quant_model_lora(tinyllama_lora_files, model, tp_size):
|
|||
# if torch.cuda.device_count() < tp_size:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
|
||||
|
||||
llm = vllm.LLM(model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
max_model_len=400,
|
||||
tensor_parallel_size=tp_size,
|
||||
quantization=model.quantization,
|
||||
trust_remote_code=True)
|
||||
llm = vllm.LLM(
|
||||
model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
max_model_len=400,
|
||||
tensor_parallel_size=tp_size,
|
||||
gpu_memory_utilization=0.2, #avoid OOM
|
||||
quantization=model.quantization,
|
||||
trust_remote_code=True)
|
||||
|
||||
if model.quantization is None:
|
||||
expected_no_lora_output = [
|
||||
|
@ -156,24 +158,28 @@ def test_quant_model_tp_equality(tinyllama_lora_files, model):
|
|||
# if torch.cuda.device_count() < 2:
|
||||
# pytest.skip(f"Not enough GPUs for tensor parallelism {2}")
|
||||
|
||||
llm_tp1 = vllm.LLM(model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=1,
|
||||
quantization=model.quantization,
|
||||
trust_remote_code=True)
|
||||
llm_tp1 = vllm.LLM(
|
||||
model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.2, #avoid OOM
|
||||
quantization=model.quantization,
|
||||
trust_remote_code=True)
|
||||
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp1
|
||||
cleanup()
|
||||
|
||||
llm_tp2 = vllm.LLM(model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=2,
|
||||
quantization=model.quantization)
|
||||
llm_tp2 = vllm.LLM(
|
||||
model=model.model_path,
|
||||
enable_lora=True,
|
||||
max_num_seqs=16,
|
||||
max_loras=4,
|
||||
tensor_parallel_size=2,
|
||||
gpu_memory_utilization=0.2, #avoid OOM
|
||||
quantization=model.quantization)
|
||||
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
|
||||
|
||||
del llm_tp2
|
||||
|
|
|
@ -86,3 +86,151 @@ class DummyLoRAManager:
|
|||
packed_lora = PackedLoRALayerWeights.pack(base_loras)
|
||||
self.set_module_lora(module_name, packed_lora)
|
||||
return packed_lora
|
||||
|
||||
|
||||
def assert_close(a, b):
|
||||
rtol, atol = {
|
||||
torch.float16: (6e-2, 6e-2),
|
||||
torch.bfloat16: (6e-2, 6e-2),
|
||||
torch.float32: (1e-2, 1e-2),
|
||||
}[a.dtype]
|
||||
torch.testing.assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def ref_torch_groupgemm(
|
||||
out_tensor,
|
||||
inputs,
|
||||
lora_weights,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
batches,
|
||||
scaling,
|
||||
op_type,
|
||||
) -> torch.Tensor:
|
||||
out_list = []
|
||||
current_offset = 0
|
||||
for lora_index, b_length in zip(range(batches), seq_len_tensor):
|
||||
input_weight = inputs[current_offset:b_length + current_offset, :]
|
||||
current_offset += b_length
|
||||
lora_weight = lora_weights[lora_indices_tensor[lora_index]]
|
||||
result = torch.nn.functional.linear(input_weight, lora_weight)
|
||||
result *= scaling
|
||||
out_list.append(result)
|
||||
cat_result = torch.cat(out_list, dim=0)
|
||||
if op_type == "expand":
|
||||
out_tensor += cat_result
|
||||
else:
|
||||
out_tensor.copy_(cat_result)
|
||||
return
|
||||
|
||||
|
||||
def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype,
|
||||
op_type, device):
|
||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||
(batches, )).to(device)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
|
||||
dim=0,
|
||||
).to(device)
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
if op_type == "shrink":
|
||||
inputs_tensor = torch.rand((total_tokens, hidden_size),
|
||||
dtype=dtype).to(device)
|
||||
lora_weights = torch.rand(
|
||||
(lora_nums, max_rank, hidden_size), # col-major
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
# shrink op need atomic_add, so output is initinized by 0
|
||||
ref_out_tensor = torch.zeros((total_tokens, max_rank),
|
||||
dtype=dtype,
|
||||
device=inputs_tensor.device)
|
||||
# NOTE shrink kernel using torch.float32 as output type
|
||||
our_out_tensor = torch.zeros((total_tokens, max_rank),
|
||||
dtype=torch.float32).to(device)
|
||||
else:
|
||||
inputs_tensor = torch.rand(
|
||||
(total_tokens, max_rank),
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
lora_weights = torch.rand(
|
||||
(lora_nums, hidden_size, max_rank), # col-major
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
# expand op needs to complete y+=a@lora_b, so output is
|
||||
# initinized randomly
|
||||
ref_out_tensor = torch.rand(
|
||||
(total_tokens, hidden_size),
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
# Ensure the same input.
|
||||
our_out_tensor = ref_out_tensor.clone()
|
||||
lora_indices_tensor = torch.randint(0,
|
||||
lora_nums - 1 if lora_nums > 1 else 1,
|
||||
(batches, )).to(device)
|
||||
indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
|
||||
current_offset = 0
|
||||
for b_id in range(batches):
|
||||
lora_index = lora_indices_tensor[b_id]
|
||||
indices[current_offset:current_offset +
|
||||
seq_len_tensor[b_id]].copy_(lora_index)
|
||||
current_offset += seq_len_tensor[b_id].item()
|
||||
return (
|
||||
inputs_tensor,
|
||||
lora_weights,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
)
|
||||
|
||||
|
||||
def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank,
|
||||
seq_length, dtype, nslices, device):
|
||||
seq_len_tensor = torch.randint(seq_length, seq_length + 1,
|
||||
(batches, )).to(device)
|
||||
b_seq_start_loc = torch.cumsum(
|
||||
torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long),
|
||||
dim=0,
|
||||
).to(device)
|
||||
total_tokens = seq_len_tensor.sum()
|
||||
inputs_tensor = torch.rand(
|
||||
(total_tokens, max_rank),
|
||||
dtype=dtype,
|
||||
).to(device)
|
||||
lora_weights_lst = []
|
||||
for _ in range(nslices):
|
||||
lora_weights_lst.append(
|
||||
torch.rand(
|
||||
(lora_nums, hidden_size, max_rank), # col-major
|
||||
dtype=dtype,
|
||||
).to(device))
|
||||
# expand op needs to complete y+=a@lora_b, so output is
|
||||
# initinized randomly
|
||||
ref_out_tensor = torch.rand((total_tokens, hidden_size * nslices),
|
||||
dtype=dtype).to(device)
|
||||
# Ensure the same input.
|
||||
our_out_tensor = ref_out_tensor.clone()
|
||||
lora_indices_tensor = torch.randint(0,
|
||||
lora_nums - 1 if lora_nums > 1 else 1,
|
||||
(batches, ))
|
||||
indices = torch.zeros((total_tokens), dtype=torch.long).to(device)
|
||||
current_offset = 0
|
||||
for b_id in range(batches):
|
||||
lora_index = lora_indices_tensor[b_id]
|
||||
indices[current_offset:current_offset +
|
||||
seq_len_tensor[b_id]] = lora_index.item()
|
||||
current_offset += seq_len_tensor[b_id].item()
|
||||
|
||||
lora_indices_tensor = lora_indices_tensor.to(device)
|
||||
return (
|
||||
inputs_tensor,
|
||||
lora_weights_lst,
|
||||
our_out_tensor,
|
||||
ref_out_tensor,
|
||||
b_seq_start_loc,
|
||||
lora_indices_tensor,
|
||||
seq_len_tensor,
|
||||
indices,
|
||||
)
|
||||
|
|
|
@ -13,12 +13,9 @@ try:
|
|||
except ImportError as e:
|
||||
logger.warning("Failed to import from vllm._C with %r", e)
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
import vllm._moe_C
|
||||
|
||||
with contextlib.suppress(ImportError):
|
||||
# ruff: noqa: F401
|
||||
import vllm._punica_C
|
||||
import vllm._moe_C
|
||||
|
||||
|
||||
def is_custom_op_supported(op_name: str) -> bool:
|
||||
|
@ -519,43 +516,6 @@ def register_graph_buffers(fa: int, handles: List[str],
|
|||
torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
|
||||
|
||||
|
||||
# punica
|
||||
def dispatch_bgmv(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
indicies: torch.Tensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
) -> None:
|
||||
torch.ops._punica_C.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx,
|
||||
scale)
|
||||
|
||||
|
||||
def dispatch_bgmv_low_level(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
indicies: torch.Tensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
h_in: int,
|
||||
h_out: int,
|
||||
y_offset: int,
|
||||
) -> None:
|
||||
torch.ops._punica_C.dispatch_bgmv_low_level(
|
||||
y,
|
||||
x,
|
||||
w_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
scale,
|
||||
h_in,
|
||||
h_out,
|
||||
y_offset,
|
||||
)
|
||||
|
||||
|
||||
# temporary fix for https://github.com/vllm-project/vllm/issues/5456
|
||||
# TODO: remove this in v0.6.0
|
||||
names_and_values = globals()
|
||||
|
|
|
@ -45,7 +45,6 @@ if TYPE_CHECKING:
|
|||
MAX_JOBS: Optional[str] = None
|
||||
NVCC_THREADS: Optional[str] = None
|
||||
VLLM_USE_PRECOMPILED: bool = False
|
||||
VLLM_INSTALL_PUNICA_KERNELS: bool = False
|
||||
VLLM_NO_DEPRECATION_WARNING: bool = False
|
||||
CMAKE_BUILD_TYPE: Optional[str] = None
|
||||
VERBOSE: bool = False
|
||||
|
@ -94,10 +93,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
|||
"VLLM_USE_PRECOMPILED":
|
||||
lambda: bool(os.environ.get("VLLM_USE_PRECOMPILED")),
|
||||
|
||||
# If set, vllm will install Punica kernels
|
||||
"VLLM_INSTALL_PUNICA_KERNELS":
|
||||
lambda: bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))),
|
||||
|
||||
# CMake build type
|
||||
# If not set, defaults to "Debug" or "RelWithDebInfo"
|
||||
# Available options: "Debug", "Release", "RelWithDebInfo"
|
||||
|
|
|
@ -14,7 +14,6 @@ from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
|||
MergedQKVParallelLinearWithLora,
|
||||
QKVParallelLinearWithLora,
|
||||
RowParallelLinearWithLoRA)
|
||||
from vllm.lora.punica import bgmv, dispatch_bgmv_low_level
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
@ -28,7 +27,7 @@ def _fully_sharded_can_replace(can_replace):
|
|||
|
||||
def dec(*args, **kwargs):
|
||||
return (can_replace(*args, **kwargs)
|
||||
and kwargs['lora_config'].fully_sharded_loras)
|
||||
and kwargs["lora_config"].fully_sharded_loras)
|
||||
|
||||
return dec
|
||||
|
||||
|
@ -59,25 +58,30 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
|||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
bgmv(buffer, x, self.lora_a_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
buffer = torch.zeros(
|
||||
(x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
||||
buffer = tensor_model_parallel_all_gather(buffer)
|
||||
bgmv(output, buffer, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
self.punica_wrapper.add_expand(output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
add_input=True)
|
||||
# now have column partitioned output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
|
@ -88,14 +92,14 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA):
|
|||
)
|
||||
|
||||
|
||||
def _mcp_apply(x, bias, layer):
|
||||
def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora):
|
||||
"""
|
||||
MergedColumnParallelLinearWithShardedLoRA and
|
||||
MergedQKVParallelLinearWithShardedLora share the same
|
||||
MergedColumnParallelLinearWithShardedLoRA and
|
||||
MergedQKVParallelLinearWithShardedLora share the same
|
||||
LoRa weight application method.
|
||||
|
||||
The main difference is the step by shard_size for lora_b which can
|
||||
vary for MergedQKVParallelLinearWithShardedLora but is constant for
|
||||
vary for MergedQKVParallelLinearWithShardedLora but is constant for
|
||||
MergedColumnParallelLinearWithShardedLoRA.
|
||||
"""
|
||||
# expecting 2 for column parallel and 3 for qkv
|
||||
|
@ -104,21 +108,27 @@ def _mcp_apply(x, bias, layer):
|
|||
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape
|
||||
buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
buffers = torch.zeros(
|
||||
(n, x.shape[0], layer.lora_a_stacked[0].shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
for idx in range(n):
|
||||
bgmv(buffers[idx], x, layer.lora_a_stacked[idx],
|
||||
layer.indices[:layer.indices_len[0]], 0, 1.0)
|
||||
layer.punica_wrapper.add_shrink(buffers[idx], x,
|
||||
layer.lora_a_stacked[idx], 1.0)
|
||||
|
||||
buffers = tensor_model_parallel_all_gather(buffers)
|
||||
left_offset = 0
|
||||
for idx in range(n):
|
||||
shard_size = layer.lora_b_stacked[idx].shape[2]
|
||||
dispatch_bgmv_low_level(output, buffers[idx],
|
||||
layer.lora_b_stacked[idx],
|
||||
layer.indices[:layer.indices_len[0]], 0, 1.0,
|
||||
left_offset, shard_size)
|
||||
layer.punica_wrapper.add_expand_slice(
|
||||
output,
|
||||
buffers[idx],
|
||||
layer.lora_b_stacked[idx],
|
||||
left_offset,
|
||||
shard_size,
|
||||
add_input=True,
|
||||
)
|
||||
left_offset += shard_size
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
|
@ -129,7 +139,7 @@ def _mcp_apply(x, bias, layer):
|
|||
class MergedColumnParallelLinearWithShardedLoRA(
|
||||
MergedColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from MergedColumnParallelLinearWithLoRA by slicing the
|
||||
Differs from MergedColumnParallelLinearWithLoRA by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
|
@ -145,7 +155,8 @@ class MergedColumnParallelLinearWithShardedLoRA(
|
|||
lora_a = [
|
||||
lora_a[0][:,
|
||||
output_start_idx:output_start_idx + output_shard_size],
|
||||
lora_a[1][:, output_start_idx:output_start_idx + output_shard_size]
|
||||
lora_a[1][:,
|
||||
output_start_idx:output_start_idx + output_shard_size],
|
||||
]
|
||||
return lora_a
|
||||
|
||||
|
@ -155,9 +166,13 @@ class MergedColumnParallelLinearWithShardedLoRA(
|
|||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
|
@ -170,7 +185,7 @@ class MergedColumnParallelLinearWithShardedLoRA(
|
|||
|
||||
class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
|
||||
"""
|
||||
Differs from QKVParallelLinearWithLora by slicing the
|
||||
Differs from QKVParallelLinearWithLora by slicing the
|
||||
LoRA A's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the rank dim.
|
||||
|
@ -193,14 +208,13 @@ class QKVParallelLinearWithShardedLora(QKVParallelLinearWithLora):
|
|||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
bgmv(buffer, x, self.lora_a_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
||||
buffer = tensor_model_parallel_all_gather(buffer)
|
||||
bgmv(output, buffer, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
self.punica_wrapper.add_expand(output,
|
||||
buffer,
|
||||
self.lora_b_stacked,
|
||||
add_input=True)
|
||||
# now have column partitioned output
|
||||
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
|
@ -237,7 +251,7 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
|||
lora_a = [
|
||||
lora_a[0][:, start_idx[0]:start_idx[0] + shard_size[0]],
|
||||
lora_a[1][:, start_idx[1]:start_idx[1] + shard_size[1]],
|
||||
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]]
|
||||
lora_a[2][:, start_idx[2]:start_idx[2] + shard_size[2]],
|
||||
]
|
||||
return lora_a
|
||||
|
||||
|
@ -247,9 +261,13 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
|||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
|
@ -262,11 +280,11 @@ class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora):
|
|||
|
||||
class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
||||
"""
|
||||
Differs from RowParallelLinearWithLoRA by slicing the
|
||||
Differs from RowParallelLinearWithLoRA by slicing the
|
||||
LoRA B's also.
|
||||
|
||||
Based on S-LoRA, slicing happens along the output dim.
|
||||
This yields a combined partial sum from the row parallel base
|
||||
This yields a combined partial sum from the row parallel base
|
||||
layer and column partitioned output from the LoRA.
|
||||
"""
|
||||
|
||||
|
@ -283,11 +301,13 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||
x = x.view(-1, x.shape[-1])
|
||||
output, out_orig_shape = output.view(-1,
|
||||
output.shape[-1]), output.shape
|
||||
buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
bgmv(buffer, x, self.lora_a_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
buffer = torch.zeros(
|
||||
(x.shape[0], self.lora_a_stacked.shape[2]),
|
||||
dtype=torch.float32,
|
||||
device=x.device,
|
||||
)
|
||||
|
||||
self.punica_wrapper.add_shrink(buffer, x, self.lora_a_stacked, 1.0)
|
||||
buffer = tensor_model_parallel_all_reduce(buffer)
|
||||
|
||||
# following S-LoRA, allows the fusing of all_gather and all_reduce
|
||||
|
@ -298,18 +318,21 @@ class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA):
|
|||
# reduced before being used
|
||||
shard_size = self.lora_b_stacked.shape[2]
|
||||
start_idx = self.tp_rank * shard_size
|
||||
dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0,
|
||||
start_idx, shard_size)
|
||||
|
||||
self.punica_wrapper.add_expand_slice(output, buffer,
|
||||
self.lora_b_stacked, start_idx,
|
||||
shard_size)
|
||||
output = output.view(*out_orig_shape)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
# specifying kwargs so they can be easily accessed in decorator
|
||||
return super().can_replace_layer(
|
||||
source_layer=source_layer,
|
||||
|
|
|
@ -17,7 +17,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
|
|||
tensor_model_parallel_all_reduce,
|
||||
tensor_model_parallel_gather)
|
||||
from vllm.distributed.utils import divide
|
||||
from vllm.lora.punica import add_lora, add_lora_slice, bgmv
|
||||
from vllm.lora.punica import PunicaWrapper
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
QKVParallelLinear,
|
||||
|
@ -55,88 +55,17 @@ def _not_fully_sharded_can_replace(can_replace):
|
|||
"""
|
||||
|
||||
def dec(*args, **kwargs):
|
||||
decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True
|
||||
condition = (not kwargs['lora_config'].fully_sharded_loras
|
||||
decorate = kwargs.pop("decorate") if "decorate" in kwargs else True
|
||||
condition = (not kwargs["lora_config"].fully_sharded_loras
|
||||
if decorate else True)
|
||||
return can_replace(*args, **kwargs) and condition
|
||||
|
||||
return dec
|
||||
|
||||
|
||||
def _apply_lora(
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: torch.Tensor,
|
||||
lora_b_stacked: torch.Tensor,
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
):
|
||||
"""Applies lora to each input.
|
||||
|
||||
This method applies all loras to each input. It uses the
|
||||
indices vector to determine which lora yields the
|
||||
correct output. An index of -1 means no lora should be
|
||||
applied. This method adds the final lora results to the
|
||||
output.
|
||||
|
||||
Input shapes:
|
||||
x: (batch_size, hidden_dim)
|
||||
lora_a_stacked: (num_loras, lora_rank, hidden_dim)
|
||||
lora_b_stacked: (num_loras, output_dim, lora_rank)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, output_dim)
|
||||
"""
|
||||
org_output = output
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
add_lora(output, x, lora_a_stacked, lora_b_stacked, indices, 0, 1.0)
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
def _apply_lora_packed_nslice(
|
||||
x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor],
|
||||
indices: torch.Tensor,
|
||||
output: torch.Tensor,
|
||||
output_slices: Tuple[int, ...],
|
||||
):
|
||||
"""Applies lora to each input.
|
||||
|
||||
This method applies all loras to each input. It uses the
|
||||
indices vector to determine which lora yields the
|
||||
correct output. An index of -1 means no lora should be
|
||||
applied. This method adds the final lora results to the
|
||||
output.
|
||||
|
||||
This method is used for layers that are composed of multiple sublayers
|
||||
(slices) packed together.
|
||||
|
||||
Input shapes:
|
||||
x: (batch_size, hidden_dim)
|
||||
lora_a_stacked: 3 element tuple of (num_loras, lora_rank, hidden_dim)
|
||||
lora_b_stacked: 3 element tuple of (num_loras, output_dim, lora_rank)
|
||||
indices: (batch_size)
|
||||
output: (batch_size, q_slice_size + 2*kv_slice_size)
|
||||
output_slices: n-1 element tuple of (slice_size...),
|
||||
where n is number of slices
|
||||
"""
|
||||
org_output = output
|
||||
x = x.view(-1, x.shape[-1])
|
||||
output = output.view(-1, output.shape[-1])
|
||||
indices = indices.view(-1)
|
||||
offset_left = 0
|
||||
for slice_idx in range(len(output_slices)):
|
||||
add_lora_slice(output, x, lora_a_stacked[slice_idx],
|
||||
lora_b_stacked[slice_idx], indices, 0, 1.0, offset_left,
|
||||
output_slices[slice_idx])
|
||||
offset_left += output_slices[slice_idx]
|
||||
return output.view_as(org_output)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoRAMapping(AdapterMapping):
|
||||
pass
|
||||
is_prefill: bool = False
|
||||
|
||||
|
||||
class BaseLayerWithLoRA(nn.Module):
|
||||
|
@ -154,10 +83,11 @@ class BaseLayerWithLoRA(nn.Module):
|
|||
...
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
"""Initializes lora matrices."""
|
||||
...
|
||||
|
||||
|
@ -177,20 +107,18 @@ class BaseLayerWithLoRA(nn.Module):
|
|||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
long_lora_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
punica_wrapper: PunicaWrapper,
|
||||
):
|
||||
"""Sets the mapping indices."""
|
||||
...
|
||||
self.punica_wrapper: PunicaWrapper = punica_wrapper
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -259,10 +187,6 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||
self.lora_a_stacked.shape[0] * self.lora_a_stacked.shape[1],
|
||||
self.lora_a_stacked.shape[2],
|
||||
)
|
||||
# Lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
self.embeddings_indices: torch.Tensor
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
|
@ -285,40 +209,27 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||
if embeddings_tensor is not None:
|
||||
self.embeddings_tensors[
|
||||
index, :embeddings_tensor.shape[0], :embeddings_tensor.
|
||||
shape[1]].copy_(embeddings_tensor, non_blocking=True)
|
||||
shape[1], ].copy_(embeddings_tensor, non_blocking=True)
|
||||
if self.embeddings_slice is not None:
|
||||
# TODO(yard1): Optimize this copy, we don't need to copy
|
||||
# everything, just the modified part
|
||||
embeddings = self.embeddings_tensors.view(
|
||||
self.embeddings_tensors.shape[0] *
|
||||
self.embeddings_tensors.shape[1],
|
||||
self.embeddings_tensors.shape[2]
|
||||
self.embeddings_tensors.shape[2],
|
||||
)[self.embeddings_slice[0]:self.embeddings_slice[1]]
|
||||
assert self.embeddings_weights is not None
|
||||
self.embeddings_weights[:embeddings.shape[0]].copy_(embeddings)
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
long_lora_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = base_indices
|
||||
self.embeddings_indices = embeddings_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
added_tokens_mask = x > self.base_layer.org_vocab_size - 1
|
||||
embedding_len = self.indices_len[3]
|
||||
indices = self.embeddings_indices[1][:embedding_len].view_as(x)
|
||||
embeddings_indices = self.punica_wrapper.embeddings_indices
|
||||
indices = embeddings_indices[1].view_as(x)
|
||||
full_lora_a_embeddings = F.embedding(
|
||||
x + indices,
|
||||
self.lora_a_stacked_2d,
|
||||
)
|
||||
indices = self.embeddings_indices[0][:embedding_len].view_as(x)
|
||||
indices = embeddings_indices[0].view_as(x)
|
||||
full_output = self.base_layer.forward(
|
||||
x.add_(indices * added_tokens_mask))
|
||||
|
||||
|
@ -329,22 +240,32 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
|||
if full_lora_a_embeddings.ndim == 3:
|
||||
full_lora_a_embeddings = full_lora_a_embeddings.view(
|
||||
full_lora_a_embeddings.shape[0] *
|
||||
full_lora_a_embeddings.shape[1], -1)
|
||||
bgmv(full_output, full_lora_a_embeddings, self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]], 0, 1.0)
|
||||
full_lora_a_embeddings.shape[1],
|
||||
-1,
|
||||
)
|
||||
|
||||
# Embedding layer only need expand op
|
||||
self.punica_wrapper.add_expand(full_output,
|
||||
full_lora_a_embeddings,
|
||||
self.lora_b_stacked,
|
||||
add_input=True)
|
||||
return full_output.view_as(full_output_org)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is VocabParallelEmbedding
|
||||
|
||||
|
||||
class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
"""
|
||||
LoRA on top of ColumnParallelLinear layer.
|
||||
|
||||
|
||||
LoRA B is sliced for tensor parallelism.
|
||||
"""
|
||||
|
||||
|
@ -357,10 +278,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||
self.device = _get_lora_device(self.base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
self.lora_config = lora_config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
lora_a_output_size_per_partition = (
|
||||
|
@ -384,10 +306,6 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||
)
|
||||
self.output_dim = self.lora_b_stacked.shape[2]
|
||||
|
||||
# lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
self.lora_b_stacked[index] = 0
|
||||
|
@ -423,28 +341,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
long_lora_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = base_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
|
@ -473,9 +374,13 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is ColumnParallelLinear or (
|
||||
type(source_layer) is MergedColumnParallelLinear
|
||||
and len(packed_modules_list) == 1)
|
||||
|
@ -494,10 +399,11 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||
super().__init__(base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
self.lora_config = lora_config
|
||||
n_slices = 2
|
||||
if not (len(self.base_layer.output_sizes) == n_slices
|
||||
|
@ -533,8 +439,6 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||
) for _ in range(n_slices))
|
||||
|
||||
self.output_dim = self.lora_b_stacked[0].shape[2]
|
||||
# Lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[0][index] = 0
|
||||
|
@ -556,7 +460,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||
start_idx = self.tp_rank * shard_size
|
||||
end_idx = (self.tp_rank + 1) * shard_size
|
||||
lora_b = [
|
||||
lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx]
|
||||
lora_b[0][:, start_idx:end_idx],
|
||||
lora_b[1][:, start_idx:end_idx],
|
||||
]
|
||||
return lora_b
|
||||
|
||||
|
@ -591,34 +496,33 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
(self.output_dim, self.output_dim),
|
||||
)
|
||||
self.punica_wrapper.add_lora_packed_nslice(
|
||||
output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0,
|
||||
(self.output_dim, self.output_dim))
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is MergedColumnParallelLinear and len(
|
||||
packed_modules_list) == 2
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return (type(source_layer) is MergedColumnParallelLinear
|
||||
and len(packed_modules_list) == 2)
|
||||
|
||||
|
||||
class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
||||
"""
|
||||
ColumnParallelLinear layer that is specifically designed for
|
||||
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
|
||||
only contains a single LoRA within their qkv_proj layer.
|
||||
ColumnParallelLinear layer that is specifically designed for
|
||||
qkv_proj. Certain models, such as chtglm3 and baichuan-7b,
|
||||
only contains a single LoRA within their qkv_proj layer.
|
||||
|
||||
During inference with Tensor Parallel, the weights of lora_b
|
||||
During inference with Tensor Parallel, the weights of lora_b
|
||||
must be accurately partitioned according to the respective ranks.
|
||||
|
||||
|
||||
Q slice may have different shape than K and V slices (which both have
|
||||
the same shape).
|
||||
"""
|
||||
|
@ -696,10 +600,11 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||
super().__init__(base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
self.lora_config = lora_config
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
@ -767,11 +672,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||
),
|
||||
)
|
||||
|
||||
self.output_slices = (self.q_proj_shard_size, self.kv_proj_shard_size,
|
||||
self.kv_proj_shard_size)
|
||||
self.output_slices = (
|
||||
self.q_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
self.kv_proj_shard_size,
|
||||
)
|
||||
self.packed_indices: Optional[torch.Tensor] = None
|
||||
self.standard_indices: Optional[torch.Tensor] = None
|
||||
# lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
|
@ -794,15 +703,15 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||
if lora_b[0] is not None:
|
||||
lora_b_q = lora_b[0][:, self.q_proj_shard_size *
|
||||
self.q_shard_id:self.q_proj_shard_size *
|
||||
(self.q_shard_id + 1)]
|
||||
(self.q_shard_id + 1), ]
|
||||
if lora_b[1] is not None:
|
||||
lora_b_k = lora_b[1][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
(self.kv_shard_id + 1), ]
|
||||
if lora_b[2] is not None:
|
||||
lora_b_v = lora_b[2][:, self.kv_proj_shard_size *
|
||||
self.kv_shard_id:self.kv_proj_shard_size *
|
||||
(self.kv_shard_id + 1)]
|
||||
(self.kv_shard_id + 1), ]
|
||||
lora_b = [lora_b_q, lora_b_k, lora_b_v]
|
||||
return lora_b
|
||||
|
||||
|
@ -851,23 +760,23 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
|
|||
def apply(self, x: torch.Tensor,
|
||||
bias: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x, bias)
|
||||
_apply_lora_packed_nslice(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
self.output_slices,
|
||||
)
|
||||
self.punica_wrapper.add_lora_packed_nslice(output, x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0,
|
||||
self.output_slices)
|
||||
return output
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
return type(source_layer) is QKVParallelLinear and len(
|
||||
packed_modules_list) == 3
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return (type(source_layer) is QKVParallelLinear
|
||||
and len(packed_modules_list) == 3)
|
||||
|
||||
|
||||
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
||||
|
@ -880,10 +789,11 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||
self.device = _get_lora_device(self.base_layer)
|
||||
|
||||
def create_lora_weights(
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None) -> None:
|
||||
self,
|
||||
max_loras: int,
|
||||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
self.lora_config = lora_config
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.lora_a_stacked = torch.zeros(
|
||||
|
@ -911,9 +821,6 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||
dtype=lora_config.lora_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
# Lazily initialized
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
|
@ -950,27 +857,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||
0, :lora_b.shape[1], :lora_b.shape[0]].copy_(
|
||||
lora_b.T, non_blocking=True)
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
long_lora_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = base_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def apply(self, x: torch.Tensor) -> torch.Tensor:
|
||||
output = self.base_layer.quant_method.apply(self.base_layer, x)
|
||||
_apply_lora(
|
||||
x,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[0]],
|
||||
output,
|
||||
)
|
||||
self.punica_wrapper.add_lora(output, x, self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
return output
|
||||
|
||||
def forward(self, input_):
|
||||
|
@ -1013,14 +903,18 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|||
|
||||
@property
|
||||
def weight(self):
|
||||
return self.base_layer.weight if hasattr(
|
||||
self.base_layer, "weight") else self.base_layer.qweight
|
||||
return (self.base_layer.weight if hasattr(self.base_layer, "weight")
|
||||
else self.base_layer.qweight)
|
||||
|
||||
@classmethod
|
||||
@_not_fully_sharded_can_replace
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
return type(source_layer) is RowParallelLinear
|
||||
|
||||
|
||||
|
@ -1125,10 +1019,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||
dtype=torch.long)
|
||||
else:
|
||||
self.sharded_to_full_mapping_gpu = None
|
||||
# Lazily initialized.
|
||||
self.indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
self.indices_padded: torch.Tensor
|
||||
|
||||
def reset_lora(self, index: int):
|
||||
self.lora_a_stacked[index] = 0
|
||||
|
@ -1154,19 +1044,6 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||
index, :embeddings_tensor.shape[0], :embeddings_tensor.
|
||||
shape[1], ] = embeddings_tensor
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
long_lora_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.indices = sampler_indices
|
||||
self.indices_padded = sampler_indices_padded
|
||||
self.indices_len = indices_len
|
||||
|
||||
def _get_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
@ -1212,38 +1089,37 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
|
|||
out=lora_logits[:-1])
|
||||
lora_logits[-1] = float("-inf")
|
||||
lora_logits = lora_logits.mT
|
||||
indices_padded = self.punica_wrapper.sampler_indices_padded
|
||||
lora_logits = (lora_logits.reshape(
|
||||
lora_logits.shape[0] * lora_logits.shape[1],
|
||||
lora_logits.shape[2],
|
||||
).index_select(0,
|
||||
self.indices_padded[:self.indices_len[2]]).nan_to_num_(
|
||||
nan=float("-inf"),
|
||||
posinf=float("inf"),
|
||||
neginf=float("-inf")))
|
||||
).index_select(0, indices_padded).nan_to_num_(nan=float("-inf"),
|
||||
posinf=float("inf"),
|
||||
neginf=float("-inf")))
|
||||
logits[:,
|
||||
self.base_layer.org_vocab_size:self.base_layer.org_vocab_size +
|
||||
lora_logits.shape[1]] = lora_logits
|
||||
lora_logits.shape[1], ] = lora_logits
|
||||
|
||||
_apply_lora(
|
||||
hidden_states,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked,
|
||||
self.indices[:self.indices_len[1]],
|
||||
logits,
|
||||
)
|
||||
# LogitsProcessorWithLoRA always using bgmv
|
||||
self.punica_wrapper.add_lora_logits(logits, hidden_states,
|
||||
self.lora_a_stacked,
|
||||
self.lora_b_stacked, 1.0)
|
||||
|
||||
# Remove paddings in vocab (if any).
|
||||
logits = logits[:, :self.base_layer.vocab_size]
|
||||
|
||||
return logits
|
||||
|
||||
def forward(self, *args, **kwargs):
|
||||
return type(self.base_layer).forward(self, *args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
# Special handling for the LogitsProcessor.
|
||||
return False
|
||||
|
||||
|
@ -1259,9 +1135,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
|
|||
def __init__(self, base_layer: RotaryEmbedding) -> None:
|
||||
super().__init__()
|
||||
self.base_layer = base_layer
|
||||
# Lazily initialized
|
||||
self.long_lora_indices: torch.Tensor
|
||||
self.indices_len: List[int]
|
||||
|
||||
@property
|
||||
def scaling_factors(self):
|
||||
|
@ -1277,9 +1150,8 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
|
|||
lora_config: LoRAConfig,
|
||||
model_config: Optional[PretrainedConfig] = None,
|
||||
) -> None:
|
||||
scaling_factors = list(
|
||||
lora_config.long_lora_scaling_factors
|
||||
) if lora_config.long_lora_scaling_factors else []
|
||||
scaling_factors = (list(lora_config.long_lora_scaling_factors)
|
||||
if lora_config.long_lora_scaling_factors else [])
|
||||
base_scaling_factor = (self.base_layer.scaling_factor if isinstance(
|
||||
self.base_layer, LinearScalingRotaryEmbedding) else 1.0)
|
||||
scaling_factors = sorted(
|
||||
|
@ -1306,18 +1178,6 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
|
|||
):
|
||||
...
|
||||
|
||||
def set_mapping(
|
||||
self,
|
||||
base_indices: torch.Tensor,
|
||||
sampler_indices: torch.Tensor,
|
||||
sampler_indices_padded: torch.Tensor,
|
||||
embeddings_indices: torch.Tensor,
|
||||
long_lora_indices: torch.Tensor,
|
||||
indices_len: List[int],
|
||||
):
|
||||
self.long_lora_indices = long_lora_indices
|
||||
self.indices_len = indices_len
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
|
@ -1328,19 +1188,24 @@ class LinearScalingRotaryEmbeddingWithLora(BaseLayerWithLoRA):
|
|||
positions,
|
||||
query,
|
||||
key,
|
||||
offsets=self.long_lora_indices[:self.indices_len[4]])
|
||||
offsets=self.punica_wrapper.long_lora_indices,
|
||||
)
|
||||
|
||||
@property
|
||||
def scaling_factor_to_offset(self) -> Dict[float, int]:
|
||||
return self.base_layer.scaling_factor_to_offset
|
||||
|
||||
@classmethod
|
||||
def can_replace_layer(cls, source_layer: nn.Module,
|
||||
lora_config: LoRAConfig, packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig]) -> bool:
|
||||
def can_replace_layer(
|
||||
cls,
|
||||
source_layer: nn.Module,
|
||||
lora_config: LoRAConfig,
|
||||
packed_modules_list: List,
|
||||
model_config: Optional[PretrainedConfig],
|
||||
) -> bool:
|
||||
"""Returns True if the layer can be replaced by this LoRA layer."""
|
||||
return type(source_layer) is LinearScalingRotaryEmbedding or type(
|
||||
source_layer) is RotaryEmbedding
|
||||
return (type(source_layer) is LinearScalingRotaryEmbedding
|
||||
or type(source_layer) is RotaryEmbedding)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return self.base_layer.extra_repr()
|
||||
|
|
|
@ -4,7 +4,7 @@ import math
|
|||
import os
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Type
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
|
@ -21,6 +21,7 @@ from vllm.lora.layers import (BaseLayerWithLoRA,
|
|||
LinearScalingRotaryEmbeddingWithLora,
|
||||
LoRAMapping)
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.punica import PunicaWrapper
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
parse_fine_tuned_lora_name, replace_submodule)
|
||||
from vllm.model_executor.models.interfaces import SupportsLoRA
|
||||
|
@ -43,115 +44,6 @@ class LongContextLoRAContext:
|
|||
offsets_by_lora_id: Dict[int, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
def convert_mapping(
|
||||
mapping: LoRAMapping,
|
||||
lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional[LongContextLoRAContext] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor], List[int]]:
|
||||
"""Converts LoRAMapping to index tensors.
|
||||
|
||||
Args:
|
||||
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
|
||||
lora_index_to_id: List mapping LoRA ids to LoRA indices.
|
||||
max_loras: Maximum number of LoRAs.
|
||||
vocab_size: Model vocab size.
|
||||
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||
long_lora_context: Passed if there are long context lora in a batch.
|
||||
|
||||
Returns:
|
||||
A tuple of tensors:
|
||||
base_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||
LoRA indices.
|
||||
sampler_indices: Tensor of shape [batch_size] mapping requests to
|
||||
LoRA indices for sampler. For generation, this will be the
|
||||
same as base_indicies. For prefill, this will map requests
|
||||
to LoRA indices.
|
||||
sampler_indices_padded: Tensor of shape [batch_size] mapping
|
||||
requests to LoRA indices for sampler with padding.
|
||||
Same as sampler_indicies, but -1 is replaced with
|
||||
max_loras.
|
||||
embeddings_indices: Tensor of shape [2, batch_size] mapping
|
||||
requests to embedding indices. First row is for embeddings
|
||||
added by the LoRAs, second row is for the LoRA.lora_a
|
||||
embeddings.
|
||||
long_lora_indices: Tensor of shape [batch_size] mapping
|
||||
requests to RoPE offsets and rot dims for long LoRAs.
|
||||
None if long context lora doesn't exist.
|
||||
indices_len: List of lengths of the above tensors.
|
||||
Used to index into each tensor. It contains length for
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, long_lora_indices). If long_lora doesn't
|
||||
exist, it only contains first 4 entries.
|
||||
"""
|
||||
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
|
||||
embedding_indices = index_mapping_indices.copy()
|
||||
lora_indices = index_mapping_indices.copy()
|
||||
long_lora_offsets: Optional[torch.Tensor] = None
|
||||
if long_lora_context:
|
||||
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
||||
device="cuda",
|
||||
dtype=torch.long)
|
||||
prompt_mapping: List[int] = [
|
||||
lora_index_to_id.index(x) if x > 0 else -1
|
||||
for x in mapping.prompt_mapping
|
||||
]
|
||||
lora_idx = None
|
||||
for i in range(len(index_mapping_indices)):
|
||||
# TODO index can be slow. optimize
|
||||
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
||||
if index_mapping_indices[i] > 0 else -1)
|
||||
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||
lora_indices[i] = lora_idx
|
||||
if long_lora_context:
|
||||
assert long_lora_offsets is not None
|
||||
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
|
||||
index_mapping_indices[i], 0)
|
||||
long_lora_offsets[i] = lora_offset
|
||||
|
||||
indices_list: List[Union[List[int], torch.Tensor]] = [
|
||||
index_mapping_indices, lora_indices, embedding_indices
|
||||
]
|
||||
if long_lora_context:
|
||||
assert long_lora_offsets is not None
|
||||
indices_list.append(long_lora_offsets)
|
||||
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
|
||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||
device="cuda",
|
||||
dtype=torch.long)
|
||||
embeddings_indices = torch.stack([
|
||||
indices[2] * extra_vocab_size,
|
||||
indices[2] * (vocab_size + extra_vocab_size)
|
||||
])
|
||||
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
||||
base_indices = indices[1]
|
||||
sampler_indices = prompt_mapping_tensor
|
||||
sampler_indices_padded = sampler_indices.clone()
|
||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||
sampler_indices_padded = (
|
||||
torch.arange(
|
||||
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) +
|
||||
(sampler_indices_padded * len(sampler_indices_padded)))
|
||||
long_lora_indices = None
|
||||
long_lora_indices_len: Optional[int] = None
|
||||
if long_lora_context:
|
||||
long_lora_indices = indices[3]
|
||||
long_lora_indices_len = long_lora_indices.shape[-1]
|
||||
# Contain length of indices tensors. Used to index into each tensor.
|
||||
indices_len = [
|
||||
base_indices.shape[-1], sampler_indices.shape[-1],
|
||||
sampler_indices_padded.shape[-1], embeddings_indices.shape[-1]
|
||||
]
|
||||
if long_lora_indices_len is not None:
|
||||
indices_len.append(long_lora_indices_len)
|
||||
|
||||
return (base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, long_lora_indices, indices_len)
|
||||
|
||||
|
||||
def get_lora_id():
|
||||
global _GLOBAL_LORA_ID
|
||||
_GLOBAL_LORA_ID += 1
|
||||
|
@ -422,29 +314,12 @@ class LoRAModelManager(AdapterModelManager):
|
|||
self.lora_index_to_id: List[Optional[int]] = [None] * self.lora_slots
|
||||
self.vocab_size = vocab_size
|
||||
self.long_lora_context: Optional[LongContextLoRAContext] = None
|
||||
self.base_indices = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.sampler_indices = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.sampler_indices_padded = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.embeddings_indices = torch.empty(2,
|
||||
self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.long_lora_indices = torch.empty(self.max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device="cuda")
|
||||
self.punica_wrapper = PunicaWrapper(max_num_batched_tokens,
|
||||
max_batches=self.max_num_seqs,
|
||||
device="cuda")
|
||||
# Scaling factor -> offset to the sin_cos_cache to it.
|
||||
# Used for long context lora.
|
||||
self.scaling_factor_to_offset: Dict[float, int] = {}
|
||||
# 4 is the number of indicies tensors defined above
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices
|
||||
self.indices_len: List[Optional[int]] = [None] * 4
|
||||
super().__init__(model)
|
||||
if hasattr(self.model, "supported_lora_modules"):
|
||||
self.supported_lora_modules = copy.deepcopy(
|
||||
|
@ -536,28 +411,16 @@ class LoRAModelManager(AdapterModelManager):
|
|||
"Pinning is not supported in LoRAModelManager."
|
||||
"Use LRUCacheLoRAModelManager for pinning") # type: ignore
|
||||
|
||||
# TODO see if this can be vectorized
|
||||
def _set_adapter_mapping(self, mapping: LoRAMapping) -> None:
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, long_lora_offsets_tensor,
|
||||
indices_len) = convert_mapping(mapping, self.lora_index_to_id,
|
||||
self.lora_slots + 1, self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context)
|
||||
self.base_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||
self.sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||
self.sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
||||
sampler_indices_padded)
|
||||
self.embeddings_indices[:embeddings_indices.
|
||||
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||
embeddings_indices)
|
||||
if long_lora_offsets_tensor is not None:
|
||||
self.long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
|
||||
long_lora_offsets_tensor)
|
||||
else:
|
||||
self.long_lora_indices.zero_()
|
||||
# Maintain the reference
|
||||
self.indices_len[:] = indices_len
|
||||
# update lora states
|
||||
self.punica_wrapper.update_metadata(
|
||||
mapping,
|
||||
self.lora_index_to_id,
|
||||
self.lora_slots + 1,
|
||||
self.vocab_size,
|
||||
self.lora_config.lora_extra_vocab_size,
|
||||
self.long_lora_context,
|
||||
)
|
||||
|
||||
def remove_all_adapters(self):
|
||||
"""Remove all LoRAModels from the manager."""
|
||||
|
@ -595,10 +458,8 @@ class LoRAModelManager(AdapterModelManager):
|
|||
self.model.config))
|
||||
self.register_module(module_name, new_module)
|
||||
self._register_packed_modules(module_name)
|
||||
new_module.set_mapping(self.base_indices, self.sampler_indices,
|
||||
self.sampler_indices_padded,
|
||||
self.embeddings_indices,
|
||||
self.long_lora_indices, self.indices_len)
|
||||
# All lora layers share the same punica_wrapper based on reference.
|
||||
new_module.set_mapping(self.punica_wrapper)
|
||||
|
||||
def register_module(self, module_name: str, module: "BaseLayerWithLoRA"):
|
||||
assert isinstance(module, BaseLayerWithLoRA)
|
||||
|
|
|
@ -0,0 +1,169 @@
|
|||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_N: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sn = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
|
||||
offset_k * xk_stride, ) # [BLOCK_K]
|
||||
else:
|
||||
tiled_a = tl.load(
|
||||
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
|
||||
mask=offset_k < K,
|
||||
other=0,
|
||||
) # [BLOCK_K]
|
||||
# N must be divisible by SPLIT_N
|
||||
split_n_length = tl.cdiv(N, SPLIT_N)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
# sliding to next row-block
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
pid_sn * split_n_length * lora_k_stride)
|
||||
c_ptr = out_ptr + cur_batch * cm_stride + pid_sn * split_n_length
|
||||
for n in range(0, split_n_length, BLOCK_N):
|
||||
current_n = n + offset_n
|
||||
current_n_c = tl.max_contiguous(current_n, BLOCK_N)
|
||||
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
|
||||
< K)
|
||||
c_mask = current_n < split_n_length
|
||||
tiled_b = tl.load(
|
||||
b_ptr + current_n_c[:, None] * lora_k_stride +
|
||||
offset_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||
else:
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||
|
||||
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def bgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
add_inputs: bool = True,
|
||||
override_config: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch, An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
||||
results to the output.
|
||||
override_config (Optional[Dict[str, int]], optional): Defaults to None.
|
||||
Triton grid config
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
BLOCK_K = triton.next_power_of_2(K)
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
batches = lora_indices_tensor.size(0)
|
||||
if override_config:
|
||||
config = override_config
|
||||
else:
|
||||
config = get_lora_op_configs("expand", batches, N)
|
||||
grid = lambda META: (
|
||||
META["SPLIT_N"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_K=BLOCK_K,
|
||||
EVEN_K=EVEN_K,
|
||||
ADD_INPUTS=ADD_INPUTS,
|
||||
CAST_TYPE=CAST_TYPE,
|
||||
**config,
|
||||
)
|
||||
return
|
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_expand_slice_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
slice_offset,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_N: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT_N can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sn = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(input_ptr + cur_batch * xm_stride +
|
||||
offset_k * xk_stride, ) # [BLOCK_K]
|
||||
else:
|
||||
tiled_a = tl.load(
|
||||
input_ptr + cur_batch * xm_stride + offset_k * xk_stride,
|
||||
mask=offset_k < K,
|
||||
other=0,
|
||||
) # [BLOCK_K]
|
||||
# N must be divisible by SPLIT_N
|
||||
split_n_length = tl.cdiv(N, SPLIT_N)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
# sliding to next row-block
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
pid_sn * split_n_length * lora_k_stride)
|
||||
c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +
|
||||
slice_offset * cn_stride)
|
||||
|
||||
for n in range(0, split_n_length, BLOCK_N):
|
||||
current_n = n + offset_n
|
||||
b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]
|
||||
< K)
|
||||
c_mask = current_n < split_n_length
|
||||
tiled_b = tl.load(
|
||||
b_ptr + current_n[:, None] * lora_k_stride +
|
||||
offset_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out
|
||||
else:
|
||||
accumulator = tl.sum(tiled_a * tiled_b, 1)
|
||||
|
||||
tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def bgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = True,
|
||||
override_config: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'b weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch, An index of -1 means no lora should be
|
||||
applied.
|
||||
slice_offst (int): output_tensor's offst
|
||||
slice_size (int): current output_tensor's size
|
||||
batches (int): batch size
|
||||
add_inputs (bool, optional): Defaults to False.
|
||||
override_config (Optional[Dict[str, int]], optional): Defaults to None.
|
||||
Triton grid config
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
|
||||
assert slice_size == lora_b_weights.size(-2)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
BLOCK_K = triton.next_power_of_2(K)
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
|
||||
batches = lora_indices_tensor.size(0)
|
||||
|
||||
if override_config:
|
||||
config = override_config
|
||||
else:
|
||||
config = get_lora_op_configs("expand", batches, N)
|
||||
|
||||
grid = lambda META: (
|
||||
META["SPLIT_N"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_expand_slice_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
slice_offset,
|
||||
BLOCK_K=BLOCK_K,
|
||||
EVEN_K=EVEN_K,
|
||||
ADD_INPUTS=ADD_INPUTS,
|
||||
CAST_TYPE=CAST_TYPE,
|
||||
**config,
|
||||
)
|
||||
return
|
|
@ -0,0 +1,150 @@
|
|||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from .utils import get_lora_op_configs
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _bgmv_shrink_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
lora_indices,
|
||||
scaling,
|
||||
xm_stride,
|
||||
xk_stride,
|
||||
l0_stride,
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
GroupGEMV, additionally, introducing SPLIT-K can improve large hidden_size's
|
||||
performance
|
||||
"""
|
||||
pid_sk = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
|
||||
offset_n = tl.arange(0, BLOCK_N)
|
||||
offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K
|
||||
a_ptr = input_ptr + cur_batch * xm_stride
|
||||
b_ptr = lora_ptr + l0_stride * lora_index
|
||||
accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)
|
||||
for k in range(0, K, BLOCK_K * SPLIT_K):
|
||||
current_k = k + offset_k
|
||||
current_k_c = tl.max_contiguous(current_k, BLOCK_K)
|
||||
tiled_a = tl.load(
|
||||
a_ptr + current_k_c,
|
||||
mask=current_k < K,
|
||||
other=0.0,
|
||||
) # [BLOCK_K]
|
||||
b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)
|
||||
|
||||
tiled_b = tl.load(
|
||||
b_ptr + offset_n[:, None] * lora_k_stride +
|
||||
current_k[None, :] * lora_n_stride,
|
||||
mask=b_ptr_mask,
|
||||
other=0.0,
|
||||
) # [BLOCK_N,BLOCK_K]
|
||||
|
||||
accumulator += tl.sum(tiled_a * tiled_b, 1)
|
||||
accumulator *= scaling
|
||||
offset_cn = tl.arange(0, BLOCK_N)
|
||||
c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride
|
||||
c_mask = offset_cn < N
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def bgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
scaling: float = 1.0,
|
||||
override_config: Optional[Dict[str, int]] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
scaling (float): Scaling factor.
|
||||
override_config (Optional[Dict[str, int]], optional): Defaults to None.
|
||||
Triton grid config
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
assert lora_a_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert inputs.is_contiguous()
|
||||
|
||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
||||
assert lora_a_weights.size(1) == 1
|
||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
||||
assert lora_a_weights.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
# TODO tuning this config
|
||||
batches = lora_indices_tensor.size(0)
|
||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||
BLOCK_N = triton.next_power_of_2(N)
|
||||
if override_config:
|
||||
config = override_config
|
||||
else:
|
||||
# First try to load optimal config from the file
|
||||
config = get_lora_op_configs("bgmv_shrink", batches, K)
|
||||
|
||||
grid = lambda META: (
|
||||
META["SPLIT_K"],
|
||||
batches,
|
||||
)
|
||||
_bgmv_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_a_weights.stride(0),
|
||||
lora_a_weights.stride(1),
|
||||
lora_a_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_N=BLOCK_N,
|
||||
**config,
|
||||
)
|
||||
return
|
|
@ -0,0 +1,192 @@
|
|||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import libentry
|
||||
|
||||
|
||||
@libentry()
|
||||
@triton.jit
|
||||
def _sgmv_expand_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's expand triton kernel is based on GroupGEMM.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride, )
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=offset_k[:, None] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_expand(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g.,if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
|
||||
length of the sequences in the batch
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
||||
results to the output.
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 32
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
_sgmv_expand_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
return
|
|
@ -0,0 +1,205 @@
|
|||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import libentry
|
||||
|
||||
|
||||
@libentry()
|
||||
@triton.jit
|
||||
def _sgmv_expand_slice_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
xm_stride,
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
slice_offset,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
ADD_INPUTS: tl.constexpr,
|
||||
CAST_TYPE: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
|
||||
Similar to the 'sgmv_expand' operator, but with an added parameter
|
||||
'slice_offset'. The reason for not reusing the 'sgmv_expand' operator
|
||||
might be that in the future, we could implement a fusion operator to
|
||||
achieve the current functionality instead of having to call it multiple
|
||||
times.
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
cur_batch = tl.program_id(axis=1)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = tl.arange(0, BLOCK_K)
|
||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride, )
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index +
|
||||
offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(tl.cdiv(K, BLOCK_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
else:
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=offset_k[:, None] < K - k * BLOCK_K,
|
||||
other=0)
|
||||
if CAST_TYPE:
|
||||
tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)
|
||||
accumulator += tl.dot(
|
||||
tiled_a,
|
||||
tiled_b,
|
||||
)
|
||||
a_ptr += BLOCK_K * xk_stride
|
||||
b_ptr += BLOCK_K * lora_n_stride
|
||||
tiled_c = accumulator.to(lora_ptr.dtype.element_ty)
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <
|
||||
(slice_offset + N))
|
||||
if ADD_INPUTS:
|
||||
tiled_out = tl.load(c_ptr, mask=c_mask)
|
||||
tiled_c += tiled_out
|
||||
tl.store(c_ptr, tiled_c, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_expand_slice(
|
||||
inputs: torch.Tensor,
|
||||
lora_b_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
slice_offset: int,
|
||||
slice_size: int,
|
||||
add_inputs: bool = False,
|
||||
):
|
||||
"""_summary_
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_b_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g.,if the sequence length is [4, 6], it is
|
||||
[0, 4, 10].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
|
||||
length of the sequences in the batch
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
slice_offst (int): output_tensor's offst
|
||||
slice_size (int): current output_tensor's size
|
||||
add_inputs (bool, optional): Defaults to False. adds the final lora
|
||||
results to the output..
|
||||
"""
|
||||
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
|
||||
assert lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_b_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert slice_size == lora_b_weights.size(-2)
|
||||
assert inputs.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
|
||||
if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)
|
||||
assert lora_b_weights.size(1) == 1
|
||||
lora_b_weights = lora_b_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)
|
||||
|
||||
assert lora_b_weights.is_contiguous()
|
||||
|
||||
# TODO tuning this config
|
||||
N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size
|
||||
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 32
|
||||
BLOCK_K = 16
|
||||
EVEN_K = K % BLOCK_K == 0
|
||||
ADD_INPUTS = add_inputs
|
||||
CAST_TYPE = False
|
||||
if inputs.dtype == torch.float32 and lora_b_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]:
|
||||
CAST_TYPE = True
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
batches,
|
||||
)
|
||||
_sgmv_expand_slice_kernel[grid](
|
||||
inputs,
|
||||
lora_b_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_b_weights.stride(0),
|
||||
lora_b_weights.stride(1),
|
||||
lora_b_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
slice_offset,
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
ADD_INPUTS,
|
||||
CAST_TYPE,
|
||||
)
|
||||
return
|
|
@ -0,0 +1,189 @@
|
|||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from vllm.triton_utils import libentry
|
||||
|
||||
|
||||
@libentry()
|
||||
@triton.jit
|
||||
def _sgmv_shrink_kernel(
|
||||
input_ptr,
|
||||
lora_ptr,
|
||||
out_ptr,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_lens,
|
||||
lora_indices,
|
||||
scaling,
|
||||
xm_stride, # hidden_size
|
||||
xk_stride, # 1
|
||||
l0_stride, # hidden_size*max_rank
|
||||
lora_k_stride,
|
||||
lora_n_stride,
|
||||
cm_stride,
|
||||
cn_stride,
|
||||
BLOCK_M: tl.constexpr,
|
||||
BLOCK_N: tl.constexpr,
|
||||
BLOCK_K: tl.constexpr,
|
||||
EVEN_K: tl.constexpr,
|
||||
SPLIT_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K.
|
||||
The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally,
|
||||
introducing SPLIT-K can improve performance
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
pid_sk = tl.program_id(axis=1)
|
||||
cur_batch = tl.program_id(axis=2)
|
||||
cta_n_num = tl.cdiv(N, BLOCK_N)
|
||||
pid_m = pid // cta_n_num
|
||||
pid_n = pid % cta_n_num
|
||||
|
||||
M = tl.load(seq_lens + cur_batch)
|
||||
if pid_m * BLOCK_M > M:
|
||||
return
|
||||
lora_index = tl.load(lora_indices + cur_batch)
|
||||
if lora_index == -1:
|
||||
return
|
||||
cur_seq_start = tl.load(b_seq_start_loc + cur_batch)
|
||||
offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
offset_k = pid_sk * BLOCK_K + tl.arange(0, BLOCK_K)
|
||||
|
||||
ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)
|
||||
rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)
|
||||
|
||||
a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +
|
||||
offset_k[None, :] * xk_stride)
|
||||
b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride +
|
||||
offset_k[:, None] * lora_n_stride)
|
||||
|
||||
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
|
||||
for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):
|
||||
if EVEN_K:
|
||||
tiled_a = tl.load(a_ptr)
|
||||
tiled_b = tl.load(b_ptr)
|
||||
else:
|
||||
k_remaining = K - k * (BLOCK_K * SPLIT_K)
|
||||
tiled_a = tl.load(a_ptr,
|
||||
mask=offset_k[None, :] < k_remaining,
|
||||
other=0.0)
|
||||
tiled_b = tl.load(b_ptr,
|
||||
mask=offset_k[:, None] < k_remaining,
|
||||
other=0.0)
|
||||
accumulator += tl.dot(tiled_a, tiled_b)
|
||||
|
||||
a_ptr += BLOCK_K * SPLIT_K * xk_stride
|
||||
b_ptr += BLOCK_K * SPLIT_K * lora_n_stride
|
||||
offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M
|
||||
|
||||
offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N
|
||||
c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +
|
||||
offset_cn[None, :] * cn_stride)
|
||||
c_mask = (offset_cm[:, None] <
|
||||
(cur_seq_start + M)) & (offset_cn[None, :] < N)
|
||||
accumulator *= scaling
|
||||
# handles write-back with reduction-splitting
|
||||
if SPLIT_K == 1:
|
||||
tl.store(c_ptr, accumulator, mask=c_mask)
|
||||
else:
|
||||
tl.atomic_add(c_ptr, accumulator, mask=c_mask)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def sgmv_shrink(
|
||||
inputs: torch.Tensor,
|
||||
lora_a_weights: torch.Tensor,
|
||||
output_tensor: torch.Tensor,
|
||||
b_seq_start_loc: torch.Tensor,
|
||||
seq_len_tensor: torch.Tensor,
|
||||
lora_indices_tensor: torch.Tensor,
|
||||
batches: int,
|
||||
max_seq_length: int,
|
||||
scaling: float,
|
||||
):
|
||||
"""
|
||||
|
||||
Args:
|
||||
inputs (torch.Tensor): input tensor
|
||||
lora_a_weights (torch.Tensor): lora'a weight
|
||||
output_tensor (torch.Tensor): output tensor
|
||||
b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative
|
||||
sequence lengths of the sequences in the batch, used to index
|
||||
into sequence. E.g.,if the sequence length is [4, 6], it is
|
||||
[0, 4].
|
||||
seq_len_tensor (torch.Tensor): (batch_size,). record the sequence
|
||||
length of the sequences in the batch
|
||||
lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index
|
||||
corresponding to each batch. An index of -1 means no lora should be
|
||||
applied.
|
||||
batches (int): batch size
|
||||
max_seq_length (int): The max sequence lengths of the sequences
|
||||
in the batch
|
||||
scaling (float): Scaling factor.
|
||||
"""
|
||||
assert inputs.dtype == lora_a_weights.dtype
|
||||
assert inputs.dtype in [torch.float16, torch.bfloat16]
|
||||
assert lora_a_weights.dtype in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
]
|
||||
assert inputs.size(1) == lora_a_weights.size(-1)
|
||||
assert b_seq_start_loc.size(0) == batches
|
||||
assert lora_indices_tensor.size(0) == batches
|
||||
assert inputs.is_contiguous()
|
||||
|
||||
if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size)
|
||||
assert lora_a_weights.size(1) == 1
|
||||
lora_a_weights = lora_a_weights.squeeze(dim=1)
|
||||
else:
|
||||
assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size)
|
||||
assert lora_a_weights.is_contiguous()
|
||||
assert output_tensor.is_contiguous()
|
||||
# TODO tuning this config
|
||||
N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank
|
||||
BLOCK_M = 32
|
||||
BLOCK_N = 16
|
||||
BLOCK_K = 32
|
||||
SPLIT_K = 8
|
||||
EVEN_K = K % (BLOCK_K * SPLIT_K) == 0
|
||||
grid = (
|
||||
triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),
|
||||
SPLIT_K,
|
||||
batches,
|
||||
)
|
||||
|
||||
_sgmv_shrink_kernel[grid](
|
||||
inputs,
|
||||
lora_a_weights,
|
||||
output_tensor,
|
||||
N,
|
||||
K,
|
||||
b_seq_start_loc,
|
||||
seq_len_tensor,
|
||||
lora_indices_tensor,
|
||||
scaling,
|
||||
inputs.stride(0),
|
||||
inputs.stride(1),
|
||||
lora_a_weights.stride(0),
|
||||
lora_a_weights.stride(1),
|
||||
lora_a_weights.stride(2),
|
||||
output_tensor.stride(0),
|
||||
output_tensor.stride(1),
|
||||
BLOCK_M,
|
||||
BLOCK_N,
|
||||
BLOCK_K,
|
||||
EVEN_K,
|
||||
SPLIT_K,
|
||||
)
|
||||
return
|
|
@ -0,0 +1,46 @@
|
|||
import functools
|
||||
from typing import Dict
|
||||
|
||||
|
||||
@functools.lru_cache
|
||||
def _get_op_configs(op_type: str, batch: int, hidden_size: int):
|
||||
# TODO: add optimal configurations
|
||||
return None
|
||||
|
||||
|
||||
def _check_divisibility(hidden_size: int):
|
||||
# The bgmv_expand kernel requires that the hidden_size be divisible by
|
||||
# the number below.
|
||||
divisibility = [2, 4, 8, 16, 32, 64]
|
||||
divisibility.sort(reverse=True)
|
||||
for div in divisibility:
|
||||
if hidden_size % div == 0:
|
||||
return div
|
||||
# hidden_size is an odd number
|
||||
return 1
|
||||
|
||||
|
||||
def _get_default_config(op_type: str, batch: int, hidden_size: int):
|
||||
if op_type == "expand":
|
||||
return {
|
||||
"BLOCK_N": 256,
|
||||
"SPLIT_N": _check_divisibility(hidden_size),
|
||||
"num_warps": 8
|
||||
}
|
||||
else:
|
||||
return {"BLOCK_K": 256, "SPLIT_K": 64, "num_warps": 8}
|
||||
|
||||
|
||||
def get_lora_op_configs(op_type: str, batch: int,
|
||||
hidden_size: int) -> Dict[str, int]:
|
||||
"""Inspired by `fused_moe_kernel`
|
||||
The return value will be a dictionary mapping an irregular grid of batch
|
||||
sizes and hidden_size to configurations of the bgmv-related kernel.
|
||||
NOTE: It currently only supports the default configuration. We plan to
|
||||
generate optimal configurations for different hardware in the future using
|
||||
scripts similar to `benchmark_moe.py`.
|
||||
"""
|
||||
config = _get_op_configs(op_type, batch, hidden_size)
|
||||
if not config:
|
||||
config = _get_default_config(op_type, batch, hidden_size)
|
||||
return config
|
|
@ -1,207 +1,604 @@
|
|||
# Based on code from https://github.com/punica-ai/punica
|
||||
"""
|
||||
Based on:
|
||||
Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023).
|
||||
Punica: Multi-Tenant LoRA Serving.
|
||||
https://arxiv.org/abs/2310.18547
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
if HAS_TRITON:
|
||||
from vllm.lora.ops.bgmv_expand import bgmv_expand
|
||||
from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice
|
||||
from vllm.lora.ops.bgmv_shrink import bgmv_shrink
|
||||
from vllm.lora.ops.sgmv_expand import sgmv_expand
|
||||
from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice
|
||||
from vllm.lora.ops.sgmv_shrink import sgmv_shrink
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# avoid circuit import
|
||||
from vllm.lora.layers import LoRAMapping
|
||||
from vllm.lora.models import LongContextLoRAContext
|
||||
|
||||
|
||||
def _check_punica_support():
|
||||
if ops.is_custom_op_supported("_punica_C::dispatch_bgmv"):
|
||||
return
|
||||
def compute_meta(
|
||||
token_lora_tensor: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, bool]:
|
||||
"""
|
||||
Get the information required for the sgmv kernel. With the features:
|
||||
1. If consecutive requests in the batch use the same LoRA, this function
|
||||
will combine them into a single request, improving sgmv kernel inference
|
||||
performance.
|
||||
2. At the beginning of each prefill stage inference, recalculations are
|
||||
needed based on the input, but only once.
|
||||
"""
|
||||
|
||||
if current_platform.get_device_capability() < (8, 0):
|
||||
raise ImportError(
|
||||
"punica LoRA kernels require compute capability >= 8.0")
|
||||
lora_indices_tensor, seq_length_tensor = torch.unique_consecutive(
|
||||
token_lora_tensor, return_counts=True)
|
||||
cum_result = torch.cumsum(seq_length_tensor, dim=0)
|
||||
b_seq_start_tensor = torch.zeros_like(seq_length_tensor)
|
||||
b_seq_start_tensor[1:].copy_(cum_result[:-1])
|
||||
max_length = seq_length_tensor.max().item()
|
||||
|
||||
batch_size = lora_indices_tensor.size(0)
|
||||
no_lora = False
|
||||
# -1 means no lora should be applied. Use `no_lora` to determine whether
|
||||
# the current step requires LoRA. If LoRA is not needed, the prefill stage
|
||||
# does not need to launch the triton kernel, which can improve performance
|
||||
if batch_size == 1 and lora_indices_tensor == -1:
|
||||
no_lora = True
|
||||
return (b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
||||
batch_size, max_length, no_lora)
|
||||
|
||||
|
||||
# TODO see if this can be vectorized
|
||||
def convert_mapping(
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
||||
Optional[torch.Tensor], List[int]]:
|
||||
"""Converts LoRAMapping to index tensors.
|
||||
|
||||
Args:
|
||||
mapping: LoRAMapping mapping rows in a batch to LoRA ids.
|
||||
lora_index_to_id: List mapping LoRA ids to LoRA indices.
|
||||
max_loras: Maximum number of LoRAs.
|
||||
vocab_size: Model vocab size.
|
||||
extra_vocab_size: Extra vocab size each LoRA can have.
|
||||
long_lora_context: Passed if there are long context lora in a batch.
|
||||
|
||||
Returns:
|
||||
A tuple of tensors:
|
||||
base_indices: Tensor of shape [batch_size] mapping batch rows to
|
||||
LoRA indices.
|
||||
sampler_indices: Tensor of shape [batch_size] mapping requests to
|
||||
LoRA indices for sampler. For generation, this will be the
|
||||
same as base_indicies. For prefill, this will map requests
|
||||
to LoRA indices.
|
||||
sampler_indices_padded: Tensor of shape [batch_size] mapping
|
||||
requests to LoRA indices for sampler with padding.
|
||||
Same as sampler_indicies, but -1 is replaced with
|
||||
max_loras.
|
||||
embeddings_indices: Tensor of shape [2, batch_size] mapping
|
||||
requests to embedding indices. First row is for embeddings
|
||||
added by the LoRAs, second row is for the LoRA.lora_a
|
||||
embeddings.
|
||||
long_lora_indices: Tensor of shape [batch_size] mapping
|
||||
requests to RoPE offsets and rot dims for long LoRAs.
|
||||
None if long context lora doesn't exist.
|
||||
indices_len: List of lengths of the above tensors. It contains
|
||||
(base_indices, sampler_indices, sampler_indices_padded,
|
||||
embeddings_indices, long_lora_indices).
|
||||
"""
|
||||
index_mapping_indices: List[int] = list(mapping.index_mapping).copy()
|
||||
embedding_indices = index_mapping_indices.copy()
|
||||
lora_indices = index_mapping_indices.copy()
|
||||
long_lora_offsets: Optional[torch.Tensor] = None
|
||||
if long_lora_context:
|
||||
long_lora_offsets = torch.zeros(len(index_mapping_indices),
|
||||
device="cuda",
|
||||
dtype=torch.long)
|
||||
prompt_mapping: List[int] = [
|
||||
lora_index_to_id.index(x) if x > 0 else -1
|
||||
for x in mapping.prompt_mapping
|
||||
]
|
||||
lora_idx = None
|
||||
for i in range(len(index_mapping_indices)):
|
||||
# TODO index can be slow. optimize
|
||||
lora_idx = (lora_index_to_id.index(index_mapping_indices[i])
|
||||
if index_mapping_indices[i] > 0 else -1)
|
||||
embedding_indices[i] = lora_idx if index_mapping_indices[i] > 0 else 0
|
||||
lora_indices[i] = lora_idx
|
||||
if long_lora_context:
|
||||
assert long_lora_offsets is not None
|
||||
lora_offset: int = long_lora_context.offsets_by_lora_id.get(
|
||||
index_mapping_indices[i], 0)
|
||||
long_lora_offsets[i] = lora_offset
|
||||
|
||||
indices_list: List[Union[List[int], torch.Tensor]] = [
|
||||
index_mapping_indices,
|
||||
lora_indices,
|
||||
embedding_indices,
|
||||
]
|
||||
if long_lora_context:
|
||||
assert long_lora_offsets is not None
|
||||
indices_list.append(long_lora_offsets)
|
||||
indices = torch.tensor(indices_list, dtype=torch.long, device="cuda")
|
||||
prompt_mapping_tensor = torch.tensor(prompt_mapping,
|
||||
device="cuda",
|
||||
dtype=torch.long)
|
||||
embeddings_indices = torch.stack([
|
||||
indices[2] * extra_vocab_size,
|
||||
indices[2] * (vocab_size + extra_vocab_size),
|
||||
])
|
||||
embeddings_indices[embeddings_indices == -1] = max_loras - 1
|
||||
base_indices = indices[1]
|
||||
sampler_indices = prompt_mapping_tensor
|
||||
sampler_indices_padded = sampler_indices.clone()
|
||||
sampler_indices_padded[sampler_indices_padded == -1] = max_loras - 1
|
||||
sampler_indices_padded = torch.arange(
|
||||
0, len(sampler_indices_padded), device="cuda", dtype=torch.long) + (
|
||||
sampler_indices_padded * len(sampler_indices_padded))
|
||||
long_lora_indices = None
|
||||
long_lora_indices_len: Optional[int] = None
|
||||
if long_lora_context:
|
||||
long_lora_indices = indices[3]
|
||||
long_lora_indices_len = long_lora_indices.shape[-1]
|
||||
# Contain length of indices tensors. Used to index into each tensor.
|
||||
indices_len = [
|
||||
base_indices.shape[-1],
|
||||
sampler_indices.shape[-1],
|
||||
sampler_indices_padded.shape[-1],
|
||||
embeddings_indices.shape[-1],
|
||||
]
|
||||
if long_lora_indices_len is not None:
|
||||
indices_len.append(long_lora_indices_len)
|
||||
else:
|
||||
raise ImportError(
|
||||
"punica LoRA kernels could not be imported. If you built vLLM "
|
||||
"from source, make sure VLLM_INSTALL_PUNICA_KERNELS=1 env var "
|
||||
"was set.")
|
||||
# If long_lora doesn't exist,append None
|
||||
indices_len.append(None)
|
||||
|
||||
|
||||
def bgmv(
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
):
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
w_t_all: Shape: `[None, L, H2, H1]`. All of the transposed weight
|
||||
matrices.
|
||||
indicies: Shape: `[B]`. Indices of the weight matrices.
|
||||
layer_idx: Layer index of the weight matrices.
|
||||
scale: Scaling factor.
|
||||
"""
|
||||
_check_punica_support()
|
||||
|
||||
ops.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale)
|
||||
|
||||
|
||||
def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor,
|
||||
w_t_all: torch.Tensor, indicies: torch.LongTensor,
|
||||
layer_idx: int, scale: float, y_offset: int,
|
||||
y_slice_size: int):
|
||||
"""
|
||||
Same as `bgmv` but you can operate on slices of y.
|
||||
Pass whole y, define y_offset and y_slice_size.
|
||||
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
|
||||
all of the transposed LoRA matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
y_offset: Offset to apply to the starting column of y.
|
||||
y_slice_size: Size of the y column slice.
|
||||
"""
|
||||
_check_punica_support()
|
||||
|
||||
ops.dispatch_bgmv_low_level(
|
||||
y,
|
||||
x,
|
||||
w_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
scale,
|
||||
x.size(1),
|
||||
y_slice_size,
|
||||
y_offset,
|
||||
return (
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
long_lora_indices,
|
||||
indices_len,
|
||||
)
|
||||
|
||||
|
||||
def add_lora(y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None):
|
||||
class PunicaWrapper:
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||
LoRA A matrices.
|
||||
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||
LoRA B matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
buffer: Optional. Shape: `[B, R]`. Temporary buffer.
|
||||
PunicaWrapper is designed to manage and provide metadata for the punica
|
||||
kernel. The main function is to maintain the state information for
|
||||
Multi-LoRA, and to provide the interface for the punica kernel.
|
||||
"""
|
||||
_check_punica_support()
|
||||
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default to avoid
|
||||
# numerical inaccuracies that would otherwise happen
|
||||
# due to downcasting.
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
ops.dispatch_bgmv(buffer, x, wa_t_all, indicies, layer_idx, 1.0)
|
||||
ops.dispatch_bgmv(y, buffer, wb_t_all, indicies, layer_idx, scale)
|
||||
def __init__(self, max_num_batched_tokens: int, max_batches: int,
|
||||
device: str):
|
||||
self._token_lora_indices = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._sampler_indices = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._sampler_indices_padded = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._embeddings_indices = torch.empty(2,
|
||||
max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._long_lora_indices = torch.empty(max_num_batched_tokens,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
|
||||
# 5 is the number of indicies tensors.
|
||||
# base_indices, sampler_indices, sampler_indices_padded,
|
||||
# embeddings_indices,long_lora_indices
|
||||
self.indices_len: List[Optional[int]] = [None] * 5
|
||||
# these attributes are the information required for sgmv kernel
|
||||
self._seq_start_locs = torch.empty(max_batches,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._seq_lengths = torch.empty(max_batches,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self._lora_indices_per_batch = torch.empty(max_batches,
|
||||
dtype=torch.long,
|
||||
device=device)
|
||||
self.max_length: int = 0
|
||||
self.batch_size: int = -1
|
||||
self.is_prefill = False
|
||||
self.no_lora = False
|
||||
|
||||
def add_lora_slice(y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
indicies: torch.LongTensor,
|
||||
layer_idx: int,
|
||||
scale: float,
|
||||
y_offset: int,
|
||||
y_slice_size: int,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None):
|
||||
"""
|
||||
Same as `add_lora` but you can operate on slices of y.
|
||||
Pass whole y, define y_offset and y_slice_size.
|
||||
def update_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
):
|
||||
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
self._update_base_metadata(mapping, lora_index_to_id, max_loras,
|
||||
vocab_size, extra_vocab_size,
|
||||
long_lora_context)
|
||||
if mapping.is_prefill:
|
||||
# Update metadata required for prefill-related operators.
|
||||
self._update_prefill_metada(self.token_lora_indices)
|
||||
self.is_prefill = True
|
||||
else:
|
||||
self.is_prefill = False
|
||||
|
||||
Args:
|
||||
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
|
||||
x: Shape: `[B, H1]`. Input vectors.
|
||||
wa_t_all: Shape: `[None, L, R, H1]`. All of the transposed
|
||||
LoRA A matrices.
|
||||
wb_t_all: Shape: `[None, L, H2, R]`. All of the transposed
|
||||
LoRA B matrices.
|
||||
indicies: Shape: `[B]`. Indices of the LoRA weights.
|
||||
layer_idx: Layer index of LoRA weights.
|
||||
scale: Scaling factor.
|
||||
y_offset: Offset to apply to the starting column of y.
|
||||
y_slice_size: Size of the y column slice.
|
||||
"""
|
||||
_check_punica_support()
|
||||
def _update_base_metadata(
|
||||
self,
|
||||
mapping: "LoRAMapping",
|
||||
lora_index_to_id: List[Optional[int]],
|
||||
max_loras: int,
|
||||
vocab_size: int,
|
||||
extra_vocab_size: int,
|
||||
long_lora_context: Optional["LongContextLoRAContext"] = None,
|
||||
):
|
||||
(
|
||||
base_indices,
|
||||
sampler_indices,
|
||||
sampler_indices_padded,
|
||||
embeddings_indices,
|
||||
long_lora_offsets_tensor,
|
||||
indices_len,
|
||||
) = convert_mapping(
|
||||
mapping,
|
||||
lora_index_to_id,
|
||||
max_loras,
|
||||
vocab_size,
|
||||
extra_vocab_size,
|
||||
long_lora_context,
|
||||
)
|
||||
self._token_lora_indices[:base_indices.shape[0]].copy_(base_indices)
|
||||
self._sampler_indices[:sampler_indices.shape[0]].copy_(sampler_indices)
|
||||
self._sampler_indices_padded[:sampler_indices_padded.shape[0]].copy_(
|
||||
sampler_indices_padded)
|
||||
self._embeddings_indices[:embeddings_indices.
|
||||
shape[0], :embeddings_indices.shape[1]].copy_(
|
||||
embeddings_indices)
|
||||
if long_lora_offsets_tensor is not None:
|
||||
self._long_lora_indices[:long_lora_offsets_tensor.shape[0]].copy_(
|
||||
long_lora_offsets_tensor)
|
||||
else:
|
||||
self._long_lora_indices.zero_()
|
||||
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default to avoid
|
||||
# numerical inaccuracies that would otherwise happen
|
||||
# due to downcasting.
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
ops.dispatch_bgmv_low_level(
|
||||
buffer,
|
||||
x,
|
||||
wa_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
1.0,
|
||||
x.size(1),
|
||||
buffer.size(1),
|
||||
0,
|
||||
)
|
||||
ops.dispatch_bgmv_low_level(
|
||||
y,
|
||||
buffer,
|
||||
wb_t_all,
|
||||
indicies,
|
||||
layer_idx,
|
||||
scale,
|
||||
buffer.size(1),
|
||||
y_slice_size,
|
||||
y_offset,
|
||||
)
|
||||
self.indices_len[:] = indices_len
|
||||
|
||||
def _update_prefill_metada(self, token_lora_tensor: torch.Tensor) -> None:
|
||||
|
||||
(b_seq_start_tensor, seq_length_tensor, lora_indices_tensor,
|
||||
batch_size, max_length, no_lora) = compute_meta(token_lora_tensor)
|
||||
|
||||
self._seq_start_locs[:b_seq_start_tensor.shape[0]].copy_(
|
||||
b_seq_start_tensor)
|
||||
self._seq_lengths[:seq_length_tensor.shape[0]].copy_(seq_length_tensor)
|
||||
self._lora_indices_per_batch[:lora_indices_tensor.shape[0]].copy_(
|
||||
lora_indices_tensor)
|
||||
self.batch_size = batch_size
|
||||
self.max_length = max_length
|
||||
self.no_lora = no_lora
|
||||
|
||||
@property
|
||||
def prefill_metadata(
|
||||
self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int]:
|
||||
"""
|
||||
This property provides a convenient way to access the necessary
|
||||
metadata for prefill-related kernel computations.
|
||||
1. seq_start_locs: Tensor of sequence start positions
|
||||
2. seq_lengths: Tensor of sequence lengths
|
||||
3. lora_indices_per_batch: Tensor of lora indices, and an index of
|
||||
-1 means no lora should be applied.
|
||||
4. batch_size: batch size after clustering identical lora indices
|
||||
5. max_length: The maximum sequence length in the batch
|
||||
"""
|
||||
return (self._seq_start_locs[:self.batch_size],
|
||||
self._seq_lengths[:self.batch_size],
|
||||
self._lora_indices_per_batch[:self.batch_size],
|
||||
self.batch_size, self.max_length)
|
||||
|
||||
@property
|
||||
def token_lora_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides the lora indices corresponding to each token
|
||||
in the batch. An index of -1 means no lora should be applied.
|
||||
"""
|
||||
token_lora_len = self.indices_len[0]
|
||||
return self._token_lora_indices[:token_lora_len]
|
||||
|
||||
@property
|
||||
def sampler_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property is used to access the lora indices specifically for
|
||||
LogitsProcessorWithLoRA
|
||||
"""
|
||||
sampler_indices_len = self.indices_len[1]
|
||||
return self._sampler_indices[:sampler_indices_len]
|
||||
|
||||
@property
|
||||
def sampler_indices_padded(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to padded sampler indices
|
||||
"""
|
||||
indices_padded_len = self.indices_len[2]
|
||||
return self._sampler_indices_padded[:indices_padded_len]
|
||||
|
||||
@property
|
||||
def embeddings_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for lora embeddings,
|
||||
specifically for VocabParallelEmbeddingWithLoRA
|
||||
"""
|
||||
embeddings_indices_len = self.indices_len[3]
|
||||
return self._embeddings_indices[:, :embeddings_indices_len]
|
||||
|
||||
@property
|
||||
def long_lora_indices(self) -> torch.Tensor:
|
||||
"""
|
||||
This property provides access to the indices used for long context
|
||||
lora, specifically for LinearScalingRotaryEmbeddingWithLora
|
||||
"""
|
||||
long_lora_len = self.indices_len[4]
|
||||
return self._long_lora_indices[:long_lora_len]
|
||||
|
||||
def shrink_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_shrink(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
scale,
|
||||
)
|
||||
|
||||
def shrink_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale)
|
||||
|
||||
def expand_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_input: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
add_input,
|
||||
)
|
||||
|
||||
def expand_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_input: bool,
|
||||
):
|
||||
bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_input)
|
||||
|
||||
def expand_slice_prefill(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: Optional[int],
|
||||
y_slice_size: Optional[int],
|
||||
add_input: bool,
|
||||
):
|
||||
#No LoRA request, so return directly
|
||||
if self.no_lora:
|
||||
return
|
||||
sgmv_expand_slice(
|
||||
x,
|
||||
w_t_all,
|
||||
y,
|
||||
*self.prefill_metadata,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_input,
|
||||
)
|
||||
|
||||
def expand_slice_decode(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: Optional[int],
|
||||
y_slice_size: Optional[int],
|
||||
add_input: bool,
|
||||
):
|
||||
bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset,
|
||||
y_slice_size, add_input)
|
||||
|
||||
def add_shrink(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'a.
|
||||
When `is_prefill is` true, it indicates that it is currently the
|
||||
prefill stage, and the `shrink_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the shrink_decode function
|
||||
should be called.
|
||||
"""
|
||||
shrink_fun: Callable = (self.shrink_prefill
|
||||
if self.is_prefill else self.shrink_decode)
|
||||
shrink_fun(y, x, w_t_all, scale)
|
||||
|
||||
def add_expand(
|
||||
self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
add_input: bool = True,
|
||||
):
|
||||
"""
|
||||
Perform the ` y+=x@w_t_all` computation, which is suitable for the
|
||||
GEMM of lora'b.
|
||||
When `is_prefill` is true, it indicates that it is currently the
|
||||
prefill stage, and the `expand_prefill` function should be called.
|
||||
Otherwise, it is the decode stage, and the expand_decode function
|
||||
should be called.
|
||||
"""
|
||||
|
||||
expand_fun: Callable = (self.expand_prefill
|
||||
if self.is_prefill else self.expand_decode)
|
||||
expand_fun(y, x, w_t_all, add_input)
|
||||
|
||||
def add_expand_slice(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
w_t_all: torch.Tensor,
|
||||
y_offset: Optional[int],
|
||||
y_slice_size: Optional[int],
|
||||
add_input: bool = True):
|
||||
"""
|
||||
Similar to `add_expand`
|
||||
"""
|
||||
|
||||
expand_slice_fun: Callable = (self.expand_slice_prefill
|
||||
if self.is_prefill else
|
||||
self.expand_slice_decode)
|
||||
expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input)
|
||||
|
||||
def add_lora(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
scale: float,
|
||||
y_offset: Optional[int] = None,
|
||||
y_slice_size: Optional[int] = None,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
Semantics:
|
||||
y[i] += (
|
||||
x[i].unsqueeze(0)
|
||||
@ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
@ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
|
||||
* scale
|
||||
).squeeze(0)
|
||||
Args:
|
||||
y (torch.Tensor): Output tensor. Will be changed in-place.
|
||||
x (torch.Tensor): Input tensor
|
||||
wa_t_all (torch.Tensor): lora_a's weight
|
||||
wb_t_all (torch.Tensor): lora_b's weight
|
||||
scale (float): Scaling factor.
|
||||
y_offset (Optional[int], optional): Offset to apply to the starting
|
||||
column of y.
|
||||
y_slice_size (Optional[int], optional): Size of the y column slice..
|
||||
buffer (Optional[torch.Tensor], optional): Defaults to None.
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default ,refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
self.add_shrink(buffer, x, wa_t_all, scale)
|
||||
if y_offset is None and y_slice_size is None:
|
||||
self.add_expand(y, buffer, wb_t_all, add_input=True)
|
||||
else:
|
||||
self.add_expand_slice(y,
|
||||
buffer,
|
||||
wb_t_all,
|
||||
y_offset,
|
||||
y_slice_size,
|
||||
add_input=True)
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor,
|
||||
lora_a_stacked: Tuple[torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor],
|
||||
lora_b_stacked: Tuple[torch.Tensor,
|
||||
torch.Tensor,
|
||||
torch.Tensor],
|
||||
scale: float,
|
||||
output_slices: Tuple[int, ...]) -> None:
|
||||
"""
|
||||
Applies lora to each input. Similar to add_lora, This method is
|
||||
used for layers that are composed of multiple sublayers
|
||||
(slices) packed together.
|
||||
"""
|
||||
y_org = y
|
||||
x = x.view(-1, x.shape[-1])
|
||||
y = y.view(-1, y.shape[-1])
|
||||
offset_left = 0
|
||||
# TODO fuse these kernels
|
||||
for slice_idx in range(len(output_slices)):
|
||||
self.add_lora(y, x, lora_a_stacked[slice_idx],
|
||||
lora_b_stacked[slice_idx], scale, offset_left,
|
||||
output_slices[slice_idx])
|
||||
offset_left += output_slices[slice_idx]
|
||||
|
||||
y = y.view_as(y_org)
|
||||
|
||||
def add_lora_logits(self,
|
||||
y: torch.Tensor,
|
||||
x: torch.Tensor,
|
||||
wa_t_all: torch.Tensor,
|
||||
wb_t_all: torch.Tensor,
|
||||
scale,
|
||||
*,
|
||||
buffer: Optional[torch.Tensor] = None) -> None:
|
||||
"""
|
||||
LogitsProcessorWithLoRA always using bgmv
|
||||
"""
|
||||
y_org = y
|
||||
y = y.view(-1, y.shape[-1])
|
||||
x = x.view(-1, x.shape[-1])
|
||||
r = wb_t_all.size(-1)
|
||||
if buffer is None:
|
||||
# We set the buffer to be float32 by default ,refer to:
|
||||
# https://github.com/triton-lang/triton/issues/1387
|
||||
buffer = torch.zeros((x.size(0), r),
|
||||
dtype=torch.float32,
|
||||
device=x.device)
|
||||
|
||||
bgmv_shrink(x, wa_t_all, buffer, self.sampler_indices, scale)
|
||||
bgmv_expand(buffer, wb_t_all, y, self.sampler_indices, add_inputs=True)
|
||||
y = y.view_as(y_org)
|
||||
|
|
|
@ -6,5 +6,6 @@ if HAS_TRITON:
|
|||
|
||||
from vllm.triton_utils.custom_cache_manager import (
|
||||
maybe_set_triton_cache_manager)
|
||||
from vllm.triton_utils.libentry import libentry
|
||||
|
||||
__all__ += ["maybe_set_triton_cache_manager"]
|
||||
__all__ += ["maybe_set_triton_cache_manager", "libentry"]
|
||||
|
|
|
@ -0,0 +1,167 @@
|
|||
# Copied From https://github.com/FlagOpen/FlagGems
|
||||
|
||||
import inspect
|
||||
|
||||
import triton
|
||||
|
||||
|
||||
class LibEntry(triton.KernelInterface):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fn,
|
||||
):
|
||||
self.fn = fn
|
||||
self.arg_names = fn.arg_names
|
||||
self.divisibility = 16
|
||||
self.kernel_cache = dict()
|
||||
fn = self.fn
|
||||
while not isinstance(fn, triton.runtime.JITFunction):
|
||||
fn = fn.fn
|
||||
self.jit_function: triton.runtime.JITFunction = fn
|
||||
self.specialize_indices = [
|
||||
p.num for p in self.jit_function.params
|
||||
if not p.is_constexpr and not p.do_not_specialize
|
||||
]
|
||||
self.do_not_specialize_indices = [
|
||||
p.num for p in self.jit_function.params
|
||||
if not p.is_constexpr and p.do_not_specialize
|
||||
]
|
||||
|
||||
def key(self, spec_args, dns_args, const_args):
|
||||
spec_key = [(arg.dtype, arg.data_ptr() %
|
||||
self.divisibility == 0) if hasattr(arg, "data_ptr") else
|
||||
(type(arg), arg) for arg in spec_args]
|
||||
dns_key = [
|
||||
arg.dtype if hasattr(
|
||||
arg, "data_ptr") else type(arg) if not isinstance(arg, int)
|
||||
else "i32" if -(2**31) <= arg and arg <= 2**31 -
|
||||
1 else "u64" if 2**63 <= arg and arg <= 2**64 - 1 else "i64"
|
||||
for arg in dns_args
|
||||
]
|
||||
# const args passed by position
|
||||
return tuple(spec_key + dns_key + const_args)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
grid = kwargs["grid"]
|
||||
# collect all the arguments
|
||||
spec_args = [] # specialize arguments
|
||||
dns_args = [] # do not specialize arguments
|
||||
const_args = [] # constexpr arguments
|
||||
k_args = [] # kernel arguments
|
||||
for i, arg in enumerate(args):
|
||||
if i in self.specialize_indices:
|
||||
k_args.append(arg)
|
||||
spec_args.append(arg)
|
||||
elif i in self.do_not_specialize_indices:
|
||||
k_args.append(arg)
|
||||
dns_args.append(arg)
|
||||
else:
|
||||
const_args.append(arg)
|
||||
for p in self.jit_function.params[len(args):]:
|
||||
if p.name in kwargs:
|
||||
val = kwargs[p.name]
|
||||
elif p.default is inspect._empty:
|
||||
continue
|
||||
else:
|
||||
val = p.default
|
||||
|
||||
if p.is_constexpr:
|
||||
const_args.append(val)
|
||||
elif p.do_not_specialize:
|
||||
dns_args.append(val)
|
||||
k_args.append(val)
|
||||
else:
|
||||
spec_args.append(val)
|
||||
k_args.append(val)
|
||||
|
||||
entry_key = self.key(spec_args, dns_args, const_args)
|
||||
|
||||
if entry_key not in self.kernel_cache:
|
||||
# compile the kernel also completes the related computations
|
||||
kernel = self.fn.run(*args, **kwargs)
|
||||
fn = self.fn
|
||||
# collect constexpr arguments for grid computation
|
||||
constexprs = {}
|
||||
while not isinstance(fn, triton.runtime.JITFunction):
|
||||
if isinstance(fn, triton.runtime.Autotuner):
|
||||
config = fn.best_config
|
||||
constexprs["num_warps"] = config.num_warps
|
||||
constexprs["num_stages"] = config.num_stages
|
||||
constexprs["num_ctas"] = config.num_ctas
|
||||
constexprs = {**constexprs, **config.kwargs}
|
||||
elif isinstance(fn, triton.runtime.Heuristics):
|
||||
for v, heur in fn.values.items():
|
||||
constexprs[v] = heur({
|
||||
**dict(zip(fn.arg_names, args)),
|
||||
**kwargs,
|
||||
**constexprs,
|
||||
})
|
||||
else:
|
||||
raise RuntimeError("Invalid Runtime Function")
|
||||
fn = fn.fn
|
||||
# In vLLM, certain kernels like fused_moe_kernel get the
|
||||
# best_config(as kwargs) from a configuration json file, rather
|
||||
# than using Autotuner & Heuristics. Therefore, all their constexprs
|
||||
# (tl.constexpr) are assigned values through the following loop.
|
||||
for p in self.jit_function.params:
|
||||
if p.is_constexpr and p.name not in constexprs:
|
||||
constexprs[p.name] = p.default #default=inspect._empty
|
||||
self.kernel_cache[entry_key] = (kernel, constexprs)
|
||||
else:
|
||||
# load kernel from cache directly
|
||||
kernel, constexprs = self.kernel_cache[entry_key]
|
||||
|
||||
if callable(grid):
|
||||
# collect all arguments to the grid fn,ie:
|
||||
# 1. args,
|
||||
# 2. kwargs,
|
||||
# 3. all all other captured arguments in CompiledKernel from
|
||||
# Autotunner & Heuristics when kwargs & captured args conflict,
|
||||
# captured args have higher priority
|
||||
# 4. We must filter out captured args with default value firstly
|
||||
constexprs = {
|
||||
k: v
|
||||
for k, v in constexprs.items() if v is not inspect._empty
|
||||
}
|
||||
meta = {
|
||||
**dict(zip(self.arg_names, args)),
|
||||
**kwargs,
|
||||
**constexprs,
|
||||
}
|
||||
grid = grid(meta)
|
||||
if isinstance(grid, tuple):
|
||||
grid = grid + (1, 1)
|
||||
elif isinstance(grid, list):
|
||||
grid = grid + [1, 1]
|
||||
kernel[grid[0:3]](*k_args)
|
||||
# maintaining the same return type as the JITFunction.run
|
||||
return kernel
|
||||
|
||||
|
||||
def libentry():
|
||||
"""
|
||||
Decorator for triton library entries.
|
||||
Motivation:
|
||||
The runtime overhead of Triton kernels is the reason for the lower
|
||||
performance of small kernels, particularly evident with smaller models.
|
||||
Using this decorator can reduce Triton runtime overhead.
|
||||
How:
|
||||
The `run` function of JITFunction needs to accomplish:
|
||||
- Parameter binding using inspect
|
||||
- KernelArg type wrapping
|
||||
- Cache key calculation
|
||||
When dealing with small size, these steps can become bottlenecks in
|
||||
Triton runtime. Libentry simplifies these steps to reduce runtime
|
||||
overhead, thereby improving the runtime expenses of small kernels.
|
||||
NOTE:
|
||||
When Triton is upgraded to version 3.0.0, libentry can be removed,
|
||||
see: https://github.com/vllm-project/vllm/pull/5036#issuecomment-2243396245
|
||||
|
||||
|
||||
"""
|
||||
|
||||
def decorator(fn):
|
||||
return LibEntry(fn)
|
||||
|
||||
return decorator
|
|
@ -578,9 +578,9 @@ class ModelInputForGPUBuilder(ModelRunnerInputBuilderBase[ModelInputForGPU]):
|
|||
for inter_data in self.inter_data_list
|
||||
])
|
||||
lora_mapping = LoRAMapping(
|
||||
lora_index_mapping,
|
||||
lora_prompt_mapping,
|
||||
)
|
||||
**dict(index_mapping=lora_index_mapping,
|
||||
prompt_mapping=lora_prompt_mapping,
|
||||
is_prefill=not self.decode_only))
|
||||
|
||||
# Prompt adapter data.
|
||||
prompt_adapter_requests: Set[PromptAdapterRequest] = set()
|
||||
|
@ -1152,9 +1152,9 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
|||
|
||||
if self.lora_config:
|
||||
lora_mapping = LoRAMapping(
|
||||
[0] * batch_size,
|
||||
[0] * batch_size,
|
||||
)
|
||||
**dict(index_mapping=[0] * batch_size,
|
||||
prompt_mapping=[0] * batch_size,
|
||||
is_prefill=False))
|
||||
self.set_active_loras(set(), lora_mapping)
|
||||
|
||||
if self.prompt_adapter_config:
|
||||
|
|
Loading…
Reference in New Issue