forked from OSchip/llvm-project
[MLIR][SPIRVToLLVM] Added a hook for descriptor set / binding encoding
This patch introduces a hook to encode descriptor set and binding number into `spv.globalVariable`'s symbolic name. This allows to preserve this information, and at the same time legalize the global variable for the conversion to LLVM dialect. This is required for `mlir-spirv-cpu-runner` to convert kernel arguments into LLVM. Also, a couple of some nits added: - removed unused comment - changed to a capital letter in the comment Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D86515
This commit is contained in:
parent
718e550cd0
commit
e850558cdc
|
@ -32,6 +32,10 @@ protected:
|
|||
LLVMTypeConverter &typeConverter;
|
||||
};
|
||||
|
||||
/// Encodes global variable's descriptor set and binding into its name if they
|
||||
/// both exist.
|
||||
void encodeBindAttribute(ModuleOp module);
|
||||
|
||||
/// Populates type conversions with additional SPIR-V types.
|
||||
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);
|
||||
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/Support/LogicalResult.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include "llvm/Support/Debug.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
|
||||
#define DEBUG_TYPE "spirv-to-llvm-pattern"
|
||||
|
||||
|
@ -1332,8 +1333,6 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
|
|||
// TODO: Support EntryPoint/ExecutionMode properly.
|
||||
ErasePattern<spirv::EntryPointOp>, ErasePattern<spirv::ExecutionModeOp>,
|
||||
|
||||
// Function Call op
|
||||
|
||||
// GLSL extended instruction set ops
|
||||
DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
|
||||
DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
|
||||
|
@ -1386,3 +1385,42 @@ void mlir::populateSPIRVToLLVMModuleConversionPatterns(
|
|||
patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
|
||||
context, typeConverter);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Pre-conversion hooks
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Hook for descriptor set and binding number encoding.
|
||||
static constexpr StringRef kBinding = "binding";
|
||||
static constexpr StringRef kDescriptorSet = "descriptor_set";
|
||||
void mlir::encodeBindAttribute(ModuleOp module) {
|
||||
auto spvModules = module.getOps<spirv::ModuleOp>();
|
||||
for (auto spvModule : spvModules) {
|
||||
spvModule.walk([&](spirv::GlobalVariableOp op) {
|
||||
IntegerAttr descriptorSet = op.getAttrOfType<IntegerAttr>(kDescriptorSet);
|
||||
IntegerAttr binding = op.getAttrOfType<IntegerAttr>(kBinding);
|
||||
// For every global variable in the module, get the ones with descriptor
|
||||
// set and binding numbers.
|
||||
if (descriptorSet && binding) {
|
||||
// Encode these numbers into the variable's symbolic name. If the
|
||||
// SPIR-V module has a name, add it at the beginning.
|
||||
auto moduleAndName = spvModule.getName().hasValue()
|
||||
? spvModule.getName().getValue().str() + "_" +
|
||||
op.sym_name().str()
|
||||
: op.sym_name().str();
|
||||
std::string name =
|
||||
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
|
||||
std::to_string(descriptorSet.getInt()),
|
||||
std::to_string(binding.getInt()));
|
||||
|
||||
// Replace all symbol uses and set the new symbol name. Finally, remove
|
||||
// descriptor set and binding attributes.
|
||||
if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
|
||||
op.emitError("unable to replace all symbol uses for ") << name;
|
||||
SymbolTable::setSymbolName(op, name);
|
||||
op.removeAttr(kDescriptorSet);
|
||||
op.removeAttr(kBinding);
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
|
|
@ -33,6 +33,9 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
|
|||
ModuleOp module = getOperation();
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
|
||||
// Encode global variable's descriptor set and binding if they exist.
|
||||
encodeBindAttribute(module);
|
||||
|
||||
OwningRewritePatternList patterns;
|
||||
|
||||
populateSPIRVToLLVMTypeConversion(converter);
|
||||
|
@ -45,7 +48,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
|
|||
target.addIllegalDialect<spirv::SPIRVDialect>();
|
||||
target.addLegalDialect<LLVM::LLVMDialect>();
|
||||
|
||||
// set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
|
||||
// Set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
|
||||
// conversion.
|
||||
target.addLegalOp<ModuleOp>();
|
||||
target.addLegalOp<ModuleTerminatorOp>();
|
||||
|
|
|
@ -37,7 +37,7 @@ spv.module Logical GLSL450 {
|
|||
spv.module Logical GLSL450 {
|
||||
// CHECK: llvm.mlir.global private @struct() : !llvm.struct<packed (float, array<10 x float>)>
|
||||
// CHECK-LABEL: @func
|
||||
// CHECK: llvm.mlir.addressof @struct : !llvm.ptr<struct<packed (float, array<10 x float>)>>
|
||||
// CHECK: llvm.mlir.addressof @struct : !llvm.ptr<struct<packed (float, array<10 x float>)>>
|
||||
spv.globalVariable @struct : !spv.ptr<!spv.struct<f32, !spv.array<10xf32>>, Private>
|
||||
spv.func @func() "None" {
|
||||
%0 = spv._address_of @struct : !spv.ptr<!spv.struct<f32, !spv.array<10xf32>>, Private>
|
||||
|
@ -45,6 +45,28 @@ spv.module Logical GLSL450 {
|
|||
}
|
||||
}
|
||||
|
||||
spv.module Logical GLSL450 {
|
||||
// CHECK: llvm.mlir.global external @bar_descriptor_set0_binding0() : !llvm.i32
|
||||
// CHECK-LABEL: @foo
|
||||
// CHECK: llvm.mlir.addressof @bar_descriptor_set0_binding0 : !llvm.ptr<i32>
|
||||
spv.globalVariable @bar bind(0, 0) : !spv.ptr<i32, StorageBuffer>
|
||||
spv.func @foo() "None" {
|
||||
%0 = spv._address_of @bar : !spv.ptr<i32, StorageBuffer>
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
spv.module @name Logical GLSL450 {
|
||||
// CHECK: llvm.mlir.global external @name_bar_descriptor_set0_binding0() : !llvm.i32
|
||||
// CHECK-LABEL: @foo
|
||||
// CHECK: llvm.mlir.addressof @name_bar_descriptor_set0_binding0 : !llvm.ptr<i32>
|
||||
spv.globalVariable @bar bind(0, 0) : !spv.ptr<i32, StorageBuffer>
|
||||
spv.func @foo() "None" {
|
||||
%0 = spv._address_of @bar : !spv.ptr<i32, StorageBuffer>
|
||||
spv.Return
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.Load
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue