forked from OSchip/llvm-project
[mlir][spirv] Add mlir-vulkan-runner
Add an initial version of mlir-vulkan-runner execution driver. A command line utility that executes a MLIR file on the Vulkan by translating MLIR GPU module to SPIR-V and host part to LLVM IR before JIT-compiling and executing the latter. Differential Revision: https://reviews.llvm.org/D72696
This commit is contained in:
parent
bb61021a8f
commit
896ee361a6
|
@ -48,6 +48,7 @@ endif()
|
|||
add_definitions(-DMLIR_CUDA_CONVERSIONS_ENABLED=${MLIR_CUDA_CONVERSIONS_ENABLED})
|
||||
|
||||
set(MLIR_CUDA_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir CUDA runner")
|
||||
set(MLIR_VULKAN_RUNNER_ENABLED 0 CACHE BOOL "Enable building the mlir Vulkan runner")
|
||||
|
||||
include_directories( "include")
|
||||
include_directories( ${MLIR_INCLUDE_DIR})
|
||||
|
|
|
@ -27,7 +27,7 @@
|
|||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -80,7 +80,7 @@ private:
|
|||
/// populates the given `numWorkGroups`.
|
||||
LogicalResult createNumWorkGroups(Location loc, OpBuilder &builder,
|
||||
mlir::gpu::LaunchFuncOp launchOp,
|
||||
SmallVector<Value, 3> &numWorkGroups);
|
||||
SmallVectorImpl<Value> &numWorkGroups);
|
||||
|
||||
/// Declares all needed runtime functions.
|
||||
void declareVulkanFunctions(Location loc);
|
||||
|
@ -153,17 +153,15 @@ void GpuLaunchFuncToVulkanCalssPass::declareVulkanFunctions(Location loc) {
|
|||
|
||||
Value GpuLaunchFuncToVulkanCalssPass::createEntryPointNameConstant(
|
||||
StringRef name, Location loc, OpBuilder &builder) {
|
||||
std::vector<char> shaderName(name.begin(), name.end());
|
||||
SmallString<16> shaderName(name.begin(), name.end());
|
||||
// Append `\0` to follow C style string given that LLVM::createGlobalString()
|
||||
// won't handle this directly for us.
|
||||
shaderName.push_back('\0');
|
||||
|
||||
std::string entryPointGlobalName =
|
||||
std::string(llvm::formatv("{0}_spv_entry_point_name", name));
|
||||
return LLVM::createGlobalString(
|
||||
loc, builder, entryPointGlobalName,
|
||||
StringRef(shaderName.data(), shaderName.size()), LLVM::Linkage::Internal,
|
||||
getLLVMDialect());
|
||||
std::string entryPointGlobalName = (name + "_spv_entry_point_name").str();
|
||||
return LLVM::createGlobalString(loc, builder, entryPointGlobalName,
|
||||
shaderName, LLVM::Linkage::Internal,
|
||||
getLLVMDialect());
|
||||
}
|
||||
|
||||
LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader(
|
||||
|
@ -171,14 +169,12 @@ LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader(
|
|||
bool done = false;
|
||||
SmallVector<uint32_t, 0> binary;
|
||||
for (auto spirvModule : module.getOps<spirv::ModuleOp>()) {
|
||||
if (done) {
|
||||
spirvModule.emitError("should only contain one 'spv.module' op");
|
||||
return failure();
|
||||
}
|
||||
if (done)
|
||||
return spirvModule.emitError("should only contain one 'spv.module' op");
|
||||
done = true;
|
||||
if (failed(spirv::serialize(spirvModule, binary))) {
|
||||
|
||||
if (failed(spirv::serialize(spirvModule, binary)))
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
|
||||
binaryShader.resize(binary.size() * sizeof(uint32_t));
|
||||
|
@ -189,14 +185,13 @@ LogicalResult GpuLaunchFuncToVulkanCalssPass::createBinaryShader(
|
|||
|
||||
LogicalResult GpuLaunchFuncToVulkanCalssPass::createNumWorkGroups(
|
||||
Location loc, OpBuilder &builder, mlir::gpu::LaunchFuncOp launchOp,
|
||||
SmallVector<Value, 3> &numWorkGroups) {
|
||||
SmallVectorImpl<Value> &numWorkGroups) {
|
||||
for (auto index : llvm::seq(0, 3)) {
|
||||
auto numWorkGroupDimConstant = dyn_cast_or_null<ConstantOp>(
|
||||
launchOp.getOperand(index).getDefiningOp());
|
||||
|
||||
if (!numWorkGroupDimConstant) {
|
||||
if (!numWorkGroupDimConstant)
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto numWorkGroupDimValue =
|
||||
numWorkGroupDimConstant.getValue().cast<IntegerAttr>().getInt();
|
||||
|
@ -207,7 +202,6 @@ LogicalResult GpuLaunchFuncToVulkanCalssPass::createNumWorkGroups(
|
|||
return success();
|
||||
}
|
||||
|
||||
// Translates gpu launch op to the sequence of Vulkan runtime calls.
|
||||
void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls(
|
||||
mlir::gpu::LaunchFuncOp launchOp) {
|
||||
ModuleOp module = getModule();
|
||||
|
@ -217,9 +211,8 @@ void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls(
|
|||
// Serialize `spirv::Module` into binary form.
|
||||
std::vector<char> binary;
|
||||
if (failed(
|
||||
GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary))) {
|
||||
GpuLaunchFuncToVulkanCalssPass::createBinaryShader(module, binary)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Create LLVM global with SPIR-V binary data, so we can pass a pointer with
|
||||
// that data to runtime call.
|
||||
|
@ -246,9 +239,8 @@ void GpuLaunchFuncToVulkanCalssPass::translateGpuLaunchCalls(
|
|||
|
||||
// Create number of local workgroup for each dimension.
|
||||
SmallVector<Value, 3> numWorkGroups;
|
||||
if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups))) {
|
||||
if (failed(createNumWorkGroups(loc, builder, launchOp, numWorkGroups)))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Create call `setNumWorkGroups` runtime function with the given numbers of
|
||||
// local workgroup.
|
||||
|
|
|
@ -15,6 +15,7 @@ set(MLIR_DIALECT_LINALG_INTEGRATION_TEST_LIB_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTOR
|
|||
# Passed to lit.site.cfg.py.in to set up the path where to find the libraries
|
||||
# for the mlir cuda runner tests.
|
||||
set(MLIR_CUDA_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||
set(MLIR_VULKAN_WRAPPER_LIBRARY_DIR ${CMAKE_LIBRARY_OUTPUT_DIRECTORY})
|
||||
|
||||
configure_lit_site_cfg(
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/lit.site.cfg.py.in
|
||||
|
@ -61,6 +62,12 @@ if(MLIR_CUDA_RUNNER_ENABLED)
|
|||
)
|
||||
endif()
|
||||
|
||||
if(MLIR_VULKAN_RUNNER_ENABLED)
|
||||
list(APPEND MLIR_TEST_DEPENDS
|
||||
mlir-vulkan-runner
|
||||
)
|
||||
endif()
|
||||
|
||||
add_lit_testsuite(check-mlir "Running the MLIR regression tests"
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
DEPENDS ${MLIR_TEST_DEPENDS}
|
||||
|
|
|
@ -67,7 +67,8 @@ tools.extend([
|
|||
ToolSubst('toy-ch4', unresolved='ignore'),
|
||||
ToolSubst('toy-ch5', unresolved='ignore'),
|
||||
ToolSubst('%linalg_test_lib_dir', config.linalg_test_lib_dir, unresolved='ignore'),
|
||||
ToolSubst('%cuda_wrapper_library_dir', config.cuda_wrapper_library_dir, unresolved='ignore')
|
||||
ToolSubst('%cuda_wrapper_library_dir', config.cuda_wrapper_library_dir, unresolved='ignore'),
|
||||
ToolSubst('%vulkan_wrapper_library_dir', config.vulkan_wrapper_library_dir, unresolved='ignore')
|
||||
])
|
||||
|
||||
llvm_config.add_tool_substitutions(tools, tool_dirs)
|
||||
|
|
|
@ -36,6 +36,8 @@ config.build_examples = @LLVM_BUILD_EXAMPLES@
|
|||
config.run_cuda_tests = @MLIR_CUDA_CONVERSIONS_ENABLED@
|
||||
config.cuda_wrapper_library_dir = "@MLIR_CUDA_WRAPPER_LIBRARY_DIR@"
|
||||
config.enable_cuda_runner = @MLIR_CUDA_RUNNER_ENABLED@
|
||||
config.vulkan_wrapper_library_dir = "@MLIR_VULKAN_WRAPPER_LIBRARY_DIR@"
|
||||
config.enable_vulkan_runner = @MLIR_VULKAN_RUNNER_ENABLED@
|
||||
|
||||
# Support substitution of the tools_dir with user parameters. This is
|
||||
# used when we can't determine the tool dir at configuration time.
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
// RUN: mlir-vulkan-runner %s --shared-libs=%vulkan_wrapper_library_dir/libvulkan-runtime-wrappers%shlibext,%linalg_test_lib_dir/libmlir_runner_utils%shlibext --entry-point-result=void | FileCheck %s
|
||||
|
||||
// CHECK: [3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3, 3.3]
|
||||
module attributes {gpu.container_module} {
|
||||
gpu.module @kernels {
|
||||
gpu.func @kernel_add(%arg0 : memref<8xf32>, %arg1 : memref<8xf32>, %arg2 : memref<8xf32>)
|
||||
attributes {gpu.kernel, spv.entry_point_abi = {local_size = dense<[1, 1, 1]>: vector<3xi32>}} {
|
||||
%0 = "gpu.block_id"() {dimension = "x"} : () -> index
|
||||
%1 = load %arg0[%0] : memref<8xf32>
|
||||
%2 = load %arg1[%0] : memref<8xf32>
|
||||
%3 = addf %1, %2 : f32
|
||||
store %3, %arg2[%0] : memref<8xf32>
|
||||
gpu.return
|
||||
}
|
||||
}
|
||||
|
||||
func @main() {
|
||||
%arg0 = alloc() : memref<8xf32>
|
||||
%arg1 = alloc() : memref<8xf32>
|
||||
%arg2 = alloc() : memref<8xf32>
|
||||
%0 = constant 0 : i32
|
||||
%1 = constant 1 : i32
|
||||
%2 = constant 2 : i32
|
||||
%value0 = constant 0.0 : f32
|
||||
%value1 = constant 1.1 : f32
|
||||
%value2 = constant 2.2 : f32
|
||||
%arg3 = memref_cast %arg0 : memref<8xf32> to memref<?xf32>
|
||||
%arg4 = memref_cast %arg1 : memref<8xf32> to memref<?xf32>
|
||||
%arg5 = memref_cast %arg2 : memref<8xf32> to memref<?xf32>
|
||||
call @setResourceData(%0, %0, %arg3, %value1) : (i32, i32, memref<?xf32>, f32) -> ()
|
||||
call @setResourceData(%0, %1, %arg4, %value2) : (i32, i32, memref<?xf32>, f32) -> ()
|
||||
call @setResourceData(%0, %2, %arg5, %value0) : (i32, i32, memref<?xf32>, f32) -> ()
|
||||
|
||||
%cst1 = constant 1 : index
|
||||
%cst8 = constant 8 : index
|
||||
"gpu.launch_func"(%cst8, %cst1, %cst1, %cst1, %cst1, %cst1, %arg0, %arg1, %arg2) { kernel = "kernel_add", kernel_module = @kernels }
|
||||
: (index, index, index, index, index, index, memref<8xf32>, memref<8xf32>, memref<8xf32>) -> ()
|
||||
%arg6 = memref_cast %arg5 : memref<?xf32> to memref<*xf32>
|
||||
call @print_memref_f32(%arg6) : (memref<*xf32>) -> ()
|
||||
return
|
||||
}
|
||||
func @setResourceData(%0 : i32, %1 : i32, %2 : memref<?xf32>, %4 : f32)
|
||||
func @print_memref_f32(%ptr : memref<*xf32>)
|
||||
}
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
if not config.enable_vulkan_runner:
|
||||
config.unsupported = True
|
|
@ -3,3 +3,4 @@ add_subdirectory(mlir-cpu-runner)
|
|||
add_subdirectory(mlir-opt)
|
||||
add_subdirectory(mlir-tblgen)
|
||||
add_subdirectory(mlir-translate)
|
||||
add_subdirectory(mlir-vulkan-runner)
|
||||
|
|
|
@ -0,0 +1,105 @@
|
|||
set(LLVM_OPTIONAL_SOURCES
|
||||
mlir-vulkan-runner.cpp
|
||||
vulkan-runtime-wrappers.cpp
|
||||
VulkanRuntime.cpp
|
||||
VulkanRuntime.h
|
||||
)
|
||||
|
||||
if (MLIR_VULKAN_RUNNER_ENABLED)
|
||||
message(STATUS "Building the Vulkan runner")
|
||||
|
||||
# At first try "FindVulkan" from:
|
||||
# https://cmake.org/cmake/help/v3.7/module/FindVulkan.html
|
||||
if (NOT CMAKE_VERSION VERSION_LESS 3.7.0)
|
||||
find_package(Vulkan)
|
||||
endif()
|
||||
|
||||
# If Vulkan is not found try a path specified by VULKAN_SDK.
|
||||
if (NOT Vulkan_FOUND)
|
||||
if ("$ENV{VULKAN_SDK}" STREQUAL "")
|
||||
message(FATAL_ERROR "Please use at least CMAKE 3.7.0 or provide "
|
||||
"VULKAN_SDK path as an environment variable")
|
||||
endif()
|
||||
|
||||
find_library(Vulkan_LIBRARY vulkan HINTS "$ENV{VULKAN_SDK}/lib" REQUIRED)
|
||||
if (Vulkan_LIBRARY)
|
||||
set(Vulkan_FOUND ON)
|
||||
set(Vulkan_INCLUDE_DIR "$ENV{VULKAN_SDK}/include")
|
||||
message(STATUS "Found Vulkan: " ${Vulkan_LIBRARY})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if (NOT Vulkan_FOUND)
|
||||
message(FATAL_ERROR "Cannot find Vulkan library")
|
||||
endif()
|
||||
|
||||
add_llvm_library(vulkan-runtime-wrappers SHARED
|
||||
vulkan-runtime-wrappers.cpp
|
||||
VulkanRuntime.cpp
|
||||
)
|
||||
|
||||
target_include_directories(vulkan-runtime-wrappers
|
||||
PRIVATE ${Vulkan_INCLUDE_DIR}
|
||||
LLVMSupport
|
||||
)
|
||||
|
||||
target_link_libraries(vulkan-runtime-wrappers
|
||||
LLVMSupport
|
||||
MLIRSPIRVSerialization
|
||||
LLVMCore
|
||||
LLVMSupport
|
||||
${Vulkan_LIBRARY}
|
||||
)
|
||||
|
||||
set(LIBS
|
||||
LLVMCore
|
||||
LLVMSupport
|
||||
MLIRJitRunner
|
||||
MLIRAffineOps
|
||||
MLIRAnalysis
|
||||
MLIREDSC
|
||||
MLIRExecutionEngine
|
||||
MLIRFxpMathOps
|
||||
MLIRGPU
|
||||
MLIRGPUtoCUDATransforms
|
||||
MLIRGPUtoNVVMTransforms
|
||||
MLIRGPUtoSPIRVTransforms
|
||||
MLIRGPUtoVulkanTransforms
|
||||
MLIRIR
|
||||
MLIRLLVMIR
|
||||
MLIRLinalgOps
|
||||
MLIRLoopToStandard
|
||||
MLIROpenMP
|
||||
MLIRParser
|
||||
MLIRQuantOps
|
||||
MLIRROCDLIR
|
||||
MLIRSPIRV
|
||||
MLIRSPIRVTransforms
|
||||
MLIRStandardOps
|
||||
MLIRStandardToLLVM
|
||||
MLIRSupport
|
||||
MLIRTargetLLVMIR
|
||||
MLIRTransforms
|
||||
MLIRTranslation
|
||||
${Vulkan_LIBRARY}
|
||||
)
|
||||
|
||||
# Manually expand the target library, since our MLIR libraries
|
||||
# aren't plugged into the LLVM dependency tracking. If we don't
|
||||
# do this then we can't insert the CodeGen library after ourselves
|
||||
llvm_expand_pseudo_components(TARGET_LIBS AllTargetsCodeGens)
|
||||
# Prepend LLVM in front of every target, this is how the library
|
||||
# are named with CMake
|
||||
SET(targets_to_link)
|
||||
FOREACH(t ${TARGET_LIBS})
|
||||
LIST(APPEND targets_to_link "LLVM${t}")
|
||||
ENDFOREACH(t)
|
||||
|
||||
add_llvm_tool(mlir-vulkan-runner
|
||||
mlir-vulkan-runner.cpp
|
||||
)
|
||||
add_dependencies(mlir-vulkan-runner vulkan-runtime-wrappers)
|
||||
llvm_update_compile_flags(mlir-vulkan-runner)
|
||||
target_link_libraries(mlir-vulkan-runner PRIVATE ${FULL_LINK_LIBS} ${LIBS})
|
||||
|
||||
endif()
|
|
@ -0,0 +1,717 @@
|
|||
//===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file provides a library for running a module on a Vulkan device.
|
||||
// Implements a Vulkan runtime.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "VulkanRuntime.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
void VulkanRuntime::setNumWorkGroups(const NumWorkGroups &numberWorkGroups) {
|
||||
numWorkGroups = numberWorkGroups;
|
||||
}
|
||||
|
||||
void VulkanRuntime::setResourceStorageClassBindingMap(
|
||||
const ResourceStorageClassBindingMap &stClassData) {
|
||||
resourceStorageClassData = stClassData;
|
||||
}
|
||||
|
||||
void VulkanRuntime::setResourceData(
|
||||
const DescriptorSetIndex desIndex, const BindingIndex bindIndex,
|
||||
const VulkanHostMemoryBuffer &hostMemBuffer) {
|
||||
resourceData[desIndex][bindIndex] = hostMemBuffer;
|
||||
resourceStorageClassData[desIndex][bindIndex] =
|
||||
spirv::StorageClass::StorageBuffer;
|
||||
}
|
||||
|
||||
void VulkanRuntime::setEntryPoint(const char *entryPointName) {
|
||||
entryPoint = entryPointName;
|
||||
}
|
||||
|
||||
void VulkanRuntime::setResourceData(const ResourceData &resData) {
|
||||
resourceData = resData;
|
||||
}
|
||||
|
||||
void VulkanRuntime::setShaderModule(uint8_t *shader, uint32_t size) {
|
||||
binary = shader;
|
||||
binarySize = size;
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::mapStorageClassToDescriptorType(
|
||||
spirv::StorageClass storageClass, VkDescriptorType &descriptorType) {
|
||||
switch (storageClass) {
|
||||
case spirv::StorageClass::StorageBuffer:
|
||||
descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
|
||||
break;
|
||||
case spirv::StorageClass::Uniform:
|
||||
descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER;
|
||||
break;
|
||||
default:
|
||||
llvm::errs() << "unsupported storage class";
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::mapStorageClassToBufferUsageFlag(
|
||||
spirv::StorageClass storageClass, VkBufferUsageFlagBits &bufferUsage) {
|
||||
switch (storageClass) {
|
||||
case spirv::StorageClass::StorageBuffer:
|
||||
bufferUsage = VK_BUFFER_USAGE_STORAGE_BUFFER_BIT;
|
||||
break;
|
||||
case spirv::StorageClass::Uniform:
|
||||
bufferUsage = VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT;
|
||||
break;
|
||||
default:
|
||||
llvm::errs() << "unsupported storage class";
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::countDeviceMemorySize() {
|
||||
for (const auto &resourceDataMapPair : resourceData) {
|
||||
const auto &resourceDataMap = resourceDataMapPair.second;
|
||||
for (const auto &resourceDataBindingPair : resourceDataMap) {
|
||||
if (resourceDataBindingPair.second.size) {
|
||||
memorySize += resourceDataBindingPair.second.size;
|
||||
} else {
|
||||
llvm::errs()
|
||||
<< "expected buffer size greater than zero for resource data";
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::initRuntime() {
|
||||
if (!resourceData.size()) {
|
||||
llvm::errs() << "Vulkan runtime needs at least one resource";
|
||||
return failure();
|
||||
}
|
||||
if (!binarySize || !binary) {
|
||||
llvm::errs() << "binary shader size must be greater than zero";
|
||||
return failure();
|
||||
}
|
||||
if (failed(countDeviceMemorySize())) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::destroy() {
|
||||
// According to Vulkan spec:
|
||||
// "To ensure that no work is active on the device, vkDeviceWaitIdle can be
|
||||
// used to gate the destruction of the device. Prior to destroying a device,
|
||||
// an application is responsible for destroying/freeing any Vulkan objects
|
||||
// that were created using that device as the first parameter of the
|
||||
// corresponding vkCreate* or vkAllocate* command."
|
||||
RETURN_ON_VULKAN_ERROR(vkDeviceWaitIdle(device), "vkDeviceWaitIdle");
|
||||
|
||||
// Free and destroy.
|
||||
vkFreeCommandBuffers(device, commandPool, commandBuffers.size(),
|
||||
commandBuffers.data());
|
||||
vkDestroyCommandPool(device, commandPool, nullptr);
|
||||
vkFreeDescriptorSets(device, descriptorPool, descriptorSets.size(),
|
||||
descriptorSets.data());
|
||||
vkDestroyDescriptorPool(device, descriptorPool, nullptr);
|
||||
vkDestroyPipeline(device, pipeline, nullptr);
|
||||
vkDestroyPipelineLayout(device, pipelineLayout, nullptr);
|
||||
for (auto &descriptorSetLayout: descriptorSetLayouts) {
|
||||
vkDestroyDescriptorSetLayout(device, descriptorSetLayout, nullptr);
|
||||
}
|
||||
vkDestroyShaderModule(device, shaderModule, nullptr);
|
||||
|
||||
// For each descriptor set.
|
||||
for (auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
|
||||
auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
|
||||
// For each descirptor binding.
|
||||
for (auto &memoryBuffer : deviceMemoryBuffers) {
|
||||
vkFreeMemory(device, memoryBuffer.deviceMemory, nullptr);
|
||||
vkDestroyBuffer(device, memoryBuffer.buffer, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
vkDestroyDevice(device, nullptr);
|
||||
vkDestroyInstance(instance, nullptr);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::run() {
|
||||
// Create logical device, shader module and memory buffers.
|
||||
if (failed(createInstance()) || failed(createDevice()) ||
|
||||
failed(createMemoryBuffers()) || failed(createShaderModule())) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Descriptor bindings divided into sets. Each descriptor binding
|
||||
// must have a layout binding attached into a descriptor set layout.
|
||||
// Each layout set must be binded into a pipeline layout.
|
||||
initDescriptorSetLayoutBindingMap();
|
||||
if (failed(createDescriptorSetLayout()) || failed(createPipelineLayout()) ||
|
||||
// Each descriptor set must be allocated from a descriptor pool.
|
||||
failed(createComputePipeline()) || failed(createDescriptorPool()) ||
|
||||
failed(allocateDescriptorSets()) || failed(setWriteDescriptors()) ||
|
||||
// Create command buffer.
|
||||
failed(createCommandPool()) || failed(createComputeCommandBuffer())) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Get working queue.
|
||||
vkGetDeviceQueue(device, queueFamilyIndex, 0, &queue);
|
||||
|
||||
// Submit command buffer into the queue.
|
||||
if (failed(submitCommandBuffersToQueue())) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
RETURN_ON_VULKAN_ERROR(vkQueueWaitIdle(queue), "vkQueueWaitIdle");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createInstance() {
|
||||
VkApplicationInfo applicationInfo = {};
|
||||
applicationInfo.sType = VK_STRUCTURE_TYPE_APPLICATION_INFO;
|
||||
applicationInfo.pNext = nullptr;
|
||||
applicationInfo.pApplicationName = "MLIR Vulkan runtime";
|
||||
applicationInfo.applicationVersion = 0;
|
||||
applicationInfo.pEngineName = "mlir";
|
||||
applicationInfo.engineVersion = 0;
|
||||
applicationInfo.apiVersion = VK_MAKE_VERSION(1, 0, 0);
|
||||
|
||||
VkInstanceCreateInfo instanceCreateInfo = {};
|
||||
instanceCreateInfo.sType = VK_STRUCTURE_TYPE_INSTANCE_CREATE_INFO;
|
||||
instanceCreateInfo.pNext = nullptr;
|
||||
instanceCreateInfo.flags = 0;
|
||||
instanceCreateInfo.pApplicationInfo = &applicationInfo;
|
||||
instanceCreateInfo.enabledLayerCount = 0;
|
||||
instanceCreateInfo.ppEnabledLayerNames = 0;
|
||||
instanceCreateInfo.enabledExtensionCount = 0;
|
||||
instanceCreateInfo.ppEnabledExtensionNames = 0;
|
||||
|
||||
RETURN_ON_VULKAN_ERROR(vkCreateInstance(&instanceCreateInfo, 0, &instance),
|
||||
"vkCreateInstance");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createDevice() {
|
||||
uint32_t physicalDeviceCount = 0;
|
||||
RETURN_ON_VULKAN_ERROR(
|
||||
vkEnumeratePhysicalDevices(instance, &physicalDeviceCount, 0),
|
||||
"vkEnumeratePhysicalDevices");
|
||||
|
||||
llvm::SmallVector<VkPhysicalDevice, 1> physicalDevices(physicalDeviceCount);
|
||||
RETURN_ON_VULKAN_ERROR(vkEnumeratePhysicalDevices(instance,
|
||||
&physicalDeviceCount,
|
||||
physicalDevices.data()),
|
||||
"vkEnumeratePhysicalDevices");
|
||||
|
||||
RETURN_ON_VULKAN_ERROR(physicalDeviceCount ? VK_SUCCESS : VK_INCOMPLETE,
|
||||
"physicalDeviceCount");
|
||||
|
||||
// TODO(denis0x0D): find the best device.
|
||||
const auto &physicalDevice = physicalDevices.front();
|
||||
getBestComputeQueue(physicalDevice);
|
||||
|
||||
const float queuePrioritory = 1.0f;
|
||||
VkDeviceQueueCreateInfo deviceQueueCreateInfo = {};
|
||||
deviceQueueCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_QUEUE_CREATE_INFO;
|
||||
deviceQueueCreateInfo.pNext = nullptr;
|
||||
deviceQueueCreateInfo.flags = 0;
|
||||
deviceQueueCreateInfo.queueFamilyIndex = queueFamilyIndex;
|
||||
deviceQueueCreateInfo.queueCount = 1;
|
||||
deviceQueueCreateInfo.pQueuePriorities = &queuePrioritory;
|
||||
|
||||
// Structure specifying parameters of a newly created device.
|
||||
VkDeviceCreateInfo deviceCreateInfo = {};
|
||||
deviceCreateInfo.sType = VK_STRUCTURE_TYPE_DEVICE_CREATE_INFO;
|
||||
deviceCreateInfo.pNext = nullptr;
|
||||
deviceCreateInfo.flags = 0;
|
||||
deviceCreateInfo.queueCreateInfoCount = 1;
|
||||
deviceCreateInfo.pQueueCreateInfos = &deviceQueueCreateInfo;
|
||||
deviceCreateInfo.enabledLayerCount = 0;
|
||||
deviceCreateInfo.ppEnabledLayerNames = nullptr;
|
||||
deviceCreateInfo.enabledExtensionCount = 0;
|
||||
deviceCreateInfo.ppEnabledExtensionNames = nullptr;
|
||||
deviceCreateInfo.pEnabledFeatures = nullptr;
|
||||
|
||||
RETURN_ON_VULKAN_ERROR(
|
||||
vkCreateDevice(physicalDevice, &deviceCreateInfo, 0, &device),
|
||||
"vkCreateDevice");
|
||||
|
||||
VkPhysicalDeviceMemoryProperties properties = {};
|
||||
vkGetPhysicalDeviceMemoryProperties(physicalDevice, &properties);
|
||||
|
||||
// Try to find memory type with following properties:
|
||||
// VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT bit specifies that memory allocated
|
||||
// with this type can be mapped for host access using vkMapMemory;
|
||||
// VK_MEMORY_PROPERTY_HOST_COHERENT_BIT bit specifies that the host cache
|
||||
// management commands vkFlushMappedMemoryRanges and
|
||||
// vkInvalidateMappedMemoryRanges are not needed to flush host writes to the
|
||||
// device or make device writes visible to the host, respectively.
|
||||
for (uint32_t i = 0, e = properties.memoryTypeCount; i < e; ++i) {
|
||||
if ((VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT &
|
||||
properties.memoryTypes[i].propertyFlags) &&
|
||||
(VK_MEMORY_PROPERTY_HOST_COHERENT_BIT &
|
||||
properties.memoryTypes[i].propertyFlags) &&
|
||||
(memorySize <=
|
||||
properties.memoryHeaps[properties.memoryTypes[i].heapIndex].size)) {
|
||||
memoryTypeIndex = i;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_ON_VULKAN_ERROR(memoryTypeIndex == VK_MAX_MEMORY_TYPES ? VK_INCOMPLETE
|
||||
: VK_SUCCESS,
|
||||
"invalid memoryTypeIndex");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
VulkanRuntime::getBestComputeQueue(const VkPhysicalDevice &physicalDevice) {
|
||||
uint32_t queueFamilyPropertiesCount = 0;
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice,
|
||||
&queueFamilyPropertiesCount, 0);
|
||||
SmallVector<VkQueueFamilyProperties, 1> queueFamilyProperties(
|
||||
queueFamilyPropertiesCount);
|
||||
|
||||
vkGetPhysicalDeviceQueueFamilyProperties(physicalDevice,
|
||||
&queueFamilyPropertiesCount,
|
||||
queueFamilyProperties.data());
|
||||
|
||||
// VK_QUEUE_COMPUTE_BIT specifies that queues in this queue family support
|
||||
// compute operations.
|
||||
for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
|
||||
const VkQueueFlags maskedFlags =
|
||||
(~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) &
|
||||
queueFamilyProperties[i].queueFlags);
|
||||
|
||||
if (!(VK_QUEUE_GRAPHICS_BIT & maskedFlags) &&
|
||||
(VK_QUEUE_COMPUTE_BIT & maskedFlags)) {
|
||||
queueFamilyIndex = i;
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < queueFamilyPropertiesCount; ++i) {
|
||||
const VkQueueFlags maskedFlags =
|
||||
(~(VK_QUEUE_TRANSFER_BIT | VK_QUEUE_SPARSE_BINDING_BIT) &
|
||||
queueFamilyProperties[i].queueFlags);
|
||||
|
||||
if (VK_QUEUE_COMPUTE_BIT & maskedFlags) {
|
||||
queueFamilyIndex = i;
|
||||
return success();
|
||||
}
|
||||
}
|
||||
|
||||
llvm::errs() << "cannot find valid queue";
|
||||
return failure();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createMemoryBuffers() {
|
||||
// For each descriptor set.
|
||||
for (const auto &resourceDataMapPair : resourceData) {
|
||||
llvm::SmallVector<VulkanDeviceMemoryBuffer, 1> deviceMemoryBuffers;
|
||||
const auto descriptorSetIndex = resourceDataMapPair.first;
|
||||
const auto &resourceDataMap = resourceDataMapPair.second;
|
||||
|
||||
// For each descriptor binding.
|
||||
for (const auto &resourceDataBindingPair : resourceDataMap) {
|
||||
// Create device memory buffer.
|
||||
VulkanDeviceMemoryBuffer memoryBuffer;
|
||||
memoryBuffer.bindingIndex = resourceDataBindingPair.first;
|
||||
VkDescriptorType descriptorType = {};
|
||||
VkBufferUsageFlagBits bufferUsage = {};
|
||||
|
||||
// Check that descriptor set has storage class map.
|
||||
const auto resourceStorageClassMapIt =
|
||||
resourceStorageClassData.find(descriptorSetIndex);
|
||||
if (resourceStorageClassMapIt == resourceStorageClassData.end()) {
|
||||
llvm::errs()
|
||||
<< "cannot find storge class for resource in descriptor set: "
|
||||
<< descriptorSetIndex;
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check that specific descriptor binding has storage class.
|
||||
const auto &resourceStorageClassMap = resourceStorageClassMapIt->second;
|
||||
const auto resourceStorageClassIt =
|
||||
resourceStorageClassMap.find(resourceDataBindingPair.first);
|
||||
if (resourceStorageClassIt == resourceStorageClassMap.end()) {
|
||||
llvm::errs()
|
||||
<< "cannot find storage class for resource with descriptor index: "
|
||||
<< resourceDataBindingPair.first;
|
||||
return failure();
|
||||
}
|
||||
|
||||
const auto resourceStorageClassBinding = resourceStorageClassIt->second;
|
||||
if (failed(mapStorageClassToDescriptorType(resourceStorageClassBinding,
|
||||
descriptorType)) ||
|
||||
failed(mapStorageClassToBufferUsageFlag(resourceStorageClassBinding,
|
||||
bufferUsage))) {
|
||||
llvm::errs() << "storage class for resource with descriptor binding: "
|
||||
<< resourceDataBindingPair.first
|
||||
<< " in the descriptor set: " << descriptorSetIndex
|
||||
<< " is not supported ";
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Set descriptor type for the specific device memory buffer.
|
||||
memoryBuffer.descriptorType = descriptorType;
|
||||
const auto bufferSize = resourceDataBindingPair.second.size;
|
||||
|
||||
// Specify memory allocation info.
|
||||
VkMemoryAllocateInfo memoryAllocateInfo = {};
|
||||
memoryAllocateInfo.sType = VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO;
|
||||
memoryAllocateInfo.pNext = nullptr;
|
||||
memoryAllocateInfo.allocationSize = bufferSize;
|
||||
memoryAllocateInfo.memoryTypeIndex = memoryTypeIndex;
|
||||
|
||||
// Allocate device memory.
|
||||
RETURN_ON_VULKAN_ERROR(vkAllocateMemory(device, &memoryAllocateInfo, 0,
|
||||
&memoryBuffer.deviceMemory),
|
||||
"vkAllocateMemory");
|
||||
void *payload;
|
||||
RETURN_ON_VULKAN_ERROR(vkMapMemory(device, memoryBuffer.deviceMemory, 0,
|
||||
bufferSize, 0,
|
||||
reinterpret_cast<void **>(&payload)),
|
||||
"vkMapMemory");
|
||||
|
||||
// Copy host memory into the mapped area.
|
||||
std::memcpy(payload, resourceDataBindingPair.second.ptr, bufferSize);
|
||||
vkUnmapMemory(device, memoryBuffer.deviceMemory);
|
||||
|
||||
VkBufferCreateInfo bufferCreateInfo = {};
|
||||
bufferCreateInfo.sType = VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO;
|
||||
bufferCreateInfo.pNext = nullptr;
|
||||
bufferCreateInfo.flags = 0;
|
||||
bufferCreateInfo.size = bufferSize;
|
||||
bufferCreateInfo.usage = bufferUsage;
|
||||
bufferCreateInfo.sharingMode = VK_SHARING_MODE_EXCLUSIVE;
|
||||
bufferCreateInfo.queueFamilyIndexCount = 1;
|
||||
bufferCreateInfo.pQueueFamilyIndices = &queueFamilyIndex;
|
||||
RETURN_ON_VULKAN_ERROR(
|
||||
vkCreateBuffer(device, &bufferCreateInfo, 0, &memoryBuffer.buffer),
|
||||
"vkCreateBuffer");
|
||||
|
||||
// Bind buffer and device memory.
|
||||
RETURN_ON_VULKAN_ERROR(vkBindBufferMemory(device, memoryBuffer.buffer,
|
||||
memoryBuffer.deviceMemory, 0),
|
||||
"vkBindBufferMemory");
|
||||
|
||||
// Update buffer info.
|
||||
memoryBuffer.bufferInfo.buffer = memoryBuffer.buffer;
|
||||
memoryBuffer.bufferInfo.offset = 0;
|
||||
memoryBuffer.bufferInfo.range = VK_WHOLE_SIZE;
|
||||
deviceMemoryBuffers.push_back(memoryBuffer);
|
||||
}
|
||||
|
||||
// Associate device memory buffers with a descriptor set.
|
||||
deviceMemoryBufferMap[descriptorSetIndex] = deviceMemoryBuffers;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createShaderModule() {
|
||||
VkShaderModuleCreateInfo shaderModuleCreateInfo = {};
|
||||
shaderModuleCreateInfo.sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO;
|
||||
shaderModuleCreateInfo.pNext = nullptr;
|
||||
shaderModuleCreateInfo.flags = 0;
|
||||
// Set size in bytes.
|
||||
shaderModuleCreateInfo.codeSize = binarySize;
|
||||
// Set pointer to the binary shader.
|
||||
shaderModuleCreateInfo.pCode = reinterpret_cast<uint32_t *>(binary);
|
||||
RETURN_ON_VULKAN_ERROR(
|
||||
vkCreateShaderModule(device, &shaderModuleCreateInfo, 0, &shaderModule),
|
||||
"vkCreateShaderModule");
|
||||
return success();
|
||||
}
|
||||
|
||||
void VulkanRuntime::initDescriptorSetLayoutBindingMap() {
|
||||
for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
|
||||
SmallVector<VkDescriptorSetLayoutBinding, 1> descriptorSetLayoutBindings;
|
||||
const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
|
||||
const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
|
||||
|
||||
// Create a layout binding for each descriptor.
|
||||
for (const auto &memBuffer : deviceMemoryBuffers) {
|
||||
VkDescriptorSetLayoutBinding descriptorSetLayoutBinding = {};
|
||||
descriptorSetLayoutBinding.binding = memBuffer.bindingIndex;
|
||||
descriptorSetLayoutBinding.descriptorType = memBuffer.descriptorType;
|
||||
descriptorSetLayoutBinding.descriptorCount = 1;
|
||||
descriptorSetLayoutBinding.stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
|
||||
descriptorSetLayoutBinding.pImmutableSamplers = 0;
|
||||
descriptorSetLayoutBindings.push_back(descriptorSetLayoutBinding);
|
||||
}
|
||||
descriptorSetLayoutBindingMap[descriptorSetIndex] =
|
||||
descriptorSetLayoutBindings;
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createDescriptorSetLayout() {
|
||||
for (const auto &deviceMemoryBufferMapPair : deviceMemoryBufferMap) {
|
||||
const auto descriptorSetIndex = deviceMemoryBufferMapPair.first;
|
||||
const auto &deviceMemoryBuffers = deviceMemoryBufferMapPair.second;
|
||||
// Each descriptor in a descriptor set must be the same type.
|
||||
VkDescriptorType descriptorType =
|
||||
deviceMemoryBuffers.front().descriptorType;
|
||||
const uint32_t descriptorSize = deviceMemoryBuffers.size();
|
||||
const auto descriptorSetLayoutBindingIt =
|
||||
descriptorSetLayoutBindingMap.find(descriptorSetIndex);
|
||||
|
||||
if (descriptorSetLayoutBindingIt == descriptorSetLayoutBindingMap.end()) {
|
||||
llvm::errs() << "cannot find layout bindings for the set with number: "
|
||||
<< descriptorSetIndex;
|
||||
return failure();
|
||||
}
|
||||
|
||||
const auto &descriptorSetLayoutBindings =
|
||||
descriptorSetLayoutBindingIt->second;
|
||||
// Create descriptor set layout.
|
||||
VkDescriptorSetLayout descriptorSetLayout = {};
|
||||
VkDescriptorSetLayoutCreateInfo descriptorSetLayoutCreateInfo = {};
|
||||
|
||||
descriptorSetLayoutCreateInfo.sType =
|
||||
VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO;
|
||||
descriptorSetLayoutCreateInfo.pNext = nullptr;
|
||||
descriptorSetLayoutCreateInfo.flags = 0;
|
||||
// Amount of descriptor bindings in a layout set.
|
||||
descriptorSetLayoutCreateInfo.bindingCount =
|
||||
descriptorSetLayoutBindings.size();
|
||||
descriptorSetLayoutCreateInfo.pBindings =
|
||||
descriptorSetLayoutBindings.data();
|
||||
RETURN_ON_VULKAN_ERROR(
|
||||
vkCreateDescriptorSetLayout(device, &descriptorSetLayoutCreateInfo, 0,
|
||||
&descriptorSetLayout),
|
||||
"vkCreateDescriptorSetLayout");
|
||||
|
||||
descriptorSetLayouts.push_back(descriptorSetLayout);
|
||||
descriptorSetInfoPool.push_back(
|
||||
{descriptorSetIndex, descriptorSize, descriptorType});
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createPipelineLayout() {
|
||||
// Associate descriptor sets with a pipeline layout.
|
||||
VkPipelineLayoutCreateInfo pipelineLayoutCreateInfo = {};
|
||||
pipelineLayoutCreateInfo.sType =
|
||||
VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO;
|
||||
pipelineLayoutCreateInfo.pNext = nullptr;
|
||||
pipelineLayoutCreateInfo.flags = 0;
|
||||
pipelineLayoutCreateInfo.setLayoutCount = descriptorSetLayouts.size();
|
||||
pipelineLayoutCreateInfo.pSetLayouts = descriptorSetLayouts.data();
|
||||
pipelineLayoutCreateInfo.pushConstantRangeCount = 0;
|
||||
pipelineLayoutCreateInfo.pPushConstantRanges = 0;
|
||||
RETURN_ON_VULKAN_ERROR(vkCreatePipelineLayout(device,
|
||||
&pipelineLayoutCreateInfo, 0,
|
||||
&pipelineLayout),
|
||||
"vkCreatePipelineLayout");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createComputePipeline() {
|
||||
VkPipelineShaderStageCreateInfo stageInfo = {};
|
||||
stageInfo.sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO;
|
||||
stageInfo.pNext = nullptr;
|
||||
stageInfo.flags = 0;
|
||||
stageInfo.stage = VK_SHADER_STAGE_COMPUTE_BIT;
|
||||
stageInfo.module = shaderModule;
|
||||
// Set entry point.
|
||||
stageInfo.pName = entryPoint;
|
||||
stageInfo.pSpecializationInfo = 0;
|
||||
|
||||
VkComputePipelineCreateInfo computePipelineCreateInfo = {};
|
||||
computePipelineCreateInfo.sType =
|
||||
VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO;
|
||||
computePipelineCreateInfo.pNext = nullptr;
|
||||
computePipelineCreateInfo.flags = 0;
|
||||
computePipelineCreateInfo.stage = stageInfo;
|
||||
computePipelineCreateInfo.layout = pipelineLayout;
|
||||
computePipelineCreateInfo.basePipelineHandle = 0;
|
||||
computePipelineCreateInfo.basePipelineIndex = 0;
|
||||
RETURN_ON_VULKAN_ERROR(vkCreateComputePipelines(device, 0, 1,
|
||||
&computePipelineCreateInfo, 0,
|
||||
&pipeline),
|
||||
"vkCreateComputePipelines");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createDescriptorPool() {
|
||||
llvm::SmallVector<VkDescriptorPoolSize, 1> descriptorPoolSizes;
|
||||
for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
|
||||
// For each descriptor set populate descriptor pool size.
|
||||
VkDescriptorPoolSize descriptorPoolSize = {};
|
||||
descriptorPoolSize.type = descriptorSetInfo.descriptorType;
|
||||
descriptorPoolSize.descriptorCount = descriptorSetInfo.descriptorSize;
|
||||
descriptorPoolSizes.push_back(descriptorPoolSize);
|
||||
}
|
||||
|
||||
VkDescriptorPoolCreateInfo descriptorPoolCreateInfo = {};
|
||||
descriptorPoolCreateInfo.sType =
|
||||
VK_STRUCTURE_TYPE_DESCRIPTOR_POOL_CREATE_INFO;
|
||||
descriptorPoolCreateInfo.pNext = nullptr;
|
||||
descriptorPoolCreateInfo.flags = 0;
|
||||
descriptorPoolCreateInfo.maxSets = descriptorPoolSizes.size();
|
||||
descriptorPoolCreateInfo.poolSizeCount = descriptorPoolSizes.size();
|
||||
descriptorPoolCreateInfo.pPoolSizes = descriptorPoolSizes.data();
|
||||
RETURN_ON_VULKAN_ERROR(vkCreateDescriptorPool(device,
|
||||
&descriptorPoolCreateInfo, 0,
|
||||
&descriptorPool),
|
||||
"vkCreateDescriptorPool");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::allocateDescriptorSets() {
|
||||
VkDescriptorSetAllocateInfo descriptorSetAllocateInfo = {};
|
||||
// Size of desciptor sets and descriptor layout sets is the same.
|
||||
descriptorSets.resize(descriptorSetLayouts.size());
|
||||
descriptorSetAllocateInfo.sType =
|
||||
VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO;
|
||||
descriptorSetAllocateInfo.pNext = nullptr;
|
||||
descriptorSetAllocateInfo.descriptorPool = descriptorPool;
|
||||
descriptorSetAllocateInfo.descriptorSetCount = descriptorSetLayouts.size();
|
||||
descriptorSetAllocateInfo.pSetLayouts = descriptorSetLayouts.data();
|
||||
RETURN_ON_VULKAN_ERROR(vkAllocateDescriptorSets(device,
|
||||
&descriptorSetAllocateInfo,
|
||||
descriptorSets.data()),
|
||||
"vkAllocateDescriptorSets");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::setWriteDescriptors() {
|
||||
if (descriptorSets.size() != descriptorSetInfoPool.size()) {
|
||||
llvm::errs() << "Each descriptor set must have descriptor set information";
|
||||
return failure();
|
||||
}
|
||||
// For each descriptor set.
|
||||
auto descriptorSetIt = descriptorSets.begin();
|
||||
// Each descriptor set is associated with descriptor set info.
|
||||
for (const auto &descriptorSetInfo : descriptorSetInfoPool) {
|
||||
// For each device memory buffer in the descriptor set.
|
||||
const auto &deviceMemoryBuffers =
|
||||
deviceMemoryBufferMap[descriptorSetInfo.descriptorSet];
|
||||
for (const auto &memoryBuffer : deviceMemoryBuffers) {
|
||||
// Structure describing descriptor sets to write to.
|
||||
VkWriteDescriptorSet wSet = {};
|
||||
wSet.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET;
|
||||
wSet.pNext = nullptr;
|
||||
// Descirptor set.
|
||||
wSet.dstSet = *descriptorSetIt;
|
||||
wSet.dstBinding = memoryBuffer.bindingIndex;
|
||||
wSet.dstArrayElement = 0;
|
||||
wSet.descriptorCount = 1;
|
||||
wSet.descriptorType = memoryBuffer.descriptorType;
|
||||
wSet.pImageInfo = nullptr;
|
||||
wSet.pBufferInfo = &memoryBuffer.bufferInfo;
|
||||
wSet.pTexelBufferView = nullptr;
|
||||
vkUpdateDescriptorSets(device, 1, &wSet, 0, nullptr);
|
||||
}
|
||||
// Increment descriptor set iterator.
|
||||
++descriptorSetIt;
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createCommandPool() {
|
||||
VkCommandPoolCreateInfo commandPoolCreateInfo = {};
|
||||
commandPoolCreateInfo.sType = VK_STRUCTURE_TYPE_COMMAND_POOL_CREATE_INFO;
|
||||
commandPoolCreateInfo.pNext = nullptr;
|
||||
commandPoolCreateInfo.flags = 0;
|
||||
commandPoolCreateInfo.queueFamilyIndex = queueFamilyIndex;
|
||||
RETURN_ON_VULKAN_ERROR(
|
||||
vkCreateCommandPool(device, &commandPoolCreateInfo, 0, &commandPool),
|
||||
"vkCreateCommandPool");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::createComputeCommandBuffer() {
|
||||
VkCommandBufferAllocateInfo commandBufferAllocateInfo = {};
|
||||
VkCommandBuffer commandBuffer;
|
||||
commandBufferAllocateInfo.sType =
|
||||
VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO;
|
||||
commandBufferAllocateInfo.pNext = nullptr;
|
||||
commandBufferAllocateInfo.commandPool = commandPool;
|
||||
commandBufferAllocateInfo.level = VK_COMMAND_BUFFER_LEVEL_PRIMARY;
|
||||
commandBufferAllocateInfo.commandBufferCount = 1;
|
||||
RETURN_ON_VULKAN_ERROR(vkAllocateCommandBuffers(device,
|
||||
&commandBufferAllocateInfo,
|
||||
&commandBuffer),
|
||||
"vkAllocateCommandBuffers");
|
||||
|
||||
VkCommandBufferBeginInfo commandBufferBeginInfo = {};
|
||||
commandBufferBeginInfo.sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO;
|
||||
commandBufferBeginInfo.pNext = nullptr;
|
||||
commandBufferBeginInfo.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT;
|
||||
commandBufferBeginInfo.pInheritanceInfo = nullptr;
|
||||
|
||||
// Commands begin.
|
||||
RETURN_ON_VULKAN_ERROR(
|
||||
vkBeginCommandBuffer(commandBuffer, &commandBufferBeginInfo),
|
||||
"vkBeginCommandBuffer");
|
||||
|
||||
// Commands.
|
||||
vkCmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
|
||||
vkCmdBindDescriptorSets(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
|
||||
pipelineLayout, 0, descriptorSets.size(),
|
||||
descriptorSets.data(), 0, 0);
|
||||
vkCmdDispatch(commandBuffer, numWorkGroups.x, numWorkGroups.y,
|
||||
numWorkGroups.z);
|
||||
|
||||
// Commands end.
|
||||
RETURN_ON_VULKAN_ERROR(vkEndCommandBuffer(commandBuffer),
|
||||
"vkEndCommandBuffer");
|
||||
|
||||
commandBuffers.push_back(commandBuffer);
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::submitCommandBuffersToQueue() {
|
||||
VkSubmitInfo submitInfo = {};
|
||||
submitInfo.sType = VK_STRUCTURE_TYPE_SUBMIT_INFO;
|
||||
submitInfo.pNext = nullptr;
|
||||
submitInfo.waitSemaphoreCount = 0;
|
||||
submitInfo.pWaitSemaphores = 0;
|
||||
submitInfo.pWaitDstStageMask = 0;
|
||||
submitInfo.commandBufferCount = commandBuffers.size();
|
||||
submitInfo.pCommandBuffers = commandBuffers.data();
|
||||
submitInfo.signalSemaphoreCount = 0;
|
||||
submitInfo.pSignalSemaphores = nullptr;
|
||||
RETURN_ON_VULKAN_ERROR(vkQueueSubmit(queue, 1, &submitInfo, 0),
|
||||
"vkQueueSubmit");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult VulkanRuntime::updateHostMemoryBuffers() {
|
||||
// For each descriptor set.
|
||||
for (auto &resourceDataMapPair : resourceData) {
|
||||
auto &resourceDataMap = resourceDataMapPair.second;
|
||||
auto &deviceMemoryBuffers =
|
||||
deviceMemoryBufferMap[resourceDataMapPair.first];
|
||||
// For each device memory buffer in the set.
|
||||
for (auto &deviceMemoryBuffer : deviceMemoryBuffers) {
|
||||
if (resourceDataMap.count(deviceMemoryBuffer.bindingIndex)) {
|
||||
void *payload;
|
||||
auto &hostMemoryBuffer =
|
||||
resourceDataMap[deviceMemoryBuffer.bindingIndex];
|
||||
RETURN_ON_VULKAN_ERROR(vkMapMemory(device,
|
||||
deviceMemoryBuffer.deviceMemory, 0,
|
||||
hostMemoryBuffer.size, 0,
|
||||
reinterpret_cast<void **>(&payload)),
|
||||
"vkMapMemory");
|
||||
std::memcpy(hostMemoryBuffer.ptr, payload, hostMemoryBuffer.size);
|
||||
vkUnmapMemory(device, deviceMemoryBuffer.deviceMemory);
|
||||
}
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
|
@ -0,0 +1,225 @@
|
|||
//===- VulkanRuntime.cpp - MLIR Vulkan runtime ------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file declares Vulkan runtime API.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef VULKAN_RUNTIME_H
|
||||
#define VULKAN_RUNTIME_H
|
||||
|
||||
#include "mlir/Analysis/Passes.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Dialect/SPIRV/Serialization.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Support/StringExtras.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/Support/ToolOutputFile.h"
|
||||
|
||||
#include <vulkan/vulkan.h>
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
using DescriptorSetIndex = uint32_t;
|
||||
using BindingIndex = uint32_t;
|
||||
|
||||
/// Struct containing information regarding to a device memory buffer.
|
||||
struct VulkanDeviceMemoryBuffer {
|
||||
BindingIndex bindingIndex{0};
|
||||
VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
|
||||
VkDescriptorBufferInfo bufferInfo{};
|
||||
VkBuffer buffer{VK_NULL_HANDLE};
|
||||
VkDeviceMemory deviceMemory{VK_NULL_HANDLE};
|
||||
};
|
||||
|
||||
/// Struct containing information regarding to a host memory buffer.
|
||||
struct VulkanHostMemoryBuffer {
|
||||
/// Pointer to a host memory.
|
||||
void *ptr{nullptr};
|
||||
/// Size of a host memory in bytes.
|
||||
uint32_t size{0};
|
||||
};
|
||||
|
||||
/// Struct containing the number of local workgroups to dispatch for each
|
||||
/// dimension.
|
||||
struct NumWorkGroups {
|
||||
uint32_t x{1};
|
||||
uint32_t y{1};
|
||||
uint32_t z{1};
|
||||
};
|
||||
|
||||
/// Struct containing information regarding a descriptor set.
|
||||
struct DescriptorSetInfo {
|
||||
/// Index of a descriptor set in descriptor sets.
|
||||
DescriptorSetIndex descriptorSet{0};
|
||||
/// Number of desriptors in a set.
|
||||
uint32_t descriptorSize{0};
|
||||
/// Type of a descriptor set.
|
||||
VkDescriptorType descriptorType{VK_DESCRIPTOR_TYPE_MAX_ENUM};
|
||||
};
|
||||
|
||||
/// VulkanHostMemoryBuffer mapped into a descriptor set and a binding.
|
||||
using ResourceData =
|
||||
llvm::DenseMap<DescriptorSetIndex,
|
||||
llvm::DenseMap<BindingIndex, VulkanHostMemoryBuffer>>;
|
||||
|
||||
/// StorageClass mapped into a descriptor set and a binding.
|
||||
using ResourceStorageClassBindingMap =
|
||||
llvm::DenseMap<DescriptorSetIndex,
|
||||
llvm::DenseMap<BindingIndex, mlir::spirv::StorageClass>>;
|
||||
|
||||
inline void emitVulkanError(const llvm::Twine &message, VkResult error) {
|
||||
llvm::errs()
|
||||
<< message.concat(" failed with error code ").concat(llvm::Twine{error});
|
||||
}
|
||||
|
||||
#define RETURN_ON_VULKAN_ERROR(result, msg) \
|
||||
if ((result) != VK_SUCCESS) { \
|
||||
emitVulkanError(msg, (result)); \
|
||||
return failure(); \
|
||||
}
|
||||
|
||||
/// Vulkan runtime.
|
||||
/// The purpose of this class is to run SPIR-V compute shader on Vulkan
|
||||
/// device.
|
||||
/// Before the run, user must provide and set resource data with descriptors,
|
||||
/// SPIR-V shader, number of work groups and entry point. After the creation of
|
||||
/// VulkanRuntime, special methods must be called in the following
|
||||
/// sequence: initRuntime(), run(), updateHostMemoryBuffers(), destroy();
|
||||
/// each method in the sequence returns succes or failure depends on the Vulkan
|
||||
/// result code.
|
||||
class VulkanRuntime {
|
||||
public:
|
||||
explicit VulkanRuntime() = default;
|
||||
VulkanRuntime(const VulkanRuntime &) = delete;
|
||||
VulkanRuntime &operator=(const VulkanRuntime &) = delete;
|
||||
|
||||
/// Sets needed data for Vulkan runtime.
|
||||
void setResourceData(const ResourceData &resData);
|
||||
void setResourceData(const DescriptorSetIndex desIndex,
|
||||
const BindingIndex bindIndex,
|
||||
const VulkanHostMemoryBuffer &hostMemBuffer);
|
||||
void setShaderModule(uint8_t *shader, uint32_t size);
|
||||
void setNumWorkGroups(const NumWorkGroups &numberWorkGroups);
|
||||
void setResourceStorageClassBindingMap(
|
||||
const ResourceStorageClassBindingMap &stClassData);
|
||||
void setEntryPoint(const char *entryPointName);
|
||||
|
||||
/// Runtime initialization.
|
||||
LogicalResult initRuntime();
|
||||
|
||||
/// Runs runtime.
|
||||
LogicalResult run();
|
||||
|
||||
/// Updates host memory buffers.
|
||||
LogicalResult updateHostMemoryBuffers();
|
||||
|
||||
/// Destroys all created vulkan objects and resources.
|
||||
LogicalResult destroy();
|
||||
|
||||
private:
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Pipeline creation methods.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult createInstance();
|
||||
LogicalResult createDevice();
|
||||
LogicalResult getBestComputeQueue(const VkPhysicalDevice &physicalDevice);
|
||||
LogicalResult createMemoryBuffers();
|
||||
LogicalResult createShaderModule();
|
||||
void initDescriptorSetLayoutBindingMap();
|
||||
LogicalResult createDescriptorSetLayout();
|
||||
LogicalResult createPipelineLayout();
|
||||
LogicalResult createComputePipeline();
|
||||
LogicalResult createDescriptorPool();
|
||||
LogicalResult allocateDescriptorSets();
|
||||
LogicalResult setWriteDescriptors();
|
||||
LogicalResult createCommandPool();
|
||||
LogicalResult createComputeCommandBuffer();
|
||||
LogicalResult submitCommandBuffersToQueue();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Helper methods.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
/// Maps storage class to a descriptor type.
|
||||
LogicalResult
|
||||
mapStorageClassToDescriptorType(spirv::StorageClass storageClass,
|
||||
VkDescriptorType &descriptorType);
|
||||
|
||||
/// Maps storage class to buffer usage flags.
|
||||
LogicalResult
|
||||
mapStorageClassToBufferUsageFlag(spirv::StorageClass storageClass,
|
||||
VkBufferUsageFlagBits &bufferUsage);
|
||||
|
||||
LogicalResult countDeviceMemorySize();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Vulkan objects.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
VkInstance instance;
|
||||
VkDevice device;
|
||||
VkQueue queue;
|
||||
|
||||
/// Specifies VulkanDeviceMemoryBuffers divided into sets.
|
||||
llvm::DenseMap<DescriptorSetIndex,
|
||||
llvm::SmallVector<VulkanDeviceMemoryBuffer, 1>>
|
||||
deviceMemoryBufferMap;
|
||||
|
||||
/// Specifies shader module.
|
||||
VkShaderModule shaderModule;
|
||||
|
||||
/// Specifies layout bindings.
|
||||
llvm::DenseMap<DescriptorSetIndex,
|
||||
llvm::SmallVector<VkDescriptorSetLayoutBinding, 1>>
|
||||
descriptorSetLayoutBindingMap;
|
||||
|
||||
/// Specifies layouts of descriptor sets.
|
||||
llvm::SmallVector<VkDescriptorSetLayout, 1> descriptorSetLayouts;
|
||||
VkPipelineLayout pipelineLayout;
|
||||
|
||||
/// Specifies descriptor sets.
|
||||
llvm::SmallVector<VkDescriptorSet, 1> descriptorSets;
|
||||
|
||||
/// Specifies a pool of descriptor set info, each descriptor set must have
|
||||
/// information such as type, index and amount of bindings.
|
||||
llvm::SmallVector<DescriptorSetInfo, 1> descriptorSetInfoPool;
|
||||
VkDescriptorPool descriptorPool;
|
||||
|
||||
/// Computation pipeline.
|
||||
VkPipeline pipeline;
|
||||
VkCommandPool commandPool;
|
||||
llvm::SmallVector<VkCommandBuffer, 1> commandBuffers;
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Vulkan memory context.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
uint32_t queueFamilyIndex{0};
|
||||
uint32_t memoryTypeIndex{VK_MAX_MEMORY_TYPES};
|
||||
VkDeviceSize memorySize{0};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Vulkan execution context.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
NumWorkGroups numWorkGroups;
|
||||
const char *entryPoint{nullptr};
|
||||
uint8_t *binary{nullptr};
|
||||
uint32_t binarySize{0};
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Vulkan resource data and storage classes.
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
||||
ResourceData resourceData;
|
||||
ResourceStorageClassBindingMap resourceStorageClassData;
|
||||
};
|
||||
#endif
|
|
@ -0,0 +1,46 @@
|
|||
//===- mlir-vulkan-runner.cpp - MLIR Vulkan Execution Driver --------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This is a command line utility that executes an MLIR file on the Vulkan by
|
||||
// translating MLIR GPU module to SPIR-V and host part to LLVM IR before
|
||||
// JIT-compiling and executing the latter.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Conversion/GPUToSPIRV/ConvertGPUToSPIRVPass.h"
|
||||
#include "mlir/Conversion/GPUToVulkan/ConvertGPUToVulkanPass.h"
|
||||
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
|
||||
#include "mlir/Conversion/StandardToSPIRV/ConvertStandardToSPIRVPass.h"
|
||||
#include "mlir/Dialect/GPU/Passes.h"
|
||||
#include "mlir/Dialect/SPIRV/Passes.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "mlir/Pass/PassManager.h"
|
||||
#include "mlir/Support/JitRunner.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
static LogicalResult runMLIRPasses(ModuleOp module) {
|
||||
PassManager passManager(module.getContext());
|
||||
applyPassManagerCLOptions(passManager);
|
||||
|
||||
passManager.addPass(createGpuKernelOutliningPass());
|
||||
passManager.addPass(createLegalizeStdOpsForSPIRVLoweringPass());
|
||||
passManager.addPass(createConvertGPUToSPIRVPass());
|
||||
OpPassManager &modulePM = passManager.nest<spirv::ModuleOp>();
|
||||
modulePM.addPass(spirv::createLowerABIAttributesPass());
|
||||
passManager.addPass(createConvertGpuLaunchFuncToVulkanCallsPass());
|
||||
passManager.addPass(createLowerToLLVMPass());
|
||||
return passManager.run(module);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
llvm::llvm_shutdown_obj x;
|
||||
registerPassManagerCLOptions();
|
||||
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
//===- vulkan-runtime-wrappers.cpp - MLIR Vulkan runner wrapper library ---===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Implements C runtime wrappers around the VulkanRuntime.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include <mutex>
|
||||
#include <numeric>
|
||||
|
||||
#include "VulkanRuntime.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO(denis0x0D): This static machinery should be replaced by `initVulkan` and
|
||||
// `deinitVulkan` to be more explicit and to avoid static initialization and
|
||||
// destruction.
|
||||
class VulkanRuntimeManager;
|
||||
static llvm::ManagedStatic<VulkanRuntimeManager> vkRuntimeManager;
|
||||
|
||||
class VulkanRuntimeManager {
|
||||
public:
|
||||
VulkanRuntimeManager() = default;
|
||||
VulkanRuntimeManager(const VulkanRuntimeManager &) = delete;
|
||||
VulkanRuntimeManager operator=(const VulkanRuntimeManager &) = delete;
|
||||
~VulkanRuntimeManager() = default;
|
||||
|
||||
void setResourceData(DescriptorSetIndex setIndex, BindingIndex bindIndex,
|
||||
const VulkanHostMemoryBuffer &memBuffer) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
vulkanRuntime.setResourceData(setIndex, bindIndex, memBuffer);
|
||||
}
|
||||
|
||||
void setEntryPoint(const char *entryPoint) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
vulkanRuntime.setEntryPoint(entryPoint);
|
||||
}
|
||||
|
||||
void setNumWorkGroups(NumWorkGroups numWorkGroups) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
vulkanRuntime.setNumWorkGroups(numWorkGroups);
|
||||
}
|
||||
|
||||
void setShaderModule(uint8_t *shader, uint32_t size) {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
vulkanRuntime.setShaderModule(shader, size);
|
||||
}
|
||||
|
||||
void runOnVulkan() {
|
||||
std::lock_guard<std::mutex> lock(mutex);
|
||||
if (failed(vulkanRuntime.initRuntime()) || failed(vulkanRuntime.run()) ||
|
||||
failed(vulkanRuntime.updateHostMemoryBuffers()) ||
|
||||
failed(vulkanRuntime.destroy())) {
|
||||
llvm::errs() << "runOnVulkan failed";
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
VulkanRuntime vulkanRuntime;
|
||||
std::mutex mutex;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
extern "C" {
|
||||
/// Fills the given memref with the given value.
|
||||
/// Binds the given memref to the given descriptor set and descriptor index.
|
||||
void setResourceData(const DescriptorSetIndex setIndex, BindingIndex bindIndex,
|
||||
float *allocated, float *aligned, int64_t offset,
|
||||
int64_t size, int64_t stride, float value) {
|
||||
std::fill_n(allocated, size, value);
|
||||
VulkanHostMemoryBuffer memBuffer{allocated,
|
||||
static_cast<uint32_t>(size * sizeof(float))};
|
||||
vkRuntimeManager->setResourceData(setIndex, bindIndex, memBuffer);
|
||||
}
|
||||
|
||||
void setEntryPoint(const char *entryPoint) {
|
||||
vkRuntimeManager->setEntryPoint(entryPoint);
|
||||
}
|
||||
|
||||
void setNumWorkGroups(uint32_t x, uint32_t y, uint32_t z) {
|
||||
vkRuntimeManager->setNumWorkGroups({x, y, z});
|
||||
}
|
||||
|
||||
void setBinaryShader(uint8_t *shader, uint32_t size) {
|
||||
vkRuntimeManager->setShaderModule(shader, size);
|
||||
}
|
||||
|
||||
void runOnVulkan() { vkRuntimeManager->runOnVulkan(); }
|
||||
}
|
Loading…
Reference in New Issue