[MLIR][SPIRVToLLVM] Added a hook for descriptor set / binding encoding

This patch introduces a hook to encode descriptor set
and binding number into `spv.globalVariable`'s symbolic name. This
allows to preserve this information, and at the same time legalize
the global variable for the conversion to LLVM dialect.

This is required for `mlir-spirv-cpu-runner` to convert kernel
arguments into LLVM.

Also, a couple of some nits added:
- removed unused comment
- changed to a capital letter in the comment

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D86515
This commit is contained in:
George Mitenkov 2020-08-27 08:25:52 +03:00
parent 718e550cd0
commit e850558cdc
4 changed files with 71 additions and 4 deletions

View File

@ -32,6 +32,10 @@ protected:
LLVMTypeConverter &typeConverter;
};
/// Encodes global variable's descriptor set and binding into its name if they
/// both exist.
void encodeBindAttribute(ModuleOp module);
/// Populates type conversions with additional SPIR-V types.
void populateSPIRVToLLVMTypeConversion(LLVMTypeConverter &typeConverter);

View File

@ -23,6 +23,7 @@
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/FormatVariadic.h"
#define DEBUG_TYPE "spirv-to-llvm-pattern"
@ -1332,8 +1333,6 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
// TODO: Support EntryPoint/ExecutionMode properly.
ErasePattern<spirv::EntryPointOp>, ErasePattern<spirv::ExecutionModeOp>,
// Function Call op
// GLSL extended instruction set ops
DirectConversionPattern<spirv::GLSLCeilOp, LLVM::FCeilOp>,
DirectConversionPattern<spirv::GLSLCosOp, LLVM::CosOp>,
@ -1386,3 +1385,42 @@ void mlir::populateSPIRVToLLVMModuleConversionPatterns(
patterns.insert<ModuleConversionPattern, ModuleEndConversionPattern>(
context, typeConverter);
}
//===----------------------------------------------------------------------===//
// Pre-conversion hooks
//===----------------------------------------------------------------------===//
/// Hook for descriptor set and binding number encoding.
static constexpr StringRef kBinding = "binding";
static constexpr StringRef kDescriptorSet = "descriptor_set";
void mlir::encodeBindAttribute(ModuleOp module) {
auto spvModules = module.getOps<spirv::ModuleOp>();
for (auto spvModule : spvModules) {
spvModule.walk([&](spirv::GlobalVariableOp op) {
IntegerAttr descriptorSet = op.getAttrOfType<IntegerAttr>(kDescriptorSet);
IntegerAttr binding = op.getAttrOfType<IntegerAttr>(kBinding);
// For every global variable in the module, get the ones with descriptor
// set and binding numbers.
if (descriptorSet && binding) {
// Encode these numbers into the variable's symbolic name. If the
// SPIR-V module has a name, add it at the beginning.
auto moduleAndName = spvModule.getName().hasValue()
? spvModule.getName().getValue().str() + "_" +
op.sym_name().str()
: op.sym_name().str();
std::string name =
llvm::formatv("{0}_descriptor_set{1}_binding{2}", moduleAndName,
std::to_string(descriptorSet.getInt()),
std::to_string(binding.getInt()));
// Replace all symbol uses and set the new symbol name. Finally, remove
// descriptor set and binding attributes.
if (failed(SymbolTable::replaceAllSymbolUses(op, name, spvModule)))
op.emitError("unable to replace all symbol uses for ") << name;
SymbolTable::setSymbolName(op, name);
op.removeAttr(kDescriptorSet);
op.removeAttr(kBinding);
}
});
}
}

View File

@ -33,6 +33,9 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
ModuleOp module = getOperation();
LLVMTypeConverter converter(&getContext());
// Encode global variable's descriptor set and binding if they exist.
encodeBindAttribute(module);
OwningRewritePatternList patterns;
populateSPIRVToLLVMTypeConversion(converter);
@ -45,7 +48,7 @@ void ConvertSPIRVToLLVMPass::runOnOperation() {
target.addIllegalDialect<spirv::SPIRVDialect>();
target.addLegalDialect<LLVM::LLVMDialect>();
// set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
// Set `ModuleOp` and `ModuleTerminatorOp` as legal for `spv.module`
// conversion.
target.addLegalOp<ModuleOp>();
target.addLegalOp<ModuleTerminatorOp>();

View File

@ -37,7 +37,7 @@ spv.module Logical GLSL450 {
spv.module Logical GLSL450 {
// CHECK: llvm.mlir.global private @struct() : !llvm.struct<packed (float, array<10 x float>)>
// CHECK-LABEL: @func
// CHECK: llvm.mlir.addressof @struct : !llvm.ptr<struct<packed (float, array<10 x float>)>>
// CHECK: llvm.mlir.addressof @struct : !llvm.ptr<struct<packed (float, array<10 x float>)>>
spv.globalVariable @struct : !spv.ptr<!spv.struct<f32, !spv.array<10xf32>>, Private>
spv.func @func() "None" {
%0 = spv._address_of @struct : !spv.ptr<!spv.struct<f32, !spv.array<10xf32>>, Private>
@ -45,6 +45,28 @@ spv.module Logical GLSL450 {
}
}
spv.module Logical GLSL450 {
// CHECK: llvm.mlir.global external @bar_descriptor_set0_binding0() : !llvm.i32
// CHECK-LABEL: @foo
// CHECK: llvm.mlir.addressof @bar_descriptor_set0_binding0 : !llvm.ptr<i32>
spv.globalVariable @bar bind(0, 0) : !spv.ptr<i32, StorageBuffer>
spv.func @foo() "None" {
%0 = spv._address_of @bar : !spv.ptr<i32, StorageBuffer>
spv.Return
}
}
spv.module @name Logical GLSL450 {
// CHECK: llvm.mlir.global external @name_bar_descriptor_set0_binding0() : !llvm.i32
// CHECK-LABEL: @foo
// CHECK: llvm.mlir.addressof @name_bar_descriptor_set0_binding0 : !llvm.ptr<i32>
spv.globalVariable @bar bind(0, 0) : !spv.ptr<i32, StorageBuffer>
spv.func @foo() "None" {
%0 = spv._address_of @bar : !spv.ptr<i32, StorageBuffer>
spv.Return
}
}
//===----------------------------------------------------------------------===//
// spv.Load
//===----------------------------------------------------------------------===//