[MLIR][SPIRV] Added optional name to SPIR-V module

This patch adds an optional name to SPIR-V module.
This will help with lowering from GPU dialect (so that we
can pass the kernel module name) and will be more naturally
aligned with `GPUModuleOp`/`ModuleOp`.

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D86386
This commit is contained in:
George Mitenkov 2020-08-27 07:10:14 +03:00
parent 76b0f99ea8
commit d7461b31e7
3 changed files with 42 additions and 10 deletions

View File

@ -361,7 +361,7 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> {
def SPV_ModuleOp : SPV_Op<"module",
[IsolatedFromAbove,
SingleBlockImplicitTerminator<"ModuleEndOp">,
SymbolTable]> {
SymbolTable, Symbol]> {
let summary = "The top-level op that defines a SPIR-V module";
let description = [{
@ -409,7 +409,8 @@ def SPV_ModuleOp : SPV_Op<"module",
let arguments = (ins
SPV_AddressingModelAttr:$addressing_model,
SPV_MemoryModelAttr:$memory_model,
OptionalAttr<SPV_VerCapExtAttr>:$vce_triple
OptionalAttr<SPV_VerCapExtAttr>:$vce_triple,
OptionalAttr<StrAttr>:$sym_name
);
let results = (outs);
@ -417,10 +418,12 @@ def SPV_ModuleOp : SPV_Op<"module",
let regions = (region SizedRegion<1>:$body);
let builders = [
OpBuilder<[{OpBuilder &, OperationState &state}]>,
OpBuilder<[{OpBuilder &, OperationState &state,
Optional<StringRef> name = llvm::None}]>,
OpBuilder<[{OpBuilder &, OperationState &state,
spirv::AddressingModel addressing_model,
spirv::MemoryModel memory_model}]>
spirv::MemoryModel memory_model,
Optional<StringRef> name = llvm::None}]>
];
// We need to ensure the block inside the region is properly terminated;
@ -432,6 +435,11 @@ def SPV_ModuleOp : SPV_Op<"module",
let autogenSerialization = 0;
let extraClassDeclaration = [{
bool isOptionalSymbol() { return true; }
Optional<StringRef> getName() { return sym_name(); }
static StringRef getVCETripleAttrName() { return "vce_triple"; }
Block& getBlock() {

View File

@ -2282,24 +2282,39 @@ static LogicalResult verify(spirv::MergeOp mergeOp) {
// spv.module
//===----------------------------------------------------------------------===//
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state) {
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
Optional<StringRef> name) {
ensureTerminator(*state.addRegion(), builder, state.location);
if (name) {
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
}
}
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
spirv::AddressingModel addressing_model,
spirv::MemoryModel memory_model) {
spirv::AddressingModel addressingModel,
spirv::MemoryModel memoryModel,
Optional<StringRef> name) {
state.addAttribute(
"addressing_model",
builder.getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
state.addAttribute("memory_model", builder.getI32IntegerAttr(
static_cast<int32_t>(memory_model)));
static_cast<int32_t>(memoryModel)));
ensureTerminator(*state.addRegion(), builder, state.location);
if (name) {
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
}
}
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
Region *body = state.addRegion();
// If the name is present, parse it.
StringAttr nameAttr;
parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
state.attributes);
// Parse attributes
spirv::AddressingModel addrModel;
spirv::MemoryModel memoryModel;
@ -2328,13 +2343,19 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
printer << spirv::ModuleOp::getOperationName();
if (Optional<StringRef> name = moduleOp.getName()) {
printer << ' ';
printer.printSymbolName(*name);
}
SmallVector<StringRef, 2> elidedAttrs;
printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
<< " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
SymbolTable::getSymbolAttrName()});
if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
printer << " requires " << *triple;

View File

@ -372,6 +372,9 @@ spv.module Logical GLSL450 {
// CHECK: spv.module Logical GLSL450
spv.module Logical GLSL450 { }
// Module with a name
// CHECK: spv.module @{{.*}} Logical GLSL450
spv.module @name Logical GLSL450 { }
// Module with (version, capabilities, extensions) triple
// CHECK: spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]>