forked from OSchip/llvm-project
[MLIR][SPIRV] Added optional name to SPIR-V module
This patch adds an optional name to SPIR-V module. This will help with lowering from GPU dialect (so that we can pass the kernel module name) and will be more naturally aligned with `GPUModuleOp`/`ModuleOp`. Reviewed By: mravishankar Differential Revision: https://reviews.llvm.org/D86386
This commit is contained in:
parent
76b0f99ea8
commit
d7461b31e7
|
@ -361,7 +361,7 @@ def SPV_GlobalVariableOp : SPV_Op<"globalVariable", [InModuleScope, Symbol]> {
|
|||
def SPV_ModuleOp : SPV_Op<"module",
|
||||
[IsolatedFromAbove,
|
||||
SingleBlockImplicitTerminator<"ModuleEndOp">,
|
||||
SymbolTable]> {
|
||||
SymbolTable, Symbol]> {
|
||||
let summary = "The top-level op that defines a SPIR-V module";
|
||||
|
||||
let description = [{
|
||||
|
@ -409,7 +409,8 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||
let arguments = (ins
|
||||
SPV_AddressingModelAttr:$addressing_model,
|
||||
SPV_MemoryModelAttr:$memory_model,
|
||||
OptionalAttr<SPV_VerCapExtAttr>:$vce_triple
|
||||
OptionalAttr<SPV_VerCapExtAttr>:$vce_triple,
|
||||
OptionalAttr<StrAttr>:$sym_name
|
||||
);
|
||||
|
||||
let results = (outs);
|
||||
|
@ -417,10 +418,12 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<[{OpBuilder &, OperationState &state}]>,
|
||||
OpBuilder<[{OpBuilder &, OperationState &state,
|
||||
Optional<StringRef> name = llvm::None}]>,
|
||||
OpBuilder<[{OpBuilder &, OperationState &state,
|
||||
spirv::AddressingModel addressing_model,
|
||||
spirv::MemoryModel memory_model}]>
|
||||
spirv::MemoryModel memory_model,
|
||||
Optional<StringRef> name = llvm::None}]>
|
||||
];
|
||||
|
||||
// We need to ensure the block inside the region is properly terminated;
|
||||
|
@ -432,6 +435,11 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||
let autogenSerialization = 0;
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
||||
bool isOptionalSymbol() { return true; }
|
||||
|
||||
Optional<StringRef> getName() { return sym_name(); }
|
||||
|
||||
static StringRef getVCETripleAttrName() { return "vce_triple"; }
|
||||
|
||||
Block& getBlock() {
|
||||
|
|
|
@ -2282,24 +2282,39 @@ static LogicalResult verify(spirv::MergeOp mergeOp) {
|
|||
// spv.module
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state) {
|
||||
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
|
||||
Optional<StringRef> name) {
|
||||
ensureTerminator(*state.addRegion(), builder, state.location);
|
||||
if (name) {
|
||||
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
|
||||
builder.getStringAttr(*name));
|
||||
}
|
||||
}
|
||||
|
||||
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
|
||||
spirv::AddressingModel addressing_model,
|
||||
spirv::MemoryModel memory_model) {
|
||||
spirv::AddressingModel addressingModel,
|
||||
spirv::MemoryModel memoryModel,
|
||||
Optional<StringRef> name) {
|
||||
state.addAttribute(
|
||||
"addressing_model",
|
||||
builder.getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
|
||||
builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
|
||||
state.addAttribute("memory_model", builder.getI32IntegerAttr(
|
||||
static_cast<int32_t>(memory_model)));
|
||||
static_cast<int32_t>(memoryModel)));
|
||||
ensureTerminator(*state.addRegion(), builder, state.location);
|
||||
if (name) {
|
||||
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
|
||||
builder.getStringAttr(*name));
|
||||
}
|
||||
}
|
||||
|
||||
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
|
||||
Region *body = state.addRegion();
|
||||
|
||||
// If the name is present, parse it.
|
||||
StringAttr nameAttr;
|
||||
parser.parseOptionalSymbolName(nameAttr, SymbolTable::getSymbolAttrName(),
|
||||
state.attributes);
|
||||
|
||||
// Parse attributes
|
||||
spirv::AddressingModel addrModel;
|
||||
spirv::MemoryModel memoryModel;
|
||||
|
@ -2328,13 +2343,19 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
|
|||
static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
|
||||
printer << spirv::ModuleOp::getOperationName();
|
||||
|
||||
if (Optional<StringRef> name = moduleOp.getName()) {
|
||||
printer << ' ';
|
||||
printer.printSymbolName(*name);
|
||||
}
|
||||
|
||||
SmallVector<StringRef, 2> elidedAttrs;
|
||||
|
||||
printer << " " << spirv::stringifyAddressingModel(moduleOp.addressing_model())
|
||||
<< " " << spirv::stringifyMemoryModel(moduleOp.memory_model());
|
||||
auto addressingModelAttrName = spirv::attributeName<spirv::AddressingModel>();
|
||||
auto memoryModelAttrName = spirv::attributeName<spirv::MemoryModel>();
|
||||
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName});
|
||||
elidedAttrs.assign({addressingModelAttrName, memoryModelAttrName,
|
||||
SymbolTable::getSymbolAttrName()});
|
||||
|
||||
if (Optional<spirv::VerCapExtAttr> triple = moduleOp.vce_triple()) {
|
||||
printer << " requires " << *triple;
|
||||
|
|
|
@ -372,6 +372,9 @@ spv.module Logical GLSL450 {
|
|||
// CHECK: spv.module Logical GLSL450
|
||||
spv.module Logical GLSL450 { }
|
||||
|
||||
// Module with a name
|
||||
// CHECK: spv.module @{{.*}} Logical GLSL450
|
||||
spv.module @name Logical GLSL450 { }
|
||||
|
||||
// Module with (version, capabilities, extensions) triple
|
||||
// CHECK: spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]>
|
||||
|
|
Loading…
Reference in New Issue