forked from OSchip/llvm-project
318 lines
13 KiB
C++
318 lines
13 KiB
C++
//===- ConvertLaunchFuncToLLVMCalls.cpp - MLIR GPU launch to LLVM pass ----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file implements passes to convert `gpu.launch_func` op into a sequence
|
|
// of LLVM calls that emulate the host and device sides.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "../PassDetail.h"
|
|
#include "mlir/Conversion/ArithmeticToLLVM/ArithmeticToLLVM.h"
|
|
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
|
|
#include "mlir/Conversion/LLVMCommon/Pattern.h"
|
|
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
|
#include "mlir/Conversion/MemRefToLLVM/MemRefToLLVM.h"
|
|
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h"
|
|
#include "mlir/Conversion/SPIRVToLLVM/SPIRVToLLVMPass.h"
|
|
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
|
|
#include "mlir/Dialect/GPU/GPUDialect.h"
|
|
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
|
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
|
|
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
|
#include "mlir/IR/BuiltinOps.h"
|
|
#include "mlir/IR/SymbolTable.h"
|
|
#include "mlir/Transforms/DialectConversion.h"
|
|
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
|
|
using namespace mlir;
|
|
|
|
static constexpr const char kSPIRVModule[] = "__spv__";
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Utility functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Returns the string name of the `DescriptorSet` decoration.
|
|
static std::string descriptorSetName() {
|
|
return llvm::convertToSnakeFromCamelCase(
|
|
stringifyDecoration(spirv::Decoration::DescriptorSet));
|
|
}
|
|
|
|
/// Returns the string name of the `Binding` decoration.
|
|
static std::string bindingName() {
|
|
return llvm::convertToSnakeFromCamelCase(
|
|
stringifyDecoration(spirv::Decoration::Binding));
|
|
}
|
|
|
|
/// Calculates the index of the kernel's operand that is represented by the
|
|
/// given global variable with the `bind` attribute. We assume that the index of
|
|
/// each kernel's operand is mapped to (descriptorSet, binding) by the map:
|
|
/// i -> (0, i)
|
|
/// which is implemented under `LowerABIAttributesPass`.
|
|
static unsigned calculateGlobalIndex(spirv::GlobalVariableOp op) {
|
|
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
|
|
return binding.getInt();
|
|
}
|
|
|
|
/// Copies the given number of bytes from src to dst pointers.
|
|
static void copy(Location loc, Value dst, Value src, Value size,
|
|
OpBuilder &builder) {
|
|
MLIRContext *context = builder.getContext();
|
|
auto llvmI1Type = IntegerType::get(context, 1);
|
|
Value isVolatile = builder.create<LLVM::ConstantOp>(
|
|
loc, llvmI1Type, builder.getBoolAttr(false));
|
|
builder.create<LLVM::MemcpyOp>(loc, dst, src, size, isVolatile);
|
|
}
|
|
|
|
/// Encodes the binding and descriptor set numbers into a new symbolic name.
|
|
/// The name is specified by
|
|
/// {kernel_module_name}_{variable_name}_descriptor_set{ds}_binding{b}
|
|
/// to avoid symbolic conflicts, where 'ds' and 'b' are descriptor set and
|
|
/// binding numbers.
|
|
static std::string
|
|
createGlobalVariableWithBindName(spirv::GlobalVariableOp op,
|
|
StringRef kernelModuleName) {
|
|
IntegerAttr descriptorSet =
|
|
op->getAttrOfType<IntegerAttr>(descriptorSetName());
|
|
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
|
|
return llvm::formatv("{0}_{1}_descriptor_set{2}_binding{3}",
|
|
kernelModuleName.str(), op.sym_name().str(),
|
|
std::to_string(descriptorSet.getInt()),
|
|
std::to_string(binding.getInt()));
|
|
}
|
|
|
|
/// Returns true if the given global variable has both a descriptor set number
|
|
/// and a binding number.
|
|
static bool hasDescriptorSetAndBinding(spirv::GlobalVariableOp op) {
|
|
IntegerAttr descriptorSet =
|
|
op->getAttrOfType<IntegerAttr>(descriptorSetName());
|
|
IntegerAttr binding = op->getAttrOfType<IntegerAttr>(bindingName());
|
|
return descriptorSet && binding;
|
|
}
|
|
|
|
/// Fills `globalVariableMap` with SPIR-V global variables that represent kernel
|
|
/// arguments from the given SPIR-V module. We assume that the module contains a
|
|
/// single entry point function. Hence, all `spv.GlobalVariable`s with a bind
|
|
/// attribute are kernel arguments.
|
|
static LogicalResult getKernelGlobalVariables(
|
|
spirv::ModuleOp module,
|
|
DenseMap<uint32_t, spirv::GlobalVariableOp> &globalVariableMap) {
|
|
auto entryPoints = module.getOps<spirv::EntryPointOp>();
|
|
if (!llvm::hasSingleElement(entryPoints)) {
|
|
return module.emitError(
|
|
"The module must contain exactly one entry point function");
|
|
}
|
|
auto globalVariables = module.getOps<spirv::GlobalVariableOp>();
|
|
for (auto globalOp : globalVariables) {
|
|
if (hasDescriptorSetAndBinding(globalOp))
|
|
globalVariableMap[calculateGlobalIndex(globalOp)] = globalOp;
|
|
}
|
|
return success();
|
|
}
|
|
|
|
/// Encodes the SPIR-V module's symbolic name into the name of the entry point
|
|
/// function.
|
|
static LogicalResult encodeKernelName(spirv::ModuleOp module) {
|
|
StringRef spvModuleName = module.sym_name().getValue();
|
|
// We already know that the module contains exactly one entry point function
|
|
// based on `getKernelGlobalVariables()` call. Update this function's name
|
|
// to:
|
|
// {spv_module_name}_{function_name}
|
|
auto entryPoint = *module.getOps<spirv::EntryPointOp>().begin();
|
|
StringRef funcName = entryPoint.fn();
|
|
auto funcOp = module.lookupSymbol<spirv::FuncOp>(entryPoint.fnAttr());
|
|
StringAttr newFuncName =
|
|
StringAttr::get(module->getContext(), spvModuleName + "_" + funcName);
|
|
if (failed(SymbolTable::replaceAllSymbolUses(funcOp, newFuncName, module)))
|
|
return failure();
|
|
SymbolTable::setSymbolName(funcOp, newFuncName);
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Conversion patterns
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Structure to group information about the variables being copied.
|
|
struct CopyInfo {
|
|
Value dst;
|
|
Value src;
|
|
Value size;
|
|
};
|
|
|
|
/// This pattern emulates a call to the kernel in LLVM dialect. For that, we
|
|
/// copy the data to the global variable (emulating device side), call the
|
|
/// kernel as a normal void LLVM function, and copy the data back (emulating the
|
|
/// host side).
|
|
class GPULaunchLowering : public ConvertOpToLLVMPattern<gpu::LaunchFuncOp> {
|
|
using ConvertOpToLLVMPattern<gpu::LaunchFuncOp>::ConvertOpToLLVMPattern;
|
|
|
|
LogicalResult
|
|
matchAndRewrite(gpu::LaunchFuncOp launchOp, OpAdaptor adaptor,
|
|
ConversionPatternRewriter &rewriter) const override {
|
|
auto *op = launchOp.getOperation();
|
|
MLIRContext *context = rewriter.getContext();
|
|
auto module = launchOp->getParentOfType<ModuleOp>();
|
|
|
|
// Get the SPIR-V module that represents the gpu kernel module. The module
|
|
// is named:
|
|
// __spv__{kernel_module_name}
|
|
// based on GPU to SPIR-V conversion.
|
|
StringRef kernelModuleName = launchOp.getKernelModuleName().getValue();
|
|
std::string spvModuleName = kSPIRVModule + kernelModuleName.str();
|
|
auto spvModule = module.lookupSymbol<spirv::ModuleOp>(
|
|
StringAttr::get(context, spvModuleName));
|
|
if (!spvModule) {
|
|
return launchOp.emitOpError("SPIR-V kernel module '")
|
|
<< spvModuleName << "' is not found";
|
|
}
|
|
|
|
// Declare kernel function in the main module so that it later can be linked
|
|
// with its definition from the kernel module. We know that the kernel
|
|
// function would have no arguments and the data is passed via global
|
|
// variables. The name of the kernel will be
|
|
// {spv_module_name}_{kernel_function_name}
|
|
// to avoid symbolic name conflicts.
|
|
StringRef kernelFuncName = launchOp.getKernelName().getValue();
|
|
std::string newKernelFuncName = spvModuleName + "_" + kernelFuncName.str();
|
|
auto kernelFunc = module.lookupSymbol<LLVM::LLVMFuncOp>(
|
|
StringAttr::get(context, newKernelFuncName));
|
|
if (!kernelFunc) {
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(module.getBody());
|
|
kernelFunc = rewriter.create<LLVM::LLVMFuncOp>(
|
|
rewriter.getUnknownLoc(), newKernelFuncName,
|
|
LLVM::LLVMFunctionType::get(LLVM::LLVMVoidType::get(context),
|
|
ArrayRef<Type>()));
|
|
rewriter.setInsertionPoint(launchOp);
|
|
}
|
|
|
|
// Get all global variables associated with the kernel operands.
|
|
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
|
|
if (failed(getKernelGlobalVariables(spvModule, globalVariableMap)))
|
|
return failure();
|
|
|
|
// Traverse kernel operands that were converted to MemRefDescriptors. For
|
|
// each operand, create a global variable and copy data from operand to it.
|
|
Location loc = launchOp.getLoc();
|
|
SmallVector<CopyInfo, 4> copyInfo;
|
|
auto numKernelOperands = launchOp.getNumKernelOperands();
|
|
auto kernelOperands = adaptor.getOperands().take_back(numKernelOperands);
|
|
for (const auto &operand : llvm::enumerate(kernelOperands)) {
|
|
// Check if the kernel's operand is a ranked memref.
|
|
auto memRefType = launchOp.getKernelOperand(operand.index())
|
|
.getType()
|
|
.dyn_cast<MemRefType>();
|
|
if (!memRefType)
|
|
return failure();
|
|
|
|
// Calculate the size of the memref and get the pointer to the allocated
|
|
// buffer.
|
|
SmallVector<Value, 4> sizes;
|
|
SmallVector<Value, 4> strides;
|
|
Value sizeBytes;
|
|
getMemRefDescriptorSizes(loc, memRefType, {}, rewriter, sizes, strides,
|
|
sizeBytes);
|
|
MemRefDescriptor descriptor(operand.value());
|
|
Value src = descriptor.allocatedPtr(rewriter, loc);
|
|
|
|
// Get the global variable in the SPIR-V module that is associated with
|
|
// the kernel operand. Construct its new name and create a corresponding
|
|
// LLVM dialect global variable.
|
|
spirv::GlobalVariableOp spirvGlobal = globalVariableMap[operand.index()];
|
|
auto pointeeType =
|
|
spirvGlobal.type().cast<spirv::PointerType>().getPointeeType();
|
|
auto dstGlobalType = typeConverter->convertType(pointeeType);
|
|
if (!dstGlobalType)
|
|
return failure();
|
|
std::string name =
|
|
createGlobalVariableWithBindName(spirvGlobal, spvModuleName);
|
|
// Check if this variable has already been created.
|
|
auto dstGlobal = module.lookupSymbol<LLVM::GlobalOp>(name);
|
|
if (!dstGlobal) {
|
|
OpBuilder::InsertionGuard guard(rewriter);
|
|
rewriter.setInsertionPointToStart(module.getBody());
|
|
dstGlobal = rewriter.create<LLVM::GlobalOp>(
|
|
loc, dstGlobalType,
|
|
/*isConstant=*/false, LLVM::Linkage::Linkonce, name, Attribute(),
|
|
/*alignment=*/0);
|
|
rewriter.setInsertionPoint(launchOp);
|
|
}
|
|
|
|
// Copy the data from src operand pointer to dst global variable. Save
|
|
// src, dst and size so that we can copy data back after emulating the
|
|
// kernel call.
|
|
Value dst = rewriter.create<LLVM::AddressOfOp>(loc, dstGlobal);
|
|
copy(loc, dst, src, sizeBytes, rewriter);
|
|
|
|
CopyInfo info;
|
|
info.dst = dst;
|
|
info.src = src;
|
|
info.size = sizeBytes;
|
|
copyInfo.push_back(info);
|
|
}
|
|
// Create a call to the kernel and copy the data back.
|
|
rewriter.replaceOpWithNewOp<LLVM::CallOp>(op, kernelFunc,
|
|
ArrayRef<Value>());
|
|
for (CopyInfo info : copyInfo)
|
|
copy(loc, info.src, info.dst, info.size, rewriter);
|
|
return success();
|
|
}
|
|
};
|
|
|
|
class LowerHostCodeToLLVM
|
|
: public LowerHostCodeToLLVMBase<LowerHostCodeToLLVM> {
|
|
public:
|
|
void runOnOperation() override {
|
|
ModuleOp module = getOperation();
|
|
|
|
// Erase the GPU module.
|
|
for (auto gpuModule :
|
|
llvm::make_early_inc_range(module.getOps<gpu::GPUModuleOp>()))
|
|
gpuModule.erase();
|
|
|
|
// Specify options to lower Standard to LLVM and pull in the conversion
|
|
// patterns.
|
|
LowerToLLVMOptions options(module.getContext());
|
|
options.emitCWrappers = true;
|
|
auto *context = module.getContext();
|
|
RewritePatternSet patterns(context);
|
|
LLVMTypeConverter typeConverter(context, options);
|
|
mlir::arith::populateArithmeticToLLVMConversionPatterns(typeConverter,
|
|
patterns);
|
|
populateMemRefToLLVMConversionPatterns(typeConverter, patterns);
|
|
populateStdToLLVMConversionPatterns(typeConverter, patterns);
|
|
patterns.add<GPULaunchLowering>(typeConverter);
|
|
|
|
// Pull in SPIR-V type conversion patterns to convert SPIR-V global
|
|
// variable's type to LLVM dialect type.
|
|
populateSPIRVToLLVMTypeConversion(typeConverter);
|
|
|
|
ConversionTarget target(*context);
|
|
target.addLegalDialect<LLVM::LLVMDialect>();
|
|
if (failed(applyPartialConversion(module, target, std::move(patterns))))
|
|
signalPassFailure();
|
|
|
|
// Finally, modify the kernel function in SPIR-V modules to avoid symbolic
|
|
// conflicts.
|
|
for (auto spvModule : module.getOps<spirv::ModuleOp>())
|
|
(void)encodeKernelName(spvModule);
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
|
|
mlir::createLowerHostCodeToLLVMPass() {
|
|
return std::make_unique<LowerHostCodeToLLVM>();
|
|
}
|