mirror of https://github.com/vllm-project/vllm
74 lines
2.3 KiB
Python
Executable File
74 lines
2.3 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
|
|
#
|
|
# A command line tool for running pytorch's hipify preprocessor on CUDA
|
|
# source files.
|
|
#
|
|
# See https://github.com/ROCm/hipify_torch
|
|
# and <torch install dir>/utils/hipify/hipify_python.py
|
|
#
|
|
|
|
import argparse
|
|
import os
|
|
import shutil
|
|
|
|
from torch.utils.hipify.hipify_python import hipify
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
|
|
# Project directory where all the source + include files live.
|
|
parser.add_argument(
|
|
"-p",
|
|
"--project_dir",
|
|
help="The project directory.",
|
|
)
|
|
|
|
# Directory where hipified files are written.
|
|
parser.add_argument(
|
|
"-o",
|
|
"--output_dir",
|
|
help="The output directory.",
|
|
)
|
|
|
|
# Source files to convert.
|
|
parser.add_argument("sources",
|
|
help="Source files to hipify.",
|
|
nargs="*",
|
|
default=[])
|
|
|
|
args = parser.parse_args()
|
|
|
|
# Limit include scope to project_dir only
|
|
includes = [os.path.join(args.project_dir, '*')]
|
|
|
|
# Get absolute path for all source files.
|
|
extra_files = [os.path.abspath(s) for s in args.sources]
|
|
|
|
# Copy sources from project directory to output directory.
|
|
# The directory might already exist to hold object files so we ignore that.
|
|
shutil.copytree(args.project_dir, args.output_dir, dirs_exist_ok=True)
|
|
|
|
hipify_result = hipify(project_directory=args.project_dir,
|
|
output_directory=args.output_dir,
|
|
header_include_dirs=[],
|
|
includes=includes,
|
|
extra_files=extra_files,
|
|
show_detailed=True,
|
|
is_pytorch_extension=True,
|
|
hipify_extra_files_only=True)
|
|
|
|
hipified_sources = []
|
|
for source in args.sources:
|
|
s_abs = os.path.abspath(source)
|
|
hipified_s_abs = (hipify_result[s_abs].hipified_path if
|
|
(s_abs in hipify_result
|
|
and hipify_result[s_abs].hipified_path is not None)
|
|
else s_abs)
|
|
hipified_sources.append(hipified_s_abs)
|
|
|
|
assert (len(hipified_sources) == len(args.sources))
|
|
|
|
# Print hipified source files.
|
|
print("\n".join(hipified_sources))
|